diff --git a/apps/application/flow/step_node/start_node/impl/base_start_node.py b/apps/application/flow/step_node/start_node/impl/base_start_node.py index 07ef33be2..fbb686289 100644 --- a/apps/application/flow/step_node/start_node/impl/base_start_node.py +++ b/apps/application/flow/step_node/start_node/impl/base_start_node.py @@ -23,6 +23,7 @@ def get_default_global_variable(input_field_list: List): if item.get('default_value', None) is not None } + def get_global_variable(node): body = node.workflow_manage.get_body() history_chat_record = node.flow_params_serializer.data.get('history_chat_record', []) @@ -74,6 +75,7 @@ class BaseStartStepNode(IStarNode): 'other': self.workflow_manage.other_list, } + self.workflow_manage.chat_context = self.workflow_manage.get_chat_info().get_chat_variable() return NodeResult(node_variable, workflow_variable) def get_details(self, index: int, **kwargs): diff --git a/apps/application/flow/step_node/variable_assign_node/impl/base_variable_assign_node.py b/apps/application/flow/step_node/variable_assign_node/impl/base_variable_assign_node.py index ce2906e62..3dbe90fca 100644 --- a/apps/application/flow/step_node/variable_assign_node/impl/base_variable_assign_node.py +++ b/apps/application/flow/step_node/variable_assign_node/impl/base_variable_assign_node.py @@ -2,8 +2,11 @@ import json from typing import List +from django.db.models import QuerySet + from application.flow.i_step_node import NodeResult from application.flow.step_node.variable_assign_node.i_variable_assign_node import IVariableAssignNode +from application.models import Chat class BaseVariableAssignNode(IVariableAssignNode): @@ -11,40 +14,56 @@ class BaseVariableAssignNode(IVariableAssignNode): self.context['variable_list'] = details.get('variable_list') self.context['result_list'] = details.get('result_list') + def global_evaluation(self, variable, value): + self.workflow_manage.context[variable['fields'][1]] = value + + def chat_evaluation(self, variable, value): + self.workflow_manage.chat_context[variable['fields'][1]] = value + + def handle(self, variable, evaluation): + result = { + 'name': variable['name'], + 'input_value': self.get_reference_content(variable['fields']), + } + if variable['source'] == 'custom': + if variable['type'] == 'json': + if isinstance(variable['value'], dict) or isinstance(variable['value'], list): + val = variable['value'] + else: + val = json.loads(variable['value']) + evaluation(variable, val) + result['output_value'] = variable['value'] = val + elif variable['type'] == 'string': + # 变量解析 例如:{{global.xxx}} + val = self.workflow_manage.generate_prompt(variable['value']) + evaluation(variable, val) + result['output_value'] = val + else: + val = variable['value'] + evaluation(variable, val) + result['output_value'] = val + else: + reference = self.get_reference_content(variable['reference']) + evaluation(variable, reference) + result['output_value'] = reference + return result + def execute(self, variable_list, stream, **kwargs) -> NodeResult: # result_list = [] + is_chat = False for variable in variable_list: if 'fields' not in variable: continue if 'global' == variable['fields'][0]: - result = { - 'name': variable['name'], - 'input_value': self.get_reference_content(variable['fields']), - } - if variable['source'] == 'custom': - if variable['type'] == 'json': - if isinstance(variable['value'], dict) or isinstance(variable['value'], list): - val = variable['value'] - else: - val = json.loads(variable['value']) - self.workflow_manage.context[variable['fields'][1]] = val - result['output_value'] = variable['value'] = val - elif variable['type'] == 'string': - # 变量解析 例如:{{global.xxx}} - val = self.workflow_manage.generate_prompt(variable['value']) - self.workflow_manage.context[variable['fields'][1]] = val - result['output_value'] = val - else: - val = variable['value'] - self.workflow_manage.context[variable['fields'][1]] = val - result['output_value'] = val - else: - reference = self.get_reference_content(variable['reference']) - self.workflow_manage.context[variable['fields'][1]] = reference - result['output_value'] = reference + result = self.handle(variable, self.global_evaluation) result_list.append(result) - + if 'chat' == variable['fields'][0]: + result = self.handle(variable, self.chat_evaluation) + result_list.append(result) + is_chat = True + if is_chat: + self.workflow_manage.get_chat_info().set_chat_variable(self.workflow_manage.chat_context) return NodeResult({'variable_list': variable_list, 'result_list': result_list}, {}) def get_reference_content(self, fields: List[str]): diff --git a/apps/application/flow/workflow_manage.py b/apps/application/flow/workflow_manage.py index 2f4c1586d..30a7be040 100644 --- a/apps/application/flow/workflow_manage.py +++ b/apps/application/flow/workflow_manage.py @@ -117,6 +117,7 @@ class WorkflowManage: self.params = params self.flow = flow self.context = {} + self.chat_context = {} self.node_chunk_manage = NodeChunkManage(self) self.work_flow_post_handler = work_flow_post_handler self.current_node = None @@ -131,6 +132,7 @@ class WorkflowManage: self.lock = threading.Lock() self.field_list = [] self.global_field_list = [] + self.chat_field_list = [] self.init_fields() if start_node_id is not None: self.load_node(chat_record, start_node_id, start_node_data) @@ -140,6 +142,7 @@ class WorkflowManage: def init_fields(self): field_list = [] global_field_list = [] + chat_field_list = [] for node in self.flow.nodes: properties = node.properties node_name = properties.get('stepName') @@ -154,10 +157,16 @@ class WorkflowManage: if global_fields is not None: for global_field in global_fields: global_field_list.append({**global_field, 'node_id': node_id, 'node_name': node_name}) + chat_fields = node_config.get('chatFields') + if chat_fields is not None: + for chat_field in chat_fields: + chat_field_list.append({**chat_field, 'node_id': node_id, 'node_name': node_name}) field_list.sort(key=lambda f: len(f.get('node_name') + f.get('value')), reverse=True) global_field_list.sort(key=lambda f: len(f.get('node_name') + f.get('value')), reverse=True) + chat_field_list.sort(key=lambda f: len(f.get('node_name') + f.get('value')), reverse=True) self.field_list = field_list self.global_field_list = global_field_list + self.chat_field_list = chat_field_list def append_answer(self, content): self.answer += content @@ -445,6 +454,9 @@ class WorkflowManage: return current_node.node_params.get('is_result', not self._has_next_node( current_node, current_node_result)) if current_node.node_params is not None else False + def get_chat_info(self): + return self.work_flow_post_handler.chat_info + def get_chunk_content(self, chunk, is_end=False): return 'data: ' + json.dumps( {'chat_id': self.params['chat_id'], 'id': self.params['chat_record_id'], 'operate': True, @@ -587,12 +599,15 @@ class WorkflowManage: """ if node_id == 'global': return INode.get_field(self.context, fields) + elif node_id == 'chat': + return INode.get_field(self.chat_context, fields) else: return self.get_node_by_id(node_id).get_reference_field(fields) def get_workflow_content(self): context = { 'global': self.context, + 'chat': self.chat_context } for node in self.node_context: @@ -610,6 +625,10 @@ class WorkflowManage: globeLabelNew = f"global.{field.get('value')}" globeValue = f"context.get('global').get('{field.get('value', '')}','')" prompt = prompt.replace(globeLabel, globeValue).replace(globeLabelNew, globeValue) + for field in self.chat_field_list: + chatLabel = f"chat.{field.get('value')}" + chatValue = f"context.get('chat').get('{field.get('value', '')}','')" + prompt = prompt.replace(chatLabel, chatValue) return prompt def generate_prompt(self, prompt: str): diff --git a/apps/application/serializers/common.py b/apps/application/serializers/common.py index f9169a29a..55c9a063a 100644 --- a/apps/application/serializers/common.py +++ b/apps/application/serializers/common.py @@ -166,6 +166,34 @@ class ChatInfo: 'exclude_paragraph_id_list': exclude_paragraph_id_list, 'stream': stream, 'chat_user_id': chat_user_id, 'chat_user_type': chat_user_type, 'form_data': form_data} + def set_chat(self, question): + if not self.debug: + if not QuerySet(Chat).filter(id=self.chat_id).exists(): + Chat(id=self.chat_id, application_id=self.application_id, abstract=question[0:1024], + chat_user_id=self.chat_user_id, chat_user_type=self.chat_user_type, + asker=self.get_chat_user()).save() + + def set_chat_variable(self, chat_context): + if not self.debug: + chat = QuerySet(Chat).filter(id=self.chat_id).first() + if chat: + chat.meta = {**(chat.meta if isinstance(chat.meta, dict) else {}), **chat_context} + chat.save() + else: + cache.set(Cache_Version.CHAT_VARIABLE.get_key(key=self.chat_id), chat_context, + version=Cache_Version.CHAT_VARIABLE.get_version(), + timeout=60 * 30) + + def get_chat_variable(self): + if not self.debug: + chat = QuerySet(Chat).filter(id=self.chat_id).first() + if chat: + return chat.meta + return {} + else: + return cache.get(Cache_Version.CHAT_VARIABLE.get_key(key=self.chat_id), + version=Cache_Version.CHAT_VARIABLE.get_version()) or {} + def append_chat_record(self, chat_record: ChatRecord): chat_record.problem_text = chat_record.problem_text[0:10240] if chat_record.problem_text is not None else "" chat_record.answer_text = chat_record.answer_text[0:40960] if chat_record.problem_text is not None else "" diff --git a/apps/chat/serializers/chat.py b/apps/chat/serializers/chat.py index c4c2027c6..abbe6e527 100644 --- a/apps/chat/serializers/chat.py +++ b/apps/chat/serializers/chat.py @@ -253,6 +253,7 @@ class ChatSerializers(serializers.Serializer): # 构建运行参数 params = chat_info.to_pipeline_manage_params(message, get_post_handler(chat_info), exclude_paragraph_id_list, chat_user_id, chat_user_type, stream, form_data) + chat_info.set_chat(message) # 运行流水线作业 pipeline_message.run(params) return pipeline_message.context['chat_result'] @@ -307,6 +308,7 @@ class ChatSerializers(serializers.Serializer): other_list, instance.get('runtime_node_id'), instance.get('node_data'), chat_record, instance.get('child_node')) + chat_info.set_chat(message) r = work_flow_manage.run() return r diff --git a/apps/common/constants/cache_version.py b/apps/common/constants/cache_version.py index daa0208c4..2cf889c17 100644 --- a/apps/common/constants/cache_version.py +++ b/apps/common/constants/cache_version.py @@ -29,6 +29,9 @@ class Cache_Version(Enum): # 对话 CHAT = "CHAT", lambda key: key + + CHAT_VARIABLE = "CHAT_VARIABLE", lambda key: key + # 应用API KEY APPLICATION_API_KEY = "APPLICATION_API_KEY", lambda secret_key, use_get_data: secret_key diff --git a/ui/src/locales/lang/zh-CN/views/application-workflow.ts b/ui/src/locales/lang/zh-CN/views/application-workflow.ts index 29c2b9cc0..8c4575711 100644 --- a/ui/src/locales/lang/zh-CN/views/application-workflow.ts +++ b/ui/src/locales/lang/zh-CN/views/application-workflow.ts @@ -52,6 +52,7 @@ export default { variable: { label: '变量', global: '全局变量', + chat: '会话变量', Referencing: '引用变量', ReferencingRequired: '引用变量必填', ReferencingError: '引用变量错误', diff --git a/ui/src/workflow/common/NodeCascader.vue b/ui/src/workflow/common/NodeCascader.vue index dae4ad078..2843cb042 100644 --- a/ui/src/workflow/common/NodeCascader.vue +++ b/ui/src/workflow/common/NodeCascader.vue @@ -51,7 +51,9 @@ const wheel = (e: any) => { function visibleChange(bool: boolean) { if (bool) { options.value = props.global - ? props.nodeModel.get_up_node_field_list(false, true).filter((v: any) => v.value === 'global') + ? props.nodeModel + .get_up_node_field_list(false, true) + .filter((v: any) => ['global', 'chat'].includes(v.value)) : props.nodeModel.get_up_node_field_list(false, true) } } diff --git a/ui/src/workflow/common/app-node.ts b/ui/src/workflow/common/app-node.ts index f6d3f5e54..e3cdb3c4f 100644 --- a/ui/src/workflow/common/app-node.ts +++ b/ui/src/workflow/common/app-node.ts @@ -34,7 +34,7 @@ class AppNode extends HtmlResize.view { } else { const filterNodes = props.graphModel.nodes.filter((v: any) => v.type === props.model.type) const filterNameSameNodes = filterNodes.filter( - (v: any) => v.properties.stepName === props.model.properties.stepName + (v: any) => v.properties.stepName === props.model.properties.stepName, ) if (filterNameSameNodes.length - 1 > 0) { getNodesName(filterNameSameNodes.length - 1) @@ -61,14 +61,20 @@ class AppNode extends HtmlResize.view { value: 'global', label: t('views.applicationWorkflow.variable.global'), type: 'global', - children: this.props.model.properties?.config?.globalFields || [] + children: this.props.model.properties?.config?.globalFields || [], + }) + result.push({ + value: 'chat', + label: t('views.applicationWorkflow.variable.chat'), + type: 'chat', + children: this.props.model.properties?.config?.chatFields || [], }) } result.push({ value: this.props.model.id, label: this.props.model.properties.stepName, type: this.props.model.type, - children: this.props.model.properties?.config?.fields || [] + children: this.props.model.properties?.config?.fields || [], }) return result } @@ -83,7 +89,7 @@ class AppNode extends HtmlResize.view { if (contain_self) { return { ...this.up_node_field_dict, - [this.props.model.id]: this.get_node_field_list() + [this.props.model.id]: this.get_node_field_list(), } } return this.up_node_field_dict ? this.up_node_field_dict : {} @@ -92,7 +98,7 @@ class AppNode extends HtmlResize.view { get_up_node_field_list(contain_self: boolean, use_cache: boolean) { const result = Object.values(this.get_up_node_field_dict(contain_self, use_cache)).reduce( (pre, next) => [...pre, ...next], - [] + [], ) const start_node_field_list = this.props.graphModel .getNodeModelById('start-node') @@ -126,7 +132,7 @@ class AppNode extends HtmlResize.view { x: x - 10, y: y - 12, width: 30, - height: 30 + height: 30, }, [ lh('div', { @@ -174,10 +180,10 @@ class AppNode extends HtmlResize.view { - ` - } - }) - ] + `, + }, + }), + ], ) } @@ -214,7 +220,7 @@ class AppNode extends HtmlResize.view { } else { this.r = h(this.component, { properties: this.props.model.properties, - nodeModel: this.props.model + nodeModel: this.props.model, }) this.app = createApp({ render() { @@ -223,13 +229,13 @@ class AppNode extends HtmlResize.view { provide() { return { getNode: () => model, - getGraph: () => graphModel + getGraph: () => graphModel, } - } + }, }) this.app.use(ElementPlus, { - locale: zhCn + locale: zhCn, }) this.app.use(Components) this.app.use(directives) @@ -295,7 +301,7 @@ class AppNodeModel extends HtmlResize.model { } getNodeStyle() { return { - overflow: 'visible' + overflow: 'visible', } } getOutlineStyle() { @@ -361,13 +367,13 @@ class AppNodeModel extends HtmlResize.model { message: t('views.applicationWorkflow.tip.onlyRight'), validate: (sourceNode: any, targetNode: any, sourceAnchor: any) => { return sourceAnchor.type === 'right' - } + }, } this.sourceRules.push({ message: t('views.applicationWorkflow.tip.notRecyclable'), validate: (sourceNode: any, targetNode: any, sourceAnchor: any, targetAnchor: any) => { return !isLoop(sourceNode.id, targetNode.id) - } + }, }) this.sourceRules.push(circleOnlyAsTarget) @@ -375,7 +381,7 @@ class AppNodeModel extends HtmlResize.model { message: t('views.applicationWorkflow.tip.onlyLeft'), validate: (sourceNode: any, targetNode: any, sourceAnchor: any, targetAnchor: any) => { return targetAnchor.type === 'left' - } + }, }) } getDefaultAnchor() { @@ -390,14 +396,14 @@ class AppNodeModel extends HtmlResize.model { y: showNode ? y : y - 15, id: `${id}_left`, edgeAddable: false, - type: 'left' + type: 'left', }) } anchors.push({ x: x + width / 2 - 10, y: showNode ? y : y - 15, id: `${id}_right`, - type: 'right' + type: 'right', }) } diff --git a/ui/src/workflow/icons/chat-icon.vue b/ui/src/workflow/icons/chat-icon.vue new file mode 100644 index 000000000..0c8ac7432 --- /dev/null +++ b/ui/src/workflow/icons/chat-icon.vue @@ -0,0 +1,4 @@ + + diff --git a/ui/src/workflow/nodes/base-node/component/ChatFieldDialog.vue b/ui/src/workflow/nodes/base-node/component/ChatFieldDialog.vue new file mode 100644 index 000000000..1c537b0ce --- /dev/null +++ b/ui/src/workflow/nodes/base-node/component/ChatFieldDialog.vue @@ -0,0 +1,118 @@ + + + diff --git a/ui/src/workflow/nodes/base-node/component/ChatFieldTable.vue b/ui/src/workflow/nodes/base-node/component/ChatFieldTable.vue new file mode 100644 index 000000000..c5f7c0bcc --- /dev/null +++ b/ui/src/workflow/nodes/base-node/component/ChatFieldTable.vue @@ -0,0 +1,110 @@ + + + + + diff --git a/ui/src/workflow/nodes/base-node/index.vue b/ui/src/workflow/nodes/base-node/index.vue index da81c3745..4f97aaafd 100644 --- a/ui/src/workflow/nodes/base-node/index.vue +++ b/ui/src/workflow/nodes/base-node/index.vue @@ -83,6 +83,7 @@ +