diff --git a/apps/application/flow/i_step_node.py b/apps/application/flow/i_step_node.py index d3c6207aa..caf25836f 100644 --- a/apps/application/flow/i_step_node.py +++ b/apps/application/flow/i_step_node.py @@ -40,6 +40,10 @@ def write_context(step_variable: Dict, global_variable: Dict, node, workflow): node.context['run_time'] = time.time() - node.context['start_time'] +def is_interrupt(node, step_variable: Dict, global_variable: Dict): + return node.type == 'form-node' and not node.context.get('is_submit', False) + + class WorkFlowPostHandler: def __init__(self, chat_info, client_id, client_type): self.chat_info = chat_info @@ -57,7 +61,7 @@ class WorkFlowPostHandler: answer_tokens = sum([row.get('answer_tokens') for row in details.values() if 'answer_tokens' in row and row.get('answer_tokens') is not None]) answer_text_list = workflow.get_answer_text_list() - answer_text = '\n\n'.join(answer_text_list) + answer_text = '\n\n'.join(answer['content'] for answer in answer_text_list) if workflow.chat_record is not None: chat_record = workflow.chat_record chat_record.answer_text = answer_text @@ -91,10 +95,11 @@ class WorkFlowPostHandler: class NodeResult: def __init__(self, node_variable: Dict, workflow_variable: Dict, - _write_context=write_context): + _write_context=write_context, _is_interrupt=is_interrupt): self._write_context = _write_context self.node_variable = node_variable self.workflow_variable = workflow_variable + self._is_interrupt = _is_interrupt def write_context(self, node, workflow): return self._write_context(self.node_variable, self.workflow_variable, node, workflow) @@ -102,6 +107,14 @@ class NodeResult: def is_assertion_result(self): return 'branch_id' in self.node_variable + def is_interrupt_exec(self, current_node): + """ + 是否中断执行 + @param current_node: + @return: + """ + return self._is_interrupt(current_node, self.node_variable, self.workflow_variable) + class ReferenceAddressSerializer(serializers.Serializer): node_id = serializers.CharField(required=True, error_messages=ErrMessage.char("节点id")) @@ -139,14 +152,18 @@ class INode: pass def get_answer_text(self): - return self.answer_text + if self.answer_text is None: + return None + return {'content': self.answer_text, 'runtime_node_id': self.runtime_node_id, + 'chat_record_id': self.workflow_params['chat_record_id']} - def __init__(self, node, workflow_params, workflow_manage, up_node_id_list=None): + def __init__(self, node, workflow_params, workflow_manage, up_node_id_list=None, + get_node_params=lambda node: node.properties.get('node_data')): # 当前步骤上下文,用于存储当前步骤信息 self.status = 200 self.err_message = '' self.node = node - self.node_params = node.properties.get('node_data') + self.node_params = get_node_params(node) self.workflow_params = workflow_params self.workflow_manage = workflow_manage self.node_params_serializer = None diff --git a/apps/application/flow/step_node/application_node/i_application_node.py b/apps/application/flow/step_node/application_node/i_application_node.py index d0184d4a3..8c4675ea7 100644 --- a/apps/application/flow/step_node/application_node/i_application_node.py +++ b/apps/application/flow/step_node/application_node/i_application_node.py @@ -14,6 +14,8 @@ class ApplicationNodeSerializer(serializers.Serializer): user_input_field_list = serializers.ListField(required=False, error_messages=ErrMessage.uuid("用户输入字段")) image_list = serializers.ListField(required=False, error_messages=ErrMessage.list("图片")) document_list = serializers.ListField(required=False, error_messages=ErrMessage.list("文档")) + child_node = serializers.DictField(required=False, allow_null=True, error_messages=ErrMessage.dict("子节点")) + node_data = serializers.DictField(required=False, allow_null=True, error_messages=ErrMessage.dict("表单数据")) class IApplicationNode(INode): @@ -55,5 +57,5 @@ class IApplicationNode(INode): message=str(question), **kwargs) def execute(self, application_id, message, chat_id, chat_record_id, stream, re_chat, client_id, client_type, - app_document_list=None, app_image_list=None, **kwargs) -> NodeResult: + app_document_list=None, app_image_list=None, child_node=None, node_data=None, **kwargs) -> NodeResult: pass diff --git a/apps/application/flow/step_node/application_node/impl/base_application_node.py b/apps/application/flow/step_node/application_node/impl/base_application_node.py index 30954a2b6..532a81e52 100644 --- a/apps/application/flow/step_node/application_node/impl/base_application_node.py +++ b/apps/application/flow/step_node/application_node/impl/base_application_node.py @@ -2,19 +2,25 @@ import json import time import uuid -from typing import List, Dict +from typing import Dict + from application.flow.i_step_node import NodeResult, INode from application.flow.step_node.application_node.i_application_node import IApplicationNode from application.models import Chat -from common.handle.impl.response.openai_to_response import OpenaiToResponse def string_to_uuid(input_str): return str(uuid.uuid5(uuid.NAMESPACE_DNS, input_str)) +def _is_interrupt_exec(node, node_variable: Dict, workflow_variable: Dict): + return node_variable.get('is_interrupt_exec', False) + + def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str): result = node_variable.get('result') + node.context['child_node'] = node_variable['child_node'] + node.context['is_interrupt_exec'] = node_variable['is_interrupt_exec'] node.context['message_tokens'] = result.get('usage', {}).get('prompt_tokens', 0) node.context['answer_tokens'] = result.get('usage', {}).get('completion_tokens', 0) node.context['answer'] = answer @@ -36,17 +42,34 @@ def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INo response = node_variable.get('result') answer = '' usage = {} + node_child_node = {} + is_interrupt_exec = False for chunk in response: # 先把流转成字符串 response_content = chunk.decode('utf-8')[6:] response_content = json.loads(response_content) - choices = response_content.get('choices') - if choices and isinstance(choices, list) and len(choices) > 0: - content = choices[0].get('delta', {}).get('content', '') - answer += content - yield content + content = response_content.get('content', '') + runtime_node_id = response_content.get('runtime_node_id', '') + chat_record_id = response_content.get('chat_record_id', '') + child_node = response_content.get('child_node') + node_type = response_content.get('node_type') + real_node_id = response_content.get('real_node_id') + node_is_end = response_content.get('node_is_end', False) + if node_type == 'form-node': + is_interrupt_exec = True + answer += content + node_child_node = {'runtime_node_id': runtime_node_id, 'chat_record_id': chat_record_id, + 'child_node': child_node} + yield {'content': content, + 'node_type': node_type, + 'runtime_node_id': runtime_node_id, 'chat_record_id': chat_record_id, + 'child_node': child_node, + 'real_node_id': real_node_id, + 'node_is_end': node_is_end} usage = response_content.get('usage', {}) node_variable['result'] = {'usage': usage} + node_variable['is_interrupt_exec'] = is_interrupt_exec + node_variable['child_node'] = node_child_node _write_context(node_variable, workflow_variable, node, workflow, answer) @@ -64,6 +87,11 @@ def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wor class BaseApplicationNode(IApplicationNode): + def get_answer_text(self): + if self.answer_text is None: + return None + return {'content': self.answer_text, 'runtime_node_id': self.runtime_node_id, + 'chat_record_id': self.workflow_params['chat_record_id'], 'child_node': self.context.get('child_node')} def save_context(self, details, workflow_manage): self.context['answer'] = details.get('answer') @@ -72,7 +100,7 @@ class BaseApplicationNode(IApplicationNode): self.answer_text = details.get('answer') def execute(self, application_id, message, chat_id, chat_record_id, stream, re_chat, client_id, client_type, - app_document_list=None, app_image_list=None, + app_document_list=None, app_image_list=None, child_node=None, node_data=None, **kwargs) -> NodeResult: from application.serializers.chat_message_serializers import ChatMessageSerializer # 生成嵌入应用的chat_id @@ -85,6 +113,14 @@ class BaseApplicationNode(IApplicationNode): app_document_list = [] if app_image_list is None: app_image_list = [] + runtime_node_id = None + record_id = None + child_node_value = None + if child_node is not None: + runtime_node_id = child_node.get('runtime_node_id') + record_id = child_node.get('chat_record_id') + child_node_value = child_node.get('child_node') + response = ChatMessageSerializer( data={'chat_id': current_chat_id, 'message': message, 're_chat': re_chat, @@ -94,16 +130,20 @@ class BaseApplicationNode(IApplicationNode): 'client_type': client_type, 'document_list': app_document_list, 'image_list': app_image_list, - 'form_data': kwargs}).chat(base_to_response=OpenaiToResponse()) + 'runtime_node_id': runtime_node_id, + 'chat_record_id': record_id, + 'child_node': child_node_value, + 'node_data': node_data, + 'form_data': kwargs}).chat() if response.status_code == 200: if stream: content_generator = response.streaming_content return NodeResult({'result': content_generator, 'question': message}, {}, - _write_context=write_context_stream) + _write_context=write_context_stream, _is_interrupt=_is_interrupt_exec) else: data = json.loads(response.content) return NodeResult({'result': data, 'question': message}, {}, - _write_context=write_context) + _write_context=write_context, _is_interrupt=_is_interrupt_exec) def get_details(self, index: int, **kwargs): global_fields = [] diff --git a/apps/application/flow/step_node/form_node/i_form_node.py b/apps/application/flow/step_node/form_node/i_form_node.py index b793a5b78..cfd178ded 100644 --- a/apps/application/flow/step_node/form_node/i_form_node.py +++ b/apps/application/flow/step_node/form_node/i_form_node.py @@ -17,6 +17,7 @@ from common.util.field_message import ErrMessage class FormNodeParamsSerializer(serializers.Serializer): form_field_list = serializers.ListField(required=True, error_messages=ErrMessage.list("表单配置")) form_content_format = serializers.CharField(required=True, error_messages=ErrMessage.char('表单输出内容')) + form_data = serializers.DictField(required=False, allow_null=True, error_messages=ErrMessage.dict("表单数据")) class IFormNode(INode): @@ -29,5 +30,5 @@ class IFormNode(INode): def _run(self): return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data) - def execute(self, form_field_list, form_content_format, **kwargs) -> NodeResult: + def execute(self, form_field_list, form_content_format, form_data, **kwargs) -> NodeResult: pass diff --git a/apps/application/flow/step_node/form_node/impl/base_form_node.py b/apps/application/flow/step_node/form_node/impl/base_form_node.py index ae86b11d6..db179f588 100644 --- a/apps/application/flow/step_node/form_node/impl/base_form_node.py +++ b/apps/application/flow/step_node/form_node/impl/base_form_node.py @@ -42,7 +42,12 @@ class BaseFormNode(IFormNode): for key in form_data: self.context[key] = form_data[key] - def execute(self, form_field_list, form_content_format, **kwargs) -> NodeResult: + def execute(self, form_field_list, form_content_format, form_data, **kwargs) -> NodeResult: + if form_data is not None: + self.context['is_submit'] = True + self.context['form_data'] = form_data + else: + self.context['is_submit'] = False form_setting = {"form_field_list": form_field_list, "runtime_node_id": self.runtime_node_id, "chat_record_id": self.flow_params_serializer.data.get("chat_record_id"), "is_submit": self.context.get("is_submit", False)} @@ -63,7 +68,8 @@ class BaseFormNode(IFormNode): form = f'{json.dumps(form_setting)}' prompt_template = PromptTemplate.from_template(form_content_format, template_format='jinja2') value = prompt_template.format(form=form) - return value + return {'content': value, 'runtime_node_id': self.runtime_node_id, + 'chat_record_id': self.workflow_params['chat_record_id']} def get_details(self, index: int, **kwargs): form_content_format = self.context.get('form_content_format') diff --git a/apps/application/flow/workflow_manage.py b/apps/application/flow/workflow_manage.py index c62612c0a..37efc5467 100644 --- a/apps/application/flow/workflow_manage.py +++ b/apps/application/flow/workflow_manage.py @@ -244,15 +244,15 @@ class WorkflowManage: base_to_response: BaseToResponse = SystemToResponse(), form_data=None, image_list=None, document_list=None, start_node_id=None, - start_node_data=None, chat_record=None): + start_node_data=None, chat_record=None, child_node=None): if form_data is None: form_data = {} if image_list is None: image_list = [] if document_list is None: document_list = [] + self.start_node_id = start_node_id self.start_node = None - self.start_node_result_future = None self.form_data = form_data self.image_list = image_list self.document_list = document_list @@ -270,6 +270,7 @@ class WorkflowManage: self.base_to_response = base_to_response self.chat_record = chat_record self.await_future_map = {} + self.child_node = child_node if start_node_id is not None: self.load_node(chat_record, start_node_id, start_node_data) else: @@ -290,11 +291,17 @@ class WorkflowManage: for node_details in sorted(chat_record.details.values(), key=lambda d: d.get('index')): node_id = node_details.get('node_id') if node_details.get('runtime_node_id') == start_node_id: - self.start_node = self.get_node_cls_by_id(node_id, node_details.get('up_node_id_list')) - self.start_node.valid_args(self.start_node.node_params, self.start_node.workflow_params) - self.start_node.save_context(node_details, self) - node_result = NodeResult({**start_node_data, 'form_data': start_node_data, 'is_submit': True}, {}) - self.start_node_result_future = NodeResultFuture(node_result, None) + def get_node_params(n): + is_result = False + if n.type == 'application-node': + is_result = True + return {**n.properties.get('node_data'), 'form_data': start_node_data, 'node_data': start_node_data, + 'child_node': self.child_node, 'is_result': is_result} + + self.start_node = self.get_node_cls_by_id(node_id, node_details.get('up_node_id_list'), + get_node_params=get_node_params) + self.start_node.valid_args( + {**self.start_node.node_params, 'form_data': start_node_data}, self.start_node.workflow_params) self.node_context.append(self.start_node) continue @@ -306,7 +313,7 @@ class WorkflowManage: def run(self): if self.params.get('stream'): - return self.run_stream(self.start_node, self.start_node_result_future) + return self.run_stream(self.start_node, None) return self.run_block() def run_block(self): @@ -352,9 +359,19 @@ class WorkflowManage: break yield chunk finally: + details = self.get_runtime_details() + message_tokens = sum([row.get('message_tokens') for row in details.values() if + 'message_tokens' in row and row.get('message_tokens') is not None]) + answer_tokens = sum([row.get('answer_tokens') for row in details.values() if + 'answer_tokens' in row and row.get('answer_tokens') is not None]) self.work_flow_post_handler.handler(self.params['chat_id'], self.params['chat_record_id'], self.answer, self) + yield self.base_to_response.to_stream_chunk_response(self.params['chat_id'], + self.params['chat_record_id'], + '', + [], + '', True, message_tokens, answer_tokens, {}) def run_chain_async(self, current_node, node_result_future): future = executor.submit(self.run_chain, current_node, node_result_future) @@ -423,6 +440,8 @@ class WorkflowManage: def hand_event_node_result(self, current_node, node_result_future): node_chunk = NodeChunk() + real_node_id = current_node.runtime_node_id + child_node = {} try: current_result = node_result_future.result() result = current_result.write_context(current_node, self) @@ -430,21 +449,38 @@ class WorkflowManage: if self.is_result(current_node, current_result): self.node_chunk_manage.add_node_chunk(node_chunk) for r in result: + content = r + child_node = {} + node_is_end = False + if isinstance(r, dict): + content = r.get('content') + child_node = {'runtime_node_id': r.get('runtime_node_id'), + 'chat_record_id': r.get('chat_record_id') + , 'child_node': r.get('child_node')} + real_node_id = r.get('real_node_id') + node_is_end = r.get('node_is_end') chunk = self.base_to_response.to_stream_chunk_response(self.params['chat_id'], self.params['chat_record_id'], current_node.id, current_node.up_node_id_list, - r, False, 0, 0, + content, False, 0, 0, {'node_type': current_node.type, - 'view_type': current_node.view_type}) + 'runtime_node_id': current_node.runtime_node_id, + 'view_type': current_node.view_type, + 'child_node': child_node, + 'node_is_end': node_is_end, + 'real_node_id': real_node_id}) node_chunk.add_chunk(chunk) chunk = self.base_to_response.to_stream_chunk_response(self.params['chat_id'], self.params['chat_record_id'], current_node.id, current_node.up_node_id_list, '', False, 0, 0, {'node_is_end': True, + 'runtime_node_id': current_node.runtime_node_id, 'node_type': current_node.type, - 'view_type': current_node.view_type}) + 'view_type': current_node.view_type, + 'child_node': child_node, + 'real_node_id': real_node_id}) node_chunk.end(chunk) else: list(result) @@ -461,8 +497,12 @@ class WorkflowManage: current_node.id, current_node.up_node_id_list, str(e), False, 0, 0, - {'node_is_end': True, 'node_type': current_node.type, - 'view_type': current_node.view_type}) + {'node_is_end': True, + 'runtime_node_id': current_node.runtime_node_id, + 'node_type': current_node.type, + 'view_type': current_node.view_type, + 'child_node': {}, + 'real_node_id': real_node_id}) if not self.node_chunk_manage.contains(node_chunk): self.node_chunk_manage.add_node_chunk(node_chunk) node_chunk.end(chunk) @@ -554,9 +594,9 @@ class WorkflowManage: else: if len(result) > 0: exec_index = len(result) - 1 - content = result[exec_index] - result[exec_index] += answer_text if len( - content) == 0 else ('\n\n' + answer_text) + content = result[exec_index]['content'] + result[exec_index]['content'] += answer_text['content'] if len( + content) == 0 else ('\n\n' + answer_text['content']) else: answer_text = node.get_answer_text() result.insert(0, answer_text) @@ -613,8 +653,8 @@ class WorkflowManage: @param current_node_result: 当前可执行节点结果 @return: 可执行节点列表 """ - - if current_node.type == 'form-node' and 'form_data' not in current_node_result.node_variable: + # 判断是否中断执行 + if current_node_result.is_interrupt_exec(current_node): return [] node_list = [] if current_node_result is not None and current_node_result.is_assertion_result(): @@ -689,11 +729,12 @@ class WorkflowManage: base_node_list = [node for node in self.flow.nodes if node.type == 'base-node'] return base_node_list[0] - def get_node_cls_by_id(self, node_id, up_node_id_list=None): + def get_node_cls_by_id(self, node_id, up_node_id_list=None, + get_node_params=lambda node: node.properties.get('node_data')): for node in self.flow.nodes: if node.id == node_id: node_instance = get_node(node.type)(node, - self.params, self, up_node_id_list) + self.params, self, up_node_id_list, get_node_params) return node_instance return None diff --git a/apps/application/migrations/0019_application_file_upload_enable_and_more.py b/apps/application/migrations/0019_application_file_upload_enable_and_more.py index f8c33c2c1..7b12321f1 100644 --- a/apps/application/migrations/0019_application_file_upload_enable_and_more.py +++ b/apps/application/migrations/0019_application_file_upload_enable_and_more.py @@ -4,8 +4,8 @@ import django.contrib.postgres.fields from django.db import migrations, models sql = """ -UPDATE "public".application_chat_record -SET "answer_text_list" = ARRAY[answer_text]; +UPDATE application_chat_record +SET answer_text_list=ARRAY[jsonb_build_object('content',answer_text)] """ @@ -28,8 +28,7 @@ class Migration(migrations.Migration): migrations.AddField( model_name='chatrecord', name='answer_text_list', - field=django.contrib.postgres.fields.ArrayField(base_field=models.CharField(max_length=40960), default=list, - size=None, verbose_name='改进标注列表'), + field=django.contrib.postgres.fields.ArrayField(base_field=models.JSONField(), default=list, size=None, verbose_name='改进标注列表') ), migrations.RunSQL(sql) ] diff --git a/apps/application/models/application.py b/apps/application/models/application.py index c90d1325c..6ed33c48c 100644 --- a/apps/application/models/application.py +++ b/apps/application/models/application.py @@ -69,7 +69,6 @@ class Application(AppModelMixin): file_upload_enable = models.BooleanField(verbose_name="文件上传是否启用", default=False) file_upload_setting = models.JSONField(verbose_name="文件上传相关设置", default=dict) - @staticmethod def get_default_model_prompt(): return ('已知信息:' @@ -148,7 +147,7 @@ class ChatRecord(AppModelMixin): problem_text = models.CharField(max_length=10240, verbose_name="问题") answer_text = models.CharField(max_length=40960, verbose_name="答案") answer_text_list = ArrayField(verbose_name="改进标注列表", - base_field=models.CharField(max_length=40960) + base_field=models.JSONField() , default=list) message_tokens = models.IntegerField(verbose_name="请求token数量", default=0) answer_tokens = models.IntegerField(verbose_name="响应token数量", default=0) diff --git a/apps/application/serializers/chat_message_serializers.py b/apps/application/serializers/chat_message_serializers.py index 5f38cfc7d..c6374c914 100644 --- a/apps/application/serializers/chat_message_serializers.py +++ b/apps/application/serializers/chat_message_serializers.py @@ -238,13 +238,14 @@ class ChatMessageSerializer(serializers.Serializer): runtime_node_id = serializers.CharField(required=False, allow_null=True, allow_blank=True, error_messages=ErrMessage.char("运行时节点id")) - node_data = serializers.DictField(required=False, error_messages=ErrMessage.char("节点参数")) + node_data = serializers.DictField(required=False, allow_null=True, error_messages=ErrMessage.char("节点参数")) application_id = serializers.UUIDField(required=False, allow_null=True, error_messages=ErrMessage.uuid("应用id")) client_id = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端id")) client_type = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端类型")) form_data = serializers.DictField(required=False, error_messages=ErrMessage.char("全局变量")) image_list = serializers.ListField(required=False, error_messages=ErrMessage.list("图片")) document_list = serializers.ListField(required=False, error_messages=ErrMessage.list("文档")) + child_node = serializers.DictField(required=False, allow_null=True, error_messages=ErrMessage.dict("子节点")) def is_valid_application_workflow(self, *, raise_exception=False): self.is_valid_intraday_access_num() @@ -353,7 +354,7 @@ class ChatMessageSerializer(serializers.Serializer): 'user_id': user_id}, WorkFlowPostHandler(chat_info, client_id, client_type), base_to_response, form_data, image_list, document_list, self.data.get('runtime_node_id'), - self.data.get('node_data'), chat_record) + self.data.get('node_data'), chat_record, self.data.get('child_node')) r = work_flow_manage.run() return r diff --git a/apps/application/template/embed.js b/apps/application/template/embed.js index f05cf1c1d..f75b590a4 100644 --- a/apps/application/template/embed.js +++ b/apps/application/template/embed.js @@ -169,7 +169,7 @@ function initMaxkbStyle(root){ position: absolute; {{x_type}}: {{x_value}}px; {{y_type}}: {{y_value}}px; - z-index: 1000; + z-index: 10001; } #maxkb .maxkb-tips { position: fixed; @@ -180,7 +180,7 @@ function initMaxkbStyle(root){ color: #ffffff; font-size: 14px; background: #3370FF; - z-index: 1000; + z-index: 10001; } #maxkb .maxkb-tips .maxkb-arrow { position: absolute; diff --git a/apps/application/views/chat_views.py b/apps/application/views/chat_views.py index a3ef50766..ccd727c8b 100644 --- a/apps/application/views/chat_views.py +++ b/apps/application/views/chat_views.py @@ -138,7 +138,8 @@ class ChatView(APIView): 'node_id': request.data.get('node_id', None), 'runtime_node_id': request.data.get('runtime_node_id', None), 'node_data': request.data.get('node_data', {}), - 'chat_record_id': request.data.get('chat_record_id')} + 'chat_record_id': request.data.get('chat_record_id'), + 'child_node': request.data.get('child_node')} ).chat() @action(methods=['GET'], detail=False) diff --git a/apps/common/event/__init__.py b/apps/common/event/__init__.py index 6b6d0541a..2ae40525b 100644 --- a/apps/common/event/__init__.py +++ b/apps/common/event/__init__.py @@ -10,8 +10,14 @@ import setting.models from setting.models import Model from .listener_manage import * +update_document_status_sql = """ +UPDATE "public"."document" +SET status ="replace"("replace"("replace"(status, '1', '3'), '0', '3'), '4', '3') +""" + def run(): # QuerySet(Document).filter(status__in=[Status.embedding, Status.queue_up]).update(**{'status': Status.error}) QuerySet(Model).filter(status=setting.models.Status.DOWNLOAD).update(status=setting.models.Status.ERROR, meta={'message': "下载程序被中断,请重试"}) + update_execute(update_document_status_sql, []) diff --git a/apps/common/handle/impl/response/openai_to_response.py b/apps/common/handle/impl/response/openai_to_response.py index 0ce605514..7897a8c60 100644 --- a/apps/common/handle/impl/response/openai_to_response.py +++ b/apps/common/handle/impl/response/openai_to_response.py @@ -35,7 +35,7 @@ class OpenaiToResponse(BaseToResponse): def to_stream_chunk_response(self, chat_id, chat_record_id, node_id, up_node_id_list, content, is_end, completion_tokens, prompt_tokens, other_params: dict = None): chunk = ChatCompletionChunk(id=chat_record_id, model='', object='chat.completion.chunk', - created=datetime.datetime.now().second, choices=[ + created=datetime.datetime.now().second,choices=[ Choice(delta=ChoiceDelta(content=content, chat_id=chat_id), finish_reason='stop' if is_end else None, index=0)], usage=CompletionUsage(completion_tokens=completion_tokens, diff --git a/apps/common/handle/impl/response/system_to_response.py b/apps/common/handle/impl/response/system_to_response.py index 7c374ef01..5999f1297 100644 --- a/apps/common/handle/impl/response/system_to_response.py +++ b/apps/common/handle/impl/response/system_to_response.py @@ -28,7 +28,7 @@ class SystemToResponse(BaseToResponse): prompt_tokens, other_params: dict = None): if other_params is None: other_params = {} - chunk = json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True, + chunk = json.dumps({'chat_id': str(chat_id), 'chat_record_id': str(chat_record_id), 'operate': True, 'content': content, 'node_id': node_id, 'up_node_id_list': up_node_id_list, 'is_end': is_end, 'usage': {'completion_tokens': completion_tokens, 'prompt_tokens': prompt_tokens, diff --git a/apps/common/util/fork.py b/apps/common/util/fork.py index ee30f696e..4405b9b76 100644 --- a/apps/common/util/fork.py +++ b/apps/common/util/fork.py @@ -142,7 +142,10 @@ class Fork: if len(charset_list) > 0: charset = charset_list[0] if charset != encoding: - html_content = response.content.decode(charset) + try: + html_content = response.content.decode(charset) + except Exception as e: + logging.getLogger("max_kb").error(f'{e}') return BeautifulSoup(html_content, "html.parser") return beautiful_soup diff --git a/apps/dataset/serializers/document_serializers.py b/apps/dataset/serializers/document_serializers.py index ac2006a52..5c8457302 100644 --- a/apps/dataset/serializers/document_serializers.py +++ b/apps/dataset/serializers/document_serializers.py @@ -18,7 +18,7 @@ import openpyxl from celery_once import AlreadyQueued from django.core import validators from django.db import transaction -from django.db.models import QuerySet +from django.db.models import QuerySet, Count from django.db.models.functions import Substr, Reverse from django.http import HttpResponse from drf_yasg import openapi @@ -56,6 +56,7 @@ from embedding.task.embedding import embedding_by_document, delete_embedding_by_ delete_embedding_by_document, update_embedding_dataset_id, delete_embedding_by_paragraph_ids, \ embedding_by_document_list from smartdoc.conf import PROJECT_DIR +from django.db import models parse_qa_handle_list = [XlsParseQAHandle(), CsvParseQAHandle(), XlsxParseQAHandle()] parse_table_handle_list = [CsvSplitHandle(), XlsSplitHandle(), XlsxSplitHandle()] @@ -442,6 +443,7 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): QuerySet(model=Paragraph).filter(document_id=document_id).delete() # 删除问题 QuerySet(model=ProblemParagraphMapping).filter(document_id=document_id).delete() + delete_problems_and_mappings([document_id]) # 删除向量库 delete_embedding_by_document(document_id) paragraphs = get_split_model('web.md').parse(result.content) @@ -660,7 +662,7 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): # 删除段落 QuerySet(model=Paragraph).filter(document_id=document_id).delete() # 删除问题 - QuerySet(model=ProblemParagraphMapping).filter(document_id=document_id).delete() + delete_problems_and_mappings([document_id]) # 删除向量库 delete_embedding_by_document(document_id) return True @@ -987,7 +989,7 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): document_id_list = instance.get("id_list") QuerySet(Document).filter(id__in=document_id_list).delete() QuerySet(Paragraph).filter(document_id__in=document_id_list).delete() - QuerySet(ProblemParagraphMapping).filter(document_id__in=document_id_list).delete() + delete_problems_and_mappings(document_id_list) # 删除向量库 delete_embedding_by_document_list(document_id_list) return True @@ -1086,3 +1088,18 @@ def file_to_paragraph(file, pattern_list: List, with_filter: bool, limit: int): if split_handle.support(file, get_buffer): return split_handle.handle(file, pattern_list, with_filter, limit, get_buffer, save_image) return default_split_handle.handle(file, pattern_list, with_filter, limit, get_buffer, save_image) + + +def delete_problems_and_mappings(document_ids): + problem_paragraph_mappings = ProblemParagraphMapping.objects.filter(document_id__in=document_ids) + problem_ids = set(problem_paragraph_mappings.values_list('problem_id', flat=True)) + + if problem_ids: + problem_paragraph_mappings.delete() + remaining_problem_counts = ProblemParagraphMapping.objects.filter(problem_id__in=problem_ids).values( + 'problem_id').annotate(count=Count('problem_id')) + remaining_problem_ids = {pc['problem_id'] for pc in remaining_problem_counts} + problem_ids_to_delete = problem_ids - remaining_problem_ids + Problem.objects.filter(id__in=problem_ids_to_delete).delete() + else: + problem_paragraph_mappings.delete() diff --git a/apps/dataset/serializers/paragraph_serializers.py b/apps/dataset/serializers/paragraph_serializers.py index a115e544b..10c352900 100644 --- a/apps/dataset/serializers/paragraph_serializers.py +++ b/apps/dataset/serializers/paragraph_serializers.py @@ -11,7 +11,7 @@ from typing import Dict from celery_once import AlreadyQueued from django.db import transaction -from django.db.models import QuerySet +from django.db.models import QuerySet, Count from drf_yasg import openapi from rest_framework import serializers @@ -291,7 +291,7 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer): self.is_valid(raise_exception=True) paragraph_id_list = instance.get("id_list") QuerySet(Paragraph).filter(id__in=paragraph_id_list).delete() - QuerySet(ProblemParagraphMapping).filter(paragraph_id__in=paragraph_id_list).delete() + delete_problems_and_mappings(paragraph_id_list) update_document_char_length(self.data.get('document_id')) # 删除向量库 delete_embedding_by_paragraph_ids(paragraph_id_list) @@ -541,14 +541,7 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer): self.is_valid(raise_exception=True) paragraph_id = self.data.get('paragraph_id') Paragraph.objects.filter(id=paragraph_id).delete() - - problem_id = ProblemParagraphMapping.objects.filter(paragraph_id=paragraph_id).values_list('problem_id', - flat=True).first() - - if problem_id is not None: - if ProblemParagraphMapping.objects.filter(problem_id=problem_id).count() == 1: - Problem.objects.filter(id=problem_id).delete() - ProblemParagraphMapping.objects.filter(paragraph_id=paragraph_id).delete() + delete_problems_and_mappings([paragraph_id]) update_document_char_length(self.data.get('document_id')) delete_embedding_by_paragraph(paragraph_id) @@ -755,3 +748,18 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer): prompt) except AlreadyQueued as e: raise AppApiException(500, "任务正在执行中,请勿重复下发") + + +def delete_problems_and_mappings(paragraph_ids): + problem_paragraph_mappings = ProblemParagraphMapping.objects.filter(paragraph_id__in=paragraph_ids) + problem_ids = set(problem_paragraph_mappings.values_list('problem_id', flat=True)) + + if problem_ids: + problem_paragraph_mappings.delete() + remaining_problem_counts = ProblemParagraphMapping.objects.filter(problem_id__in=problem_ids).values( + 'problem_id').annotate(count=Count('problem_id')) + remaining_problem_ids = {pc['problem_id'] for pc in remaining_problem_counts} + problem_ids_to_delete = problem_ids - remaining_problem_ids + Problem.objects.filter(id__in=problem_ids_to_delete).delete() + else: + problem_paragraph_mappings.delete() diff --git a/apps/embedding/task/embedding.py b/apps/embedding/task/embedding.py index 7087636c2..7c39ebcdd 100644 --- a/apps/embedding/task/embedding.py +++ b/apps/embedding/task/embedding.py @@ -6,7 +6,6 @@ @date:2024/8/19 14:13 @desc: """ -import datetime import logging import traceback from typing import List @@ -17,7 +16,7 @@ from django.db.models import QuerySet from common.config.embedding_config import ModelManage from common.event import ListenerManagement, UpdateProblemArgs, UpdateEmbeddingDatasetIdArgs, \ UpdateEmbeddingDocumentIdArgs -from dataset.models import Document, Status, TaskType, State +from dataset.models import Document, TaskType, State from ops import celery_app from setting.models import Model from setting.models_provider import get_model diff --git a/apps/setting/migrations/0009_set_default_model_params_form.py b/apps/setting/migrations/0009_set_default_model_params_form.py new file mode 100644 index 000000000..6b4d4b453 --- /dev/null +++ b/apps/setting/migrations/0009_set_default_model_params_form.py @@ -0,0 +1,19 @@ +# Generated by Django 4.2.15 on 2024-10-15 14:49 + +from django.db import migrations, models + +sql = """ +UPDATE "public"."model" +SET "model_params_form" = '[{"attrs": {"max": 1, "min": 0.1, "step": 0.01, "precision": 2, "show-input": true, "show-input-controls": false}, "field": "temperature", "label": {"attrs": {"tooltip": "较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定"}, "label": "温度", "input_type": "TooltipLabel", "props_info": {}}, "required": true, "input_type": "Slider", "props_info": {}, "trigger_type": "OPTION_LIST", "default_value": 0.5, "relation_show_field_dict": {}, "relation_trigger_field_dict": {}}, {"attrs": {"max": 100000, "min": 1, "step": 1, "precision": 0, "show-input": true, "show-input-controls": false}, "field": "max_tokens", "label": {"attrs": {"tooltip": "指定模型可生成的最大token个数"}, "label": "输出最大Tokens", "input_type": "TooltipLabel", "props_info": {}}, "required": true, "input_type": "Slider", "props_info": {}, "trigger_type": "OPTION_LIST", "default_value": 4096, "relation_show_field_dict": {}, "relation_trigger_field_dict": {}}]' +WHERE jsonb_array_length(model_params_form)=0 +""" + + +class Migration(migrations.Migration): + dependencies = [ + ('setting', '0008_modelparam'), + ] + + operations = [ + migrations.RunSQL(sql) + ] diff --git a/apps/setting/serializers/provider_serializers.py b/apps/setting/serializers/provider_serializers.py index e76e67d2e..e732087b0 100644 --- a/apps/setting/serializers/provider_serializers.py +++ b/apps/setting/serializers/provider_serializers.py @@ -79,7 +79,6 @@ class ModelSerializer(serializers.Serializer): create_user = serializers.CharField(required=False, error_messages=ErrMessage.char("创建者")) - def list(self, with_valid): if with_valid: self.is_valid(raise_exception=True) @@ -92,7 +91,8 @@ class ModelSerializer(serializers.Serializer): model_query_set = QuerySet(Model).filter(Q(user_id=create_user)) # 当前用户能查看其他人的模型,只能查看公开的 else: - model_query_set = QuerySet(Model).filter((Q(user_id=self.data.get('create_user')) & Q(permission_type='PUBLIC'))) + model_query_set = QuerySet(Model).filter( + (Q(user_id=self.data.get('create_user')) & Q(permission_type='PUBLIC'))) else: model_query_set = QuerySet(Model).filter((Q(user_id=user_id) | Q(permission_type='PUBLIC'))) query_params = {} @@ -107,11 +107,11 @@ class ModelSerializer(serializers.Serializer): if self.data.get('permission_type') is not None: query_params['permission_type'] = self.data.get('permission_type') - return [ {'id': str(model.id), 'provider': model.provider, 'name': model.name, 'model_type': model.model_type, 'model_name': model.model_name, 'status': model.status, 'meta': model.meta, - 'permission_type': model.permission_type, 'user_id': model.user_id, 'username': model.user.username} for model in + 'permission_type': model.permission_type, 'user_id': model.user_id, 'username': model.user.username} + for model in model_query_set.filter(**query_params).order_by("-create_time")] class Edit(serializers.Serializer): @@ -243,14 +243,7 @@ class ModelSerializer(serializers.Serializer): self.is_valid(raise_exception=True) model_id = self.data.get('id') model = QuerySet(Model).filter(id=model_id).first() - credential = get_model_credential(model.provider, model.model_type, model.model_name) # 已经保存过的模型参数表单 - if model.model_params_form is not None and len(model.model_params_form) > 0: - return model.model_params_form - # 没有保存过的LLM类型的 - if credential.get_model_params_setting_form(model.model_name) is not None: - return credential.get_model_params_setting_form(model.model_name).to_form_list() - # 其他的 return model.model_params_form class ModelParamsForm(serializers.Serializer): diff --git a/apps/smartdoc/settings/lib.py b/apps/smartdoc/settings/lib.py index a4c1aaabb..fc1d3244f 100644 --- a/apps/smartdoc/settings/lib.py +++ b/apps/smartdoc/settings/lib.py @@ -7,6 +7,7 @@ @desc: """ import os +import shutil from smartdoc.const import CONFIG, PROJECT_DIR @@ -34,6 +35,11 @@ CELERY_TASK_SOFT_TIME_LIMIT = 3600 CELERY_WORKER_CANCEL_LONG_RUNNING_TASKS_ON_CONNECTION_LOSS = True CELERY_ACKS_LATE = True celery_once_path = os.path.join(celery_data_dir, "celery_once") +try: + if os.path.exists(celery_once_path) and os.path.isdir(celery_once_path): + shutil.rmtree(celery_once_path) +except Exception as e: + pass CELERY_ONCE = { 'backend': 'celery_once.backends.File', 'settings': {'location': celery_once_path} diff --git a/ui/package.json b/ui/package.json index 441315fe3..4a5806501 100644 --- a/ui/package.json +++ b/ui/package.json @@ -4,7 +4,7 @@ "private": true, "scripts": { "dev": "vite", - "build": "run-p --max_old_space_size=4096 type-check build-only", + "build": "set NODE_OPTIONS=--max_old_space_size=4096 && run-p type-check build-only", "preview": "vite preview", "test:unit": "vitest", "build-only": "vite build", diff --git a/ui/public/MaxKB.gif b/ui/public/MaxKB.gif index f18b93b84..055d49a6a 100644 Binary files a/ui/public/MaxKB.gif and b/ui/public/MaxKB.gif differ diff --git a/ui/src/api/type/application.ts b/ui/src/api/type/application.ts index b45a6783f..95812bbab 100644 --- a/ui/src/api/type/application.ts +++ b/ui/src/api/type/application.ts @@ -23,8 +23,9 @@ interface ApplicationFormType { tts_type?: string } interface Chunk { + real_node_id: string chat_id: string - id: string + chat_record_id: string content: string node_id: string up_node_id: string @@ -32,13 +33,20 @@ interface Chunk { node_is_end: boolean node_type: string view_type: string + runtime_node_id: string + child_node: any } interface chatType { id: string problem_text: string answer_text: string buffer: Array - answer_text_list: Array + answer_text_list: Array<{ + content: string + chat_record_id?: string + runtime_node_id?: string + child_node?: any + }> /** * 是否写入结束 */ @@ -92,15 +100,24 @@ export class ChatRecordManage { this.write_ed = false this.node_list = [] } - append_answer(chunk_answer: string, index?: number) { - this.chat.answer_text_list[index != undefined ? index : this.chat.answer_text_list.length - 1] = - this.chat.answer_text_list[ - index !== undefined ? index : this.chat.answer_text_list.length - 1 - ] - ? this.chat.answer_text_list[ - index !== undefined ? index : this.chat.answer_text_list.length - 1 - ] + chunk_answer - : chunk_answer + append_answer( + chunk_answer: string, + index?: number, + chat_record_id?: string, + runtime_node_id?: string, + child_node?: any + ) { + const set_index = index != undefined ? index : this.chat.answer_text_list.length - 1 + const content = this.chat.answer_text_list[set_index] + ? this.chat.answer_text_list[set_index].content + chunk_answer + : chunk_answer + this.chat.answer_text_list[set_index] = { + content: content, + chat_record_id, + runtime_node_id, + child_node + } + this.chat.answer_text = this.chat.answer_text + chunk_answer } @@ -127,14 +144,22 @@ export class ChatRecordManage { run_node.view_type == 'single_view' || (run_node.view_type == 'many_view' && current_up_node.view_type == 'single_view') ) { - const none_index = this.chat.answer_text_list.indexOf('') + const none_index = this.findIndex( + this.chat.answer_text_list, + (item) => item.content == '', + 'index' + ) if (none_index > -1) { answer_text_list_index = none_index } else { answer_text_list_index = this.chat.answer_text_list.length } } else { - const none_index = this.chat.answer_text_list.indexOf('') + const none_index = this.findIndex( + this.chat.answer_text_list, + (item) => item.content === '', + 'index' + ) if (none_index > -1) { answer_text_list_index = none_index } else { @@ -152,6 +177,19 @@ export class ChatRecordManage { } return undefined } + findIndex(array: Array, find: (item: T) => boolean, type: 'last' | 'index') { + let set_index = -1 + for (let index = 0; index < array.length; index++) { + const element = array[index] + if (find(element)) { + set_index = index + if (type == 'index') { + break + } + } + } + return set_index + } closeInterval() { this.chat.write_ed = true this.write_ed = true @@ -161,7 +199,11 @@ export class ChatRecordManage { if (this.id) { clearInterval(this.id) } - const last_index = this.chat.answer_text_list.lastIndexOf('') + const last_index = this.findIndex( + this.chat.answer_text_list, + (item) => item.content == '', + 'last' + ) if (last_index > 0) { this.chat.answer_text_list.splice(last_index, 1) } @@ -193,19 +235,29 @@ export class ChatRecordManage { ) this.append_answer( (divider_content ? divider_content.splice(0).join('') : '') + context.join(''), - answer_text_list_index + answer_text_list_index, + current_node.chat_record_id, + current_node.runtime_node_id, + current_node.child_node ) } else if (this.is_close) { while (true) { const node_info = this.get_run_node() + if (node_info == undefined) { break } this.append_answer( (node_info.divider_content ? node_info.divider_content.splice(0).join('') : '') + node_info.current_node.buffer.splice(0).join(''), - node_info.answer_text_list_index + node_info.answer_text_list_index, + current_node.chat_record_id, + current_node.runtime_node_id, + current_node.child_node ) + if (node_info.current_node.buffer.length == 0) { + node_info.current_node.is_end = true + } } this.closeInterval() } else { @@ -213,7 +265,10 @@ export class ChatRecordManage { if (s !== undefined) { this.append_answer( (divider_content ? divider_content.splice(0).join('') : '') + s, - answer_text_list_index + answer_text_list_index, + current_node.chat_record_id, + current_node.runtime_node_id, + current_node.child_node ) } } @@ -235,16 +290,18 @@ export class ChatRecordManage { this.is_stop = false } appendChunk(chunk: Chunk) { - let n = this.node_list.find( - (item) => item.node_id == chunk.node_id && item.up_node_id === chunk.up_node_id - ) + let n = this.node_list.find((item) => item.real_node_id == chunk.real_node_id) if (n) { n.buffer.push(...chunk.content) } else { n = { buffer: [...chunk.content], + real_node_id: chunk.real_node_id, node_id: chunk.node_id, + chat_record_id: chunk.chat_record_id, up_node_id: chunk.up_node_id, + runtime_node_id: chunk.runtime_node_id, + child_node: chunk.child_node, node_type: chunk.node_type, index: this.node_list.length, view_type: chunk.view_type, @@ -257,9 +314,12 @@ export class ChatRecordManage { } } append(answer_text_block: string) { - const index =this.chat.answer_text_list.indexOf("") - this.chat.answer_text_list[index]=answer_text_block - + let set_index = this.findIndex( + this.chat.answer_text_list, + (item) => item.content == '', + 'index' + ) + this.chat.answer_text_list[set_index] = { content: answer_text_block } } } diff --git a/ui/src/components/ai-chat/ExecutionDetailDialog.vue b/ui/src/components/ai-chat/ExecutionDetailDialog.vue index 641d7eca9..6cb066bb9 100644 --- a/ui/src/components/ai-chat/ExecutionDetailDialog.vue +++ b/ui/src/components/ai-chat/ExecutionDetailDialog.vue @@ -342,7 +342,7 @@