From b627b6638ca0055d20b1c31225864bbb6f685132 Mon Sep 17 00:00:00 2001 From: shaohuzhang1 Date: Thu, 29 Aug 2024 14:06:36 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E5=B7=A5=E4=BD=9C?= =?UTF-8?q?=E6=B5=81=E7=BC=96=E6=8E=92=E5=8E=86=E5=8F=B2=E6=95=B0=E6=8D=AE?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E5=8F=82=E6=95=B0=E5=BF=85=E5=A1=AB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../step_node/ai_chat_step_node/i_chat_node.py | 2 +- .../ai_chat_step_node/impl/base_chat_node.py | 14 +++++++++++++- .../step_node/question_node/i_question_node.py | 2 +- .../question_node/impl/base_question_node.py | 13 +++++++++++++ 4 files changed, 28 insertions(+), 3 deletions(-) diff --git a/apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py b/apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py index be89f35ef..15117c0b7 100644 --- a/apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py +++ b/apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py @@ -24,7 +24,7 @@ class ChatNodeSerializer(serializers.Serializer): is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容')) - model_params_setting = serializers.DictField(required=True, error_messages=ErrMessage.integer("模型参数相关设置")) + model_params_setting = serializers.DictField(required=False, error_messages=ErrMessage.integer("模型参数相关设置")) class IChatNode(INode): diff --git a/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py b/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py index 3581b5cfd..f8ec8a011 100644 --- a/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py +++ b/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py @@ -10,11 +10,14 @@ import time from functools import reduce from typing import List, Dict +from django.db.models import QuerySet from langchain.schema import HumanMessage, SystemMessage from langchain_core.messages import BaseMessage from application.flow.i_step_node import NodeResult, INode from application.flow.step_node.ai_chat_step_node.i_chat_node import IChatNode +from setting.models import Model +from setting.models_provider import get_model_credential from setting.models_provider.tools import get_model_instance_by_model_user_id @@ -61,11 +64,20 @@ def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wor _write_context(node_variable, workflow_variable, node, workflow, answer) +def get_default_model_params_setting(model_id): + model = QuerySet(Model).filter(id=model_id).first() + credential = get_model_credential(model.provider, model.model_type, model.model_name) + model_params_setting = credential.get_model_params_setting_form( + model.model_name).get_default_form_data() + return model_params_setting + + class BaseChatNode(IChatNode): def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id, model_params_setting, **kwargs) -> NodeResult: - + if model_params_setting is None: + model_params_setting = get_default_model_params_setting(model_id) chat_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'), **model_params_setting) history_message = self.get_history_message(history_chat_record, dialogue_number) diff --git a/apps/application/flow/step_node/question_node/i_question_node.py b/apps/application/flow/step_node/question_node/i_question_node.py index 8fda34a5d..9b8c12562 100644 --- a/apps/application/flow/step_node/question_node/i_question_node.py +++ b/apps/application/flow/step_node/question_node/i_question_node.py @@ -23,7 +23,7 @@ class QuestionNodeSerializer(serializers.Serializer): dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer("多轮对话数量")) is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容')) - model_params_setting = serializers.DictField(required=True, error_messages=ErrMessage.integer("模型参数相关设置")) + model_params_setting = serializers.DictField(required=False, error_messages=ErrMessage.integer("模型参数相关设置")) class IQuestionNode(INode): diff --git a/apps/application/flow/step_node/question_node/impl/base_question_node.py b/apps/application/flow/step_node/question_node/impl/base_question_node.py index 397a0bfc1..33e1c0fe4 100644 --- a/apps/application/flow/step_node/question_node/impl/base_question_node.py +++ b/apps/application/flow/step_node/question_node/impl/base_question_node.py @@ -10,11 +10,14 @@ import time from functools import reduce from typing import List, Dict +from django.db.models import QuerySet from langchain.schema import HumanMessage, SystemMessage from langchain_core.messages import BaseMessage from application.flow.i_step_node import NodeResult, INode from application.flow.step_node.question_node.i_question_node import IQuestionNode +from setting.models import Model +from setting.models_provider import get_model_credential from setting.models_provider.tools import get_model_instance_by_model_user_id @@ -61,10 +64,20 @@ def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wor _write_context(node_variable, workflow_variable, node, workflow, answer) +def get_default_model_params_setting(model_id): + model = QuerySet(Model).filter(id=model_id).first() + credential = get_model_credential(model.provider, model.model_type, model.model_name) + model_params_setting = credential.get_model_params_setting_form( + model.model_name).get_default_form_data() + return model_params_setting + + class BaseQuestionNode(IQuestionNode): def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id, model_params_setting, **kwargs) -> NodeResult: + if model_params_setting is None: + model_params_setting = get_default_model_params_setting(model_id) chat_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'), **model_params_setting) history_message = self.get_history_message(history_chat_record, dialogue_number)