fix: 修复建议应用问题优化无效

This commit is contained in:
shaohuzhang1 2024-09-25 16:19:03 +08:00 committed by shaohuzhang1
parent e16e827028
commit 6e366c70b5
2 changed files with 8 additions and 8 deletions

View File

@ -28,7 +28,8 @@ class IResetProblemStep(IBaseChatPipelineStep):
history_chat_record = serializers.ListField(child=InstanceField(model_type=ChatRecord, required=True),
error_messages=ErrMessage.list("历史对答"))
# 大语言模型
chat_model = ModelField(required=False, allow_null=True, error_messages=ErrMessage.base("大语言模型"))
model_id = serializers.UUIDField(required=False, allow_null=True, error_messages=ErrMessage.uuid("模型id"))
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))
problem_optimization_prompt = serializers.CharField(required=False, max_length=102400,
error_messages=ErrMessage.char("问题补全提示词"))
@ -48,7 +49,8 @@ class IResetProblemStep(IBaseChatPipelineStep):
manage.context['answer_tokens'] = manage.context['answer_tokens'] + self.context.get('answer_tokens')
@abstractmethod
def execute(self, problem_text: str, history_chat_record: List[ChatRecord] = None, chat_model: BaseChatModel = None,
def execute(self, problem_text: str, history_chat_record: List[ChatRecord] = None, model_id: str = None,
problem_optimization_prompt=None,
user_id=None,
**kwargs):
pass

View File

@ -8,25 +8,23 @@
"""
from typing import List
from langchain.chat_models.base import BaseChatModel
from langchain.schema import HumanMessage
from application.chat_pipeline.step.reset_problem_step.i_reset_problem_step import IResetProblemStep
from application.models import ChatRecord
from common.util.split_model import flat_map
from setting.models_provider.tools import get_model_instance_by_model_user_id
prompt = (
'()里面是用户问题,根据上下文回答揣测用户问题({question}) 要求: 输出一个补全问题,并且放在<data></data>标签中')
class BaseResetProblemStep(IResetProblemStep):
def execute(self, problem_text: str, history_chat_record: List[ChatRecord] = None, chat_model: BaseChatModel = None,
def execute(self, problem_text: str, history_chat_record: List[ChatRecord] = None, model_id: str = None,
problem_optimization_prompt=None,
user_id=None,
**kwargs) -> str:
if chat_model is None:
self.context['message_tokens'] = 0
self.context['answer_tokens'] = 0
return problem_text
chat_model = get_model_instance_by_model_user_id(model_id, user_id) if model_id is not None else None
start_index = len(history_chat_record) - 3
history_message = [[history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()]
for index in