diff --git a/apps/application/flow/step_node/document_extract_node/i_document_extract_node.py b/apps/application/flow/step_node/document_extract_node/i_document_extract_node.py index d030e0078..ceda39444 100644 --- a/apps/application/flow/step_node/document_extract_node/i_document_extract_node.py +++ b/apps/application/flow/step_node/document_extract_node/i_document_extract_node.py @@ -9,12 +9,7 @@ from common.util.field_message import ErrMessage class DocumentExtractNodeSerializer(serializers.Serializer): - # 需要查询的数据集id列表 - file_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True), - error_messages=ErrMessage.list("数据集id列表")) - - def is_valid(self, *, raise_exception=False): - super().is_valid(raise_exception=True) + document_list = serializers.ListField(required=False, error_messages=ErrMessage.list("文档")) class IDocumentExtractNode(INode): @@ -24,7 +19,9 @@ class IDocumentExtractNode(INode): return DocumentExtractNodeSerializer def _run(self): - return self.execute(**self.flow_params_serializer.data) + res = self.workflow_manage.get_reference_field(self.node_params_serializer.data.get('document_list')[0], + self.node_params_serializer.data.get('document_list')[1:]) + return self.execute(document=res, **self.flow_params_serializer.data) - def execute(self, file_list, **kwargs) -> NodeResult: + def execute(self, document, **kwargs) -> NodeResult: pass diff --git a/apps/application/flow/step_node/document_extract_node/impl/base_document_extract_node.py b/apps/application/flow/step_node/document_extract_node/impl/base_document_extract_node.py index bf900ffe2..35fc6edff 100644 --- a/apps/application/flow/step_node/document_extract_node/impl/base_document_extract_node.py +++ b/apps/application/flow/step_node/document_extract_node/impl/base_document_extract_node.py @@ -1,11 +1,26 @@ # coding=utf-8 - +from application.flow.i_step_node import NodeResult from application.flow.step_node.document_extract_node.i_document_extract_node import IDocumentExtractNode class BaseDocumentExtractNode(IDocumentExtractNode): - def execute(self, file_list, **kwargs): - pass + def execute(self, document, **kwargs): + self.context['document_list'] = document + content = '' + if len(document) > 0: + for doc in document: + content += doc['name'] + content += '\n-----------------------------------\n' + return NodeResult({'content': content}, {}) def get_details(self, index: int, **kwargs): - pass + return { + 'name': self.node.properties.get('stepName'), + "index": index, + 'run_time': self.context.get('run_time'), + 'type': self.node.type, + 'content': self.context.get('content'), + 'status': self.status, + 'err_message': self.err_message, + 'document_list': self.context.get('document_list') + } diff --git a/apps/application/flow/step_node/image_understand_step_node/i_image_understand_node.py b/apps/application/flow/step_node/image_understand_step_node/i_image_understand_node.py index 4c15ad8cd..26fb431d0 100644 --- a/apps/application/flow/step_node/image_understand_step_node/i_image_understand_node.py +++ b/apps/application/flow/step_node/image_understand_step_node/i_image_understand_node.py @@ -18,7 +18,7 @@ class ImageUnderstandNodeSerializer(serializers.Serializer): is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容')) - image_list = serializers.ListField(required=False, error_messages=ErrMessage.list("图片仅1张")) + image_list = serializers.ListField(required=False, error_messages=ErrMessage.list("图片")) class IImageUnderstandNode(INode): diff --git a/apps/application/flow/step_node/image_understand_step_node/impl/base_image_understand_node.py b/apps/application/flow/step_node/image_understand_step_node/impl/base_image_understand_node.py index e6a88a9f1..046d6f783 100644 --- a/apps/application/flow/step_node/image_understand_step_node/impl/base_image_understand_node.py +++ b/apps/application/flow/step_node/image_understand_step_node/impl/base_image_understand_node.py @@ -25,7 +25,7 @@ def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wo node.context['question'] = node_variable['question'] node.context['run_time'] = time.time() - node.context['start_time'] if workflow.is_result(node, NodeResult(node_variable, workflow_variable)): - workflow.answer += answer + node.answer_text = answer def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow): diff --git a/apps/application/flow/step_node/start_node/impl/base_start_node.py b/apps/application/flow/step_node/start_node/impl/base_start_node.py index 9ac7e2aef..6388e4dfd 100644 --- a/apps/application/flow/step_node/start_node/impl/base_start_node.py +++ b/apps/application/flow/step_node/start_node/impl/base_start_node.py @@ -52,8 +52,12 @@ class BaseStartStepNode(IStarNode): """ 开始节点 初始化全局变量 """ - return NodeResult({'question': question, 'image': self.workflow_manage.image_list}, - workflow_variable) + node_variable = { + 'question': question, + 'image': self.workflow_manage.image_list, + 'document': self.workflow_manage.document_list + } + return NodeResult(node_variable, workflow_variable) def get_details(self, index: int, **kwargs): global_fields = [] diff --git a/apps/application/flow/workflow_manage.py b/apps/application/flow/workflow_manage.py index 39c42aa24..edd21391d 100644 --- a/apps/application/flow/workflow_manage.py +++ b/apps/application/flow/workflow_manage.py @@ -240,16 +240,20 @@ class NodeChunk: class WorkflowManage: def __init__(self, flow: Flow, params, work_flow_post_handler: WorkFlowPostHandler, base_to_response: BaseToResponse = SystemToResponse(), form_data=None, image_list=None, + document_list=None, start_node_id=None, start_node_data=None, chat_record=None): if form_data is None: form_data = {} if image_list is None: image_list = [] + if document_list is None: + document_list = [] self.start_node = None self.start_node_result_future = None self.form_data = form_data self.image_list = image_list + self.document_list = document_list self.params = params self.flow = flow self.lock = threading.Lock() diff --git a/apps/application/serializers/chat_message_serializers.py b/apps/application/serializers/chat_message_serializers.py index 455ef6a67..919cb71cf 100644 --- a/apps/application/serializers/chat_message_serializers.py +++ b/apps/application/serializers/chat_message_serializers.py @@ -230,7 +230,8 @@ class ChatMessageSerializer(serializers.Serializer): client_id = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端id")) client_type = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端类型")) form_data = serializers.DictField(required=False, error_messages=ErrMessage.char("全局变量")) - image_list = serializers.ListField(required=False, error_messages=ErrMessage.list("图片仅1张")) + image_list = serializers.ListField(required=False, error_messages=ErrMessage.list("图片")) + document_list = serializers.ListField(required=False, error_messages=ErrMessage.list("文档")) def is_valid_application_workflow(self, *, raise_exception=False): self.is_valid_intraday_access_num() @@ -322,6 +323,7 @@ class ChatMessageSerializer(serializers.Serializer): client_type = self.data.get('client_type') form_data = self.data.get('form_data') image_list = self.data.get('image_list') + document_list = self.data.get('document_list') user_id = chat_info.application.user_id chat_record_id = self.data.get('chat_record_id') chat_record = None @@ -336,7 +338,7 @@ class ChatMessageSerializer(serializers.Serializer): 'client_id': client_id, 'client_type': client_type, 'user_id': user_id}, WorkFlowPostHandler(chat_info, client_id, client_type), - base_to_response, form_data, image_list, self.data.get('runtime_node_id'), + base_to_response, form_data, image_list, document_list, self.data.get('runtime_node_id'), self.data.get('node_data'), chat_record) r = work_flow_manage.run() return r diff --git a/apps/application/views/chat_views.py b/apps/application/views/chat_views.py index 586787b20..790277860 100644 --- a/apps/application/views/chat_views.py +++ b/apps/application/views/chat_views.py @@ -132,6 +132,8 @@ class ChatView(APIView): 'image_list': request.data.get( 'image_list') if 'image_list' in request.data else [], + 'document_list': request.data.get( + 'document_list') if 'document_list' in request.data else [], 'client_type': request.auth.client_type, 'runtime_node_id': request.data.get('runtime_node_id', None), 'node_data': request.data.get('node_data', {}), diff --git a/ui/src/api/type/application.ts b/ui/src/api/type/application.ts index 5d4971f7f..11b630f13 100644 --- a/ui/src/api/type/application.ts +++ b/ui/src/api/type/application.ts @@ -39,7 +39,8 @@ interface chatType { record_id: string chat_id: string vote_status: string - status?: number + status?: number, + execution_details: any[] } export class ChatRecordManage { diff --git a/ui/src/components/ai-chat/component/chat-input-operate/index.vue b/ui/src/components/ai-chat/component/chat-input-operate/index.vue index 8ae880d5c..b3bee751a 100644 --- a/ui/src/components/ai-chat/component/chat-input-operate/index.vue +++ b/ui/src/components/ai-chat/component/chat-input-operate/index.vue @@ -20,10 +20,12 @@