diff --git a/apps/application/chat_pipeline/step/chat_step/i_chat_step.py b/apps/application/chat_pipeline/step/chat_step/i_chat_step.py index b3437390b..534b9d409 100644 --- a/apps/application/chat_pipeline/step/chat_step/i_chat_step.py +++ b/apps/application/chat_pipeline/step/chat_step/i_chat_step.py @@ -15,9 +15,9 @@ from rest_framework import serializers from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep, ParagraphPipelineModel from application.chat_pipeline.pipeline_manage import PipelineManage +from application.serializers.application_serializers import NoReferencesSetting from common.field.common import InstanceField from common.util.field_message import ErrMessage -from dataset.models import Paragraph class ModelField(serializers.Field): @@ -70,6 +70,8 @@ class IChatStep(IBaseChatPipelineStep): stream = serializers.BooleanField(required=False, error_messages=ErrMessage.base("流式输出")) client_id = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端id")) client_type = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端类型")) + # 未查询到引用分段 + no_references_setting = NoReferencesSetting(required=True, error_messages=ErrMessage.base("无引用分段设置")) def is_valid(self, *, raise_exception=False): super().is_valid(raise_exception=True) @@ -92,5 +94,6 @@ class IChatStep(IBaseChatPipelineStep): chat_model: BaseChatModel = None, paragraph_list=None, manage: PipelineManage = None, - padding_problem_text: str = None, stream: bool = True, client_id=None, client_type=None, **kwargs): + padding_problem_text: str = None, stream: bool = True, client_id=None, client_type=None, + no_references_setting=None, **kwargs): pass diff --git a/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py b/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py index 8beeab16e..85485706d 100644 --- a/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py +++ b/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py @@ -17,7 +17,8 @@ from django.db.models import QuerySet from django.http import StreamingHttpResponse from langchain.chat_models.base import BaseChatModel from langchain.schema import BaseMessage -from langchain.schema.messages import BaseMessageChunk, HumanMessage, AIMessage +from langchain.schema.messages import HumanMessage, AIMessage +from langchain_core.messages import AIMessageChunk from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel from application.chat_pipeline.pipeline_manage import PipelineManage @@ -47,7 +48,8 @@ def event_content(response, message_list: List[BaseMessage], problem_text: str, padding_problem_text: str = None, - client_id=None, client_type=None): + client_id=None, client_type=None, + is_ai_chat: bool = None): all_text = '' try: for chunk in response: @@ -56,8 +58,12 @@ def event_content(response, 'content': chunk.content, 'is_end': False}) + "\n\n" # 获取token - request_token = chat_model.get_num_tokens_from_messages(message_list) - response_token = chat_model.get_num_tokens(all_text) + if is_ai_chat: + request_token = chat_model.get_num_tokens_from_messages(message_list) + response_token = chat_model.get_num_tokens(all_text) + else: + request_token = 0 + response_token = 0 step.context['message_tokens'] = request_token step.context['answer_tokens'] = response_token current_time = time.time() @@ -88,15 +94,16 @@ class BaseChatStep(IChatStep): padding_problem_text: str = None, stream: bool = True, client_id=None, client_type=None, + no_references_setting=None, **kwargs): if stream: return self.execute_stream(message_list, chat_id, problem_text, post_response_handler, chat_model, paragraph_list, - manage, padding_problem_text, client_id, client_type) + manage, padding_problem_text, client_id, client_type, no_references_setting) else: return self.execute_block(message_list, chat_id, problem_text, post_response_handler, chat_model, paragraph_list, - manage, padding_problem_text, client_id, client_type) + manage, padding_problem_text, client_id, client_type, no_references_setting) def get_details(self, manage, **kwargs): return { @@ -127,19 +134,26 @@ class BaseChatStep(IChatStep): paragraph_list=None, manage: PipelineManage = None, padding_problem_text: str = None, - client_id=None, client_type=None): + client_id=None, client_type=None, + no_references_setting=None): + is_ai_chat = False # 调用模型 if chat_model is None: chat_result = iter( - [BaseMessageChunk(content=paragraph.title + "\n" + paragraph.content) for paragraph in paragraph_list]) + [AIMessageChunk(content=paragraph.title + "\n" + paragraph.content) for paragraph in paragraph_list]) else: - chat_result = chat_model.stream(message_list) + if (paragraph_list is None or len(paragraph_list) == 0) and no_references_setting.get( + 'status') == 'designated_answer': + chat_result = iter([AIMessageChunk(content=no_references_setting.get('value'))]) + else: + chat_result = chat_model.stream(message_list) + is_ai_chat = True chat_record_id = uuid.uuid1() r = StreamingHttpResponse( streaming_content=event_content(chat_result, chat_id, chat_record_id, paragraph_list, post_response_handler, manage, self, chat_model, message_list, problem_text, - padding_problem_text, client_id, client_type), + padding_problem_text, client_id, client_type, is_ai_chat), content_type='text/event-stream;charset=utf-8') r['Cache-Control'] = 'no-cache' @@ -153,16 +167,26 @@ class BaseChatStep(IChatStep): paragraph_list=None, manage: PipelineManage = None, padding_problem_text: str = None, - client_id=None, client_type=None): + client_id=None, client_type=None, no_references_setting=None): + is_ai_chat = False # 调用模型 if chat_model is None: chat_result = AIMessage( content="\n\n".join([paragraph.title + "\n" + paragraph.content for paragraph in paragraph_list])) else: - chat_result = chat_model.invoke(message_list) + if (paragraph_list is None or len(paragraph_list) == 0) and no_references_setting.get( + 'status') == 'designated_answer': + chat_result = AIMessage(content=no_references_setting.get('value')) + else: + chat_result = chat_model.invoke(message_list) + is_ai_chat = True chat_record_id = uuid.uuid1() - request_token = chat_model.get_num_tokens_from_messages(message_list) - response_token = chat_model.get_num_tokens(chat_result.content) + if is_ai_chat: + request_token = chat_model.get_num_tokens_from_messages(message_list) + response_token = chat_model.get_num_tokens(chat_result.content) + else: + request_token = 0 + response_token = 0 self.context['message_tokens'] = request_token self.context['answer_tokens'] = response_token current_time = time.time() diff --git a/apps/application/chat_pipeline/step/generate_human_message_step/i_generate_human_message_step.py b/apps/application/chat_pipeline/step/generate_human_message_step/i_generate_human_message_step.py index 350151659..ca2d00e0b 100644 --- a/apps/application/chat_pipeline/step/generate_human_message_step/i_generate_human_message_step.py +++ b/apps/application/chat_pipeline/step/generate_human_message_step/i_generate_human_message_step.py @@ -15,9 +15,9 @@ from rest_framework import serializers from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep, ParagraphPipelineModel from application.chat_pipeline.pipeline_manage import PipelineManage from application.models import ChatRecord +from application.serializers.application_serializers import NoReferencesSetting from common.field.common import InstanceField from common.util.field_message import ErrMessage -from dataset.models import Paragraph class IGenerateHumanMessageStep(IBaseChatPipelineStep): @@ -39,6 +39,8 @@ class IGenerateHumanMessageStep(IBaseChatPipelineStep): prompt = serializers.CharField(required=True, error_messages=ErrMessage.char("提示词")) # 补齐问题 padding_problem_text = serializers.CharField(required=False, error_messages=ErrMessage.char("补齐问题")) + # 未查询到引用分段 + no_references_setting = NoReferencesSetting(required=True, error_messages=ErrMessage.base("无引用分段设置")) def get_step_serializer(self, manage: PipelineManage) -> Type[serializers.Serializer]: return self.InstanceSerializer @@ -56,6 +58,7 @@ class IGenerateHumanMessageStep(IBaseChatPipelineStep): max_paragraph_char_number: int, prompt: str, padding_problem_text: str = None, + no_references_setting=None, **kwargs) -> List[BaseMessage]: """ @@ -67,6 +70,7 @@ class IGenerateHumanMessageStep(IBaseChatPipelineStep): :param prompt: 模板 :param padding_problem_text 用户修改文本 :param kwargs: 其他参数 + :param no_references_setting: 无引用分段设置 :return: """ pass diff --git a/apps/application/chat_pipeline/step/generate_human_message_step/impl/base_generate_human_message_step.py b/apps/application/chat_pipeline/step/generate_human_message_step/impl/base_generate_human_message_step.py index 86971d03d..bff3aa2b2 100644 --- a/apps/application/chat_pipeline/step/generate_human_message_step/impl/base_generate_human_message_step.py +++ b/apps/application/chat_pipeline/step/generate_human_message_step/impl/base_generate_human_message_step.py @@ -6,7 +6,7 @@ @date:2024/1/10 17:50 @desc: """ -from typing import List +from typing import List, Dict from langchain.schema import BaseMessage, HumanMessage @@ -26,22 +26,31 @@ class BaseGenerateHumanMessageStep(IGenerateHumanMessageStep): max_paragraph_char_number: int, prompt: str, padding_problem_text: str = None, + no_references_setting=None, **kwargs) -> List[BaseMessage]: + prompt = prompt if no_references_setting.get('status') == 'designated_answer' else no_references_setting.get( + 'value') exec_problem_text = padding_problem_text if padding_problem_text is not None else problem_text start_index = len(history_chat_record) - dialogue_number history_message = [[history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()] for index in range(start_index if start_index > 0 else 0, len(history_chat_record))] return [*flat_map(history_message), - self.to_human_message(prompt, exec_problem_text, max_paragraph_char_number, paragraph_list)] + self.to_human_message(prompt, exec_problem_text, max_paragraph_char_number, paragraph_list, + no_references_setting)] @staticmethod def to_human_message(prompt: str, problem: str, max_paragraph_char_number: int, - paragraph_list: List[ParagraphPipelineModel]): + paragraph_list: List[ParagraphPipelineModel], + no_references_setting: Dict): if paragraph_list is None or len(paragraph_list) == 0: - return HumanMessage(content=prompt.format(**{'data': "", 'question': problem})) + if no_references_setting.get('status') == 'ai_questioning': + return HumanMessage( + content=no_references_setting.get('value').format(**{'question': problem})) + else: + return HumanMessage(content=prompt.format(**{'data': "", 'question': problem})) temp_data = "" data_list = [] for p in paragraph_list: diff --git a/apps/application/models/application.py b/apps/application/models/application.py index 101816d38..774e1bdc2 100644 --- a/apps/application/models/application.py +++ b/apps/application/models/application.py @@ -19,7 +19,11 @@ from users.models import User def get_dataset_setting_dict(): - return {'top_n': 3, 'similarity': 0.6, 'max_paragraph_char_number': 5000, 'search_mode': 'embedding'} + return {'top_n': 3, 'similarity': 0.6, 'max_paragraph_char_number': 5000, 'search_mode': 'embedding', + 'no_references_setting': { + 'status': 'ai_questioning', + 'value': '{question}' + }} def get_model_setting_dict(): diff --git a/apps/application/serializers/application_serializers.py b/apps/application/serializers/application_serializers.py index 30af7f56e..e60aaa320 100644 --- a/apps/application/serializers/application_serializers.py +++ b/apps/application/serializers/application_serializers.py @@ -73,6 +73,18 @@ class ApplicationSerializerModel(serializers.ModelSerializer): fields = "__all__" +class NoReferencesChoices(models.TextChoices): + """订单类型""" + ai_questioning = 'ai_questioning', 'ai回答' + designated_answer = 'designated_answer', '指定回答' + + +class NoReferencesSetting(serializers.Serializer): + status = serializers.ChoiceField(required=True, choices=NoReferencesChoices.choices, + error_messages=ErrMessage.char("无引用状态")) + value = serializers.CharField(required=True, error_messages=ErrMessage.char("提示词")) + + class DatasetSettingSerializer(serializers.Serializer): top_n = serializers.FloatField(required=True, max_value=100, min_value=1, error_messages=ErrMessage.float("引用分段数")) @@ -85,6 +97,8 @@ class DatasetSettingSerializer(serializers.Serializer): message="类型只支持register|reset_password", code=500) ], error_messages=ErrMessage.char("检索模式")) + no_references_setting = NoReferencesSetting(required=True, error_messages=ErrMessage.base("未引用分段设置")) + class ModelSettingSerializer(serializers.Serializer): prompt = serializers.CharField(required=True, max_length=2048, error_messages=ErrMessage.char("提示词")) @@ -383,7 +397,9 @@ class ApplicationSerializer(serializers.Serializer): application['multiple_rounds_dialogue'] = True if application.get('dialogue_number') > 0 else False del application['dialogue_number'] if 'dataset_setting' in application: - application['dataset_setting'] = {**application['dataset_setting'], 'search_mode': 'embedding'} + application['dataset_setting'] = {'search_mode': 'embedding', 'no_references_setting': { + 'status': 'ai_questioning', + 'value': '{question}'}, **application['dataset_setting']} return application def page(self, current_page: int, page_size: int, with_valid=True): diff --git a/apps/application/serializers/chat_message_serializers.py b/apps/application/serializers/chat_message_serializers.py index daf850e76..ddfea22af 100644 --- a/apps/application/serializers/chat_message_serializers.py +++ b/apps/application/serializers/chat_message_serializers.py @@ -78,7 +78,11 @@ class ChatInfo: 'problem_optimization': self.application.problem_optimization, 'stream': True, 'search_mode': self.application.dataset_setting.get( - 'search_mode') if 'search_mode' in self.application.dataset_setting else 'embedding' + 'search_mode') if 'search_mode' in self.application.dataset_setting else 'embedding', + 'no_references_setting': self.application.dataset_setting.get( + 'no_references_setting') if 'no_references_setting' in self.application.dataset_setting else { + 'status': 'ai_questioning', + 'value': '{question}'} } diff --git a/apps/application/swagger_api/application_api.py b/apps/application/swagger_api/application_api.py index 542fd9f75..e92d67801 100644 --- a/apps/application/swagger_api/application_api.py +++ b/apps/application/swagger_api/application_api.py @@ -176,6 +176,18 @@ class ApplicationApi(ApiMixin): description="最多引用字符数", default=3000), 'search_mode': openapi.Schema(type=openapi.TYPE_STRING, title='检索模式', description="embedding|keywords|blend", default='embedding'), + 'no_references_setting': openapi.Schema(type=openapi.TYPE_OBJECT, title='检索模式', + required=['status', 'value'], + properties={ + 'status': openapi.Schema(type=openapi.TYPE_STRING, + title="状态", + description="ai作答:ai_questioning,指定回答:designated_answer", + default='ai_questioning'), + 'value': openapi.Schema(type=openapi.TYPE_STRING, + title="值", + description="ai作答:就是题词,指定回答:就是指定回答内容", + default='{question}'), + }), } ) diff --git a/ui/src/styles/element-plus.scss b/ui/src/styles/element-plus.scss index fc45acca0..b85ad4f77 100644 --- a/ui/src/styles/element-plus.scss +++ b/ui/src/styles/element-plus.scss @@ -335,3 +335,19 @@ .auto-tooltip-popper { max-width: 500px; } + +// radio 一行一个样式 +.radio-block { + width: 100%; + display: block; + .el-radio { + align-items: flex-start; + height: 100%; + width: 100%; + } + .el-radio__label { + width: 100%; + margin-top: -8px; + line-height: 30px; + } +} diff --git a/ui/src/views/application-overview/component/EditAvatarDialog.vue b/ui/src/views/application-overview/component/EditAvatarDialog.vue index 53b271542..35f78a802 100644 --- a/ui/src/views/application-overview/component/EditAvatarDialog.vue +++ b/ui/src/views/application-overview/component/EditAvatarDialog.vue @@ -1,6 +1,6 @@