diff --git a/apps/application/flow/i_step_node.py b/apps/application/flow/i_step_node.py index d3c6207aa..caf25836f 100644 --- a/apps/application/flow/i_step_node.py +++ b/apps/application/flow/i_step_node.py @@ -40,6 +40,10 @@ def write_context(step_variable: Dict, global_variable: Dict, node, workflow): node.context['run_time'] = time.time() - node.context['start_time'] +def is_interrupt(node, step_variable: Dict, global_variable: Dict): + return node.type == 'form-node' and not node.context.get('is_submit', False) + + class WorkFlowPostHandler: def __init__(self, chat_info, client_id, client_type): self.chat_info = chat_info @@ -57,7 +61,7 @@ class WorkFlowPostHandler: answer_tokens = sum([row.get('answer_tokens') for row in details.values() if 'answer_tokens' in row and row.get('answer_tokens') is not None]) answer_text_list = workflow.get_answer_text_list() - answer_text = '\n\n'.join(answer_text_list) + answer_text = '\n\n'.join(answer['content'] for answer in answer_text_list) if workflow.chat_record is not None: chat_record = workflow.chat_record chat_record.answer_text = answer_text @@ -91,10 +95,11 @@ class WorkFlowPostHandler: class NodeResult: def __init__(self, node_variable: Dict, workflow_variable: Dict, - _write_context=write_context): + _write_context=write_context, _is_interrupt=is_interrupt): self._write_context = _write_context self.node_variable = node_variable self.workflow_variable = workflow_variable + self._is_interrupt = _is_interrupt def write_context(self, node, workflow): return self._write_context(self.node_variable, self.workflow_variable, node, workflow) @@ -102,6 +107,14 @@ class NodeResult: def is_assertion_result(self): return 'branch_id' in self.node_variable + def is_interrupt_exec(self, current_node): + """ + 是否中断执行 + @param current_node: + @return: + """ + return self._is_interrupt(current_node, self.node_variable, self.workflow_variable) + class ReferenceAddressSerializer(serializers.Serializer): node_id = serializers.CharField(required=True, error_messages=ErrMessage.char("节点id")) @@ -139,14 +152,18 @@ class INode: pass def get_answer_text(self): - return self.answer_text + if self.answer_text is None: + return None + return {'content': self.answer_text, 'runtime_node_id': self.runtime_node_id, + 'chat_record_id': self.workflow_params['chat_record_id']} - def __init__(self, node, workflow_params, workflow_manage, up_node_id_list=None): + def __init__(self, node, workflow_params, workflow_manage, up_node_id_list=None, + get_node_params=lambda node: node.properties.get('node_data')): # 当前步骤上下文,用于存储当前步骤信息 self.status = 200 self.err_message = '' self.node = node - self.node_params = node.properties.get('node_data') + self.node_params = get_node_params(node) self.workflow_params = workflow_params self.workflow_manage = workflow_manage self.node_params_serializer = None diff --git a/apps/application/flow/step_node/application_node/i_application_node.py b/apps/application/flow/step_node/application_node/i_application_node.py index d0184d4a3..8c4675ea7 100644 --- a/apps/application/flow/step_node/application_node/i_application_node.py +++ b/apps/application/flow/step_node/application_node/i_application_node.py @@ -14,6 +14,8 @@ class ApplicationNodeSerializer(serializers.Serializer): user_input_field_list = serializers.ListField(required=False, error_messages=ErrMessage.uuid("用户输入字段")) image_list = serializers.ListField(required=False, error_messages=ErrMessage.list("图片")) document_list = serializers.ListField(required=False, error_messages=ErrMessage.list("文档")) + child_node = serializers.DictField(required=False, allow_null=True, error_messages=ErrMessage.dict("子节点")) + node_data = serializers.DictField(required=False, allow_null=True, error_messages=ErrMessage.dict("表单数据")) class IApplicationNode(INode): @@ -55,5 +57,5 @@ class IApplicationNode(INode): message=str(question), **kwargs) def execute(self, application_id, message, chat_id, chat_record_id, stream, re_chat, client_id, client_type, - app_document_list=None, app_image_list=None, **kwargs) -> NodeResult: + app_document_list=None, app_image_list=None, child_node=None, node_data=None, **kwargs) -> NodeResult: pass diff --git a/apps/application/flow/step_node/application_node/impl/base_application_node.py b/apps/application/flow/step_node/application_node/impl/base_application_node.py index 30954a2b6..532a81e52 100644 --- a/apps/application/flow/step_node/application_node/impl/base_application_node.py +++ b/apps/application/flow/step_node/application_node/impl/base_application_node.py @@ -2,19 +2,25 @@ import json import time import uuid -from typing import List, Dict +from typing import Dict + from application.flow.i_step_node import NodeResult, INode from application.flow.step_node.application_node.i_application_node import IApplicationNode from application.models import Chat -from common.handle.impl.response.openai_to_response import OpenaiToResponse def string_to_uuid(input_str): return str(uuid.uuid5(uuid.NAMESPACE_DNS, input_str)) +def _is_interrupt_exec(node, node_variable: Dict, workflow_variable: Dict): + return node_variable.get('is_interrupt_exec', False) + + def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str): result = node_variable.get('result') + node.context['child_node'] = node_variable['child_node'] + node.context['is_interrupt_exec'] = node_variable['is_interrupt_exec'] node.context['message_tokens'] = result.get('usage', {}).get('prompt_tokens', 0) node.context['answer_tokens'] = result.get('usage', {}).get('completion_tokens', 0) node.context['answer'] = answer @@ -36,17 +42,34 @@ def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INo response = node_variable.get('result') answer = '' usage = {} + node_child_node = {} + is_interrupt_exec = False for chunk in response: # 先把流转成字符串 response_content = chunk.decode('utf-8')[6:] response_content = json.loads(response_content) - choices = response_content.get('choices') - if choices and isinstance(choices, list) and len(choices) > 0: - content = choices[0].get('delta', {}).get('content', '') - answer += content - yield content + content = response_content.get('content', '') + runtime_node_id = response_content.get('runtime_node_id', '') + chat_record_id = response_content.get('chat_record_id', '') + child_node = response_content.get('child_node') + node_type = response_content.get('node_type') + real_node_id = response_content.get('real_node_id') + node_is_end = response_content.get('node_is_end', False) + if node_type == 'form-node': + is_interrupt_exec = True + answer += content + node_child_node = {'runtime_node_id': runtime_node_id, 'chat_record_id': chat_record_id, + 'child_node': child_node} + yield {'content': content, + 'node_type': node_type, + 'runtime_node_id': runtime_node_id, 'chat_record_id': chat_record_id, + 'child_node': child_node, + 'real_node_id': real_node_id, + 'node_is_end': node_is_end} usage = response_content.get('usage', {}) node_variable['result'] = {'usage': usage} + node_variable['is_interrupt_exec'] = is_interrupt_exec + node_variable['child_node'] = node_child_node _write_context(node_variable, workflow_variable, node, workflow, answer) @@ -64,6 +87,11 @@ def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wor class BaseApplicationNode(IApplicationNode): + def get_answer_text(self): + if self.answer_text is None: + return None + return {'content': self.answer_text, 'runtime_node_id': self.runtime_node_id, + 'chat_record_id': self.workflow_params['chat_record_id'], 'child_node': self.context.get('child_node')} def save_context(self, details, workflow_manage): self.context['answer'] = details.get('answer') @@ -72,7 +100,7 @@ class BaseApplicationNode(IApplicationNode): self.answer_text = details.get('answer') def execute(self, application_id, message, chat_id, chat_record_id, stream, re_chat, client_id, client_type, - app_document_list=None, app_image_list=None, + app_document_list=None, app_image_list=None, child_node=None, node_data=None, **kwargs) -> NodeResult: from application.serializers.chat_message_serializers import ChatMessageSerializer # 生成嵌入应用的chat_id @@ -85,6 +113,14 @@ class BaseApplicationNode(IApplicationNode): app_document_list = [] if app_image_list is None: app_image_list = [] + runtime_node_id = None + record_id = None + child_node_value = None + if child_node is not None: + runtime_node_id = child_node.get('runtime_node_id') + record_id = child_node.get('chat_record_id') + child_node_value = child_node.get('child_node') + response = ChatMessageSerializer( data={'chat_id': current_chat_id, 'message': message, 're_chat': re_chat, @@ -94,16 +130,20 @@ class BaseApplicationNode(IApplicationNode): 'client_type': client_type, 'document_list': app_document_list, 'image_list': app_image_list, - 'form_data': kwargs}).chat(base_to_response=OpenaiToResponse()) + 'runtime_node_id': runtime_node_id, + 'chat_record_id': record_id, + 'child_node': child_node_value, + 'node_data': node_data, + 'form_data': kwargs}).chat() if response.status_code == 200: if stream: content_generator = response.streaming_content return NodeResult({'result': content_generator, 'question': message}, {}, - _write_context=write_context_stream) + _write_context=write_context_stream, _is_interrupt=_is_interrupt_exec) else: data = json.loads(response.content) return NodeResult({'result': data, 'question': message}, {}, - _write_context=write_context) + _write_context=write_context, _is_interrupt=_is_interrupt_exec) def get_details(self, index: int, **kwargs): global_fields = [] diff --git a/apps/application/flow/step_node/form_node/i_form_node.py b/apps/application/flow/step_node/form_node/i_form_node.py index b793a5b78..cfd178ded 100644 --- a/apps/application/flow/step_node/form_node/i_form_node.py +++ b/apps/application/flow/step_node/form_node/i_form_node.py @@ -17,6 +17,7 @@ from common.util.field_message import ErrMessage class FormNodeParamsSerializer(serializers.Serializer): form_field_list = serializers.ListField(required=True, error_messages=ErrMessage.list("表单配置")) form_content_format = serializers.CharField(required=True, error_messages=ErrMessage.char('表单输出内容')) + form_data = serializers.DictField(required=False, allow_null=True, error_messages=ErrMessage.dict("表单数据")) class IFormNode(INode): @@ -29,5 +30,5 @@ class IFormNode(INode): def _run(self): return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data) - def execute(self, form_field_list, form_content_format, **kwargs) -> NodeResult: + def execute(self, form_field_list, form_content_format, form_data, **kwargs) -> NodeResult: pass diff --git a/apps/application/flow/step_node/form_node/impl/base_form_node.py b/apps/application/flow/step_node/form_node/impl/base_form_node.py index ae86b11d6..db179f588 100644 --- a/apps/application/flow/step_node/form_node/impl/base_form_node.py +++ b/apps/application/flow/step_node/form_node/impl/base_form_node.py @@ -42,7 +42,12 @@ class BaseFormNode(IFormNode): for key in form_data: self.context[key] = form_data[key] - def execute(self, form_field_list, form_content_format, **kwargs) -> NodeResult: + def execute(self, form_field_list, form_content_format, form_data, **kwargs) -> NodeResult: + if form_data is not None: + self.context['is_submit'] = True + self.context['form_data'] = form_data + else: + self.context['is_submit'] = False form_setting = {"form_field_list": form_field_list, "runtime_node_id": self.runtime_node_id, "chat_record_id": self.flow_params_serializer.data.get("chat_record_id"), "is_submit": self.context.get("is_submit", False)} @@ -63,7 +68,8 @@ class BaseFormNode(IFormNode): form = f'{json.dumps(form_setting)}' prompt_template = PromptTemplate.from_template(form_content_format, template_format='jinja2') value = prompt_template.format(form=form) - return value + return {'content': value, 'runtime_node_id': self.runtime_node_id, + 'chat_record_id': self.workflow_params['chat_record_id']} def get_details(self, index: int, **kwargs): form_content_format = self.context.get('form_content_format') diff --git a/apps/application/flow/workflow_manage.py b/apps/application/flow/workflow_manage.py index c62612c0a..60fad57c9 100644 --- a/apps/application/flow/workflow_manage.py +++ b/apps/application/flow/workflow_manage.py @@ -244,15 +244,15 @@ class WorkflowManage: base_to_response: BaseToResponse = SystemToResponse(), form_data=None, image_list=None, document_list=None, start_node_id=None, - start_node_data=None, chat_record=None): + start_node_data=None, chat_record=None, child_node=None): if form_data is None: form_data = {} if image_list is None: image_list = [] if document_list is None: document_list = [] + self.start_node_id = start_node_id self.start_node = None - self.start_node_result_future = None self.form_data = form_data self.image_list = image_list self.document_list = document_list @@ -270,6 +270,7 @@ class WorkflowManage: self.base_to_response = base_to_response self.chat_record = chat_record self.await_future_map = {} + self.child_node = child_node if start_node_id is not None: self.load_node(chat_record, start_node_id, start_node_data) else: @@ -290,11 +291,17 @@ class WorkflowManage: for node_details in sorted(chat_record.details.values(), key=lambda d: d.get('index')): node_id = node_details.get('node_id') if node_details.get('runtime_node_id') == start_node_id: - self.start_node = self.get_node_cls_by_id(node_id, node_details.get('up_node_id_list')) - self.start_node.valid_args(self.start_node.node_params, self.start_node.workflow_params) - self.start_node.save_context(node_details, self) - node_result = NodeResult({**start_node_data, 'form_data': start_node_data, 'is_submit': True}, {}) - self.start_node_result_future = NodeResultFuture(node_result, None) + def get_node_params(n): + is_result = False + if n.type == 'application-node': + is_result = True + return {**n.properties.get('node_data'), 'form_data': start_node_data, 'node_data': start_node_data, + 'child_node': self.child_node, 'is_result': is_result} + + self.start_node = self.get_node_cls_by_id(node_id, node_details.get('up_node_id_list'), + get_node_params=get_node_params) + self.start_node.valid_args( + {**self.start_node.node_params, 'form_data': start_node_data}, self.start_node.workflow_params) self.node_context.append(self.start_node) continue @@ -306,7 +313,7 @@ class WorkflowManage: def run(self): if self.params.get('stream'): - return self.run_stream(self.start_node, self.start_node_result_future) + return self.run_stream(self.start_node, None) return self.run_block() def run_block(self): @@ -429,22 +436,41 @@ class WorkflowManage: if result is not None: if self.is_result(current_node, current_result): self.node_chunk_manage.add_node_chunk(node_chunk) + child_node = {} + real_node_id = current_node.runtime_node_id for r in result: + content = r + child_node = {} + node_is_end = False + if isinstance(r, dict): + content = r.get('content') + child_node = {'runtime_node_id': r.get('runtime_node_id'), + 'chat_record_id': r.get('chat_record_id') + , 'child_node': r.get('child_node')} + real_node_id = r.get('real_node_id') + node_is_end = r.get('node_is_end') chunk = self.base_to_response.to_stream_chunk_response(self.params['chat_id'], self.params['chat_record_id'], current_node.id, current_node.up_node_id_list, - r, False, 0, 0, + content, False, 0, 0, {'node_type': current_node.type, - 'view_type': current_node.view_type}) + 'runtime_node_id': current_node.runtime_node_id, + 'view_type': current_node.view_type, + 'child_node': child_node, + 'node_is_end': node_is_end, + 'real_node_id': real_node_id}) node_chunk.add_chunk(chunk) chunk = self.base_to_response.to_stream_chunk_response(self.params['chat_id'], self.params['chat_record_id'], current_node.id, current_node.up_node_id_list, '', False, 0, 0, {'node_is_end': True, + 'runtime_node_id': current_node.runtime_node_id, 'node_type': current_node.type, - 'view_type': current_node.view_type}) + 'view_type': current_node.view_type, + 'child_node': child_node, + 'real_node_id': real_node_id}) node_chunk.end(chunk) else: list(result) @@ -554,9 +580,9 @@ class WorkflowManage: else: if len(result) > 0: exec_index = len(result) - 1 - content = result[exec_index] - result[exec_index] += answer_text if len( - content) == 0 else ('\n\n' + answer_text) + content = result[exec_index]['content'] + result[exec_index]['content'] += answer_text['content'] if len( + content) == 0 else ('\n\n' + answer_text['content']) else: answer_text = node.get_answer_text() result.insert(0, answer_text) @@ -613,8 +639,8 @@ class WorkflowManage: @param current_node_result: 当前可执行节点结果 @return: 可执行节点列表 """ - - if current_node.type == 'form-node' and 'form_data' not in current_node_result.node_variable: + # 判断是否中断执行 + if current_node_result.is_interrupt_exec(current_node): return [] node_list = [] if current_node_result is not None and current_node_result.is_assertion_result(): @@ -689,11 +715,12 @@ class WorkflowManage: base_node_list = [node for node in self.flow.nodes if node.type == 'base-node'] return base_node_list[0] - def get_node_cls_by_id(self, node_id, up_node_id_list=None): + def get_node_cls_by_id(self, node_id, up_node_id_list=None, + get_node_params=lambda node: node.properties.get('node_data')): for node in self.flow.nodes: if node.id == node_id: node_instance = get_node(node.type)(node, - self.params, self, up_node_id_list) + self.params, self, up_node_id_list, get_node_params) return node_instance return None diff --git a/apps/application/migrations/0019_application_file_upload_enable_and_more.py b/apps/application/migrations/0019_application_file_upload_enable_and_more.py index f8c33c2c1..7b12321f1 100644 --- a/apps/application/migrations/0019_application_file_upload_enable_and_more.py +++ b/apps/application/migrations/0019_application_file_upload_enable_and_more.py @@ -4,8 +4,8 @@ import django.contrib.postgres.fields from django.db import migrations, models sql = """ -UPDATE "public".application_chat_record -SET "answer_text_list" = ARRAY[answer_text]; +UPDATE application_chat_record +SET answer_text_list=ARRAY[jsonb_build_object('content',answer_text)] """ @@ -28,8 +28,7 @@ class Migration(migrations.Migration): migrations.AddField( model_name='chatrecord', name='answer_text_list', - field=django.contrib.postgres.fields.ArrayField(base_field=models.CharField(max_length=40960), default=list, - size=None, verbose_name='改进标注列表'), + field=django.contrib.postgres.fields.ArrayField(base_field=models.JSONField(), default=list, size=None, verbose_name='改进标注列表') ), migrations.RunSQL(sql) ] diff --git a/apps/application/models/application.py b/apps/application/models/application.py index c90d1325c..6ed33c48c 100644 --- a/apps/application/models/application.py +++ b/apps/application/models/application.py @@ -69,7 +69,6 @@ class Application(AppModelMixin): file_upload_enable = models.BooleanField(verbose_name="文件上传是否启用", default=False) file_upload_setting = models.JSONField(verbose_name="文件上传相关设置", default=dict) - @staticmethod def get_default_model_prompt(): return ('已知信息:' @@ -148,7 +147,7 @@ class ChatRecord(AppModelMixin): problem_text = models.CharField(max_length=10240, verbose_name="问题") answer_text = models.CharField(max_length=40960, verbose_name="答案") answer_text_list = ArrayField(verbose_name="改进标注列表", - base_field=models.CharField(max_length=40960) + base_field=models.JSONField() , default=list) message_tokens = models.IntegerField(verbose_name="请求token数量", default=0) answer_tokens = models.IntegerField(verbose_name="响应token数量", default=0) diff --git a/apps/application/serializers/chat_message_serializers.py b/apps/application/serializers/chat_message_serializers.py index 5f38cfc7d..c6374c914 100644 --- a/apps/application/serializers/chat_message_serializers.py +++ b/apps/application/serializers/chat_message_serializers.py @@ -238,13 +238,14 @@ class ChatMessageSerializer(serializers.Serializer): runtime_node_id = serializers.CharField(required=False, allow_null=True, allow_blank=True, error_messages=ErrMessage.char("运行时节点id")) - node_data = serializers.DictField(required=False, error_messages=ErrMessage.char("节点参数")) + node_data = serializers.DictField(required=False, allow_null=True, error_messages=ErrMessage.char("节点参数")) application_id = serializers.UUIDField(required=False, allow_null=True, error_messages=ErrMessage.uuid("应用id")) client_id = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端id")) client_type = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端类型")) form_data = serializers.DictField(required=False, error_messages=ErrMessage.char("全局变量")) image_list = serializers.ListField(required=False, error_messages=ErrMessage.list("图片")) document_list = serializers.ListField(required=False, error_messages=ErrMessage.list("文档")) + child_node = serializers.DictField(required=False, allow_null=True, error_messages=ErrMessage.dict("子节点")) def is_valid_application_workflow(self, *, raise_exception=False): self.is_valid_intraday_access_num() @@ -353,7 +354,7 @@ class ChatMessageSerializer(serializers.Serializer): 'user_id': user_id}, WorkFlowPostHandler(chat_info, client_id, client_type), base_to_response, form_data, image_list, document_list, self.data.get('runtime_node_id'), - self.data.get('node_data'), chat_record) + self.data.get('node_data'), chat_record, self.data.get('child_node')) r = work_flow_manage.run() return r diff --git a/apps/application/views/chat_views.py b/apps/application/views/chat_views.py index a3ef50766..ccd727c8b 100644 --- a/apps/application/views/chat_views.py +++ b/apps/application/views/chat_views.py @@ -138,7 +138,8 @@ class ChatView(APIView): 'node_id': request.data.get('node_id', None), 'runtime_node_id': request.data.get('runtime_node_id', None), 'node_data': request.data.get('node_data', {}), - 'chat_record_id': request.data.get('chat_record_id')} + 'chat_record_id': request.data.get('chat_record_id'), + 'child_node': request.data.get('child_node')} ).chat() @action(methods=['GET'], detail=False) diff --git a/apps/common/handle/impl/response/openai_to_response.py b/apps/common/handle/impl/response/openai_to_response.py index 0ce605514..7897a8c60 100644 --- a/apps/common/handle/impl/response/openai_to_response.py +++ b/apps/common/handle/impl/response/openai_to_response.py @@ -35,7 +35,7 @@ class OpenaiToResponse(BaseToResponse): def to_stream_chunk_response(self, chat_id, chat_record_id, node_id, up_node_id_list, content, is_end, completion_tokens, prompt_tokens, other_params: dict = None): chunk = ChatCompletionChunk(id=chat_record_id, model='', object='chat.completion.chunk', - created=datetime.datetime.now().second, choices=[ + created=datetime.datetime.now().second,choices=[ Choice(delta=ChoiceDelta(content=content, chat_id=chat_id), finish_reason='stop' if is_end else None, index=0)], usage=CompletionUsage(completion_tokens=completion_tokens, diff --git a/apps/common/handle/impl/response/system_to_response.py b/apps/common/handle/impl/response/system_to_response.py index 7c374ef01..5999f1297 100644 --- a/apps/common/handle/impl/response/system_to_response.py +++ b/apps/common/handle/impl/response/system_to_response.py @@ -28,7 +28,7 @@ class SystemToResponse(BaseToResponse): prompt_tokens, other_params: dict = None): if other_params is None: other_params = {} - chunk = json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True, + chunk = json.dumps({'chat_id': str(chat_id), 'chat_record_id': str(chat_record_id), 'operate': True, 'content': content, 'node_id': node_id, 'up_node_id_list': up_node_id_list, 'is_end': is_end, 'usage': {'completion_tokens': completion_tokens, 'prompt_tokens': prompt_tokens, diff --git a/apps/embedding/task/embedding.py b/apps/embedding/task/embedding.py index 7087636c2..7c39ebcdd 100644 --- a/apps/embedding/task/embedding.py +++ b/apps/embedding/task/embedding.py @@ -6,7 +6,6 @@ @date:2024/8/19 14:13 @desc: """ -import datetime import logging import traceback from typing import List @@ -17,7 +16,7 @@ from django.db.models import QuerySet from common.config.embedding_config import ModelManage from common.event import ListenerManagement, UpdateProblemArgs, UpdateEmbeddingDatasetIdArgs, \ UpdateEmbeddingDocumentIdArgs -from dataset.models import Document, Status, TaskType, State +from dataset.models import Document, TaskType, State from ops import celery_app from setting.models import Model from setting.models_provider import get_model diff --git a/ui/src/api/type/application.ts b/ui/src/api/type/application.ts index b45a6783f..95812bbab 100644 --- a/ui/src/api/type/application.ts +++ b/ui/src/api/type/application.ts @@ -23,8 +23,9 @@ interface ApplicationFormType { tts_type?: string } interface Chunk { + real_node_id: string chat_id: string - id: string + chat_record_id: string content: string node_id: string up_node_id: string @@ -32,13 +33,20 @@ interface Chunk { node_is_end: boolean node_type: string view_type: string + runtime_node_id: string + child_node: any } interface chatType { id: string problem_text: string answer_text: string buffer: Array - answer_text_list: Array + answer_text_list: Array<{ + content: string + chat_record_id?: string + runtime_node_id?: string + child_node?: any + }> /** * 是否写入结束 */ @@ -92,15 +100,24 @@ export class ChatRecordManage { this.write_ed = false this.node_list = [] } - append_answer(chunk_answer: string, index?: number) { - this.chat.answer_text_list[index != undefined ? index : this.chat.answer_text_list.length - 1] = - this.chat.answer_text_list[ - index !== undefined ? index : this.chat.answer_text_list.length - 1 - ] - ? this.chat.answer_text_list[ - index !== undefined ? index : this.chat.answer_text_list.length - 1 - ] + chunk_answer - : chunk_answer + append_answer( + chunk_answer: string, + index?: number, + chat_record_id?: string, + runtime_node_id?: string, + child_node?: any + ) { + const set_index = index != undefined ? index : this.chat.answer_text_list.length - 1 + const content = this.chat.answer_text_list[set_index] + ? this.chat.answer_text_list[set_index].content + chunk_answer + : chunk_answer + this.chat.answer_text_list[set_index] = { + content: content, + chat_record_id, + runtime_node_id, + child_node + } + this.chat.answer_text = this.chat.answer_text + chunk_answer } @@ -127,14 +144,22 @@ export class ChatRecordManage { run_node.view_type == 'single_view' || (run_node.view_type == 'many_view' && current_up_node.view_type == 'single_view') ) { - const none_index = this.chat.answer_text_list.indexOf('') + const none_index = this.findIndex( + this.chat.answer_text_list, + (item) => item.content == '', + 'index' + ) if (none_index > -1) { answer_text_list_index = none_index } else { answer_text_list_index = this.chat.answer_text_list.length } } else { - const none_index = this.chat.answer_text_list.indexOf('') + const none_index = this.findIndex( + this.chat.answer_text_list, + (item) => item.content === '', + 'index' + ) if (none_index > -1) { answer_text_list_index = none_index } else { @@ -152,6 +177,19 @@ export class ChatRecordManage { } return undefined } + findIndex(array: Array, find: (item: T) => boolean, type: 'last' | 'index') { + let set_index = -1 + for (let index = 0; index < array.length; index++) { + const element = array[index] + if (find(element)) { + set_index = index + if (type == 'index') { + break + } + } + } + return set_index + } closeInterval() { this.chat.write_ed = true this.write_ed = true @@ -161,7 +199,11 @@ export class ChatRecordManage { if (this.id) { clearInterval(this.id) } - const last_index = this.chat.answer_text_list.lastIndexOf('') + const last_index = this.findIndex( + this.chat.answer_text_list, + (item) => item.content == '', + 'last' + ) if (last_index > 0) { this.chat.answer_text_list.splice(last_index, 1) } @@ -193,19 +235,29 @@ export class ChatRecordManage { ) this.append_answer( (divider_content ? divider_content.splice(0).join('') : '') + context.join(''), - answer_text_list_index + answer_text_list_index, + current_node.chat_record_id, + current_node.runtime_node_id, + current_node.child_node ) } else if (this.is_close) { while (true) { const node_info = this.get_run_node() + if (node_info == undefined) { break } this.append_answer( (node_info.divider_content ? node_info.divider_content.splice(0).join('') : '') + node_info.current_node.buffer.splice(0).join(''), - node_info.answer_text_list_index + node_info.answer_text_list_index, + current_node.chat_record_id, + current_node.runtime_node_id, + current_node.child_node ) + if (node_info.current_node.buffer.length == 0) { + node_info.current_node.is_end = true + } } this.closeInterval() } else { @@ -213,7 +265,10 @@ export class ChatRecordManage { if (s !== undefined) { this.append_answer( (divider_content ? divider_content.splice(0).join('') : '') + s, - answer_text_list_index + answer_text_list_index, + current_node.chat_record_id, + current_node.runtime_node_id, + current_node.child_node ) } } @@ -235,16 +290,18 @@ export class ChatRecordManage { this.is_stop = false } appendChunk(chunk: Chunk) { - let n = this.node_list.find( - (item) => item.node_id == chunk.node_id && item.up_node_id === chunk.up_node_id - ) + let n = this.node_list.find((item) => item.real_node_id == chunk.real_node_id) if (n) { n.buffer.push(...chunk.content) } else { n = { buffer: [...chunk.content], + real_node_id: chunk.real_node_id, node_id: chunk.node_id, + chat_record_id: chunk.chat_record_id, up_node_id: chunk.up_node_id, + runtime_node_id: chunk.runtime_node_id, + child_node: chunk.child_node, node_type: chunk.node_type, index: this.node_list.length, view_type: chunk.view_type, @@ -257,9 +314,12 @@ export class ChatRecordManage { } } append(answer_text_block: string) { - const index =this.chat.answer_text_list.indexOf("") - this.chat.answer_text_list[index]=answer_text_block - + let set_index = this.findIndex( + this.chat.answer_text_list, + (item) => item.content == '', + 'index' + ) + this.chat.answer_text_list[set_index] = { content: answer_text_block } } } diff --git a/ui/src/components/ai-chat/component/answer-content/index.vue b/ui/src/components/ai-chat/component/answer-content/index.vue index ec4f13f2b..4d1c7a684 100644 --- a/ui/src/components/ai-chat/component/answer-content/index.vue +++ b/ui/src/components/ai-chat/component/answer-content/index.vue @@ -1,6 +1,6 @@