fix: 修复工作流编排历史数据模型参数必填

This commit is contained in:
shaohuzhang1 2024-08-29 14:06:36 +08:00 committed by shaohuzhang1
parent c2622e4a5d
commit b627b6638c
4 changed files with 28 additions and 3 deletions

View File

@ -24,7 +24,7 @@ class ChatNodeSerializer(serializers.Serializer):
is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容'))
model_params_setting = serializers.DictField(required=True, error_messages=ErrMessage.integer("模型参数相关设置"))
model_params_setting = serializers.DictField(required=False, error_messages=ErrMessage.integer("模型参数相关设置"))
class IChatNode(INode):

View File

@ -10,11 +10,14 @@ import time
from functools import reduce
from typing import List, Dict
from django.db.models import QuerySet
from langchain.schema import HumanMessage, SystemMessage
from langchain_core.messages import BaseMessage
from application.flow.i_step_node import NodeResult, INode
from application.flow.step_node.ai_chat_step_node.i_chat_node import IChatNode
from setting.models import Model
from setting.models_provider import get_model_credential
from setting.models_provider.tools import get_model_instance_by_model_user_id
@ -61,11 +64,20 @@ def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wor
_write_context(node_variable, workflow_variable, node, workflow, answer)
def get_default_model_params_setting(model_id):
model = QuerySet(Model).filter(id=model_id).first()
credential = get_model_credential(model.provider, model.model_type, model.model_name)
model_params_setting = credential.get_model_params_setting_form(
model.model_name).get_default_form_data()
return model_params_setting
class BaseChatNode(IChatNode):
def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id,
model_params_setting,
**kwargs) -> NodeResult:
if model_params_setting is None:
model_params_setting = get_default_model_params_setting(model_id)
chat_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'),
**model_params_setting)
history_message = self.get_history_message(history_chat_record, dialogue_number)

View File

@ -23,7 +23,7 @@ class QuestionNodeSerializer(serializers.Serializer):
dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer("多轮对话数量"))
is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容'))
model_params_setting = serializers.DictField(required=True, error_messages=ErrMessage.integer("模型参数相关设置"))
model_params_setting = serializers.DictField(required=False, error_messages=ErrMessage.integer("模型参数相关设置"))
class IQuestionNode(INode):

View File

@ -10,11 +10,14 @@ import time
from functools import reduce
from typing import List, Dict
from django.db.models import QuerySet
from langchain.schema import HumanMessage, SystemMessage
from langchain_core.messages import BaseMessage
from application.flow.i_step_node import NodeResult, INode
from application.flow.step_node.question_node.i_question_node import IQuestionNode
from setting.models import Model
from setting.models_provider import get_model_credential
from setting.models_provider.tools import get_model_instance_by_model_user_id
@ -61,10 +64,20 @@ def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wor
_write_context(node_variable, workflow_variable, node, workflow, answer)
def get_default_model_params_setting(model_id):
model = QuerySet(Model).filter(id=model_id).first()
credential = get_model_credential(model.provider, model.model_type, model.model_name)
model_params_setting = credential.get_model_params_setting_form(
model.model_name).get_default_form_data()
return model_params_setting
class BaseQuestionNode(IQuestionNode):
def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id,
model_params_setting,
**kwargs) -> NodeResult:
if model_params_setting is None:
model_params_setting = get_default_model_params_setting(model_id)
chat_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'),
**model_params_setting)
history_message = self.get_history_message(history_chat_record, dialogue_number)