diff --git a/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py b/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py index 4cad1796e..8d59deaa7 100644 --- a/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py +++ b/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py @@ -65,9 +65,12 @@ def event_content(response, try: for chunk in response: all_text += chunk.content - yield manage.get_base_to_response().to_stream_chunk_response(chat_id, str(chat_record_id), chunk.content, + yield manage.get_base_to_response().to_stream_chunk_response(chat_id, str(chat_record_id), 'ai-chat-node', + [], chunk.content, False, - 0, 0) + 0, 0, {'node_is_end': False, + 'view_type': 'many_view', + 'node_type': 'ai-chat-node'}) # 获取token if is_ai_chat: try: @@ -82,8 +85,11 @@ def event_content(response, write_context(step, manage, request_token, response_token, all_text) post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text, all_text, manage, step, padding_problem_text, client_id) - yield manage.get_base_to_response().to_stream_chunk_response(chat_id, str(chat_record_id), '', True, - request_token, response_token) + yield manage.get_base_to_response().to_stream_chunk_response(chat_id, str(chat_record_id), 'ai-chat-node', + [], '', True, + request_token, response_token, + {'node_is_end': True, 'view_type': 'many_view', + 'node_type': 'ai-chat-node'}) add_access_num(client_id, client_type) except Exception as e: logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}') @@ -92,7 +98,11 @@ def event_content(response, post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text, all_text, manage, step, padding_problem_text, client_id) add_access_num(client_id, client_type) - yield manage.get_base_to_response().to_stream_chunk_response(chat_id, str(chat_record_id), all_text, True, 0, 0) + yield manage.get_base_to_response().to_stream_chunk_response(chat_id, str(chat_record_id), all_text, + 'ai-chat-node', + [], True, 0, 0, + {'node_is_end': True, 'view_type': 'many_view', + 'node_type': 'ai-chat-node'}) class BaseChatStep(IChatStep): diff --git a/apps/application/flow/i_step_node.py b/apps/application/flow/i_step_node.py index 360c46bb5..d3c6207aa 100644 --- a/apps/application/flow/i_step_node.py +++ b/apps/application/flow/i_step_node.py @@ -9,6 +9,7 @@ import time import uuid from abc import abstractmethod +from hashlib import sha1 from typing import Type, Dict, List from django.core import cache @@ -131,6 +132,7 @@ class FlowParamsSerializer(serializers.Serializer): class INode: + view_type = 'many_view' @abstractmethod def save_context(self, details, workflow_manage): @@ -139,7 +141,7 @@ class INode: def get_answer_text(self): return self.answer_text - def __init__(self, node, workflow_params, workflow_manage, runtime_node_id=None): + def __init__(self, node, workflow_params, workflow_manage, up_node_id_list=None): # 当前步骤上下文,用于存储当前步骤信息 self.status = 200 self.err_message = '' @@ -152,10 +154,13 @@ class INode: self.context = {} self.answer_text = None self.id = node.id - if runtime_node_id is None: - self.runtime_node_id = str(uuid.uuid1()) - else: - self.runtime_node_id = runtime_node_id + if up_node_id_list is None: + up_node_id_list = [] + self.up_node_id_list = up_node_id_list + self.runtime_node_id = sha1(uuid.NAMESPACE_DNS.bytes + bytes(str(uuid.uuid5(uuid.NAMESPACE_DNS, + "".join([*sorted(up_node_id_list), + node.id]))), + "utf-8")).hexdigest() def valid_args(self, node_params, flow_params): flow_params_serializer_class = self.get_flow_params_serializer_class() diff --git a/apps/application/flow/step_node/application_node/impl/base_application_node.py b/apps/application/flow/step_node/application_node/impl/base_application_node.py index ba16286fd..30954a2b6 100644 --- a/apps/application/flow/step_node/application_node/impl/base_application_node.py +++ b/apps/application/flow/step_node/application_node/impl/base_application_node.py @@ -18,6 +18,7 @@ def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wo node.context['message_tokens'] = result.get('usage', {}).get('prompt_tokens', 0) node.context['answer_tokens'] = result.get('usage', {}).get('completion_tokens', 0) node.context['answer'] = answer + node.context['result'] = answer node.context['question'] = node_variable['question'] node.context['run_time'] = time.time() - node.context['start_time'] if workflow.is_result(node, NodeResult(node_variable, workflow_variable)): diff --git a/apps/application/flow/step_node/form_node/i_form_node.py b/apps/application/flow/step_node/form_node/i_form_node.py index 18ff91fda..b793a5b78 100644 --- a/apps/application/flow/step_node/form_node/i_form_node.py +++ b/apps/application/flow/step_node/form_node/i_form_node.py @@ -21,6 +21,7 @@ class FormNodeParamsSerializer(serializers.Serializer): class IFormNode(INode): type = 'form-node' + view_type = 'single_view' def get_node_params_serializer_class(self) -> Type[serializers.Serializer]: return FormNodeParamsSerializer diff --git a/apps/application/flow/step_node/form_node/impl/base_form_node.py b/apps/application/flow/step_node/form_node/impl/base_form_node.py index 1d5170577..ae86b11d6 100644 --- a/apps/application/flow/step_node/form_node/impl/base_form_node.py +++ b/apps/application/flow/step_node/form_node/impl/base_form_node.py @@ -29,12 +29,18 @@ def write_context(step_variable: Dict, global_variable: Dict, node, workflow): class BaseFormNode(IFormNode): def save_context(self, details, workflow_manage): + form_data = details.get('form_data', None) self.context['result'] = details.get('result') self.context['form_content_format'] = details.get('form_content_format') self.context['form_field_list'] = details.get('form_field_list') self.context['run_time'] = details.get('run_time') self.context['start_time'] = details.get('start_time') + self.context['form_data'] = form_data + self.context['is_submit'] = details.get('is_submit') self.answer_text = details.get('result') + if form_data is not None: + for key in form_data: + self.context[key] = form_data[key] def execute(self, form_field_list, form_content_format, **kwargs) -> NodeResult: form_setting = {"form_field_list": form_field_list, "runtime_node_id": self.runtime_node_id, @@ -77,6 +83,7 @@ class BaseFormNode(IFormNode): "form_field_list": self.context.get('form_field_list'), 'form_data': self.context.get('form_data'), 'start_time': self.context.get('start_time'), + 'is_submit': self.context.get('is_submit'), 'run_time': self.context.get('run_time'), 'type': self.node.type, 'status': self.status, diff --git a/apps/application/flow/step_node/image_understand_step_node/impl/base_image_understand_node.py b/apps/application/flow/step_node/image_understand_step_node/impl/base_image_understand_node.py index 0ddd7e489..1817ba56a 100644 --- a/apps/application/flow/step_node/image_understand_step_node/impl/base_image_understand_node.py +++ b/apps/application/flow/step_node/image_understand_step_node/impl/base_image_understand_node.py @@ -6,7 +6,7 @@ from functools import reduce from typing import List, Dict from django.db.models import QuerySet -from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage +from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage, AIMessage from application.flow.i_step_node import NodeResult, INode from application.flow.step_node.image_understand_step_node.i_image_understand_node import IImageUnderstandNode @@ -96,11 +96,19 @@ class BaseImageUnderstandNode(IImageUnderstandNode): def get_history_message_for_details(self, history_chat_record, dialogue_number): start_index = len(history_chat_record) - dialogue_number history_message = reduce(lambda x, y: [*x, *y], [ - [self.generate_history_human_message_for_details(history_chat_record[index]), history_chat_record[index].get_ai_message()] + [self.generate_history_human_message_for_details(history_chat_record[index]), self.generate_history_ai_message(history_chat_record[index])] for index in range(start_index if start_index > 0 else 0, len(history_chat_record))], []) return history_message + def generate_history_ai_message(self, chat_record): + for val in chat_record.details.values(): + if self.node.id == val['node_id'] and 'image_list' in val: + if val['dialogue_type'] == 'WORKFLOW': + return chat_record.get_ai_message() + return AIMessage(content=val['answer']) + return chat_record.get_ai_message() + def generate_history_human_message_for_details(self, chat_record): for data in chat_record.details.values(): if self.node.id == data['node_id'] and 'image_list' in data: @@ -113,7 +121,7 @@ class BaseImageUnderstandNode(IImageUnderstandNode): def get_history_message(self, history_chat_record, dialogue_number): start_index = len(history_chat_record) - dialogue_number history_message = reduce(lambda x, y: [*x, *y], [ - [self.generate_history_human_message(history_chat_record[index]), history_chat_record[index].get_ai_message()] + [self.generate_history_human_message(history_chat_record[index]), self.generate_history_ai_message(history_chat_record[index])] for index in range(start_index if start_index > 0 else 0, len(history_chat_record))], []) return history_message diff --git a/apps/application/flow/step_node/start_node/impl/base_start_node.py b/apps/application/flow/step_node/start_node/impl/base_start_node.py index 4f95656d2..59f875fcc 100644 --- a/apps/application/flow/step_node/start_node/impl/base_start_node.py +++ b/apps/application/flow/step_node/start_node/impl/base_start_node.py @@ -37,6 +37,8 @@ class BaseStartStepNode(IStarNode): workflow_variable = {**default_global_variable, **get_global_variable(self)} self.context['question'] = details.get('question') self.context['run_time'] = details.get('run_time') + self.context['document'] = details.get('document_list') + self.context['image'] = details.get('image_list') self.status = details.get('status') self.err_message = details.get('err_message') for key, value in workflow_variable.items(): diff --git a/apps/application/flow/workflow_manage.py b/apps/application/flow/workflow_manage.py index b05bdea83..c62612c0a 100644 --- a/apps/application/flow/workflow_manage.py +++ b/apps/application/flow/workflow_manage.py @@ -52,7 +52,8 @@ class Node: self.__setattr__(keyword, kwargs.get(keyword)) -end_nodes = ['ai-chat-node', 'reply-node', 'function-node', 'function-lib-node', 'application-node', 'image-understand-node'] +end_nodes = ['ai-chat-node', 'reply-node', 'function-node', 'function-lib-node', 'application-node', + 'image-understand-node'] class Flow: @@ -229,7 +230,9 @@ class NodeChunk: def add_chunk(self, chunk): self.chunk_list.append(chunk) - def end(self): + def end(self, chunk=None): + if chunk is not None: + self.add_chunk(chunk) self.status = 200 def is_end(self): @@ -266,6 +269,7 @@ class WorkflowManage: self.status = 0 self.base_to_response = base_to_response self.chat_record = chat_record + self.await_future_map = {} if start_node_id is not None: self.load_node(chat_record, start_node_id, start_node_data) else: @@ -286,14 +290,16 @@ class WorkflowManage: for node_details in sorted(chat_record.details.values(), key=lambda d: d.get('index')): node_id = node_details.get('node_id') if node_details.get('runtime_node_id') == start_node_id: - self.start_node = self.get_node_cls_by_id(node_id, node_details.get('runtime_node_id')) + self.start_node = self.get_node_cls_by_id(node_id, node_details.get('up_node_id_list')) self.start_node.valid_args(self.start_node.node_params, self.start_node.workflow_params) self.start_node.save_context(node_details, self) node_result = NodeResult({**start_node_data, 'form_data': start_node_data, 'is_submit': True}, {}) self.start_node_result_future = NodeResultFuture(node_result, None) - return + self.node_context.append(self.start_node) + continue + node_id = node_details.get('node_id') - node = self.get_node_cls_by_id(node_id, node_details.get('runtime_node_id')) + node = self.get_node_cls_by_id(node_id, node_details.get('up_node_id_list')) node.valid_args(node.node_params, node.workflow_params) node.save_context(node_details, self) self.node_context.append(node) @@ -345,17 +351,22 @@ class WorkflowManage: if chunk is None: break yield chunk - yield self.get_chunk_content('', True) finally: self.work_flow_post_handler.handler(self.params['chat_id'], self.params['chat_record_id'], self.answer, self) - yield self.get_chunk_content('', True) def run_chain_async(self, current_node, node_result_future): future = executor.submit(self.run_chain, current_node, node_result_future) return future + def set_await_map(self, node_run_list): + sorted_node_run_list = sorted(node_run_list, key=lambda n: n.get('node').node.y) + for index in range(len(sorted_node_run_list)): + self.await_future_map[sorted_node_run_list[index].get('node').runtime_node_id] = [ + sorted_node_run_list[i].get('future') + for i in range(index)] + def run_chain(self, current_node, node_result_future=None): if current_node is None: start_node = self.get_start_node() @@ -365,6 +376,9 @@ class WorkflowManage: try: is_stream = self.params.get('stream', True) # 处理节点响应 + await_future_list = self.await_future_map.get(current_node.runtime_node_id, None) + if await_future_list is not None: + [f.result() for f in await_future_list] result = self.hand_event_node_result(current_node, node_result_future) if is_stream else self.hand_node_result( current_node, node_result_future) @@ -373,11 +387,9 @@ class WorkflowManage: return node_list = self.get_next_node_list(current_node, result) # 获取到可执行的子节点 - result_list = [] - for node in node_list: - result = self.run_chain_async(node, None) - result_list.append(result) - [r.result() for r in result_list] + result_list = [{'node': node, 'future': self.run_chain_async(node, None)} for node in node_list] + self.set_await_map(result_list) + [r.get('future').result() for r in result_list] if self.status == 0: self.status = 200 except Exception as e: @@ -401,6 +413,14 @@ class WorkflowManage: current_node.get_write_error_context(e) self.answer += str(e) + def append_node(self, current_node): + for index in range(len(self.node_context)): + n = self.node_context[index] + if current_node.id == n.node.id and current_node.runtime_node_id == n.runtime_node_id: + self.node_context[index] = current_node + return + self.node_context.append(current_node) + def hand_event_node_result(self, current_node, node_result_future): node_chunk = NodeChunk() try: @@ -412,26 +432,40 @@ class WorkflowManage: for r in result: chunk = self.base_to_response.to_stream_chunk_response(self.params['chat_id'], self.params['chat_record_id'], - r, False, 0, 0) + current_node.id, + current_node.up_node_id_list, + r, False, 0, 0, + {'node_type': current_node.type, + 'view_type': current_node.view_type}) node_chunk.add_chunk(chunk) - node_chunk.end() + chunk = self.base_to_response.to_stream_chunk_response(self.params['chat_id'], + self.params['chat_record_id'], + current_node.id, + current_node.up_node_id_list, + '', False, 0, 0, {'node_is_end': True, + 'node_type': current_node.type, + 'view_type': current_node.view_type}) + node_chunk.end(chunk) else: list(result) # 添加节点 - self.node_context.append(current_node) + self.append_node(current_node) return current_result except Exception as e: # 添加节点 - self.node_context.append(current_node) + self.append_node(current_node) traceback.print_exc() self.answer += str(e) chunk = self.base_to_response.to_stream_chunk_response(self.params['chat_id'], self.params['chat_record_id'], - str(e), False, 0, 0) + current_node.id, + current_node.up_node_id_list, + str(e), False, 0, 0, + {'node_is_end': True, 'node_type': current_node.type, + 'view_type': current_node.view_type}) if not self.node_chunk_manage.contains(node_chunk): self.node_chunk_manage.add_node_chunk(node_chunk) - node_chunk.add_chunk(chunk) - node_chunk.end() + node_chunk.end(chunk) current_node.get_write_error_context(e) self.status = 500 @@ -492,32 +526,41 @@ class WorkflowManage: continue details = node.get_details(index) details['node_id'] = node.id + details['up_node_id_list'] = node.up_node_id_list details['runtime_node_id'] = node.runtime_node_id details_result[node.runtime_node_id] = details return details_result def get_answer_text_list(self): - answer_text_list = [] + result = [] + next_node_id_list = [] + if self.start_node is not None: + next_node_id_list = [edge.targetNodeId for edge in self.flow.edges if + edge.sourceNodeId == self.start_node.id] for index in range(len(self.node_context)): node = self.node_context[index] + up_node = None + if index > 0: + up_node = self.node_context[index - 1] answer_text = node.get_answer_text() if answer_text is not None: - if self.chat_record is not None and self.chat_record.details is not None: - details = self.chat_record.details.get(node.runtime_node_id) - if details is not None and self.start_node.runtime_node_id != node.runtime_node_id: - continue - answer_text_list.append( - {'content': answer_text, 'type': 'form' if node.type == 'form-node' else 'md'}) - result = [] - for index in range(len(answer_text_list)): - answer = answer_text_list[index] - if index == 0: - result.append(answer.get('content')) - continue - if answer.get('type') != answer_text_list[index - 1].get('type'): - result.append(answer.get('content')) - else: - result[-1] += answer.get('content') + if up_node is None or node.view_type == 'single_view' or ( + node.view_type == 'many_view' and up_node.view_type == 'single_view'): + result.append(node.get_answer_text()) + elif self.chat_record is not None and next_node_id_list.__contains__( + node.id) and up_node is not None and not next_node_id_list.__contains__( + up_node.id): + result.append(node.get_answer_text()) + else: + if len(result) > 0: + exec_index = len(result) - 1 + content = result[exec_index] + result[exec_index] += answer_text if len( + content) == 0 else ('\n\n' + answer_text) + else: + answer_text = node.get_answer_text() + result.insert(0, answer_text) + return result def get_next_node(self): @@ -540,6 +583,15 @@ class WorkflowManage: return None + @staticmethod + def dependent_node(up_node_id, node): + if node.id == up_node_id: + if node.type == 'form-node': + if node.context.get('form_data', None) is not None: + return True + return False + return True + def dependent_node_been_executed(self, node_id): """ 判断依赖节点是否都已执行 @@ -547,7 +599,12 @@ class WorkflowManage: @return: """ up_node_id_list = [edge.sourceNodeId for edge in self.flow.edges if edge.targetNodeId == node_id] - return all([any([node.id == up_node_id for node in self.node_context]) for up_node_id in up_node_id_list]) + return all([any([self.dependent_node(up_node_id, node) for node in self.node_context]) for up_node_id in + up_node_id_list]) + + def get_up_node_id_list(self, node_id): + up_node_id_list = [edge.sourceNodeId for edge in self.flow.edges if edge.targetNodeId == node_id] + return up_node_id_list def get_next_node_list(self, current_node, current_node_result): """ @@ -556,6 +613,7 @@ class WorkflowManage: @param current_node_result: 当前可执行节点结果 @return: 可执行节点列表 """ + if current_node.type == 'form-node' and 'form_data' not in current_node_result.node_variable: return [] node_list = [] @@ -564,11 +622,13 @@ class WorkflowManage: if (edge.sourceNodeId == current_node.id and f"{edge.sourceNodeId}_{current_node_result.node_variable.get('branch_id')}_right" == edge.sourceAnchorId): if self.dependent_node_been_executed(edge.targetNodeId): - node_list.append(self.get_node_cls_by_id(edge.targetNodeId)) + node_list.append( + self.get_node_cls_by_id(edge.targetNodeId, self.get_up_node_id_list(edge.targetNodeId))) else: for edge in self.flow.edges: if edge.sourceNodeId == current_node.id and self.dependent_node_been_executed(edge.targetNodeId): - node_list.append(self.get_node_cls_by_id(edge.targetNodeId)) + node_list.append( + self.get_node_cls_by_id(edge.targetNodeId, self.get_up_node_id_list(edge.targetNodeId))) return node_list def get_reference_field(self, node_id: str, fields: List[str]): @@ -629,11 +689,11 @@ class WorkflowManage: base_node_list = [node for node in self.flow.nodes if node.type == 'base-node'] return base_node_list[0] - def get_node_cls_by_id(self, node_id, runtime_node_id=None): + def get_node_cls_by_id(self, node_id, up_node_id_list=None): for node in self.flow.nodes: if node.id == node_id: node_instance = get_node(node.type)(node, - self.params, self, runtime_node_id) + self.params, self, up_node_id_list) return node_instance return None diff --git a/apps/application/serializers/chat_message_serializers.py b/apps/application/serializers/chat_message_serializers.py index 22fc18bae..5f38cfc7d 100644 --- a/apps/application/serializers/chat_message_serializers.py +++ b/apps/application/serializers/chat_message_serializers.py @@ -120,8 +120,15 @@ class ChatInfo: def append_chat_record(self, chat_record: ChatRecord, client_id=None): chat_record.problem_text = chat_record.problem_text[0:10240] if chat_record.problem_text is not None else "" chat_record.answer_text = chat_record.answer_text[0:40960] if chat_record.problem_text is not None else "" + is_save = True # 存入缓存中 - self.chat_record_list.append(chat_record) + for index in range(len(self.chat_record_list)): + record = self.chat_record_list[index] + if record.id == chat_record.id: + self.chat_record_list[index] = chat_record + is_save = False + if is_save: + self.chat_record_list.append(chat_record) if self.application.id is not None: # 插入数据库 if not QuerySet(Chat).filter(id=self.chat_id).exists(): @@ -224,8 +231,13 @@ class ChatMessageSerializer(serializers.Serializer): re_chat = serializers.BooleanField(required=True, error_messages=ErrMessage.char("是否重新回答")) chat_record_id = serializers.UUIDField(required=False, allow_null=True, error_messages=ErrMessage.uuid("对话记录id")) + + node_id = serializers.CharField(required=False, allow_null=True, allow_blank=True, + error_messages=ErrMessage.char("节点id")) + runtime_node_id = serializers.CharField(required=False, allow_null=True, allow_blank=True, - error_messages=ErrMessage.char("节点id")) + error_messages=ErrMessage.char("运行时节点id")) + node_data = serializers.DictField(required=False, error_messages=ErrMessage.char("节点参数")) application_id = serializers.UUIDField(required=False, allow_null=True, error_messages=ErrMessage.uuid("应用id")) client_id = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端id")) @@ -339,7 +351,8 @@ class ChatMessageSerializer(serializers.Serializer): 'client_id': client_id, 'client_type': client_type, 'user_id': user_id}, WorkFlowPostHandler(chat_info, client_id, client_type), - base_to_response, form_data, image_list, document_list, self.data.get('runtime_node_id'), + base_to_response, form_data, image_list, document_list, + self.data.get('runtime_node_id'), self.data.get('node_data'), chat_record) r = work_flow_manage.run() return r diff --git a/apps/application/views/chat_views.py b/apps/application/views/chat_views.py index 74b6fc361..a3ef50766 100644 --- a/apps/application/views/chat_views.py +++ b/apps/application/views/chat_views.py @@ -135,6 +135,7 @@ class ChatView(APIView): 'document_list': request.data.get( 'document_list') if 'document_list' in request.data else [], 'client_type': request.auth.client_type, + 'node_id': request.data.get('node_id', None), 'runtime_node_id': request.data.get('runtime_node_id', None), 'node_data': request.data.get('node_data', {}), 'chat_record_id': request.data.get('chat_record_id')} diff --git a/apps/common/handle/base_to_response.py b/apps/common/handle/base_to_response.py index 05af57cb9..376d1a9dd 100644 --- a/apps/common/handle/base_to_response.py +++ b/apps/common/handle/base_to_response.py @@ -14,12 +14,15 @@ from rest_framework import status class BaseToResponse(ABC): @abstractmethod - def to_block_response(self, chat_id, chat_record_id, content, is_end, completion_tokens, prompt_tokens, + def to_block_response(self, chat_id, chat_record_id, content, is_end, completion_tokens, + prompt_tokens, other_params: dict = None, _status=status.HTTP_200_OK): pass @abstractmethod - def to_stream_chunk_response(self, chat_id, chat_record_id, content, is_end, completion_tokens, prompt_tokens): + def to_stream_chunk_response(self, chat_id, chat_record_id, node_id, up_node_id_list, content, is_end, + completion_tokens, + prompt_tokens, other_params: dict = None): pass @staticmethod diff --git a/apps/common/handle/impl/response/openai_to_response.py b/apps/common/handle/impl/response/openai_to_response.py index b2508078f..0ce605514 100644 --- a/apps/common/handle/impl/response/openai_to_response.py +++ b/apps/common/handle/impl/response/openai_to_response.py @@ -20,6 +20,7 @@ from common.handle.base_to_response import BaseToResponse class OpenaiToResponse(BaseToResponse): def to_block_response(self, chat_id, chat_record_id, content, is_end, completion_tokens, prompt_tokens, + other_params: dict = None, _status=status.HTTP_200_OK): data = ChatCompletion(id=chat_record_id, choices=[ BlockChoice(finish_reason='stop', index=0, chat_id=chat_id, @@ -31,7 +32,8 @@ class OpenaiToResponse(BaseToResponse): ).dict() return JsonResponse(data=data, status=_status) - def to_stream_chunk_response(self, chat_id, chat_record_id, content, is_end, completion_tokens, prompt_tokens): + def to_stream_chunk_response(self, chat_id, chat_record_id, node_id, up_node_id_list, content, is_end, completion_tokens, + prompt_tokens, other_params: dict = None): chunk = ChatCompletionChunk(id=chat_record_id, model='', object='chat.completion.chunk', created=datetime.datetime.now().second, choices=[ Choice(delta=ChoiceDelta(content=content, chat_id=chat_id), finish_reason='stop' if is_end else None, diff --git a/apps/common/handle/impl/response/system_to_response.py b/apps/common/handle/impl/response/system_to_response.py index 1ec980633..7c374ef01 100644 --- a/apps/common/handle/impl/response/system_to_response.py +++ b/apps/common/handle/impl/response/system_to_response.py @@ -15,12 +15,23 @@ from common.response import result class SystemToResponse(BaseToResponse): - def to_block_response(self, chat_id, chat_record_id, content, is_end, completion_tokens, prompt_tokens, + def to_block_response(self, chat_id, chat_record_id, content, is_end, completion_tokens, + prompt_tokens, other_params: dict = None, _status=status.HTTP_200_OK): + if other_params is None: + other_params = {} return result.success({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True, - 'content': content, 'is_end': is_end}, response_status=_status, code=_status) + 'content': content, 'is_end': is_end, **other_params}, response_status=_status, + code=_status) - def to_stream_chunk_response(self, chat_id, chat_record_id, content, is_end, completion_tokens, prompt_tokens): + def to_stream_chunk_response(self, chat_id, chat_record_id, node_id, up_node_id_list, content, is_end, completion_tokens, + prompt_tokens, other_params: dict = None): + if other_params is None: + other_params = {} chunk = json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True, - 'content': content, 'is_end': is_end}) + 'content': content, 'node_id': node_id, 'up_node_id_list': up_node_id_list, 'is_end': is_end, + 'usage': {'completion_tokens': completion_tokens, + 'prompt_tokens': prompt_tokens, + 'total_tokens': completion_tokens + prompt_tokens}, + **other_params}) return super().format_stream_chunk(chunk) diff --git a/apps/dataset/migrations/0011_document_status_meta_paragraph_status_meta_and_more.py b/apps/dataset/migrations/0011_document_status_meta_paragraph_status_meta_and_more.py index e47bfd60c..7c5d63755 100644 --- a/apps/dataset/migrations/0011_document_status_meta_paragraph_status_meta_and_more.py +++ b/apps/dataset/migrations/0011_document_status_meta_paragraph_status_meta_and_more.py @@ -9,16 +9,16 @@ from dataset.models import State, TaskType sql = """ UPDATE "document" -SET status ="replace"(status, '1', '3') +SET status ="replace"("replace"("replace"(status, '2', '3'),'0','3'),'1','2') +""" +sql_paragraph = """ +UPDATE "paragraph" +SET status ="replace"("replace"("replace"(status, '2', '3'),'0','3'),'1','2') """ def updateDocumentStatus(apps, schema_editor): - ParagraphModel = apps.get_model('dataset', 'Paragraph') DocumentModel = apps.get_model('dataset', 'Document') - success_list = QuerySet(DocumentModel).filter(status='2') - ListenerManagement.update_status(QuerySet(ParagraphModel).filter(document_id__in=[d.id for d in success_list]), - TaskType.EMBEDDING, State.SUCCESS) ListenerManagement.get_aggregation_document_status_by_query_set(QuerySet(DocumentModel))() @@ -48,6 +48,7 @@ class Migration(migrations.Migration): name='status', field=models.CharField(default=dataset.models.data_set.Status.__str__, max_length=20, verbose_name='状态'), ), + migrations.RunSQL(sql_paragraph), migrations.RunSQL(sql), migrations.RunPython(updateDocumentStatus) ] diff --git a/apps/dataset/serializers/dataset_serializers.py b/apps/dataset/serializers/dataset_serializers.py index 85e73ee32..a0c171559 100644 --- a/apps/dataset/serializers/dataset_serializers.py +++ b/apps/dataset/serializers/dataset_serializers.py @@ -139,14 +139,14 @@ class DataSetSerializers(serializers.ModelSerializer): query_set = QuerySet(model=get_dynamics_model( {'temp.name': models.CharField(), 'temp.desc': models.CharField(), "document_temp.char_length": models.IntegerField(), 'temp.create_time': models.DateTimeField(), - 'temp.user_id': models.CharField(), })) + 'temp.user_id': models.CharField(), 'temp.id': models.CharField()})) if "desc" in self.data and self.data.get('desc') is not None: query_set = query_set.filter(**{'temp.desc__icontains': self.data.get("desc")}) if "name" in self.data and self.data.get('name') is not None: query_set = query_set.filter(**{'temp.name__icontains': self.data.get("name")}) if "select_user_id" in self.data and self.data.get('select_user_id') is not None: query_set = query_set.filter(**{'temp.user_id__exact': self.data.get("select_user_id")}) - query_set = query_set.order_by("-temp.create_time") + query_set = query_set.order_by("-temp.create_time", "temp.id") query_set_dict['default_sql'] = query_set query_set_dict['dataset_custom_sql'] = QuerySet(model=get_dynamics_model( diff --git a/apps/dataset/serializers/document_serializers.py b/apps/dataset/serializers/document_serializers.py index 45057d9bc..ac2006a52 100644 --- a/apps/dataset/serializers/document_serializers.py +++ b/apps/dataset/serializers/document_serializers.py @@ -374,7 +374,7 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): query_set = query_set.filter(**{'is_active': self.data.get('is_active')}) if 'status' in self.data and self.data.get('status') is not None: query_set = query_set.filter(**{'status': self.data.get('status')}) - query_set = query_set.order_by('-create_time') + query_set = query_set.order_by('-create_time', 'id') return query_set def list(self, with_valid=False): @@ -432,6 +432,7 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), TaskType.SYNC, State.PENDING) + ListenerManagement.get_aggregation_document_status(document_id)() source_url = document.meta.get('source_url') selector_list = document.meta.get('selector').split( " ") if 'selector' in document.meta and document.meta.get('selector') is not None else [] @@ -444,10 +445,10 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): # 删除向量库 delete_embedding_by_document(document_id) paragraphs = get_split_model('web.md').parse(result.content) - document.char_length = reduce(lambda x, y: x + y, - [len(p.get('content')) for p in paragraphs], - 0) - document.save() + char_length = reduce(lambda x, y: x + y, + [len(p.get('content')) for p in paragraphs], + 0) + QuerySet(Document).filter(id=document_id).update(char_length=char_length) document_paragraph_model = DocumentSerializers.Create.get_paragraph_model(document, paragraphs) paragraph_model_list = document_paragraph_model.get('paragraph_model_list') @@ -464,6 +465,13 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): # 向量化 if with_embedding: embedding_model_id = get_embedding_model_id_by_dataset_id(document.dataset_id) + ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), + TaskType.EMBEDDING, + State.PENDING) + ListenerManagement.update_status(QuerySet(Paragraph).filter(document_id=document_id), + TaskType.EMBEDDING, + State.PENDING) + ListenerManagement.get_aggregation_document_status(document_id)() embedding_by_document.delay(document_id, embedding_model_id) else: @@ -477,6 +485,7 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): ListenerManagement.update_status(QuerySet(Paragraph).filter(document_id=document_id), TaskType.SYNC, state) + ListenerManagement.get_aggregation_document_status(document_id)() return True class Operate(ApiMixin, serializers.Serializer): diff --git a/apps/embedding/vector/base_vector.py b/apps/embedding/vector/base_vector.py index 9a6eaff6b..0aadef6c1 100644 --- a/apps/embedding/vector/base_vector.py +++ b/apps/embedding/vector/base_vector.py @@ -84,7 +84,7 @@ class BaseVectorStore(ABC): chunk_list = chunk_data(data) result = sub_array(chunk_list) for child_array in result: - self._batch_save(child_array, embedding, lambda: True) + self._batch_save(child_array, embedding, lambda: False) def batch_save(self, data_list: List[Dict], embedding: Embeddings, is_the_task_interrupted): """ diff --git a/ui/src/api/type/application.ts b/ui/src/api/type/application.ts index be8eeae90..b45a6783f 100644 --- a/ui/src/api/type/application.ts +++ b/ui/src/api/type/application.ts @@ -22,6 +22,17 @@ interface ApplicationFormType { tts_model_enable?: boolean tts_type?: string } +interface Chunk { + chat_id: string + id: string + content: string + node_id: string + up_node_id: string + is_end: boolean + node_is_end: boolean + node_type: string + view_type: string +} interface chatType { id: string problem_text: string @@ -47,6 +58,21 @@ interface chatType { } } +interface Node { + buffer: Array + node_id: string + up_node_id: string + node_type: string + view_type: string + index: number + is_end: boolean +} +interface WriteNodeInfo { + current_node: any + answer_text_list_index: number + current_up_node?: any + divider_content?: Array +} export class ChatRecordManage { id?: any ms: number @@ -55,6 +81,8 @@ export class ChatRecordManage { write_ed?: boolean is_stop?: boolean loading?: Ref + node_list: Array + write_node_info?: WriteNodeInfo constructor(chat: chatType, ms?: number, loading?: Ref) { this.ms = ms ? ms : 10 this.chat = chat @@ -62,12 +90,82 @@ export class ChatRecordManage { this.is_stop = false this.is_close = false this.write_ed = false + this.node_list = [] } - append_answer(chunk_answer: String) { - this.chat.answer_text_list[this.chat.answer_text_list.length - 1] = - this.chat.answer_text_list[this.chat.answer_text_list.length - 1] + chunk_answer + append_answer(chunk_answer: string, index?: number) { + this.chat.answer_text_list[index != undefined ? index : this.chat.answer_text_list.length - 1] = + this.chat.answer_text_list[ + index !== undefined ? index : this.chat.answer_text_list.length - 1 + ] + ? this.chat.answer_text_list[ + index !== undefined ? index : this.chat.answer_text_list.length - 1 + ] + chunk_answer + : chunk_answer this.chat.answer_text = this.chat.answer_text + chunk_answer } + + get_run_node() { + if ( + this.write_node_info && + (this.write_node_info.current_node.buffer.length > 0 || + !this.write_node_info.current_node.is_end) + ) { + return this.write_node_info + } + const run_node = this.node_list.filter((item) => item.buffer.length > 0 || !item.is_end).at(0) + + if (run_node) { + const index = this.node_list.indexOf(run_node) + let current_up_node = undefined + if (index > 0) { + current_up_node = this.node_list[index - 1] + } + let answer_text_list_index = 0 + + if ( + current_up_node == undefined || + run_node.view_type == 'single_view' || + (run_node.view_type == 'many_view' && current_up_node.view_type == 'single_view') + ) { + const none_index = this.chat.answer_text_list.indexOf('') + if (none_index > -1) { + answer_text_list_index = none_index + } else { + answer_text_list_index = this.chat.answer_text_list.length + } + } else { + const none_index = this.chat.answer_text_list.indexOf('') + if (none_index > -1) { + answer_text_list_index = none_index + } else { + answer_text_list_index = this.chat.answer_text_list.length - 1 + } + } + + this.write_node_info = { + current_node: run_node, + divider_content: ['\n\n'], + current_up_node: current_up_node, + answer_text_list_index: answer_text_list_index + } + return this.write_node_info + } + return undefined + } + closeInterval() { + this.chat.write_ed = true + this.write_ed = true + if (this.loading) { + this.loading.value = false + } + if (this.id) { + clearInterval(this.id) + } + const last_index = this.chat.answer_text_list.lastIndexOf('') + if (last_index > 0) { + this.chat.answer_text_list.splice(last_index, 1) + } + } write() { this.chat.is_stop = false this.is_stop = false @@ -78,22 +176,45 @@ export class ChatRecordManage { this.loading.value = true } this.id = setInterval(() => { - if (this.chat.buffer.length > 20) { - this.append_answer(this.chat.buffer.splice(0, this.chat.buffer.length - 20).join('')) + const node_info = this.get_run_node() + if (node_info == undefined) { + if (this.is_close) { + this.closeInterval() + } + return + } + const { current_node, answer_text_list_index, divider_content } = node_info + if (current_node.buffer.length > 20) { + const context = current_node.is_end + ? current_node.buffer.splice(0) + : current_node.buffer.splice( + 0, + current_node.is_end ? undefined : current_node.buffer.length - 20 + ) + this.append_answer( + (divider_content ? divider_content.splice(0).join('') : '') + context.join(''), + answer_text_list_index + ) } else if (this.is_close) { - this.append_answer(this.chat.buffer.splice(0).join('')) - this.chat.write_ed = true - this.write_ed = true - if (this.loading) { - this.loading.value = false - } - if (this.id) { - clearInterval(this.id) + while (true) { + const node_info = this.get_run_node() + if (node_info == undefined) { + break + } + this.append_answer( + (node_info.divider_content ? node_info.divider_content.splice(0).join('') : '') + + node_info.current_node.buffer.splice(0).join(''), + node_info.answer_text_list_index + ) } + this.closeInterval() } else { - const s = this.chat.buffer.shift() + const s = current_node.buffer.shift() if (s !== undefined) { - this.append_answer(s) + this.append_answer( + (divider_content ? divider_content.splice(0).join('') : '') + s, + answer_text_list_index + ) } } }, this.ms) @@ -113,10 +234,32 @@ export class ChatRecordManage { this.is_close = false this.is_stop = false } - append(answer_text_block: string) { - for (let index = 0; index < answer_text_block.length; index++) { - this.chat.buffer.push(answer_text_block[index]) + appendChunk(chunk: Chunk) { + let n = this.node_list.find( + (item) => item.node_id == chunk.node_id && item.up_node_id === chunk.up_node_id + ) + if (n) { + n.buffer.push(...chunk.content) + } else { + n = { + buffer: [...chunk.content], + node_id: chunk.node_id, + up_node_id: chunk.up_node_id, + node_type: chunk.node_type, + index: this.node_list.length, + view_type: chunk.view_type, + is_end: false + } + this.node_list.push(n) } + if (chunk.node_is_end) { + n['is_end'] = true + } + } + append(answer_text_block: string) { + const index =this.chat.answer_text_list.indexOf("") + this.chat.answer_text_list[index]=answer_text_block + } } @@ -126,6 +269,12 @@ export class ChatManagement { static addChatRecord(chat: chatType, ms: number, loading?: Ref) { this.chatMessageContainer[chat.id] = new ChatRecordManage(chat, ms, loading) } + static appendChunk(chatRecordId: string, chunk: Chunk) { + const chatRecord = this.chatMessageContainer[chatRecordId] + if (chatRecord) { + chatRecord.appendChunk(chunk) + } + } static append(chatRecordId: string, content: string) { const chatRecord = this.chatMessageContainer[chatRecordId] if (chatRecord) { diff --git a/ui/src/components/ai-chat/ExecutionDetailDialog.vue b/ui/src/components/ai-chat/ExecutionDetailDialog.vue index c39e95c60..fc89c08bd 100644 --- a/ui/src/components/ai-chat/ExecutionDetailDialog.vue +++ b/ui/src/components/ai-chat/ExecutionDetailDialog.vue @@ -341,7 +341,7 @@
{{ f.label.label }}: - {{ item.form_data[f.field] }} + {{ (item.form_data?item.form_data:{})[f.field] }}
@@ -467,7 +467,6 @@ watch(dialogVisible, (bool) => { const open = (data: any) => { detail.value = cloneDeep(data) - console.log(detail.value) dialogVisible.value = true } onBeforeUnmount(() => { diff --git a/ui/src/components/ai-chat/component/answer-content/index.vue b/ui/src/components/ai-chat/component/answer-content/index.vue index 1a75737bb..a0f1ee0d4 100644 --- a/ui/src/components/ai-chat/component/answer-content/index.vue +++ b/ui/src/components/ai-chat/component/answer-content/index.vue @@ -14,6 +14,7 @@ source=" 抱歉,没有查找到相关内容,请重新描述您的问题或提供更多信息。" > { if (type === 'old') { - props.chatRecord.answer_text_list.push('') + add_answer_text_list(props.chatRecord.answer_text_list) props.sendMessage(question, other_params_data, props.chatRecord) props.chatManagement.write(props.chatRecord.id) } else { props.sendMessage(question, other_params_data) } } +const add_answer_text_list = (answer_text_list: Array) => { + answer_text_list.push('') +} function showSource(row: any) { if (props.type === 'log') { @@ -78,8 +82,8 @@ function showSource(row: any) { } return false } -const regenerationChart = (question: string) => { - props.sendMessage(question, { rechat: true }) +const regenerationChart = (chat: chatType) => { + props.sendMessage(chat.problem_text, { rechat: true }) } const stopChat = (chat: chatType) => { props.chatManagement.stop(chat.id) diff --git a/ui/src/components/ai-chat/index.vue b/ui/src/components/ai-chat/index.vue index cc216e4f5..e04ab70e4 100644 --- a/ui/src/components/ai-chat/index.vue +++ b/ui/src/components/ai-chat/index.vue @@ -223,10 +223,8 @@ const getWrite = (chat: any, reader: any, stream: boolean) => { const chunk = JSON?.parse(split[index].replace('data:', '')) chat.chat_id = chunk.chat_id chat.record_id = chunk.id - const content = chunk?.content - if (content) { - ChatManagement.append(chat.id, content) - } + ChatManagement.appendChunk(chat.id, chunk) + if (chunk.is_end) { // 流处理成功 返回成功回调 return Promise.resolve() @@ -306,6 +304,10 @@ function chatMessage(chat?: any, problem?: string, re_chat?: boolean, other_para scrollDiv.value.setScrollTop(getMaxHeight()) }) } + if (chat.run_time) { + ChatManagement.addChatRecord(chat, 50, loading) + ChatManagement.write(chat.id) + } if (!chartOpenId.value) { getChartOpenId(chat).catch(() => { errorWrite(chat) diff --git a/ui/src/components/dynamics-form/constructor/items/JsonInputConstructor.vue b/ui/src/components/dynamics-form/constructor/items/JsonInputConstructor.vue index 10eb95613..5121bd020 100644 --- a/ui/src/components/dynamics-form/constructor/items/JsonInputConstructor.vue +++ b/ui/src/components/dynamics-form/constructor/items/JsonInputConstructor.vue @@ -14,7 +14,6 @@ import JsonInput from '@/components/dynamics-form/items/JsonInput.vue' const props = defineProps<{ modelValue: any }>() -const formField = ref({}) const emit = defineEmits(['update:modelValue']) const formValue = computed({ set: (item) => { diff --git a/ui/src/components/dynamics-form/constructor/items/MultiSelectConstructor.vue b/ui/src/components/dynamics-form/constructor/items/MultiSelectConstructor.vue index 253ab160a..4a5fa18ef 100644 --- a/ui/src/components/dynamics-form/constructor/items/MultiSelectConstructor.vue +++ b/ui/src/components/dynamics-form/constructor/items/MultiSelectConstructor.vue @@ -63,7 +63,7 @@ @@ -102,8 +102,8 @@ const getData = () => { input_type: 'MultiSelect', attrs: {}, default_value: formValue.value.default_value, - textField: 'label', - valueField: 'value', + text_field: 'label', + value_field: 'value', option_list: formValue.value.option_list } } @@ -116,6 +116,8 @@ defineExpose({ getData, rander }) onMounted(() => { formValue.value.option_list = [] formValue.value.default_value = '' + + addOption() })