diff --git a/apps/application/chat_pipeline/step/chat_step/i_chat_step.py b/apps/application/chat_pipeline/step/chat_step/i_chat_step.py index a2fc2af3b..958eb59b1 100644 --- a/apps/application/chat_pipeline/step/chat_step/i_chat_step.py +++ b/apps/application/chat_pipeline/step/chat_step/i_chat_step.py @@ -73,9 +73,8 @@ class IChatStep(IBaseChatPipelineStep): no_references_setting = NoReferencesSetting(required=True, error_messages=ErrMessage.base("无引用分段设置")) user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id")) - temperature = serializers.FloatField(required=False, allow_null=True, error_messages=ErrMessage.float("温度")) - max_tokens = serializers.IntegerField(required=False, allow_null=True, - error_messages=ErrMessage.integer("最大token数")) + + model_params_setting = serializers.DictField(required=True, error_messages=ErrMessage.dict("模型参数设置")) def is_valid(self, *, raise_exception=False): super().is_valid(raise_exception=True) @@ -100,5 +99,5 @@ class IChatStep(IBaseChatPipelineStep): paragraph_list=None, manage: PipelineManage = None, padding_problem_text: str = None, stream: bool = True, client_id=None, client_type=None, - no_references_setting=None, **kwargs): + no_references_setting=None, model_params_setting=None, **kwargs): pass diff --git a/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py b/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py index b68577409..a4bced004 100644 --- a/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py +++ b/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py @@ -110,8 +110,10 @@ class BaseChatStep(IChatStep): stream: bool = True, client_id=None, client_type=None, no_references_setting=None, + model_params_setting=None, **kwargs): - chat_model = get_model_instance_by_model_user_id(model_id, user_id, **kwargs) if model_id is not None else None + chat_model = get_model_instance_by_model_user_id(model_id, user_id, + **model_params_setting) if model_id is not None else None if stream: return self.execute_stream(message_list, chat_id, problem_text, post_response_handler, chat_model, paragraph_list, diff --git a/apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py b/apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py index 60c70268d..be89f35ef 100644 --- a/apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py +++ b/apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py @@ -23,9 +23,8 @@ class ChatNodeSerializer(serializers.Serializer): dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer("多轮对话数量")) is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容')) - temperature = serializers.FloatField(required=False, allow_null=True, error_messages=ErrMessage.float("温度")) - max_tokens = serializers.IntegerField(required=False, allow_null=True, - error_messages=ErrMessage.integer("最大token数")) + + model_params_setting = serializers.DictField(required=True, error_messages=ErrMessage.integer("模型参数相关设置")) class IChatNode(INode): @@ -39,5 +38,6 @@ class IChatNode(INode): def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id, + model_params_setting, **kwargs) -> NodeResult: pass diff --git a/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py b/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py index d611df32e..6eac5004b 100644 --- a/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py +++ b/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py @@ -65,12 +65,11 @@ def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wor class BaseChatNode(IChatNode): def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id, + model_params_setting, **kwargs) -> NodeResult: - kwargs = {k: v for k, v in kwargs.items() if k in ['temperature', 'max_tokens']} chat_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'), - **kwargs - ) + **model_params_setting) history_message = self.get_history_message(history_chat_record, dialogue_number) self.context['history_message'] = history_message question = self.generate_prompt_question(prompt) diff --git a/apps/application/flow/step_node/question_node/i_question_node.py b/apps/application/flow/step_node/question_node/i_question_node.py index 30790c7c6..8fda34a5d 100644 --- a/apps/application/flow/step_node/question_node/i_question_node.py +++ b/apps/application/flow/step_node/question_node/i_question_node.py @@ -23,6 +23,7 @@ class QuestionNodeSerializer(serializers.Serializer): dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer("多轮对话数量")) is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容')) + model_params_setting = serializers.DictField(required=True, error_messages=ErrMessage.integer("模型参数相关设置")) class IQuestionNode(INode): @@ -35,5 +36,6 @@ class IQuestionNode(INode): return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data) def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id, + model_params_setting, **kwargs) -> NodeResult: pass diff --git a/apps/application/flow/step_node/question_node/impl/base_question_node.py b/apps/application/flow/step_node/question_node/impl/base_question_node.py index 3826e59e5..de40d3a8c 100644 --- a/apps/application/flow/step_node/question_node/impl/base_question_node.py +++ b/apps/application/flow/step_node/question_node/impl/base_question_node.py @@ -65,8 +65,10 @@ def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wor class BaseQuestionNode(IQuestionNode): def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id, + model_params_setting, **kwargs) -> NodeResult: - chat_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id')) + chat_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'), + **model_params_setting) history_message = self.get_history_message(history_chat_record, dialogue_number) self.context['history_message'] = history_message question = self.generate_prompt_question(prompt) diff --git a/apps/application/flow/workflow_manage.py b/apps/application/flow/workflow_manage.py index 7e38d2192..b48c2853e 100644 --- a/apps/application/flow/workflow_manage.py +++ b/apps/application/flow/workflow_manage.py @@ -10,6 +10,7 @@ import json from functools import reduce from typing import List, Dict +from django.db.models import QuerySet from langchain_core.messages import AIMessage from langchain_core.prompts import PromptTemplate @@ -17,6 +18,8 @@ from application.flow import tools from application.flow.i_step_node import INode, WorkFlowPostHandler, NodeResult from application.flow.step_node import get_node from common.exception.app_exception import AppApiException +from setting.models import Model +from setting.models_provider import get_model_credential class Edge: @@ -68,6 +71,7 @@ class Flow: """ 校验工作流数据 """ + self.is_valid_model_params() self.is_valid_start_node() self.is_valid_base_node() self.is_valid_work_flow() @@ -123,6 +127,19 @@ class Flow: if len(start_node_list) > 1: raise AppApiException(500, '开始节点只能有一个') + def is_valid_model_params(self): + node_list = [node for node in self.nodes if (node.type == 'ai-chat-node' or node.type == 'question-node')] + for node in node_list: + model = QuerySet(Model).filter(id=node.properties.get('node_data', {}).get('model_id')).first() + credential = get_model_credential(model.provider, model.model_type, model.model_name) + model_params_setting = node.properties.get('node_data', {}).get('model_params_setting') + model_params_setting_form = credential.get_model_params_setting_form( + model.model_name) + if model_params_setting is None: + model_params_setting = model_params_setting_form.get_default_form_data() + node.properties.get('node_data', {})['model_params_setting'] = model_params_setting + model_params_setting_form.valid_form(model_params_setting) + def is_valid_base_node(self): base_node_list = [node for node in self.nodes if node.id == 'base-node'] if len(base_node_list) == 0: diff --git a/apps/application/migrations/0011_application_model_params_setting.py b/apps/application/migrations/0011_application_model_params_setting.py new file mode 100644 index 000000000..440b94df5 --- /dev/null +++ b/apps/application/migrations/0011_application_model_params_setting.py @@ -0,0 +1,18 @@ +# Generated by Django 4.2.15 on 2024-08-23 14:17 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('application', '0010_alter_chatrecord_details'), + ] + + operations = [ + migrations.AddField( + model_name='application', + name='model_params_setting', + field=models.JSONField(default={}, verbose_name='模型参数相关设置'), + ), + ] diff --git a/apps/application/models/application.py b/apps/application/models/application.py index bdd0a672e..837aadbad 100644 --- a/apps/application/models/application.py +++ b/apps/application/models/application.py @@ -48,6 +48,7 @@ class Application(AppModelMixin): model = models.ForeignKey(Model, on_delete=models.SET_NULL, db_constraint=False, blank=True, null=True) dataset_setting = models.JSONField(verbose_name="数据集参数设置", default=get_dataset_setting_dict) model_setting = models.JSONField(verbose_name="模型参数相关设置", default=get_model_setting_dict) + model_params_setting = models.JSONField(verbose_name="模型参数相关设置", default={}) problem_optimization = models.BooleanField(verbose_name="问题优化", default=False) icon = models.CharField(max_length=256, verbose_name="应用icon", default="/ui/favicon.ico") work_flow = models.JSONField(verbose_name="工作流数据", default=dict) diff --git a/apps/application/serializers/application_serializers.py b/apps/application/serializers/application_serializers.py index bef20bd64..47e8b95e4 100644 --- a/apps/application/serializers/application_serializers.py +++ b/apps/application/serializers/application_serializers.py @@ -44,6 +44,7 @@ from dataset.serializers.common_serializers import list_paragraph, get_embedding from embedding.models import SearchMode from setting.models import AuthOperate from setting.models.model_management import Model +from setting.models_provider import get_model_credential from setting.models_provider.constants.model_provider_constants import ModelProvideConstants from setting.serializers.provider_serializers import ModelSerializer from smartdoc.conf import PROJECT_DIR @@ -93,6 +94,14 @@ class NoReferencesSetting(serializers.Serializer): value = serializers.CharField(required=True, error_messages=ErrMessage.char("提示词")) +def valid_model_params_setting(model_id, model_params_setting): + if model_id is None: + return + model = QuerySet(Model).filter(id=model_id).first() + credential = get_model_credential(model.provider, model.model_type, model.model_name) + credential.get_model_params_setting_form(model.model_name).valid_form(model_params_setting) + + class DatasetSettingSerializer(serializers.Serializer): top_n = serializers.FloatField(required=True, max_value=100, min_value=1, error_messages=ErrMessage.float("引用分段数")) @@ -110,10 +119,6 @@ class DatasetSettingSerializer(serializers.Serializer): class ModelSettingSerializer(serializers.Serializer): prompt = serializers.CharField(required=True, max_length=2048, error_messages=ErrMessage.char("提示词")) - temperature = serializers.FloatField(required=False, allow_null=True, - error_messages=ErrMessage.char("温度")) - max_tokens = serializers.IntegerField(required=False, allow_null=True, - error_messages=ErrMessage.integer("最大token数")) class ApplicationWorkflowSerializer(serializers.Serializer): @@ -185,6 +190,7 @@ class ApplicationSerializer(serializers.Serializer): message="应用类型只支持SIMPLE|WORK_FLOW", code=500) ] ) + model_params_setting = serializers.DictField(required=False, error_messages=ErrMessage.dict('模型参数')) def is_valid(self, *, user_id=None, raise_exception=False): super().is_valid(raise_exception=True) @@ -355,6 +361,8 @@ class ApplicationSerializer(serializers.Serializer): error_messages=ErrMessage.boolean("问题补全")) icon = serializers.CharField(required=False, allow_null=True, error_messages=ErrMessage.char("icon图标")) + model_params_setting = serializers.DictField(required=False, error_messages=ErrMessage.dict('模型参数')) + class Create(serializers.Serializer): user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id")) @@ -667,7 +675,7 @@ class ApplicationSerializer(serializers.Serializer): ApplicationSerializer.Edit(data=instance).is_valid( raise_exception=True) application_id = self.data.get("application_id") - + valid_model_params_setting(instance.get('model_id'), instance.get('model_params_setting')) application = QuerySet(Application).get(id=application_id) if instance.get('model_id') is None or len(instance.get('model_id')) == 0: application.model_id = None diff --git a/apps/application/serializers/chat_message_serializers.py b/apps/application/serializers/chat_message_serializers.py index d526e3540..08c371158 100644 --- a/apps/application/serializers/chat_message_serializers.py +++ b/apps/application/serializers/chat_message_serializers.py @@ -33,6 +33,7 @@ from common.util.field_message import ErrMessage from common.util.split_model import flat_map from dataset.models import Paragraph, Document from setting.models import Model, Status +from setting.models_provider import get_model_credential chat_cache = caches['chat_cache'] @@ -63,6 +64,12 @@ class ChatInfo: def to_base_pipeline_manage_params(self): dataset_setting = self.application.dataset_setting model_setting = self.application.model_setting + model_id = self.application.model.id if self.application.model is not None else None + model_params_setting = None + if model_id is not None: + model = QuerySet(Model).filter(id=model_id).first() + credential = get_model_credential(model.provider, model.model_type, model.model_name) + model_params_setting = credential.get_model_params_setting_form(model.model_name).get_default_form_data() return { 'dataset_id_list': self.dataset_id_list, 'exclude_document_id_list': self.exclude_document_id_list, @@ -77,11 +84,11 @@ class ChatInfo: 'prompt': model_setting.get( 'prompt') if 'prompt' in model_setting else Application.get_default_model_prompt(), 'chat_model': self.chat_model, - 'model_id': self.application.model.id if self.application.model is not None else None, + 'model_id': model_id, 'problem_optimization': self.application.problem_optimization, 'stream': True, - 'temperature': model_setting.get('temperature') if 'temperature' in model_setting else None, - 'max_tokens': model_setting.get('max_tokens') if 'max_tokens' in model_setting else None, + 'model_params_setting': model_params_setting if self.application.model_params_setting is None or len( + self.application.model_params_setting.keys()) == 0 else self.application.model_params_setting, 'search_mode': self.application.dataset_setting.get( 'search_mode') if 'search_mode' in self.application.dataset_setting else 'embedding', 'no_references_setting': self.application.dataset_setting.get( @@ -90,7 +97,6 @@ class ChatInfo: 'value': '{question}', }, 'user_id': self.application.user_id - } def to_pipeline_manage_params(self, problem_text: str, post_response_handler: PostResponseHandler, diff --git a/apps/application/serializers/chat_serializers.py b/apps/application/serializers/chat_serializers.py index 061c8855b..aff1c32a9 100644 --- a/apps/application/serializers/chat_serializers.py +++ b/apps/application/serializers/chat_serializers.py @@ -33,17 +33,17 @@ from application.serializers.application_serializers import ModelDatasetAssociat from application.serializers.chat_message_serializers import ChatInfo from common.constants.permission_constants import RoleConstants from common.db.search import native_search, native_page_search, page_search, get_dynamics_model -from common.event import ListenerManagement from common.exception.app_exception import AppApiException from common.util.common import post from common.util.field_message import ErrMessage from common.util.file_util import get_file_content from common.util.lock import try_lock, un_lock from dataset.models import Document, Problem, Paragraph, ProblemParagraphMapping -from dataset.serializers.common_serializers import get_embedding_model_by_dataset_id, \ - get_embedding_model_id_by_dataset_id +from dataset.serializers.common_serializers import get_embedding_model_id_by_dataset_id from dataset.serializers.paragraph_serializers import ParagraphSerializers from embedding.task import embedding_by_paragraph +from setting.models import Model +from setting.models_provider import get_model_credential from smartdoc.conf import PROJECT_DIR chat_cache = caches['chat_cache'] @@ -54,6 +54,17 @@ class WorkFlowSerializers(serializers.Serializer): edges = serializers.ListSerializer(child=serializers.DictField(), error_messages=ErrMessage.uuid("连线")) +def valid_model_params_setting(model_id, model_params_setting): + if model_id is None: + return + model = QuerySet(Model).filter(id=model_id).first() + credential = get_model_credential(model.provider, model.model_type, model.model_name) + model_params_setting_form = credential.get_model_params_setting_form(model.model_name) + if model_params_setting is None or len(model_params_setting.keys()) == 0: + model_params_setting = model_params_setting_form.get_default_form_data() + credential.get_model_params_setting_form(model.model_name).valid_form(model_params_setting) + + class ChatSerializers(serializers.Serializer): class Operate(serializers.Serializer): chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("对话id")) @@ -253,6 +264,7 @@ class ChatSerializers(serializers.Serializer): return chat_id def open_simple(self, application): + valid_model_params_setting(application.model_id, application.model_params_setting) application_id = self.data.get('application_id') dataset_id_list = [str(row.dataset_id) for row in QuerySet(ApplicationDatasetMapping).filter( @@ -309,6 +321,8 @@ class ChatSerializers(serializers.Serializer): model_setting = ModelSettingSerializer(required=True) # 问题补全 problem_optimization = serializers.BooleanField(required=True, error_messages=ErrMessage.boolean("问题补全")) + # 模型相关设置 + model_params_setting = serializers.JSONField(required=True) def is_valid(self, *, raise_exception=False): super().is_valid(raise_exception=True) @@ -332,10 +346,12 @@ class ChatSerializers(serializers.Serializer): model_id = self.data.get('model_id') dataset_id_list = self.data.get('dataset_id_list') dialogue_number = 3 if self.data.get('multiple_rounds_dialogue', False) else 0 + valid_model_params_setting(model_id, self.data.get('model_params_setting')) application = Application(id=None, dialogue_number=dialogue_number, model_id=model_id, dataset_setting=self.data.get('dataset_setting'), model_setting=self.data.get('model_setting'), problem_optimization=self.data.get('problem_optimization'), + model_params_setting=self.data.get('model_params_setting'), user_id=user_id) chat_cache.set(chat_id, ChatInfo(chat_id, None, dataset_id_list, diff --git a/apps/application/urls.py b/apps/application/urls.py index 1409452de..335205d37 100644 --- a/apps/application/urls.py +++ b/apps/application/urls.py @@ -19,10 +19,6 @@ urlpatterns = [ path('application//statistics/chat_record_aggregate_trend', views.ApplicationStatistics.ChatRecordAggregateTrend.as_view()), path('application//model', views.Application.Model.as_view()), - path('application//model/', views.Application.Model.Operate.as_view()), - path('application//model//', - views.Application.Model.Operate.as_view()), - path('application//other-config', views.Application.Model.OtherConfig.as_view()), path('application//hit_test', views.Application.HitTest.as_view()), path('application//api_key', views.Application.ApplicationKey.as_view()), path("application//api_key/", diff --git a/apps/application/views/application_views.py b/apps/application/views/application_views.py index a468d0b65..e6fe7191e 100644 --- a/apps/application/views/application_views.py +++ b/apps/application/views/application_views.py @@ -187,30 +187,6 @@ class Application(APIView): data={'application_id': application_id, 'user_id': request.user.id}).list_model(request.query_params.get('model_type'))) - class Operate(APIView): - authentication_classes = [TokenAuth] - - @swagger_auto_schema(operation_summary="获取应用参数设置其他字段", - operation_id="获取应用参数设置其他字段", - tags=["应用/会话"]) - def get(self, request: Request, application_id: str, model_id: str, ai_node_id=None): - return result.success( - ApplicationSerializer.Operate( - data={'application_id': application_id, 'model_id': model_id, - 'ai_node_id': ai_node_id}).get_other_file_list()) - - class OtherConfig(APIView): - authentication_classes = [TokenAuth] - - @swagger_auto_schema(operation_summary="获取应用参数设置其他字段", - operation_id="获取应用参数设置其他字段", - tags=["应用/会话"]) - def put(self, request: Request, application_id: str): - return result.success( - ApplicationSerializer.Operate( - data={'application_id': application_id}).save_other_config( - request.data)) - class Profile(APIView): authentication_classes = [TokenAuth] diff --git a/apps/common/forms/__init__.py b/apps/common/forms/__init__.py index cda6fe040..609542193 100644 --- a/apps/common/forms/__init__.py +++ b/apps/common/forms/__init__.py @@ -20,3 +20,5 @@ from .text_input_field import * from .radio_button_field import * from .table_checkbox import * from .radio_card_field import * +from .label import * +from .slider_field import * diff --git a/apps/common/forms/base_field.py b/apps/common/forms/base_field.py index d12ae77a7..dedd78d83 100644 --- a/apps/common/forms/base_field.py +++ b/apps/common/forms/base_field.py @@ -9,6 +9,9 @@ from enum import Enum from typing import List, Dict +from common.exception.app_exception import AppApiException +from common.forms.label.base_label import BaseLabel + class TriggerType(Enum): # 执行函数获取 OptionList数据 @@ -20,7 +23,7 @@ class TriggerType(Enum): class BaseField: def __init__(self, input_type: str, - label: str, + label: str or BaseLabel, required: bool = False, default_value: object = None, relation_show_field_dict: Dict = None, @@ -53,10 +56,16 @@ class BaseField: self.required = required self.trigger_type = trigger_type - def to_dict(self): + def is_valid(self, value): + field_label = self.label.label if hasattr(self.label, 'to_dict') else self.label + if self.required and value is None: + raise AppApiException(500, + f"{field_label} 为必填参数") + + def to_dict(self, **kwargs): return { 'input_type': self.input_type, - 'label': self.label, + 'label': self.label.to_dict(**kwargs) if hasattr(self.label, 'to_dict') else self.label, 'required': self.required, 'default_value': self.default_value, 'relation_show_field_dict': self.relation_show_field_dict, @@ -64,6 +73,7 @@ class BaseField: 'trigger_type': self.trigger_type.value, 'attrs': self.attrs, 'props_info': self.props_info, + **kwargs } @@ -97,8 +107,8 @@ class BaseDefaultOptionField(BaseField): self.value_field = value_field self.option_list = option_list - def to_dict(self): - return {**super().to_dict(), 'text_field': self.text_field, 'value_field': self.value_field, + def to_dict(self, **kwargs): + return {**super().to_dict(**kwargs), 'text_field': self.text_field, 'value_field': self.value_field, 'option_list': self.option_list} @@ -141,6 +151,6 @@ class BaseExecField(BaseField): self.provider = provider self.method = method - def to_dict(self): - return {**super().to_dict(), 'text_field': self.text_field, 'value_field': self.value_field, + def to_dict(self, **kwargs): + return {**super().to_dict(**kwargs), 'text_field': self.text_field, 'value_field': self.value_field, 'provider': self.provider, 'method': self.method} diff --git a/apps/common/forms/base_form.py b/apps/common/forms/base_form.py index 93984b8c6..5ef92c5c1 100644 --- a/apps/common/forms/base_form.py +++ b/apps/common/forms/base_form.py @@ -6,11 +6,25 @@ @date:2023/11/1 16:04 @desc: """ +from typing import Dict + from common.forms import BaseField class BaseForm: - def to_form_list(self): - return [{**self.__getattribute__(key).to_dict(), 'field': key} for key in + def to_form_list(self, **kwargs): + return [{**self.__getattribute__(key).to_dict(**kwargs), 'field': key} for key in list(filter(lambda key: isinstance(self.__getattribute__(key), BaseField), [attr for attr in vars(self.__class__) if not attr.startswith("__")]))] + + def valid_form(self, form_data): + field_keys = list(filter(lambda key: isinstance(self.__getattribute__(key), BaseField), + [attr for attr in vars(self.__class__) if not attr.startswith("__")])) + for field_key in field_keys: + self.__getattribute__(field_key).is_valid(form_data.get(field_key)) + + def get_default_form_data(self): + return {key: self.__getattribute__(key).default_value for key in + [attr for attr in vars(self.__class__) if not attr.startswith("__")] if + isinstance(self.__getattribute__(key), BaseField) and self.__getattribute__( + key).default_value is not None} diff --git a/apps/common/forms/label/__init__.py b/apps/common/forms/label/__init__.py new file mode 100644 index 000000000..81c1b3298 --- /dev/null +++ b/apps/common/forms/label/__init__.py @@ -0,0 +1,10 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: __init__.py.py + @date:2024/8/22 17:19 + @desc: +""" +from .base_label import * +from .tooltip_label import * diff --git a/apps/common/forms/label/base_label.py b/apps/common/forms/label/base_label.py new file mode 100644 index 000000000..59e4d3722 --- /dev/null +++ b/apps/common/forms/label/base_label.py @@ -0,0 +1,28 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: base_label.py + @date:2024/8/22 17:11 + @desc: +""" + + +class BaseLabel: + def __init__(self, + input_type: str, + label: str, + attrs=None, + props_info=None): + self.input_type = input_type + self.label = label + self.attrs = attrs + self.props_info = props_info + + def to_dict(self, **kwargs): + return { + 'input_type': self.input_type, + 'label': self.label, + 'attrs': {} if self.attrs is None else self.attrs, + 'props_info': {} if self.props_info is None else self.props_info, + } diff --git a/apps/common/forms/label/tooltip_label.py b/apps/common/forms/label/tooltip_label.py new file mode 100644 index 000000000..885345daf --- /dev/null +++ b/apps/common/forms/label/tooltip_label.py @@ -0,0 +1,14 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: tooltip_label.py + @date:2024/8/22 17:19 + @desc: +""" +from common.forms.label.base_label import BaseLabel + + +class TooltipLabel(BaseLabel): + def __init__(self, label, tooltip): + super().__init__('TooltipLabel', label, attrs={'tooltip': tooltip}, props_info={}) diff --git a/apps/common/forms/slider_field.py b/apps/common/forms/slider_field.py new file mode 100644 index 000000000..6bf3625d6 --- /dev/null +++ b/apps/common/forms/slider_field.py @@ -0,0 +1,58 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: slider_field.py + @date:2024/8/22 17:06 + @desc: +""" +from typing import Dict + +from common.exception.app_exception import AppApiException +from common.forms import BaseField, TriggerType, BaseLabel + + +class SliderField(BaseField): + """ + 滑块输入框 + """ + + def __init__(self, label: str or BaseLabel, + _min, + _max, + _step, + precision, + required: bool = False, + default_value=None, + relation_show_field_dict: Dict = None, + attrs=None, props_info=None): + """ + @param label: 提示 + @param _min: 最小值 + @param _max: 最大值 + @param _step: 步长 + @param precision: 保留多少小数 + @param required: 是否必填 + @param default_value: 默认值 + @param relation_show_field_dict: + @param attrs: + @param props_info: + """ + _attrs = {'min': _min, 'max': _max, 'step': _step, + 'precision': precision, 'show-input-controls': False, 'show-input': True} + if attrs is not None: + _attrs.update(attrs) + super().__init__('Slider', label, required, default_value, relation_show_field_dict, + {}, + TriggerType.OPTION_LIST, _attrs, props_info) + + def is_valid(self, value): + super().is_valid(value) + field_label = self.label.label if hasattr(self.label, 'to_dict') else self.label + if value is not None: + if value < self.attrs.get('min'): + raise AppApiException(500, + f"{field_label} 不能小于{self.attrs.get('min')}") + if value > self.attrs.get('max'): + raise AppApiException(500, + f"{field_label} 不能大于{self.attrs.get('max')}") diff --git a/apps/setting/models_provider/base_model_provider.py b/apps/setting/models_provider/base_model_provider.py index a6966247c..6b3313983 100644 --- a/apps/setting/models_provider/base_model_provider.py +++ b/apps/setting/models_provider/base_model_provider.py @@ -108,10 +108,10 @@ class BaseModelCredential(ABC): """ pass - def get_other_fields(self, model_name): + def get_model_params_setting_form(self, model_name): """ - 获取其他字段 - :return: + 模型参数设置表单 + :return: """ pass diff --git a/apps/setting/models_provider/impl/aws_bedrock_model_provider/credential/llm.py b/apps/setting/models_provider/impl/aws_bedrock_model_provider/credential/llm.py index 08829834b..2392d3d04 100644 --- a/apps/setting/models_provider/impl/aws_bedrock_model_provider/credential/llm.py +++ b/apps/setting/models_provider/impl/aws_bedrock_model_provider/credential/llm.py @@ -1,11 +1,30 @@ import os import re from typing import Dict -from common.exception.app_exception import AppApiException -from common.forms import BaseForm -from setting.models_provider.base_model_provider import ValidCode, BaseModelCredential + from langchain_core.messages import HumanMessage + from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm, TooltipLabel +from setting.models_provider.base_model_provider import ValidCode, BaseModelCredential + + +class BedrockLLMModelParams(BaseForm): + temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), + required=True, default_value=0.7, + _min=0.1, + _max=1.0, + _step=0.01, + precision=2) + + max_tokens = forms.SliderField( + TooltipLabel('输出最大Tokens', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), + required=True, default_value=1024, + _min=1, + _max=4096, + _step=1, + precision=0) class BedrockLLMModelCredential(BaseForm, BaseModelCredential): @@ -61,24 +80,5 @@ class BedrockLLMModelCredential(BaseForm, BaseModelCredential): access_key_id = forms.TextInputField('Access Key ID', required=True) secret_access_key = forms.PasswordInputField('Secret Access Key', required=True) - def get_other_fields(self, model_name): - return { - 'temperature': { - 'value': 0.7, - 'min': 0.1, - 'max': 1, - 'step': 0.01, - 'label': '温度', - 'precision': 2, - 'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定' - }, - 'max_tokens': { - 'value': 1024, - 'min': 1, - 'max': 4096, - 'step': 1, - 'label': '输出最大Tokens', - 'precision': 0, - 'tooltip': '指定模型可生成的最大token个数' - } - } + def get_model_params_setting_form(self, model_name): + return BedrockLLMModelParams() diff --git a/apps/setting/models_provider/impl/azure_model_provider/credential/llm.py b/apps/setting/models_provider/impl/azure_model_provider/credential/llm.py index b20c63368..1906d20c5 100644 --- a/apps/setting/models_provider/impl/azure_model_provider/credential/llm.py +++ b/apps/setting/models_provider/impl/azure_model_provider/credential/llm.py @@ -12,10 +12,27 @@ from langchain_core.messages import HumanMessage from common import forms from common.exception.app_exception import AppApiException -from common.forms import BaseForm +from common.forms import BaseForm, TooltipLabel from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode +class AzureLLMModelParams(BaseForm): + temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), + required=True, default_value=0.7, + _min=0.1, + _max=1.0, + _step=0.01, + precision=2) + + max_tokens = forms.SliderField( + TooltipLabel('输出最大Tokens', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), + required=True, default_value=800, + _min=1, + _max=4096, + _step=1, + precision=0) + + class AzureLLMModelCredential(BaseForm, BaseModelCredential): def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, @@ -54,24 +71,5 @@ class AzureLLMModelCredential(BaseForm, BaseModelCredential): deployment_name = forms.TextInputField("部署名 (deployment_name)", required=True) - def get_other_fields(self, model_name): - return { - 'temperature': { - 'value': 0.7, - 'min': 0.1, - 'max': 1, - 'step': 0.01, - 'label': '温度', - 'precision': 2, - 'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定' - }, - 'max_tokens': { - 'value': 800, - 'min': 1, - 'max': 4096, - 'step': 1, - 'label': '输出最大Tokens', - 'precision': 0, - 'tooltip': '指定模型可生成的最大token个数' - } - } + def get_model_params_setting_form(self, model_name): + return AzureLLMModelParams() diff --git a/apps/setting/models_provider/impl/deepseek_model_provider/credential/llm.py b/apps/setting/models_provider/impl/deepseek_model_provider/credential/llm.py index d4bea15ea..72f101c42 100644 --- a/apps/setting/models_provider/impl/deepseek_model_provider/credential/llm.py +++ b/apps/setting/models_provider/impl/deepseek_model_provider/credential/llm.py @@ -12,10 +12,27 @@ from langchain_core.messages import HumanMessage from common import forms from common.exception.app_exception import AppApiException -from common.forms import BaseForm +from common.forms import BaseForm, TooltipLabel from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode +class DeepSeekLLMModelParams(BaseForm): + temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), + required=True, default_value=0.7, + _min=0.1, + _max=1.0, + _step=0.01, + precision=2) + + max_tokens = forms.SliderField( + TooltipLabel('输出最大Tokens', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), + required=True, default_value=800, + _min=1, + _max=4096, + _step=1, + precision=0) + + class DeepSeekLLMModelCredential(BaseForm, BaseModelCredential): def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, @@ -47,24 +64,5 @@ class DeepSeekLLMModelCredential(BaseForm, BaseModelCredential): api_key = forms.PasswordInputField('API Key', required=True) - def get_other_fields(self, model_name): - return { - 'temperature': { - 'value': 0.7, - 'min': 0.1, - 'max': 1, - 'step': 0.01, - 'label': '温度', - 'precision': 2, - 'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定' - }, - 'max_tokens': { - 'value': 800, - 'min': 1, - 'max': 4096, - 'step': 1, - 'label': '输出最大Tokens', - 'precision': 0, - 'tooltip': '指定模型可生成的最大token个数' - } - } + def get_model_params_setting_form(self, model_name): + return DeepSeekLLMModelParams() diff --git a/apps/setting/models_provider/impl/gemini_model_provider/credential/llm.py b/apps/setting/models_provider/impl/gemini_model_provider/credential/llm.py index 1c48f2d55..77e01df39 100644 --- a/apps/setting/models_provider/impl/gemini_model_provider/credential/llm.py +++ b/apps/setting/models_provider/impl/gemini_model_provider/credential/llm.py @@ -12,10 +12,27 @@ from langchain_core.messages import HumanMessage from common import forms from common.exception.app_exception import AppApiException -from common.forms import BaseForm +from common.forms import BaseForm, TooltipLabel from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode +class GeminiLLMModelParams(BaseForm): + temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), + required=True, default_value=0.7, + _min=0.1, + _max=1.0, + _step=0.01, + precision=2) + + max_tokens = forms.SliderField( + TooltipLabel('输出最大Tokens', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), + required=True, default_value=800, + _min=1, + _max=4096, + _step=1, + precision=0) + + class GeminiLLMModelCredential(BaseForm, BaseModelCredential): def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, @@ -48,24 +65,5 @@ class GeminiLLMModelCredential(BaseForm, BaseModelCredential): api_key = forms.PasswordInputField('API Key', required=True) - def get_other_fields(self, model_name): - return { - 'temperature': { - 'value': 0.7, - 'min': 0.1, - 'max': 1.0, - 'step': 0.01, - 'label': '温度', - 'precision': 2, - 'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定' - }, - 'max_tokens': { - 'value': 800, - 'min': 1, - 'max': 4096, - 'step': 1, - 'label': '输出最大Tokens', - 'precision': 0, - 'tooltip': '指定模型可生成的最大token个数' - } - } + def get_model_params_setting_form(self, model_name): + return GeminiLLMModelParams() diff --git a/apps/setting/models_provider/impl/kimi_model_provider/credential/llm.py b/apps/setting/models_provider/impl/kimi_model_provider/credential/llm.py index 13245af36..ef2aeb122 100644 --- a/apps/setting/models_provider/impl/kimi_model_provider/credential/llm.py +++ b/apps/setting/models_provider/impl/kimi_model_provider/credential/llm.py @@ -12,10 +12,27 @@ from langchain_core.messages import HumanMessage from common import forms from common.exception.app_exception import AppApiException -from common.forms import BaseForm +from common.forms import BaseForm, TooltipLabel from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode +class KimiLLMModelParams(BaseForm): + temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), + required=True, default_value=0.3, + _min=0.1, + _max=1.0, + _step=0.01, + precision=2) + + max_tokens = forms.SliderField( + TooltipLabel('输出最大Tokens', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), + required=True, default_value=1024, + _min=1, + _max=4096, + _step=1, + precision=0) + + class KimiLLMModelCredential(BaseForm, BaseModelCredential): def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, @@ -48,24 +65,5 @@ class KimiLLMModelCredential(BaseForm, BaseModelCredential): api_base = forms.TextInputField('API 域名', required=True) api_key = forms.PasswordInputField('API Key', required=True) - def get_other_fields(self, model_name): - return { - 'temperature': { - 'value': 0.3, - 'min': 0.1, - 'max': 1.0, - 'step': 0.01, - 'label': '温度', - 'precision': 2, - 'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定' - }, - 'max_tokens': { - 'value': 1024, - 'min': 1, - 'max': 4096, - 'step': 1, - 'label': '输出最大Tokens', - 'precision': 0, - 'tooltip': '指定模型可生成的最大token个数' - } - } + def get_model_params_setting_form(self, model_name): + return KimiLLMModelParams() diff --git a/apps/setting/models_provider/impl/ollama_model_provider/credential/llm.py b/apps/setting/models_provider/impl/ollama_model_provider/credential/llm.py index 136f7d2a9..14634b478 100644 --- a/apps/setting/models_provider/impl/ollama_model_provider/credential/llm.py +++ b/apps/setting/models_provider/impl/ollama_model_provider/credential/llm.py @@ -10,10 +10,27 @@ from typing import Dict from common import forms from common.exception.app_exception import AppApiException -from common.forms import BaseForm +from common.forms import BaseForm, TooltipLabel from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode +class OllamaLLMModelParams(BaseForm): + temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), + required=True, default_value=0.3, + _min=0.1, + _max=1.0, + _step=0.01, + precision=2) + + max_tokens = forms.SliderField( + TooltipLabel('输出最大Tokens', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), + required=True, default_value=1024, + _min=1, + _max=4096, + _step=1, + precision=0) + + class OllamaLLMModelCredential(BaseForm, BaseModelCredential): def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, raise_exception=False): @@ -43,24 +60,5 @@ class OllamaLLMModelCredential(BaseForm, BaseModelCredential): api_base = forms.TextInputField('API 域名', required=True) api_key = forms.PasswordInputField('API Key', required=True) - def get_other_fields(self, model_name): - return { - 'temperature': { - 'value': 0.3, - 'min': 0.1, - 'max': 1.0, - 'step': 0.01, - 'label': '温度', - 'precision': 2, - 'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定' - }, - 'max_tokens': { - 'value': 1024, - 'min': 1, - 'max': 4096, - 'step': 1, - 'label': '输出最大Tokens', - 'precision': 0, - 'tooltip': '指定模型可生成的最大token个数' - } - } + def get_model_params_setting_form(self, model_name): + return OllamaLLMModelParams() diff --git a/apps/setting/models_provider/impl/openai_model_provider/credential/llm.py b/apps/setting/models_provider/impl/openai_model_provider/credential/llm.py index 42a713a3f..f7d244a53 100644 --- a/apps/setting/models_provider/impl/openai_model_provider/credential/llm.py +++ b/apps/setting/models_provider/impl/openai_model_provider/credential/llm.py @@ -12,10 +12,27 @@ from langchain_core.messages import HumanMessage from common import forms from common.exception.app_exception import AppApiException -from common.forms import BaseForm +from common.forms import BaseForm, TooltipLabel from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode +class OpenAILLMModelParams(BaseForm): + temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), + required=True, default_value=0.7, + _min=0.1, + _max=1.0, + _step=0.01, + precision=2) + + max_tokens = forms.SliderField( + TooltipLabel('输出最大Tokens', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), + required=True, default_value=800, + _min=1, + _max=4096, + _step=1, + precision=0) + + class OpenAILLMModelCredential(BaseForm, BaseModelCredential): def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, @@ -48,24 +65,5 @@ class OpenAILLMModelCredential(BaseForm, BaseModelCredential): api_base = forms.TextInputField('API 域名', required=True) api_key = forms.PasswordInputField('API Key', required=True) - def get_other_fields(self, model_name): - return { - 'temperature': { - 'value': 0.7, - 'min': 0.1, - 'max': 1.0, - 'step': 0.01, - 'label': '温度', - 'precision': 2, - 'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定' - }, - 'max_tokens': { - 'value': 800, - 'min': 1, - 'max': 4096, - 'step': 1, - 'label': '输出最大Tokens', - 'precision': 0, - 'tooltip': '指定模型可生成的最大token个数' - } - } + def get_model_params_setting_form(self, model_name): + return OpenAILLMModelParams() diff --git a/apps/setting/models_provider/impl/qwen_model_provider/credential/llm.py b/apps/setting/models_provider/impl/qwen_model_provider/credential/llm.py index d8c9b85e0..7ad068454 100644 --- a/apps/setting/models_provider/impl/qwen_model_provider/credential/llm.py +++ b/apps/setting/models_provider/impl/qwen_model_provider/credential/llm.py @@ -12,10 +12,27 @@ from langchain_core.messages import HumanMessage from common import forms from common.exception.app_exception import AppApiException -from common.forms import BaseForm +from common.forms import BaseForm, TooltipLabel from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode +class QwenModelParams(BaseForm): + temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), + required=True, default_value=1.0, + _min=0.1, + _max=1.9, + _step=0.01, + precision=2) + + max_tokens = forms.SliderField( + TooltipLabel('输出最大Tokens', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), + required=True, default_value=800, + _min=1, + _max=2048, + _step=1, + precision=0) + + class OpenAILLMModelCredential(BaseForm, BaseModelCredential): def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, @@ -46,24 +63,5 @@ class OpenAILLMModelCredential(BaseForm, BaseModelCredential): api_key = forms.PasswordInputField('API Key', required=True) - def get_other_fields(self, model_name): - return { - 'temperature': { - 'value': 1.0, - 'min': 0.1, - 'max': 1.9, - 'step': 0.01, - 'label': '温度', - 'precision': 2, - 'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定' - }, - 'max_tokens': { - 'value': 800, - 'min': 1, - 'max': 2048, - 'step': 1, - 'label': '输出最大Tokens', - 'precision': 0, - 'tooltip': '指定模型可生成的最大token个数' - } - } + def get_model_params_setting_form(self, model_name): + return QwenModelParams() diff --git a/apps/setting/models_provider/impl/tencent_model_provider/credential/llm.py b/apps/setting/models_provider/impl/tencent_model_provider/credential/llm.py index 06c892675..20b1bf824 100644 --- a/apps/setting/models_provider/impl/tencent_model_provider/credential/llm.py +++ b/apps/setting/models_provider/impl/tencent_model_provider/credential/llm.py @@ -1,11 +1,21 @@ # coding=utf-8 from langchain_core.messages import HumanMessage + from common import forms from common.exception.app_exception import AppApiException -from common.forms import BaseForm +from common.forms import BaseForm, TooltipLabel from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode +class TencentLLMModelParams(BaseForm): + temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), + required=True, default_value=0.5, + _min=0.1, + _max=2.0, + _step=0.01, + precision=2) + + class TencentLLMModelCredential(BaseForm, BaseModelCredential): REQUIRED_FIELDS = ['hunyuan_app_id', 'hunyuan_secret_id', 'hunyuan_secret_key'] @@ -46,15 +56,5 @@ class TencentLLMModelCredential(BaseForm, BaseModelCredential): hunyuan_secret_id = forms.PasswordInputField('SecretId', required=True) hunyuan_secret_key = forms.PasswordInputField('SecretKey', required=True) - def get_other_fields(self, model_name): - return { - 'temperature': { - 'value': 0.5, - 'min': 0.1, - 'max': 2.0, - 'step': 0.01, - 'label': '温度', - 'precision': 2, - 'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定' - }, - } + def get_model_params_setting_form(self, model_name): + return TencentLLMModelParams() diff --git a/apps/setting/models_provider/impl/vllm_model_provider/credential/llm.py b/apps/setting/models_provider/impl/vllm_model_provider/credential/llm.py index 87ef597c8..7ea3cabe4 100644 --- a/apps/setting/models_provider/impl/vllm_model_provider/credential/llm.py +++ b/apps/setting/models_provider/impl/vllm_model_provider/credential/llm.py @@ -6,10 +6,27 @@ from langchain_core.messages import HumanMessage from common import forms from common.exception.app_exception import AppApiException -from common.forms import BaseForm +from common.forms import BaseForm, TooltipLabel from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode +class VLLMModelParams(BaseForm): + temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), + required=True, default_value=0.7, + _min=0.1, + _max=1.0, + _step=0.01, + precision=2) + + max_tokens = forms.SliderField( + TooltipLabel('输出最大Tokens', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), + required=True, default_value=800, + _min=1, + _max=4096, + _step=1, + precision=0) + + class VLLMModelCredential(BaseForm, BaseModelCredential): def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, raise_exception=False): @@ -44,24 +61,5 @@ class VLLMModelCredential(BaseForm, BaseModelCredential): api_base = forms.TextInputField('API 域名', required=True) api_key = forms.PasswordInputField('API Key', required=True) - def get_other_fields(self, model_name): - return { - 'temperature': { - 'value': 0.7, - 'min': 0.1, - 'max': 1, - 'step': 0.01, - 'label': '温度', - 'precision': 2, - 'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定' - }, - 'max_tokens': { - 'value': 800, - 'min': 1, - 'max': 4096, - 'step': 1, - 'label': '输出最大Tokens', - 'precision': 0, - 'tooltip': '指定模型可生成的最大token个数' - } - } + def get_model_params_setting_form(self, model_name): + return VLLMModelParams() diff --git a/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/llm.py b/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/llm.py index 025c43409..f918b437d 100644 --- a/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/llm.py +++ b/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/llm.py @@ -12,10 +12,27 @@ from langchain_core.messages import HumanMessage from common import forms from common.exception.app_exception import AppApiException -from common.forms import BaseForm +from common.forms import BaseForm, TooltipLabel from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode +class VolcanicEngineLLMModelParams(BaseForm): + temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), + required=True, default_value=0.3, + _min=0.1, + _max=1.0, + _step=0.01, + precision=2) + + max_tokens = forms.SliderField( + TooltipLabel('输出最大Tokens', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), + required=True, default_value=1024, + _min=1, + _max=4096, + _step=1, + precision=0) + + class VolcanicEngineLLMModelCredential(BaseForm, BaseModelCredential): def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, @@ -49,24 +66,5 @@ class VolcanicEngineLLMModelCredential(BaseForm, BaseModelCredential): access_key_id = forms.PasswordInputField('Access Key ID', required=True) secret_access_key = forms.PasswordInputField('Secret Access Key', required=True) - def get_other_fields(self): - return { - 'temperature': { - 'value': 0.3, - 'min': 0.1, - 'max': 1.0, - 'step': 0.01, - 'label': '温度', - 'precision': 2, - 'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定' - }, - 'max_tokens': { - 'value': 1024, - 'min': 1, - 'max': 4096, - 'step': 1, - 'label': '输出最大Tokens', - 'precision': 0, - 'tooltip': '指定模型可生成的最大token个数' - } - } + def get_model_params_setting_form(self, model_name): + return VolcanicEngineLLMModelParams() diff --git a/apps/setting/models_provider/impl/wenxin_model_provider/credential/llm.py b/apps/setting/models_provider/impl/wenxin_model_provider/credential/llm.py index 81ad4e404..be294e81a 100644 --- a/apps/setting/models_provider/impl/wenxin_model_provider/credential/llm.py +++ b/apps/setting/models_provider/impl/wenxin_model_provider/credential/llm.py @@ -12,10 +12,27 @@ from langchain_core.messages import HumanMessage from common import forms from common.exception.app_exception import AppApiException -from common.forms import BaseForm +from common.forms import BaseForm, TooltipLabel from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode +class WenxinLLMModelParams(BaseForm): + temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), + required=True, default_value=0.95, + _min=0.1, + _max=1.0, + _step=0.01, + precision=2) + + max_tokens = forms.SliderField( + TooltipLabel('输出最大Tokens', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), + required=True, default_value=1024, + _min=2, + _max=2048, + _step=1, + precision=0) + + class WenxinLLMModelCredential(BaseForm, BaseModelCredential): def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, raise_exception=False): @@ -54,24 +71,5 @@ class WenxinLLMModelCredential(BaseForm, BaseModelCredential): secret_key = forms.PasswordInputField("Secret Key", required=True) - def get_other_fields(self, model_name): - return { - 'temperature': { - 'value': 0.95, - 'min': 0.1, - 'max': 1.0, - 'step': 0.01, - 'label': '温度', - 'precision': 2, - 'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定' - }, - 'max_tokens': { - 'value': 1024, - 'min': 2, - 'max': 2048, - 'step': 1, - 'label': '输出最大Tokens', - 'precision': 0, - 'tooltip': '指定模型可生成的最大token个数' - } - } + def get_model_params_setting_form(self, model_name): + return WenxinLLMModelParams() diff --git a/apps/setting/models_provider/impl/xf_model_provider/credential/llm.py b/apps/setting/models_provider/impl/xf_model_provider/credential/llm.py index 706aca2d6..770aff27d 100644 --- a/apps/setting/models_provider/impl/xf_model_provider/credential/llm.py +++ b/apps/setting/models_provider/impl/xf_model_provider/credential/llm.py @@ -12,10 +12,44 @@ from langchain_core.messages import HumanMessage from common import forms from common.exception.app_exception import AppApiException -from common.forms import BaseForm +from common.forms import BaseForm, TooltipLabel from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode +class XunFeiLLMModelGeneralParams(BaseForm): + temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), + required=True, default_value=0.5, + _min=0.1, + _max=1.0, + _step=0.01, + precision=2) + + max_tokens = forms.SliderField( + TooltipLabel('输出最大Tokens', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), + required=True, default_value=4096, + _min=1, + _max=4096, + _step=1, + precision=0) + + +class XunFeiLLMModelProParams(BaseForm): + temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), + required=True, default_value=0.5, + _min=0.1, + _max=1.0, + _step=0.01, + precision=2) + + max_tokens = forms.SliderField( + TooltipLabel('输出最大Tokens', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), + required=True, default_value=4096, + _min=1, + _max=8192, + _step=1, + precision=0) + + class XunFeiLLMModelCredential(BaseForm, BaseModelCredential): def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, @@ -50,27 +84,7 @@ class XunFeiLLMModelCredential(BaseForm, BaseModelCredential): spark_api_key = forms.PasswordInputField("API Key", required=True) spark_api_secret = forms.PasswordInputField('API Secret', required=True) - def get_other_fields(self, model_name): - max_value = 8192 + def get_model_params_setting_form(self, model_name): if model_name == 'general' or model_name == 'pro-128k': - max_value = 4096 - return { - 'temperature': { - 'value': 0.5, - 'min': 0.1, - 'max': 1, - 'step': 0.01, - 'label': '温度', - 'precision': 2, - 'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定' - }, - 'max_tokens': { - 'value': 4096, - 'min': 1, - 'max': max_value, - 'step': 1, - 'label': '输出最大Tokens', - 'precision': 0, - 'tooltip': '指定模型可生成的最大token个数' - } - } + return XunFeiLLMModelGeneralParams() + return XunFeiLLMModelProParams() diff --git a/apps/setting/models_provider/impl/xinference_model_provider/credential/llm.py b/apps/setting/models_provider/impl/xinference_model_provider/credential/llm.py index a5425ec81..bb17e5c22 100644 --- a/apps/setting/models_provider/impl/xinference_model_provider/credential/llm.py +++ b/apps/setting/models_provider/impl/xinference_model_provider/credential/llm.py @@ -6,10 +6,27 @@ from langchain_core.messages import HumanMessage from common import forms from common.exception.app_exception import AppApiException -from common.forms import BaseForm +from common.forms import BaseForm, TooltipLabel from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode +class XinferenceLLMModelParams(BaseForm): + temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), + required=True, default_value=0.7, + _min=0.1, + _max=1.0, + _step=0.01, + precision=2) + + max_tokens = forms.SliderField( + TooltipLabel('输出最大Tokens', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), + required=True, default_value=800, + _min=1, + _max=4096, + _step=1, + precision=0) + + class XinferenceLLMModelCredential(BaseForm, BaseModelCredential): def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, raise_exception=False): @@ -40,24 +57,5 @@ class XinferenceLLMModelCredential(BaseForm, BaseModelCredential): api_base = forms.TextInputField('API 域名', required=True) api_key = forms.PasswordInputField('API Key', required=True) - def get_other_fields(self, model_name): - return { - 'temperature': { - 'value': 0.7, - 'min': 0.1, - 'max': 1, - 'step': 0.01, - 'label': '温度', - 'precision': 2, - 'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定' - }, - 'max_tokens': { - 'value': 800, - 'min': 1, - 'max': 4096, - 'step': 1, - 'label': '输出最大Tokens', - 'precision': 0, - 'tooltip': '指定模型可生成的最大token个数' - } - } + def get_model_params_setting_form(self, model_name): + return XinferenceLLMModelParams() diff --git a/apps/setting/models_provider/impl/xinference_model_provider/model/llm.py b/apps/setting/models_provider/impl/xinference_model_provider/model/llm.py index 398e032fc..b0fb4d16e 100644 --- a/apps/setting/models_provider/impl/xinference_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/xinference_model_provider/model/llm.py @@ -1,7 +1,8 @@ # coding=utf-8 -from typing import List, Dict +from typing import Dict from urllib.parse import urlparse, ParseResult + from setting.models_provider.base_model_provider import MaxKBBaseModel from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI diff --git a/apps/setting/models_provider/impl/zhipu_model_provider/credential/llm.py b/apps/setting/models_provider/impl/zhipu_model_provider/credential/llm.py index b2c16dcd7..aee7441f1 100644 --- a/apps/setting/models_provider/impl/zhipu_model_provider/credential/llm.py +++ b/apps/setting/models_provider/impl/zhipu_model_provider/credential/llm.py @@ -12,10 +12,27 @@ from langchain_core.messages import HumanMessage from common import forms from common.exception.app_exception import AppApiException -from common.forms import BaseForm +from common.forms import BaseForm, TooltipLabel from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode +class ZhiPuLLMModelParams(BaseForm): + temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), + required=True, default_value=0.95, + _min=0.1, + _max=1.0, + _step=0.01, + precision=2) + + max_tokens = forms.SliderField( + TooltipLabel('输出最大Tokens', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), + required=True, default_value=1024, + _min=1, + _max=4096, + _step=1, + precision=0) + + class ZhiPuLLMModelCredential(BaseForm, BaseModelCredential): def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, @@ -46,25 +63,5 @@ class ZhiPuLLMModelCredential(BaseForm, BaseModelCredential): api_key = forms.PasswordInputField('API Key', required=True) - def get_other_fields(self, model_name): - return { - 'temperature': { - 'value': 0.95, - 'min': 0.1, - 'max': 1.0, - 'step': 0.01, - 'label': '温度', - 'precision': 2, - 'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定' - }, - 'max_tokens': { - 'value': 1024, - 'min': 1, - 'max': 4095, - 'step': 1, - 'label': '输出最大Tokens', - 'precision': 0, - 'tooltip': '指定模型可生成的最大token个数' - } - } - + def get_model_params_setting_form(self, model_name): + return ZhiPuLLMModelParams() diff --git a/apps/setting/serializers/provider_serializers.py b/apps/setting/serializers/provider_serializers.py index eaab10f77..e032a0dcb 100644 --- a/apps/setting/serializers/provider_serializers.py +++ b/apps/setting/serializers/provider_serializers.py @@ -24,6 +24,7 @@ from common.util.field_message import ErrMessage from common.util.rsa_util import rsa_long_decrypt, rsa_long_encrypt from dataset.models import DataSet from setting.models.model_management import Model, Status +from setting.models_provider import get_model, get_model_credential from setting.models_provider.base_model_provider import ValidCode, DownModelChunkStatus from setting.models_provider.constants.model_provider_constants import ModelProvideConstants @@ -234,6 +235,14 @@ class ModelSerializer(serializers.Serializer): 'meta': model.meta } + def get_model_params(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + model_id = self.data.get('id') + model = QuerySet(Model).filter(id=model_id).first() + credential = get_model_credential(model.provider, model.model_type, model.model_name) + return credential.get_model_params_setting_form(model.model_name).to_form_list() + def delete(self, with_valid=True): if with_valid: self.is_valid(raise_exception=True) diff --git a/apps/setting/urls.py b/apps/setting/urls.py index 865ce4088..557edcceb 100644 --- a/apps/setting/urls.py +++ b/apps/setting/urls.py @@ -17,6 +17,7 @@ urlpatterns = [ path('provider/model_form', views.Provide.ModelForm.as_view(), name="provider/model_form"), path('model', views.Model.as_view(), name='model'), + path('model//model_params_form', views.Model.ModelParamsForm.as_view(), name='model/model_params_form'), path('model/', views.Model.Operate.as_view(), name='model/operate'), path('model//pause_download', views.Model.PauseDownload.as_view(), name='model/operate'), path('model//meta', views.Model.ModelMeta.as_view(), name='model/operate/meta'), diff --git a/apps/setting/views/model.py b/apps/setting/views/model.py index 9108aa15a..2c0ba3bca 100644 --- a/apps/setting/views/model.py +++ b/apps/setting/views/model.py @@ -81,6 +81,19 @@ class Model(APIView): return result.success( ModelSerializer.Operate(data={'id': model_id, 'user_id': request.user.id}).pause_download()) + class ModelParamsForm(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取模型参数表单", + operation_id="获取模型参数表单", + manual_parameters=ProvideApi.ModelForm.get_request_params_api(), + tags=["模型"]) + @has_permissions(PermissionConstants.MODEL_READ) + def get(self, request: Request, model_id: str): + return result.success( + ModelSerializer.Operate(data={'id': model_id, 'user_id': request.user.id}).get_model_params()) + class Operate(APIView): authentication_classes = [TokenAuth] diff --git a/ui/src/api/application.ts b/ui/src/api/application.ts index cb6858968..06e8d1313 100644 --- a/ui/src/api/application.ts +++ b/ui/src/api/application.ts @@ -249,41 +249,6 @@ const putPublishApplication: ( return put(`${prefix}/${application_id}/publish`, data, undefined, loading) } -/** - * 获取模型其他配置 - * @param application_id - * @param model_id - * @param loading - * @returns - */ -const getModelOtherConfig: ( - application_id: string, - model_id: string, - ai_node_id?: string, - loading?: Ref -) => Promise> = (application_id, model_id, ai_node_id?, loading?) => { - if (ai_node_id) { - // 如果 ai_node_id 不为空,则调用带有 ai_node_id 的 API - return get(`${prefix}/${application_id}/model/${model_id}/${ai_node_id}`, undefined, loading) - } else { - // 如果 ai_node_id 为空,则调用不带 ai_node_id 的 API - return get(`${prefix}/${application_id}/model/${model_id}`, undefined, loading) - } -} - -/** - * 保存其他配置信息 - * @param application_id - * - */ -const putModelOtherConfig: ( - application_id: string, - data: any, - loading?: Ref -) => Promise> = (application_id, data, loading) => { - return put(`${prefix}/${application_id}/other-config`, data, undefined, loading) -} - export default { getAllAppilcation, getApplication, @@ -303,7 +268,5 @@ export default { getApplicationHitTest, getApplicationModel, putPublishApplication, - postWorkflowChatOpen, - getModelOtherConfig, - putModelOtherConfig + postWorkflowChatOpen } diff --git a/ui/src/api/model.ts b/ui/src/api/model.ts index 15f6517ff..bcb31f9c4 100644 --- a/ui/src/api/model.ts +++ b/ui/src/api/model.ts @@ -51,6 +51,18 @@ const getModelCreateForm: ( return get(`${prefix_provider}/model_form`, { provider, model_type, model_name }, loading) } +/** + * 获取模型参数表单 + * @param model_id 模型id + * @param loading + * @returns + */ +const getModelParamsForm: ( + model_id: string, + loading?: Ref +) => Promise>> = (model_id, loading) => { + return get(`model/${model_id}/model_params_form`, {}, loading) +} /** * 获取模型类型列表 * @param provider 供应商 @@ -159,5 +171,6 @@ export default { deleteModel, getModelById, getModelMetaById, - pauseDownload + pauseDownload, + getModelParamsForm } diff --git a/ui/src/api/type/application.ts b/ui/src/api/type/application.ts index 4e4c4564e..944f9179d 100644 --- a/ui/src/api/type/application.ts +++ b/ui/src/api/type/application.ts @@ -13,6 +13,7 @@ interface ApplicationFormType { icon?: string | undefined type?: string work_flow?: any + model_params_setting?: any } interface chatType { id: string diff --git a/ui/src/components/dynamics-form/Demo.vue b/ui/src/components/dynamics-form/Demo.vue index 74a7b54bb..9fad393a4 100644 --- a/ui/src/components/dynamics-form/Demo.vue +++ b/ui/src/components/dynamics-form/Demo.vue @@ -1,6 +1,11 @@