From 7b5ccd9089580c070c3e7bc0e608000a9ddbe7bf Mon Sep 17 00:00:00 2001 From: shaohuzhang1 <80892890+shaohuzhang1@users.noreply.github.com> Date: Sun, 28 Apr 2024 17:09:12 +0800 Subject: [PATCH] =?UTF-8?q?perf:=20=E5=BA=94=E7=94=A8=E7=9A=84AI=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B=E4=BF=AE=E6=94=B9=E4=B8=BA=E4=B8=8D=E5=BF=85=E5=A1=AB?= =?UTF-8?q?=20(#297)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../step/chat_step/i_chat_step.py | 2 +- .../step/chat_step/impl/base_chat_step.py | 88 +++++++++---------- .../i_reset_problem_step.py | 2 +- .../impl/base_reset_problem_step.py | 2 + .../serializers/application_serializers.py | 29 +++--- .../serializers/chat_message_serializers.py | 4 +- .../serializers/chat_serializers.py | 20 +++-- ui/src/components/ai-chat/index.vue | 18 ++-- ui/src/views/application/CreateAndSetting.vue | 5 +- 9 files changed, 91 insertions(+), 79 deletions(-) diff --git a/apps/application/chat_pipeline/step/chat_step/i_chat_step.py b/apps/application/chat_pipeline/step/chat_step/i_chat_step.py index 534b9d409..8fbac34c9 100644 --- a/apps/application/chat_pipeline/step/chat_step/i_chat_step.py +++ b/apps/application/chat_pipeline/step/chat_step/i_chat_step.py @@ -54,7 +54,7 @@ class IChatStep(IBaseChatPipelineStep): message_list = serializers.ListField(required=True, child=MessageField(required=True), error_messages=ErrMessage.list("对话列表")) # 大语言模型 - chat_model = ModelField(error_messages=ErrMessage.list("大语言模型")) + chat_model = ModelField(required=False, allow_null=True, error_messages=ErrMessage.list("大语言模型")) # 段落列表 paragraph_list = serializers.ListField(error_messages=ErrMessage.list("段落列表")) # 对话id diff --git a/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py b/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py index 0919abbe2..b53f12941 100644 --- a/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py +++ b/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py @@ -126,6 +126,26 @@ class BaseChatStep(IChatStep): result.append({'role': 'ai', 'content': answer_text}) return result + @staticmethod + def get_stream_result(message_list: List[BaseMessage], + chat_model: BaseChatModel = None, + paragraph_list=None, + no_references_setting=None): + if paragraph_list is None: + paragraph_list = [] + directly_return_chunk_list = [AIMessageChunk(content=paragraph.content) + for paragraph in paragraph_list if + paragraph.hit_handling_method == 'directly_return'] + if directly_return_chunk_list is not None and len(directly_return_chunk_list) > 0: + return iter(directly_return_chunk_list), False + elif no_references_setting.get( + 'status') == 'designated_answer': + return iter([AIMessageChunk(content=no_references_setting.get('value'))]), False + if chat_model is None: + return iter([AIMessageChunk('抱歉,没有在知识库中查询到相关信息。')]), False + else: + return chat_model.stream(message_list), True + def execute_stream(self, message_list: List[BaseMessage], chat_id, problem_text, @@ -136,29 +156,8 @@ class BaseChatStep(IChatStep): padding_problem_text: str = None, client_id=None, client_type=None, no_references_setting=None): - is_ai_chat = False - # 调用模型 - if chat_model is None: - chat_result = iter( - [AIMessageChunk(content=paragraph.title + "\n" + paragraph.content) for paragraph in paragraph_list]) - else: - if (paragraph_list is None or len(paragraph_list) == 0) and no_references_setting.get( - 'status') == 'designated_answer': - chat_result = iter([AIMessageChunk(content=no_references_setting.get('value'))]) - else: - if paragraph_list is not None and len(paragraph_list) > 0: - directly_return_chunk_list = [AIMessageChunk(content=paragraph.title + "\n" + paragraph.content) - for paragraph in paragraph_list if - paragraph.hit_handling_method == 'directly_return'] - if directly_return_chunk_list is not None and len(directly_return_chunk_list) > 0: - chat_result = iter(directly_return_chunk_list) - else: - chat_result = chat_model.stream(message_list) - is_ai_chat = True - else: - chat_result = chat_model.stream(message_list) - is_ai_chat = True - + chat_result, is_ai_chat = self.get_stream_result(message_list, chat_model, paragraph_list, + no_references_setting) chat_record_id = uuid.uuid1() r = StreamingHttpResponse( streaming_content=event_content(chat_result, chat_id, chat_record_id, paragraph_list, @@ -169,6 +168,27 @@ class BaseChatStep(IChatStep): r['Cache-Control'] = 'no-cache' return r + @staticmethod + def get_block_result(message_list: List[BaseMessage], + chat_model: BaseChatModel = None, + paragraph_list=None, + no_references_setting=None): + if paragraph_list is None: + paragraph_list = [] + + directly_return_chunk_list = [AIMessage(content=paragraph.content) + for paragraph in paragraph_list if + paragraph.hit_handling_method == 'directly_return'] + if directly_return_chunk_list is not None and len(directly_return_chunk_list) > 0: + return directly_return_chunk_list[0], False + elif no_references_setting.get( + 'status') == 'designated_answer': + return AIMessage(no_references_setting.get('value')), False + if chat_model is None: + return AIMessage('抱歉,没有在知识库中查询到相关信息。'), False + else: + return chat_model.invoke(message_list), True + def execute_block(self, message_list: List[BaseMessage], chat_id, problem_text, @@ -178,28 +198,8 @@ class BaseChatStep(IChatStep): manage: PipelineManage = None, padding_problem_text: str = None, client_id=None, client_type=None, no_references_setting=None): - is_ai_chat = False # 调用模型 - if chat_model is None: - chat_result = AIMessage( - content="\n\n".join([paragraph.title + "\n" + paragraph.content for paragraph in paragraph_list])) - else: - if (paragraph_list is None or len(paragraph_list) == 0) and no_references_setting.get( - 'status') == 'designated_answer': - chat_result = AIMessage(content=no_references_setting.get('value')) - else: - if paragraph_list is not None and len(paragraph_list) > 0: - directly_return_chunk_list = [AIMessageChunk(content=paragraph.title + "\n" + paragraph.content) - for paragraph in paragraph_list if - paragraph.hit_handling_method == 'directly_return'] - if directly_return_chunk_list is not None and len(directly_return_chunk_list) > 0: - chat_result = iter(directly_return_chunk_list) - else: - chat_result = chat_model.invoke(message_list) - is_ai_chat = True - else: - chat_result = chat_model.invoke(message_list) - is_ai_chat = True + chat_result, is_ai_chat = self.get_block_result(message_list, chat_model, paragraph_list, no_references_setting) chat_record_id = uuid.uuid1() if is_ai_chat: request_token = chat_model.get_num_tokens_from_messages(message_list) diff --git a/apps/application/chat_pipeline/step/reset_problem_step/i_reset_problem_step.py b/apps/application/chat_pipeline/step/reset_problem_step/i_reset_problem_step.py index 930fd2482..ce30d96af 100644 --- a/apps/application/chat_pipeline/step/reset_problem_step/i_reset_problem_step.py +++ b/apps/application/chat_pipeline/step/reset_problem_step/i_reset_problem_step.py @@ -28,7 +28,7 @@ class IResetProblemStep(IBaseChatPipelineStep): history_chat_record = serializers.ListField(child=InstanceField(model_type=ChatRecord, required=True), error_messages=ErrMessage.list("历史对答")) # 大语言模型 - chat_model = ModelField(error_messages=ErrMessage.base("大语言模型")) + chat_model = ModelField(required=False, allow_null=True, error_messages=ErrMessage.base("大语言模型")) def get_step_serializer(self, manage: PipelineManage) -> Type[serializers.Serializer]: return self.InstanceSerializer diff --git a/apps/application/chat_pipeline/step/reset_problem_step/impl/base_reset_problem_step.py b/apps/application/chat_pipeline/step/reset_problem_step/impl/base_reset_problem_step.py index 2386be4fe..aad66446c 100644 --- a/apps/application/chat_pipeline/step/reset_problem_step/impl/base_reset_problem_step.py +++ b/apps/application/chat_pipeline/step/reset_problem_step/impl/base_reset_problem_step.py @@ -22,6 +22,8 @@ prompt = ( class BaseResetProblemStep(IResetProblemStep): def execute(self, problem_text: str, history_chat_record: List[ChatRecord] = None, chat_model: BaseChatModel = None, **kwargs) -> str: + if chat_model is None: + return problem_text start_index = len(history_chat_record) - 3 history_message = [[history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()] for index in diff --git a/apps/application/serializers/application_serializers.py b/apps/application/serializers/application_serializers.py index 3307d873a..164e0083f 100644 --- a/apps/application/serializers/application_serializers.py +++ b/apps/application/serializers/application_serializers.py @@ -47,7 +47,8 @@ chat_cache = cache.caches['chat_cache'] class ModelDatasetAssociation(serializers.Serializer): user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id")) - model_id = serializers.CharField(required=True, error_messages=ErrMessage.char("模型id")) + model_id = serializers.CharField(required=False, allow_null=True, allow_blank=True, + error_messages=ErrMessage.char("模型id")) dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True, error_messages=ErrMessage.uuid( "知识库id")), @@ -57,8 +58,9 @@ class ModelDatasetAssociation(serializers.Serializer): super().is_valid(raise_exception=True) model_id = self.data.get('model_id') user_id = self.data.get('user_id') - if not QuerySet(Model).filter(id=model_id).exists(): - raise AppApiException(500, f'模型不存在【{model_id}】') + if model_id is not None and len(model_id) > 0: + if not QuerySet(Model).filter(id=model_id).exists(): + raise AppApiException(500, f'模型不存在【{model_id}】') dataset_id_list = list(set(self.data.get('dataset_id_list'))) exist_dataset_id_list = [str(dataset.id) for dataset in QuerySet(DataSet).filter(id__in=dataset_id_list, user_id=user_id)] @@ -109,7 +111,8 @@ class ApplicationSerializer(serializers.Serializer): desc = serializers.CharField(required=False, allow_null=True, allow_blank=True, max_length=256, min_length=1, error_messages=ErrMessage.char("应用描述")) - model_id = serializers.CharField(required=True, error_messages=ErrMessage.char("模型")) + model_id = serializers.CharField(required=False, allow_null=True, allow_blank=True, + error_messages=ErrMessage.char("模型")) multiple_rounds_dialogue = serializers.BooleanField(required=True, error_messages=ErrMessage.char("多轮对话")) prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True, max_length=1024, error_messages=ErrMessage.char("开场白")) @@ -254,7 +257,8 @@ class ApplicationSerializer(serializers.Serializer): error_messages=ErrMessage.char("应用名称")) desc = serializers.CharField(required=False, max_length=256, min_length=1, allow_null=True, allow_blank=True, error_messages=ErrMessage.char("应用描述")) - model_id = serializers.CharField(required=False, error_messages=ErrMessage.char("模型")) + model_id = serializers.CharField(required=False, allow_blank=True, allow_null=True, + error_messages=ErrMessage.char("模型")) multiple_rounds_dialogue = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean("多轮会话")) prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True, max_length=1024, @@ -494,13 +498,14 @@ class ApplicationSerializer(serializers.Serializer): application_id = self.data.get("application_id") application = QuerySet(Application).get(id=application_id) - - model = QuerySet(Model).filter( - id=instance.get('model_id') if 'model_id' in instance else application.model_id, - user_id=application.user_id).first() - if model is None: - raise AppApiException(500, "模型不存在") - + if instance.get('model_id') is None or len(instance.get('model_id')) == 0: + application.model_id = None + else: + model = QuerySet(Model).filter( + id=instance.get('model_id'), + user_id=application.user_id).first() + if model is None: + raise AppApiException(500, "模型不存在") update_keys = ['name', 'desc', 'model_id', 'multiple_rounds_dialogue', 'prologue', 'status', 'dataset_setting', 'model_setting', 'problem_optimization', 'api_key_is_active', 'icon'] diff --git a/apps/application/serializers/chat_message_serializers.py b/apps/application/serializers/chat_message_serializers.py index 47b905a77..7c89e9de4 100644 --- a/apps/application/serializers/chat_message_serializers.py +++ b/apps/application/serializers/chat_message_serializers.py @@ -167,9 +167,11 @@ class ChatMessageSerializer(serializers.Serializer): chat_cache.set(chat_id, chat_info, timeout=60 * 30) model = chat_info.application.model + if model is None: + return chat_info model = QuerySet(Model).filter(id=model.id).first() if model is None: - raise AppApiException(500, "模型不存在") + return chat_info if model.status == Status.ERROR: raise AppApiException(500, "当前模型不可用") if model.status == Status.DOWNLOAD: diff --git a/apps/application/serializers/chat_serializers.py b/apps/application/serializers/chat_serializers.py index 5f58fe876..7476219dd 100644 --- a/apps/application/serializers/chat_serializers.py +++ b/apps/application/serializers/chat_serializers.py @@ -213,7 +213,8 @@ class ChatSerializers(serializers.Serializer): id = serializers.UUIDField(required=False, allow_null=True, error_messages=ErrMessage.uuid("应用id")) - model_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("模型id")) + model_id = serializers.CharField(required=False, allow_null=True, allow_blank=True, + error_messages=ErrMessage.uuid("模型id")) multiple_rounds_dialogue = serializers.BooleanField(required=True, error_messages=ErrMessage.boolean("多轮会话")) @@ -246,14 +247,17 @@ class ChatSerializers(serializers.Serializer): def open(self): user_id = self.is_valid(raise_exception=True) chat_id = str(uuid.uuid1()) - model = QuerySet(Model).filter(user_id=user_id, id=self.data.get('model_id')).first() - if model is None: - raise AppApiException(500, "模型不存在") + model_id = self.data.get('model_id') + if model_id is not None and len(model_id) > 0: + model = QuerySet(Model).filter(user_id=user_id, id=self.data.get('model_id')).first() + chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name, + json.loads( + decrypt(model.credential)), + streaming=True) + else: + model = None + chat_model = None dataset_id_list = self.data.get('dataset_id_list') - chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name, - json.loads( - decrypt(model.credential)), - streaming=True) application = Application(id=None, dialogue_number=3, model=model, dataset_setting=self.data.get('dataset_setting'), model_setting=self.data.get('model_setting'), diff --git a/ui/src/components/ai-chat/index.vue b/ui/src/components/ai-chat/index.vue index 8eb60639e..48ffc1151 100644 --- a/ui/src/components/ai-chat/index.vue +++ b/ui/src/components/ai-chat/index.vue @@ -224,7 +224,7 @@ const chartOpenId = ref('') const chatList = ref([]) const isDisabledChart = computed( - () => !(inputValue.value.trim() && (props.appId || (props.data?.name && props.data?.model_id))) + () => !(inputValue.value.trim() && (props.appId || props.data?.name)) ) const isMdArray = (val: string) => val.match(/^-\s.*/m) const prologueList = computed(() => { @@ -509,16 +509,14 @@ function regenerationChart(item: chatType) { } function getSourceDetail(row: any) { - logApi - .getRecordDetail(id || props.appId, row.chat_id, row.record_id, loading) - .then((res) => { - const exclude_keys = ['answer_text', 'id'] - Object.keys(res.data).forEach((key) => { - if (!exclude_keys.includes(key)) { - row[key] = res.data[key] - } - }) + logApi.getRecordDetail(id || props.appId, row.chat_id, row.record_id, loading).then((res) => { + const exclude_keys = ['answer_text', 'id'] + Object.keys(res.data).forEach((key) => { + if (!exclude_keys.includes(key)) { + row[key] = res.data[key] + } }) + }) return true } diff --git a/ui/src/views/application/CreateAndSetting.vue b/ui/src/views/application/CreateAndSetting.vue index 5612f5561..ee666229c 100644 --- a/ui/src/views/application/CreateAndSetting.vue +++ b/ui/src/views/application/CreateAndSetting.vue @@ -48,7 +48,7 @@ >({ name: [{ required: true, message: '请输入应用名称', trigger: 'blur' }], model_id: [ { - required: true, + required: false, message: '请选择模型', trigger: 'change' }