diff --git a/apps/application/chat_pipeline/step/reset_problem_step/i_reset_problem_step.py b/apps/application/chat_pipeline/step/reset_problem_step/i_reset_problem_step.py index fe8068119..e12fd082d 100644 --- a/apps/application/chat_pipeline/step/reset_problem_step/i_reset_problem_step.py +++ b/apps/application/chat_pipeline/step/reset_problem_step/i_reset_problem_step.py @@ -28,7 +28,8 @@ class IResetProblemStep(IBaseChatPipelineStep): history_chat_record = serializers.ListField(child=InstanceField(model_type=ChatRecord, required=True), error_messages=ErrMessage.list("历史对答")) # 大语言模型 - chat_model = ModelField(required=False, allow_null=True, error_messages=ErrMessage.base("大语言模型")) + model_id = serializers.UUIDField(required=False, allow_null=True, error_messages=ErrMessage.uuid("模型id")) + user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id")) problem_optimization_prompt = serializers.CharField(required=False, max_length=102400, error_messages=ErrMessage.char("问题补全提示词")) @@ -48,7 +49,8 @@ class IResetProblemStep(IBaseChatPipelineStep): manage.context['answer_tokens'] = manage.context['answer_tokens'] + self.context.get('answer_tokens') @abstractmethod - def execute(self, problem_text: str, history_chat_record: List[ChatRecord] = None, chat_model: BaseChatModel = None, + def execute(self, problem_text: str, history_chat_record: List[ChatRecord] = None, model_id: str = None, problem_optimization_prompt=None, + user_id=None, **kwargs): pass diff --git a/apps/application/chat_pipeline/step/reset_problem_step/impl/base_reset_problem_step.py b/apps/application/chat_pipeline/step/reset_problem_step/impl/base_reset_problem_step.py index 3a32bbf02..2d631e076 100644 --- a/apps/application/chat_pipeline/step/reset_problem_step/impl/base_reset_problem_step.py +++ b/apps/application/chat_pipeline/step/reset_problem_step/impl/base_reset_problem_step.py @@ -8,25 +8,23 @@ """ from typing import List -from langchain.chat_models.base import BaseChatModel from langchain.schema import HumanMessage from application.chat_pipeline.step.reset_problem_step.i_reset_problem_step import IResetProblemStep from application.models import ChatRecord from common.util.split_model import flat_map +from setting.models_provider.tools import get_model_instance_by_model_user_id prompt = ( '()里面是用户问题,根据上下文回答揣测用户问题({question}) 要求: 输出一个补全问题,并且放在标签中') class BaseResetProblemStep(IResetProblemStep): - def execute(self, problem_text: str, history_chat_record: List[ChatRecord] = None, chat_model: BaseChatModel = None, + def execute(self, problem_text: str, history_chat_record: List[ChatRecord] = None, model_id: str = None, problem_optimization_prompt=None, + user_id=None, **kwargs) -> str: - if chat_model is None: - self.context['message_tokens'] = 0 - self.context['answer_tokens'] = 0 - return problem_text + chat_model = get_model_instance_by_model_user_id(model_id, user_id) if model_id is not None else None start_index = len(history_chat_record) - 3 history_message = [[history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()] for index in