refactor: 重构模型参数代码

This commit is contained in:
shaohuzhang1 2024-08-23 17:46:05 +08:00 committed by shaohuzhang1
parent 228f913d9b
commit 9e0ac81f1d
57 changed files with 743 additions and 581 deletions

View File

@ -73,9 +73,8 @@ class IChatStep(IBaseChatPipelineStep):
no_references_setting = NoReferencesSetting(required=True, error_messages=ErrMessage.base("无引用分段设置"))
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))
temperature = serializers.FloatField(required=False, allow_null=True, error_messages=ErrMessage.float("温度"))
max_tokens = serializers.IntegerField(required=False, allow_null=True,
error_messages=ErrMessage.integer("最大token数"))
model_params_setting = serializers.DictField(required=True, error_messages=ErrMessage.dict("模型参数设置"))
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
@ -100,5 +99,5 @@ class IChatStep(IBaseChatPipelineStep):
paragraph_list=None,
manage: PipelineManage = None,
padding_problem_text: str = None, stream: bool = True, client_id=None, client_type=None,
no_references_setting=None, **kwargs):
no_references_setting=None, model_params_setting=None, **kwargs):
pass

View File

@ -110,8 +110,10 @@ class BaseChatStep(IChatStep):
stream: bool = True,
client_id=None, client_type=None,
no_references_setting=None,
model_params_setting=None,
**kwargs):
chat_model = get_model_instance_by_model_user_id(model_id, user_id, **kwargs) if model_id is not None else None
chat_model = get_model_instance_by_model_user_id(model_id, user_id,
**model_params_setting) if model_id is not None else None
if stream:
return self.execute_stream(message_list, chat_id, problem_text, post_response_handler, chat_model,
paragraph_list,

View File

@ -23,9 +23,8 @@ class ChatNodeSerializer(serializers.Serializer):
dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer("多轮对话数量"))
is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容'))
temperature = serializers.FloatField(required=False, allow_null=True, error_messages=ErrMessage.float("温度"))
max_tokens = serializers.IntegerField(required=False, allow_null=True,
error_messages=ErrMessage.integer("最大token数"))
model_params_setting = serializers.DictField(required=True, error_messages=ErrMessage.integer("模型参数相关设置"))
class IChatNode(INode):
@ -39,5 +38,6 @@ class IChatNode(INode):
def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id,
chat_record_id,
model_params_setting,
**kwargs) -> NodeResult:
pass

View File

@ -65,12 +65,11 @@ def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wor
class BaseChatNode(IChatNode):
def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id,
model_params_setting,
**kwargs) -> NodeResult:
kwargs = {k: v for k, v in kwargs.items() if k in ['temperature', 'max_tokens']}
chat_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'),
**kwargs
)
**model_params_setting)
history_message = self.get_history_message(history_chat_record, dialogue_number)
self.context['history_message'] = history_message
question = self.generate_prompt_question(prompt)

View File

@ -23,6 +23,7 @@ class QuestionNodeSerializer(serializers.Serializer):
dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer("多轮对话数量"))
is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容'))
model_params_setting = serializers.DictField(required=True, error_messages=ErrMessage.integer("模型参数相关设置"))
class IQuestionNode(INode):
@ -35,5 +36,6 @@ class IQuestionNode(INode):
return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data)
def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id,
model_params_setting,
**kwargs) -> NodeResult:
pass

View File

@ -65,8 +65,10 @@ def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wor
class BaseQuestionNode(IQuestionNode):
def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id,
model_params_setting,
**kwargs) -> NodeResult:
chat_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'))
chat_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'),
**model_params_setting)
history_message = self.get_history_message(history_chat_record, dialogue_number)
self.context['history_message'] = history_message
question = self.generate_prompt_question(prompt)

View File

@ -10,6 +10,7 @@ import json
from functools import reduce
from typing import List, Dict
from django.db.models import QuerySet
from langchain_core.messages import AIMessage
from langchain_core.prompts import PromptTemplate
@ -17,6 +18,8 @@ from application.flow import tools
from application.flow.i_step_node import INode, WorkFlowPostHandler, NodeResult
from application.flow.step_node import get_node
from common.exception.app_exception import AppApiException
from setting.models import Model
from setting.models_provider import get_model_credential
class Edge:
@ -68,6 +71,7 @@ class Flow:
"""
校验工作流数据
"""
self.is_valid_model_params()
self.is_valid_start_node()
self.is_valid_base_node()
self.is_valid_work_flow()
@ -123,6 +127,19 @@ class Flow:
if len(start_node_list) > 1:
raise AppApiException(500, '开始节点只能有一个')
def is_valid_model_params(self):
node_list = [node for node in self.nodes if (node.type == 'ai-chat-node' or node.type == 'question-node')]
for node in node_list:
model = QuerySet(Model).filter(id=node.properties.get('node_data', {}).get('model_id')).first()
credential = get_model_credential(model.provider, model.model_type, model.model_name)
model_params_setting = node.properties.get('node_data', {}).get('model_params_setting')
model_params_setting_form = credential.get_model_params_setting_form(
model.model_name)
if model_params_setting is None:
model_params_setting = model_params_setting_form.get_default_form_data()
node.properties.get('node_data', {})['model_params_setting'] = model_params_setting
model_params_setting_form.valid_form(model_params_setting)
def is_valid_base_node(self):
base_node_list = [node for node in self.nodes if node.id == 'base-node']
if len(base_node_list) == 0:

View File

@ -0,0 +1,18 @@
# Generated by Django 4.2.15 on 2024-08-23 14:17
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('application', '0010_alter_chatrecord_details'),
]
operations = [
migrations.AddField(
model_name='application',
name='model_params_setting',
field=models.JSONField(default={}, verbose_name='模型参数相关设置'),
),
]

View File

@ -48,6 +48,7 @@ class Application(AppModelMixin):
model = models.ForeignKey(Model, on_delete=models.SET_NULL, db_constraint=False, blank=True, null=True)
dataset_setting = models.JSONField(verbose_name="数据集参数设置", default=get_dataset_setting_dict)
model_setting = models.JSONField(verbose_name="模型参数相关设置", default=get_model_setting_dict)
model_params_setting = models.JSONField(verbose_name="模型参数相关设置", default={})
problem_optimization = models.BooleanField(verbose_name="问题优化", default=False)
icon = models.CharField(max_length=256, verbose_name="应用icon", default="/ui/favicon.ico")
work_flow = models.JSONField(verbose_name="工作流数据", default=dict)

View File

@ -44,6 +44,7 @@ from dataset.serializers.common_serializers import list_paragraph, get_embedding
from embedding.models import SearchMode
from setting.models import AuthOperate
from setting.models.model_management import Model
from setting.models_provider import get_model_credential
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
from setting.serializers.provider_serializers import ModelSerializer
from smartdoc.conf import PROJECT_DIR
@ -93,6 +94,14 @@ class NoReferencesSetting(serializers.Serializer):
value = serializers.CharField(required=True, error_messages=ErrMessage.char("提示词"))
def valid_model_params_setting(model_id, model_params_setting):
if model_id is None:
return
model = QuerySet(Model).filter(id=model_id).first()
credential = get_model_credential(model.provider, model.model_type, model.model_name)
credential.get_model_params_setting_form(model.model_name).valid_form(model_params_setting)
class DatasetSettingSerializer(serializers.Serializer):
top_n = serializers.FloatField(required=True, max_value=100, min_value=1,
error_messages=ErrMessage.float("引用分段数"))
@ -110,10 +119,6 @@ class DatasetSettingSerializer(serializers.Serializer):
class ModelSettingSerializer(serializers.Serializer):
prompt = serializers.CharField(required=True, max_length=2048, error_messages=ErrMessage.char("提示词"))
temperature = serializers.FloatField(required=False, allow_null=True,
error_messages=ErrMessage.char("温度"))
max_tokens = serializers.IntegerField(required=False, allow_null=True,
error_messages=ErrMessage.integer("最大token数"))
class ApplicationWorkflowSerializer(serializers.Serializer):
@ -185,6 +190,7 @@ class ApplicationSerializer(serializers.Serializer):
message="应用类型只支持SIMPLE|WORK_FLOW", code=500)
]
)
model_params_setting = serializers.DictField(required=False, error_messages=ErrMessage.dict('模型参数'))
def is_valid(self, *, user_id=None, raise_exception=False):
super().is_valid(raise_exception=True)
@ -355,6 +361,8 @@ class ApplicationSerializer(serializers.Serializer):
error_messages=ErrMessage.boolean("问题补全"))
icon = serializers.CharField(required=False, allow_null=True, error_messages=ErrMessage.char("icon图标"))
model_params_setting = serializers.DictField(required=False, error_messages=ErrMessage.dict('模型参数'))
class Create(serializers.Serializer):
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))
@ -667,7 +675,7 @@ class ApplicationSerializer(serializers.Serializer):
ApplicationSerializer.Edit(data=instance).is_valid(
raise_exception=True)
application_id = self.data.get("application_id")
valid_model_params_setting(instance.get('model_id'), instance.get('model_params_setting'))
application = QuerySet(Application).get(id=application_id)
if instance.get('model_id') is None or len(instance.get('model_id')) == 0:
application.model_id = None

View File

@ -33,6 +33,7 @@ from common.util.field_message import ErrMessage
from common.util.split_model import flat_map
from dataset.models import Paragraph, Document
from setting.models import Model, Status
from setting.models_provider import get_model_credential
chat_cache = caches['chat_cache']
@ -63,6 +64,12 @@ class ChatInfo:
def to_base_pipeline_manage_params(self):
dataset_setting = self.application.dataset_setting
model_setting = self.application.model_setting
model_id = self.application.model.id if self.application.model is not None else None
model_params_setting = None
if model_id is not None:
model = QuerySet(Model).filter(id=model_id).first()
credential = get_model_credential(model.provider, model.model_type, model.model_name)
model_params_setting = credential.get_model_params_setting_form(model.model_name).get_default_form_data()
return {
'dataset_id_list': self.dataset_id_list,
'exclude_document_id_list': self.exclude_document_id_list,
@ -77,11 +84,11 @@ class ChatInfo:
'prompt': model_setting.get(
'prompt') if 'prompt' in model_setting else Application.get_default_model_prompt(),
'chat_model': self.chat_model,
'model_id': self.application.model.id if self.application.model is not None else None,
'model_id': model_id,
'problem_optimization': self.application.problem_optimization,
'stream': True,
'temperature': model_setting.get('temperature') if 'temperature' in model_setting else None,
'max_tokens': model_setting.get('max_tokens') if 'max_tokens' in model_setting else None,
'model_params_setting': model_params_setting if self.application.model_params_setting is None or len(
self.application.model_params_setting.keys()) == 0 else self.application.model_params_setting,
'search_mode': self.application.dataset_setting.get(
'search_mode') if 'search_mode' in self.application.dataset_setting else 'embedding',
'no_references_setting': self.application.dataset_setting.get(
@ -90,7 +97,6 @@ class ChatInfo:
'value': '{question}',
},
'user_id': self.application.user_id
}
def to_pipeline_manage_params(self, problem_text: str, post_response_handler: PostResponseHandler,

View File

@ -33,17 +33,17 @@ from application.serializers.application_serializers import ModelDatasetAssociat
from application.serializers.chat_message_serializers import ChatInfo
from common.constants.permission_constants import RoleConstants
from common.db.search import native_search, native_page_search, page_search, get_dynamics_model
from common.event import ListenerManagement
from common.exception.app_exception import AppApiException
from common.util.common import post
from common.util.field_message import ErrMessage
from common.util.file_util import get_file_content
from common.util.lock import try_lock, un_lock
from dataset.models import Document, Problem, Paragraph, ProblemParagraphMapping
from dataset.serializers.common_serializers import get_embedding_model_by_dataset_id, \
get_embedding_model_id_by_dataset_id
from dataset.serializers.common_serializers import get_embedding_model_id_by_dataset_id
from dataset.serializers.paragraph_serializers import ParagraphSerializers
from embedding.task import embedding_by_paragraph
from setting.models import Model
from setting.models_provider import get_model_credential
from smartdoc.conf import PROJECT_DIR
chat_cache = caches['chat_cache']
@ -54,6 +54,17 @@ class WorkFlowSerializers(serializers.Serializer):
edges = serializers.ListSerializer(child=serializers.DictField(), error_messages=ErrMessage.uuid("连线"))
def valid_model_params_setting(model_id, model_params_setting):
if model_id is None:
return
model = QuerySet(Model).filter(id=model_id).first()
credential = get_model_credential(model.provider, model.model_type, model.model_name)
model_params_setting_form = credential.get_model_params_setting_form(model.model_name)
if model_params_setting is None or len(model_params_setting.keys()) == 0:
model_params_setting = model_params_setting_form.get_default_form_data()
credential.get_model_params_setting_form(model.model_name).valid_form(model_params_setting)
class ChatSerializers(serializers.Serializer):
class Operate(serializers.Serializer):
chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("对话id"))
@ -253,6 +264,7 @@ class ChatSerializers(serializers.Serializer):
return chat_id
def open_simple(self, application):
valid_model_params_setting(application.model_id, application.model_params_setting)
application_id = self.data.get('application_id')
dataset_id_list = [str(row.dataset_id) for row in
QuerySet(ApplicationDatasetMapping).filter(
@ -309,6 +321,8 @@ class ChatSerializers(serializers.Serializer):
model_setting = ModelSettingSerializer(required=True)
# 问题补全
problem_optimization = serializers.BooleanField(required=True, error_messages=ErrMessage.boolean("问题补全"))
# 模型相关设置
model_params_setting = serializers.JSONField(required=True)
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
@ -332,10 +346,12 @@ class ChatSerializers(serializers.Serializer):
model_id = self.data.get('model_id')
dataset_id_list = self.data.get('dataset_id_list')
dialogue_number = 3 if self.data.get('multiple_rounds_dialogue', False) else 0
valid_model_params_setting(model_id, self.data.get('model_params_setting'))
application = Application(id=None, dialogue_number=dialogue_number, model_id=model_id,
dataset_setting=self.data.get('dataset_setting'),
model_setting=self.data.get('model_setting'),
problem_optimization=self.data.get('problem_optimization'),
model_params_setting=self.data.get('model_params_setting'),
user_id=user_id)
chat_cache.set(chat_id,
ChatInfo(chat_id, None, dataset_id_list,

View File

@ -19,10 +19,6 @@ urlpatterns = [
path('application/<str:application_id>/statistics/chat_record_aggregate_trend',
views.ApplicationStatistics.ChatRecordAggregateTrend.as_view()),
path('application/<str:application_id>/model', views.Application.Model.as_view()),
path('application/<str:application_id>/model/<str:model_id>', views.Application.Model.Operate.as_view()),
path('application/<str:application_id>/model/<str:model_id>/<str:ai_node_id>',
views.Application.Model.Operate.as_view()),
path('application/<str:application_id>/other-config', views.Application.Model.OtherConfig.as_view()),
path('application/<str:application_id>/hit_test', views.Application.HitTest.as_view()),
path('application/<str:application_id>/api_key', views.Application.ApplicationKey.as_view()),
path("application/<str:application_id>/api_key/<str:api_key_id>",

View File

@ -187,30 +187,6 @@ class Application(APIView):
data={'application_id': application_id,
'user_id': request.user.id}).list_model(request.query_params.get('model_type')))
class Operate(APIView):
authentication_classes = [TokenAuth]
@swagger_auto_schema(operation_summary="获取应用参数设置其他字段",
operation_id="获取应用参数设置其他字段",
tags=["应用/会话"])
def get(self, request: Request, application_id: str, model_id: str, ai_node_id=None):
return result.success(
ApplicationSerializer.Operate(
data={'application_id': application_id, 'model_id': model_id,
'ai_node_id': ai_node_id}).get_other_file_list())
class OtherConfig(APIView):
authentication_classes = [TokenAuth]
@swagger_auto_schema(operation_summary="获取应用参数设置其他字段",
operation_id="获取应用参数设置其他字段",
tags=["应用/会话"])
def put(self, request: Request, application_id: str):
return result.success(
ApplicationSerializer.Operate(
data={'application_id': application_id}).save_other_config(
request.data))
class Profile(APIView):
authentication_classes = [TokenAuth]

View File

@ -20,3 +20,5 @@ from .text_input_field import *
from .radio_button_field import *
from .table_checkbox import *
from .radio_card_field import *
from .label import *
from .slider_field import *

View File

@ -9,6 +9,9 @@
from enum import Enum
from typing import List, Dict
from common.exception.app_exception import AppApiException
from common.forms.label.base_label import BaseLabel
class TriggerType(Enum):
# 执行函数获取 OptionList数据
@ -20,7 +23,7 @@ class TriggerType(Enum):
class BaseField:
def __init__(self,
input_type: str,
label: str,
label: str or BaseLabel,
required: bool = False,
default_value: object = None,
relation_show_field_dict: Dict = None,
@ -53,10 +56,16 @@ class BaseField:
self.required = required
self.trigger_type = trigger_type
def to_dict(self):
def is_valid(self, value):
field_label = self.label.label if hasattr(self.label, 'to_dict') else self.label
if self.required and value is None:
raise AppApiException(500,
f"{field_label} 为必填参数")
def to_dict(self, **kwargs):
return {
'input_type': self.input_type,
'label': self.label,
'label': self.label.to_dict(**kwargs) if hasattr(self.label, 'to_dict') else self.label,
'required': self.required,
'default_value': self.default_value,
'relation_show_field_dict': self.relation_show_field_dict,
@ -64,6 +73,7 @@ class BaseField:
'trigger_type': self.trigger_type.value,
'attrs': self.attrs,
'props_info': self.props_info,
**kwargs
}
@ -97,8 +107,8 @@ class BaseDefaultOptionField(BaseField):
self.value_field = value_field
self.option_list = option_list
def to_dict(self):
return {**super().to_dict(), 'text_field': self.text_field, 'value_field': self.value_field,
def to_dict(self, **kwargs):
return {**super().to_dict(**kwargs), 'text_field': self.text_field, 'value_field': self.value_field,
'option_list': self.option_list}
@ -141,6 +151,6 @@ class BaseExecField(BaseField):
self.provider = provider
self.method = method
def to_dict(self):
return {**super().to_dict(), 'text_field': self.text_field, 'value_field': self.value_field,
def to_dict(self, **kwargs):
return {**super().to_dict(**kwargs), 'text_field': self.text_field, 'value_field': self.value_field,
'provider': self.provider, 'method': self.method}

View File

@ -6,11 +6,25 @@
@date2023/11/1 16:04
@desc:
"""
from typing import Dict
from common.forms import BaseField
class BaseForm:
def to_form_list(self):
return [{**self.__getattribute__(key).to_dict(), 'field': key} for key in
def to_form_list(self, **kwargs):
return [{**self.__getattribute__(key).to_dict(**kwargs), 'field': key} for key in
list(filter(lambda key: isinstance(self.__getattribute__(key), BaseField),
[attr for attr in vars(self.__class__) if not attr.startswith("__")]))]
def valid_form(self, form_data):
field_keys = list(filter(lambda key: isinstance(self.__getattribute__(key), BaseField),
[attr for attr in vars(self.__class__) if not attr.startswith("__")]))
for field_key in field_keys:
self.__getattribute__(field_key).is_valid(form_data.get(field_key))
def get_default_form_data(self):
return {key: self.__getattribute__(key).default_value for key in
[attr for attr in vars(self.__class__) if not attr.startswith("__")] if
isinstance(self.__getattribute__(key), BaseField) and self.__getattribute__(
key).default_value is not None}

View File

@ -0,0 +1,10 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file __init__.py.py
@date2024/8/22 17:19
@desc:
"""
from .base_label import *
from .tooltip_label import *

View File

@ -0,0 +1,28 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file base_label.py
@date2024/8/22 17:11
@desc:
"""
class BaseLabel:
def __init__(self,
input_type: str,
label: str,
attrs=None,
props_info=None):
self.input_type = input_type
self.label = label
self.attrs = attrs
self.props_info = props_info
def to_dict(self, **kwargs):
return {
'input_type': self.input_type,
'label': self.label,
'attrs': {} if self.attrs is None else self.attrs,
'props_info': {} if self.props_info is None else self.props_info,
}

View File

@ -0,0 +1,14 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file tooltip_label.py
@date2024/8/22 17:19
@desc:
"""
from common.forms.label.base_label import BaseLabel
class TooltipLabel(BaseLabel):
def __init__(self, label, tooltip):
super().__init__('TooltipLabel', label, attrs={'tooltip': tooltip}, props_info={})

View File

@ -0,0 +1,58 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file slider_field.py
@date2024/8/22 17:06
@desc:
"""
from typing import Dict
from common.exception.app_exception import AppApiException
from common.forms import BaseField, TriggerType, BaseLabel
class SliderField(BaseField):
"""
滑块输入框
"""
def __init__(self, label: str or BaseLabel,
_min,
_max,
_step,
precision,
required: bool = False,
default_value=None,
relation_show_field_dict: Dict = None,
attrs=None, props_info=None):
"""
@param label: 提示
@param _min: 最小值
@param _max: 最大值
@param _step: 步长
@param precision: 保留多少小数
@param required: 是否必填
@param default_value: 默认值
@param relation_show_field_dict:
@param attrs:
@param props_info:
"""
_attrs = {'min': _min, 'max': _max, 'step': _step,
'precision': precision, 'show-input-controls': False, 'show-input': True}
if attrs is not None:
_attrs.update(attrs)
super().__init__('Slider', label, required, default_value, relation_show_field_dict,
{},
TriggerType.OPTION_LIST, _attrs, props_info)
def is_valid(self, value):
super().is_valid(value)
field_label = self.label.label if hasattr(self.label, 'to_dict') else self.label
if value is not None:
if value < self.attrs.get('min'):
raise AppApiException(500,
f"{field_label} 不能小于{self.attrs.get('min')}")
if value > self.attrs.get('max'):
raise AppApiException(500,
f"{field_label} 不能大于{self.attrs.get('max')}")

View File

@ -108,10 +108,10 @@ class BaseModelCredential(ABC):
"""
pass
def get_other_fields(self, model_name):
def get_model_params_setting_form(self, model_name):
"""
获取其他字段
:return:
模型参数设置表单
:return:
"""
pass

View File

@ -1,11 +1,30 @@
import os
import re
from typing import Dict
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from setting.models_provider.base_model_provider import ValidCode, BaseModelCredential
from langchain_core.messages import HumanMessage
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm, TooltipLabel
from setting.models_provider.base_model_provider import ValidCode, BaseModelCredential
class BedrockLLMModelParams(BaseForm):
temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
required=True, default_value=0.7,
_min=0.1,
_max=1.0,
_step=0.01,
precision=2)
max_tokens = forms.SliderField(
TooltipLabel('输出最大Tokens', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
required=True, default_value=1024,
_min=1,
_max=4096,
_step=1,
precision=0)
class BedrockLLMModelCredential(BaseForm, BaseModelCredential):
@ -61,24 +80,5 @@ class BedrockLLMModelCredential(BaseForm, BaseModelCredential):
access_key_id = forms.TextInputField('Access Key ID', required=True)
secret_access_key = forms.PasswordInputField('Secret Access Key', required=True)
def get_other_fields(self, model_name):
return {
'temperature': {
'value': 0.7,
'min': 0.1,
'max': 1,
'step': 0.01,
'label': '温度',
'precision': 2,
'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'
},
'max_tokens': {
'value': 1024,
'min': 1,
'max': 4096,
'step': 1,
'label': '输出最大Tokens',
'precision': 0,
'tooltip': '指定模型可生成的最大token个数'
}
}
def get_model_params_setting_form(self, model_name):
return BedrockLLMModelParams()

View File

@ -12,10 +12,27 @@ from langchain_core.messages import HumanMessage
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from common.forms import BaseForm, TooltipLabel
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
class AzureLLMModelParams(BaseForm):
temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
required=True, default_value=0.7,
_min=0.1,
_max=1.0,
_step=0.01,
precision=2)
max_tokens = forms.SliderField(
TooltipLabel('输出最大Tokens', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
required=True, default_value=800,
_min=1,
_max=4096,
_step=1,
precision=0)
class AzureLLMModelCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
@ -54,24 +71,5 @@ class AzureLLMModelCredential(BaseForm, BaseModelCredential):
deployment_name = forms.TextInputField("部署名 (deployment_name)", required=True)
def get_other_fields(self, model_name):
return {
'temperature': {
'value': 0.7,
'min': 0.1,
'max': 1,
'step': 0.01,
'label': '温度',
'precision': 2,
'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'
},
'max_tokens': {
'value': 800,
'min': 1,
'max': 4096,
'step': 1,
'label': '输出最大Tokens',
'precision': 0,
'tooltip': '指定模型可生成的最大token个数'
}
}
def get_model_params_setting_form(self, model_name):
return AzureLLMModelParams()

View File

@ -12,10 +12,27 @@ from langchain_core.messages import HumanMessage
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from common.forms import BaseForm, TooltipLabel
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
class DeepSeekLLMModelParams(BaseForm):
temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
required=True, default_value=0.7,
_min=0.1,
_max=1.0,
_step=0.01,
precision=2)
max_tokens = forms.SliderField(
TooltipLabel('输出最大Tokens', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
required=True, default_value=800,
_min=1,
_max=4096,
_step=1,
precision=0)
class DeepSeekLLMModelCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
@ -47,24 +64,5 @@ class DeepSeekLLMModelCredential(BaseForm, BaseModelCredential):
api_key = forms.PasswordInputField('API Key', required=True)
def get_other_fields(self, model_name):
return {
'temperature': {
'value': 0.7,
'min': 0.1,
'max': 1,
'step': 0.01,
'label': '温度',
'precision': 2,
'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'
},
'max_tokens': {
'value': 800,
'min': 1,
'max': 4096,
'step': 1,
'label': '输出最大Tokens',
'precision': 0,
'tooltip': '指定模型可生成的最大token个数'
}
}
def get_model_params_setting_form(self, model_name):
return DeepSeekLLMModelParams()

View File

@ -12,10 +12,27 @@ from langchain_core.messages import HumanMessage
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from common.forms import BaseForm, TooltipLabel
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
class GeminiLLMModelParams(BaseForm):
temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
required=True, default_value=0.7,
_min=0.1,
_max=1.0,
_step=0.01,
precision=2)
max_tokens = forms.SliderField(
TooltipLabel('输出最大Tokens', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
required=True, default_value=800,
_min=1,
_max=4096,
_step=1,
precision=0)
class GeminiLLMModelCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
@ -48,24 +65,5 @@ class GeminiLLMModelCredential(BaseForm, BaseModelCredential):
api_key = forms.PasswordInputField('API Key', required=True)
def get_other_fields(self, model_name):
return {
'temperature': {
'value': 0.7,
'min': 0.1,
'max': 1.0,
'step': 0.01,
'label': '温度',
'precision': 2,
'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'
},
'max_tokens': {
'value': 800,
'min': 1,
'max': 4096,
'step': 1,
'label': '输出最大Tokens',
'precision': 0,
'tooltip': '指定模型可生成的最大token个数'
}
}
def get_model_params_setting_form(self, model_name):
return GeminiLLMModelParams()

View File

@ -12,10 +12,27 @@ from langchain_core.messages import HumanMessage
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from common.forms import BaseForm, TooltipLabel
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
class KimiLLMModelParams(BaseForm):
temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
required=True, default_value=0.3,
_min=0.1,
_max=1.0,
_step=0.01,
precision=2)
max_tokens = forms.SliderField(
TooltipLabel('输出最大Tokens', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
required=True, default_value=1024,
_min=1,
_max=4096,
_step=1,
precision=0)
class KimiLLMModelCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
@ -48,24 +65,5 @@ class KimiLLMModelCredential(BaseForm, BaseModelCredential):
api_base = forms.TextInputField('API 域名', required=True)
api_key = forms.PasswordInputField('API Key', required=True)
def get_other_fields(self, model_name):
return {
'temperature': {
'value': 0.3,
'min': 0.1,
'max': 1.0,
'step': 0.01,
'label': '温度',
'precision': 2,
'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'
},
'max_tokens': {
'value': 1024,
'min': 1,
'max': 4096,
'step': 1,
'label': '输出最大Tokens',
'precision': 0,
'tooltip': '指定模型可生成的最大token个数'
}
}
def get_model_params_setting_form(self, model_name):
return KimiLLMModelParams()

View File

@ -10,10 +10,27 @@ from typing import Dict
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from common.forms import BaseForm, TooltipLabel
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
class OllamaLLMModelParams(BaseForm):
temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
required=True, default_value=0.3,
_min=0.1,
_max=1.0,
_step=0.01,
precision=2)
max_tokens = forms.SliderField(
TooltipLabel('输出最大Tokens', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
required=True, default_value=1024,
_min=1,
_max=4096,
_step=1,
precision=0)
class OllamaLLMModelCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
raise_exception=False):
@ -43,24 +60,5 @@ class OllamaLLMModelCredential(BaseForm, BaseModelCredential):
api_base = forms.TextInputField('API 域名', required=True)
api_key = forms.PasswordInputField('API Key', required=True)
def get_other_fields(self, model_name):
return {
'temperature': {
'value': 0.3,
'min': 0.1,
'max': 1.0,
'step': 0.01,
'label': '温度',
'precision': 2,
'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'
},
'max_tokens': {
'value': 1024,
'min': 1,
'max': 4096,
'step': 1,
'label': '输出最大Tokens',
'precision': 0,
'tooltip': '指定模型可生成的最大token个数'
}
}
def get_model_params_setting_form(self, model_name):
return OllamaLLMModelParams()

View File

@ -12,10 +12,27 @@ from langchain_core.messages import HumanMessage
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from common.forms import BaseForm, TooltipLabel
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
class OpenAILLMModelParams(BaseForm):
temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
required=True, default_value=0.7,
_min=0.1,
_max=1.0,
_step=0.01,
precision=2)
max_tokens = forms.SliderField(
TooltipLabel('输出最大Tokens', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
required=True, default_value=800,
_min=1,
_max=4096,
_step=1,
precision=0)
class OpenAILLMModelCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
@ -48,24 +65,5 @@ class OpenAILLMModelCredential(BaseForm, BaseModelCredential):
api_base = forms.TextInputField('API 域名', required=True)
api_key = forms.PasswordInputField('API Key', required=True)
def get_other_fields(self, model_name):
return {
'temperature': {
'value': 0.7,
'min': 0.1,
'max': 1.0,
'step': 0.01,
'label': '温度',
'precision': 2,
'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'
},
'max_tokens': {
'value': 800,
'min': 1,
'max': 4096,
'step': 1,
'label': '输出最大Tokens',
'precision': 0,
'tooltip': '指定模型可生成的最大token个数'
}
}
def get_model_params_setting_form(self, model_name):
return OpenAILLMModelParams()

View File

@ -12,10 +12,27 @@ from langchain_core.messages import HumanMessage
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from common.forms import BaseForm, TooltipLabel
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
class QwenModelParams(BaseForm):
temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
required=True, default_value=1.0,
_min=0.1,
_max=1.9,
_step=0.01,
precision=2)
max_tokens = forms.SliderField(
TooltipLabel('输出最大Tokens', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
required=True, default_value=800,
_min=1,
_max=2048,
_step=1,
precision=0)
class OpenAILLMModelCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
@ -46,24 +63,5 @@ class OpenAILLMModelCredential(BaseForm, BaseModelCredential):
api_key = forms.PasswordInputField('API Key', required=True)
def get_other_fields(self, model_name):
return {
'temperature': {
'value': 1.0,
'min': 0.1,
'max': 1.9,
'step': 0.01,
'label': '温度',
'precision': 2,
'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'
},
'max_tokens': {
'value': 800,
'min': 1,
'max': 2048,
'step': 1,
'label': '输出最大Tokens',
'precision': 0,
'tooltip': '指定模型可生成的最大token个数'
}
}
def get_model_params_setting_form(self, model_name):
return QwenModelParams()

View File

@ -1,11 +1,21 @@
# coding=utf-8
from langchain_core.messages import HumanMessage
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from common.forms import BaseForm, TooltipLabel
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
class TencentLLMModelParams(BaseForm):
temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
required=True, default_value=0.5,
_min=0.1,
_max=2.0,
_step=0.01,
precision=2)
class TencentLLMModelCredential(BaseForm, BaseModelCredential):
REQUIRED_FIELDS = ['hunyuan_app_id', 'hunyuan_secret_id', 'hunyuan_secret_key']
@ -46,15 +56,5 @@ class TencentLLMModelCredential(BaseForm, BaseModelCredential):
hunyuan_secret_id = forms.PasswordInputField('SecretId', required=True)
hunyuan_secret_key = forms.PasswordInputField('SecretKey', required=True)
def get_other_fields(self, model_name):
return {
'temperature': {
'value': 0.5,
'min': 0.1,
'max': 2.0,
'step': 0.01,
'label': '温度',
'precision': 2,
'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'
},
}
def get_model_params_setting_form(self, model_name):
return TencentLLMModelParams()

View File

@ -6,10 +6,27 @@ from langchain_core.messages import HumanMessage
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from common.forms import BaseForm, TooltipLabel
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
class VLLMModelParams(BaseForm):
temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
required=True, default_value=0.7,
_min=0.1,
_max=1.0,
_step=0.01,
precision=2)
max_tokens = forms.SliderField(
TooltipLabel('输出最大Tokens', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
required=True, default_value=800,
_min=1,
_max=4096,
_step=1,
precision=0)
class VLLMModelCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
raise_exception=False):
@ -44,24 +61,5 @@ class VLLMModelCredential(BaseForm, BaseModelCredential):
api_base = forms.TextInputField('API 域名', required=True)
api_key = forms.PasswordInputField('API Key', required=True)
def get_other_fields(self, model_name):
return {
'temperature': {
'value': 0.7,
'min': 0.1,
'max': 1,
'step': 0.01,
'label': '温度',
'precision': 2,
'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'
},
'max_tokens': {
'value': 800,
'min': 1,
'max': 4096,
'step': 1,
'label': '输出最大Tokens',
'precision': 0,
'tooltip': '指定模型可生成的最大token个数'
}
}
def get_model_params_setting_form(self, model_name):
return VLLMModelParams()

View File

@ -12,10 +12,27 @@ from langchain_core.messages import HumanMessage
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from common.forms import BaseForm, TooltipLabel
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
class VolcanicEngineLLMModelParams(BaseForm):
temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
required=True, default_value=0.3,
_min=0.1,
_max=1.0,
_step=0.01,
precision=2)
max_tokens = forms.SliderField(
TooltipLabel('输出最大Tokens', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
required=True, default_value=1024,
_min=1,
_max=4096,
_step=1,
precision=0)
class VolcanicEngineLLMModelCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
@ -49,24 +66,5 @@ class VolcanicEngineLLMModelCredential(BaseForm, BaseModelCredential):
access_key_id = forms.PasswordInputField('Access Key ID', required=True)
secret_access_key = forms.PasswordInputField('Secret Access Key', required=True)
def get_other_fields(self):
return {
'temperature': {
'value': 0.3,
'min': 0.1,
'max': 1.0,
'step': 0.01,
'label': '温度',
'precision': 2,
'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'
},
'max_tokens': {
'value': 1024,
'min': 1,
'max': 4096,
'step': 1,
'label': '输出最大Tokens',
'precision': 0,
'tooltip': '指定模型可生成的最大token个数'
}
}
def get_model_params_setting_form(self, model_name):
return VolcanicEngineLLMModelParams()

View File

@ -12,10 +12,27 @@ from langchain_core.messages import HumanMessage
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from common.forms import BaseForm, TooltipLabel
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
class WenxinLLMModelParams(BaseForm):
temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
required=True, default_value=0.95,
_min=0.1,
_max=1.0,
_step=0.01,
precision=2)
max_tokens = forms.SliderField(
TooltipLabel('输出最大Tokens', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
required=True, default_value=1024,
_min=2,
_max=2048,
_step=1,
precision=0)
class WenxinLLMModelCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
raise_exception=False):
@ -54,24 +71,5 @@ class WenxinLLMModelCredential(BaseForm, BaseModelCredential):
secret_key = forms.PasswordInputField("Secret Key", required=True)
def get_other_fields(self, model_name):
return {
'temperature': {
'value': 0.95,
'min': 0.1,
'max': 1.0,
'step': 0.01,
'label': '温度',
'precision': 2,
'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'
},
'max_tokens': {
'value': 1024,
'min': 2,
'max': 2048,
'step': 1,
'label': '输出最大Tokens',
'precision': 0,
'tooltip': '指定模型可生成的最大token个数'
}
}
def get_model_params_setting_form(self, model_name):
return WenxinLLMModelParams()

View File

@ -12,10 +12,44 @@ from langchain_core.messages import HumanMessage
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from common.forms import BaseForm, TooltipLabel
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
class XunFeiLLMModelGeneralParams(BaseForm):
temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
required=True, default_value=0.5,
_min=0.1,
_max=1.0,
_step=0.01,
precision=2)
max_tokens = forms.SliderField(
TooltipLabel('输出最大Tokens', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
required=True, default_value=4096,
_min=1,
_max=4096,
_step=1,
precision=0)
class XunFeiLLMModelProParams(BaseForm):
temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
required=True, default_value=0.5,
_min=0.1,
_max=1.0,
_step=0.01,
precision=2)
max_tokens = forms.SliderField(
TooltipLabel('输出最大Tokens', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
required=True, default_value=4096,
_min=1,
_max=8192,
_step=1,
precision=0)
class XunFeiLLMModelCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
@ -50,27 +84,7 @@ class XunFeiLLMModelCredential(BaseForm, BaseModelCredential):
spark_api_key = forms.PasswordInputField("API Key", required=True)
spark_api_secret = forms.PasswordInputField('API Secret', required=True)
def get_other_fields(self, model_name):
max_value = 8192
def get_model_params_setting_form(self, model_name):
if model_name == 'general' or model_name == 'pro-128k':
max_value = 4096
return {
'temperature': {
'value': 0.5,
'min': 0.1,
'max': 1,
'step': 0.01,
'label': '温度',
'precision': 2,
'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'
},
'max_tokens': {
'value': 4096,
'min': 1,
'max': max_value,
'step': 1,
'label': '输出最大Tokens',
'precision': 0,
'tooltip': '指定模型可生成的最大token个数'
}
}
return XunFeiLLMModelGeneralParams()
return XunFeiLLMModelProParams()

View File

@ -6,10 +6,27 @@ from langchain_core.messages import HumanMessage
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from common.forms import BaseForm, TooltipLabel
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
class XinferenceLLMModelParams(BaseForm):
temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
required=True, default_value=0.7,
_min=0.1,
_max=1.0,
_step=0.01,
precision=2)
max_tokens = forms.SliderField(
TooltipLabel('输出最大Tokens', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
required=True, default_value=800,
_min=1,
_max=4096,
_step=1,
precision=0)
class XinferenceLLMModelCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
raise_exception=False):
@ -40,24 +57,5 @@ class XinferenceLLMModelCredential(BaseForm, BaseModelCredential):
api_base = forms.TextInputField('API 域名', required=True)
api_key = forms.PasswordInputField('API Key', required=True)
def get_other_fields(self, model_name):
return {
'temperature': {
'value': 0.7,
'min': 0.1,
'max': 1,
'step': 0.01,
'label': '温度',
'precision': 2,
'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'
},
'max_tokens': {
'value': 800,
'min': 1,
'max': 4096,
'step': 1,
'label': '输出最大Tokens',
'precision': 0,
'tooltip': '指定模型可生成的最大token个数'
}
}
def get_model_params_setting_form(self, model_name):
return XinferenceLLMModelParams()

View File

@ -1,7 +1,8 @@
# coding=utf-8
from typing import List, Dict
from typing import Dict
from urllib.parse import urlparse, ParseResult
from setting.models_provider.base_model_provider import MaxKBBaseModel
from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI

View File

@ -12,10 +12,27 @@ from langchain_core.messages import HumanMessage
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from common.forms import BaseForm, TooltipLabel
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
class ZhiPuLLMModelParams(BaseForm):
temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
required=True, default_value=0.95,
_min=0.1,
_max=1.0,
_step=0.01,
precision=2)
max_tokens = forms.SliderField(
TooltipLabel('输出最大Tokens', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
required=True, default_value=1024,
_min=1,
_max=4096,
_step=1,
precision=0)
class ZhiPuLLMModelCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
@ -46,25 +63,5 @@ class ZhiPuLLMModelCredential(BaseForm, BaseModelCredential):
api_key = forms.PasswordInputField('API Key', required=True)
def get_other_fields(self, model_name):
return {
'temperature': {
'value': 0.95,
'min': 0.1,
'max': 1.0,
'step': 0.01,
'label': '温度',
'precision': 2,
'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'
},
'max_tokens': {
'value': 1024,
'min': 1,
'max': 4095,
'step': 1,
'label': '输出最大Tokens',
'precision': 0,
'tooltip': '指定模型可生成的最大token个数'
}
}
def get_model_params_setting_form(self, model_name):
return ZhiPuLLMModelParams()

View File

@ -24,6 +24,7 @@ from common.util.field_message import ErrMessage
from common.util.rsa_util import rsa_long_decrypt, rsa_long_encrypt
from dataset.models import DataSet
from setting.models.model_management import Model, Status
from setting.models_provider import get_model, get_model_credential
from setting.models_provider.base_model_provider import ValidCode, DownModelChunkStatus
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
@ -234,6 +235,14 @@ class ModelSerializer(serializers.Serializer):
'meta': model.meta
}
def get_model_params(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
model_id = self.data.get('id')
model = QuerySet(Model).filter(id=model_id).first()
credential = get_model_credential(model.provider, model.model_type, model.model_name)
return credential.get_model_params_setting_form(model.model_name).to_form_list()
def delete(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)

View File

@ -17,6 +17,7 @@ urlpatterns = [
path('provider/model_form', views.Provide.ModelForm.as_view(),
name="provider/model_form"),
path('model', views.Model.as_view(), name='model'),
path('model/<str:model_id>/model_params_form', views.Model.ModelParamsForm.as_view(), name='model/model_params_form'),
path('model/<str:model_id>', views.Model.Operate.as_view(), name='model/operate'),
path('model/<str:model_id>/pause_download', views.Model.PauseDownload.as_view(), name='model/operate'),
path('model/<str:model_id>/meta', views.Model.ModelMeta.as_view(), name='model/operate/meta'),

View File

@ -81,6 +81,19 @@ class Model(APIView):
return result.success(
ModelSerializer.Operate(data={'id': model_id, 'user_id': request.user.id}).pause_download())
class ModelParamsForm(APIView):
authentication_classes = [TokenAuth]
@action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="获取模型参数表单",
operation_id="获取模型参数表单",
manual_parameters=ProvideApi.ModelForm.get_request_params_api(),
tags=["模型"])
@has_permissions(PermissionConstants.MODEL_READ)
def get(self, request: Request, model_id: str):
return result.success(
ModelSerializer.Operate(data={'id': model_id, 'user_id': request.user.id}).get_model_params())
class Operate(APIView):
authentication_classes = [TokenAuth]

View File

@ -249,41 +249,6 @@ const putPublishApplication: (
return put(`${prefix}/${application_id}/publish`, data, undefined, loading)
}
/**
*
* @param application_id
* @param model_id
* @param loading
* @returns
*/
const getModelOtherConfig: (
application_id: string,
model_id: string,
ai_node_id?: string,
loading?: Ref<boolean>
) => Promise<Result<any>> = (application_id, model_id, ai_node_id?, loading?) => {
if (ai_node_id) {
// 如果 ai_node_id 不为空,则调用带有 ai_node_id 的 API
return get(`${prefix}/${application_id}/model/${model_id}/${ai_node_id}`, undefined, loading)
} else {
// 如果 ai_node_id 为空,则调用不带 ai_node_id 的 API
return get(`${prefix}/${application_id}/model/${model_id}`, undefined, loading)
}
}
/**
*
* @param application_id
*
*/
const putModelOtherConfig: (
application_id: string,
data: any,
loading?: Ref<boolean>
) => Promise<Result<any>> = (application_id, data, loading) => {
return put(`${prefix}/${application_id}/other-config`, data, undefined, loading)
}
export default {
getAllAppilcation,
getApplication,
@ -303,7 +268,5 @@ export default {
getApplicationHitTest,
getApplicationModel,
putPublishApplication,
postWorkflowChatOpen,
getModelOtherConfig,
putModelOtherConfig
postWorkflowChatOpen
}

View File

@ -51,6 +51,18 @@ const getModelCreateForm: (
return get(`${prefix_provider}/model_form`, { provider, model_type, model_name }, loading)
}
/**
*
* @param model_id id
* @param loading
* @returns
*/
const getModelParamsForm: (
model_id: string,
loading?: Ref<boolean>
) => Promise<Result<Array<FormField>>> = (model_id, loading) => {
return get(`model/${model_id}/model_params_form`, {}, loading)
}
/**
*
* @param provider
@ -159,5 +171,6 @@ export default {
deleteModel,
getModelById,
getModelMetaById,
pauseDownload
pauseDownload,
getModelParamsForm
}

View File

@ -13,6 +13,7 @@ interface ApplicationFormType {
icon?: string | undefined
type?: string
work_flow?: any
model_params_setting?: any
}
interface chatType {
id: string

View File

@ -1,6 +1,11 @@
<template>
<div style="width: 1024px">
<DynamicsForm v-model="form_data" :render_data="damo_data" ref="dynamicsFormRef">
<DynamicsForm
v-model="form_data"
:model="form_data"
:render_data="damo_data"
ref="dynamicsFormRef"
>
<template #default="scope">
<el-form-item label="其他字段">
<el-input v-model="scope.form_value['zha']" /> </el-form-item
@ -14,6 +19,7 @@ import type { FormField } from '@/components/dynamics-form/type'
import DynamicsForm from '@/components/dynamics-form/index.vue'
import { ref } from 'vue'
import type { Dict } from '@/api/type/common'
const damo_data: Array<FormField> = [
{ field: 'name', input_type: 'PasswordInput', label: '用戶名', required: false },
{
@ -29,6 +35,20 @@ const damo_data: Array<FormField> = [
{ field: 'name3', input_type: 'TextInput', label: '用戶名3' }
]
},
{
field: 'maxkb_tokens',
input_type: 'Slider',
default_value: 1,
attrs: {
min: 0,
max: 10,
step: 1,
precision: 1,
'show-input-controls': false,
'show-input': true
},
label: { label: '温度', attrs: { tooltip: 'sss' }, input_type: 'TooltipLabel' }
},
{
field: 'object_card_field',
input_type: 'ObjectCard',

View File

@ -2,13 +2,18 @@
<el-form-item
v-loading="loading"
:style="formItemStyle"
:label="formfield.label"
:prop="formfield.field"
:key="formfield.field"
:rules="rules"
>
<template #label v-if="formfield.label">
<FormItemLabel :form-field="formfield"></FormItemLabel>
<FormItemLabel v-if="isString(formfield.label)" :form-field="formfield"></FormItemLabel>
<component
v-else
:is="formfield.label.input_type"
:label="formfield.label.label"
v-bind="label_attrs"
></component>
</template>
<component
ref="componentFormRef"
@ -58,6 +63,9 @@ const emit = defineEmits(['change'])
const loading = ref<boolean>(false)
const isString = (value: any) => {
return typeof value === 'string'
}
const itemValue = computed({
get: () => {
return props.modelValue
@ -72,7 +80,13 @@ const itemValue = computed({
}
})
const componentFormRef = ref<any>()
const label_attrs = computed(() => {
return props.formfield.label &&
typeof props.formfield.label !== 'string' &&
props.formfield.label.attrs
? props.formfield.label.attrs
: {}
})
const props_info = computed(() => {
return props.formfield.props_info ? props.formfield.props_info : {}
})
@ -87,7 +101,11 @@ const formItemStyle = computed(() => {
* 表单错误Msg
*/
const errMsg = computed(() => {
return props_info.value.err_msg ? props_info.value.err_msg : props.formfield.label + '不能为空'
return props_info.value.err_msg
? props_info.value.err_msg
: isString(props.formfield.label)
? props.formfield.label
: props.formfield.label.label + '不能为空'
})
/**

View File

@ -33,7 +33,7 @@
import type { Dict } from '@/api/type/common'
import FormItem from '@/components/dynamics-form/FormItem.vue'
import type { FormField } from '@/components/dynamics-form/type'
import { ref, onMounted, watch, type Ref } from 'vue'
import { ref, onBeforeMount, watch, type Ref } from 'vue'
import type { FormInstance } from 'element-plus'
import triggerApi from '@/api/provider'
import type Result from '@/request/Result'
@ -152,7 +152,7 @@ const initDefaultData = (formField: FormField) => {
}
}
onMounted(() => {
onBeforeMount(() => {
render(props.render_data, {})
})

View File

@ -8,6 +8,7 @@
require-asterisk-position="right"
ref="ceFormRef"
v-model="_data[index]"
:model="_data[index]"
:other-params="other"
:render_data="render_data()"
v-bind="attr"

View File

@ -16,6 +16,7 @@
require-asterisk-position="right"
ref="ceFormRef"
v-model="_data[index]"
:model="_data[index]"
:other-params="other"
:render_data="render_data()"
v-bind="attr"

View File

@ -0,0 +1,40 @@
<template>
<div class="flex align-center" style="display: inline-flex">
<div class="flex-between mr-4">
<span>{{ label }}</span>
</div>
<el-tooltip effect="dark" placement="right">
<template #content>{{ tooltip }}</template>
<AppIcon iconName="app-warning" class="app-warning-icon"></AppIcon>
</el-tooltip>
</div>
</template>
<script setup lang="ts">
defineProps<{
label: string
tooltip: string
}>()
</script>
<style lang="scss" scope>
.aiMode-param-dialog {
padding: 8px 8px 24px 8px;
.el-dialog__header {
padding: 16px 16px 0 16px;
}
.el-dialog__body {
padding: 16px !important;
}
.dialog-max-height {
height: 550px;
}
.custom-slider {
.el-input-number.is-without-controls .el-input__wrapper {
padding: 0 !important;
}
}
}
</style>

View File

@ -0,0 +1,11 @@
<template>
<el-slider v-bind="$attrs" show-input :show-input-controls="false" class="custom-slider" />
</template>
<script setup lang="ts"></script>
<style lang="scss" scoped>
.custom-slider {
.el-input-number.is-without-controls .el-input__wrapper {
padding: 0 !important;
}
}
</style>

View File

@ -117,7 +117,7 @@ interface FormField {
/**
*
*/
label?: string
label?: string | any
/**
*
*/

View File

@ -130,37 +130,6 @@ const useApplicationStore = defineStore({
reject(error)
})
})
},
// 获取模型的 温度/max_token字段设置
async asyncGetModelConfig(
id: string,
model_id: string,
ai_node_id?: string,
loading?: Ref<boolean>
) {
return new Promise((resolve, reject) => {
applicationApi
.getModelOtherConfig(id, model_id, ai_node_id, loading)
.then((data) => {
resolve(data)
})
.catch((error) => {
reject(error)
})
})
},
// 保存应用的 温度/max_token字段设置
async asyncPostModelConfig(id: string, params?: any, loading?: Ref<boolean>) {
return new Promise((resolve, reject) => {
applicationApi
.putModelOtherConfig(id, params, loading)
.then((data) => {
resolve(data)
})
.catch((error) => {
reject(error)
})
})
}
}
})

View File

@ -328,7 +328,7 @@
</el-col>
</el-row>
<AIModeParamSettingDialog ref="AIModeParamSettingDialogRef" :id="id" @refresh="refreshForm" />
<AIModeParamSettingDialog ref="AIModeParamSettingDialogRef" @refresh="refreshForm" />
<ParamSettingDialog ref="ParamSettingDialogRef" @refresh="refreshParam" />
<AddDatasetDialog
ref="AddDatasetDialogRef"
@ -348,8 +348,8 @@
</LayoutContainer>
</template>
<script setup lang="ts">
import { reactive, ref, watch, onMounted } from 'vue'
import { useRouter, useRoute } from 'vue-router'
import { reactive, ref, onMounted } from 'vue'
import { useRoute } from 'vue-router'
import { groupBy } from 'lodash'
import AIModeParamSettingDialog from './component/AIModeParamSettingDialog.vue'
import ParamSettingDialog from './component/ParamSettingDialog.vue'
@ -408,6 +408,7 @@ const applicationForm = ref<ApplicationFormType>({
model_setting: {
prompt: defaultPrompt
},
model_params_setting: {},
problem_optimization: false,
type: 'SIMPLE'
})
@ -456,9 +457,7 @@ const openAIParamSettingDialog = () => {
MsgSuccess(t('请选择AI 模型'))
return
}
application.asyncGetModelConfig(id, model_id, '', loading).then((res: any) => {
AIModeParamSettingDialogRef.value?.open(res.data)
})
AIModeParamSettingDialogRef.value?.open(model_id, applicationForm.value.model_params_setting)
}
const openParamSettingDialog = () => {
@ -470,11 +469,7 @@ function refreshParam(data: any) {
}
function refreshForm(data: any) {
// data datamodel_setting
applicationForm.value.model_setting = {
...applicationForm.value.model_setting,
...data
}
applicationForm.value.model_params_setting = data
}
const openCreateModel = (provider?: Provider) => {

View File

@ -7,31 +7,16 @@
style="width: 550px"
append-to-body
>
<el-form label-position="top" ref="paramFormRef" :model="form">
<el-form-item v-for="(item, key) in form" :key="key">
<template #label>
<div class="flex align-center">
<div class="flex-between mr-4">
<span>{{ item.label }}</span>
</div>
<el-tooltip effect="dark" placement="right">
<template #content>{{ item.tooltip }}</template>
<AppIcon iconName="app-warning" class="app-warning-icon"></AppIcon>
</el-tooltip>
</div>
</template>
<el-slider
v-model="item.value"
show-input
:show-input-controls="false"
:min="item.min"
:max="item.max"
:precision="item.precision || 0"
:step="item.step || 1"
class="custom-slider"
/>
</el-form-item>
</el-form>
<DynamicsForm
v-model="form_data"
:model="form_data"
label-position="top"
require-asterisk-position="right"
:render_data="model_form_field"
ref="dynamicsFormRef"
>
</DynamicsForm>
<template #footer>
<span class="dialog-footer p-16">
<el-button @click.prevent="dialogVisible = false">
@ -46,82 +31,40 @@
</template>
<script setup lang="ts">
import { ref, reactive, watch } from 'vue'
import { cloneDeep, set } from 'lodash'
import { ref } from 'vue'
import type { FormInstance } from 'element-plus'
import useStore from '@/stores'
const { application } = useStore()
import type { FormField } from '@/components/dynamics-form/type'
import modelAPi from '@/api/model'
import DynamicsForm from '@/components/dynamics-form/index.vue'
import { keys } from 'lodash'
const model_form_field = ref<Array<FormField>>([])
const emit = defineEmits(['refresh'])
const paramFormRef = ref<FormInstance>()
const form = reactive<Form>({})
const dynamicsFormRef = ref<InstanceType<typeof DynamicsForm>>()
const form_data = ref<any>({})
const dialogVisible = ref(false)
const loading = ref(false)
const props = defineProps<{
id: string
nodeId?: string
}>()
const resetForm = () => {
// form
Object.keys(form).forEach((key) => delete form[key])
}
interface Form {
[key: string]: FormField
}
interface FormField {
value: any
min?: number
max?: number
step?: number
label?: string
precision?: number
tooltip?: string
}
const open = (data: any) => {
const newData = cloneDeep(data)
Object.keys(form).forEach((key) => {
delete form[key]
})
Object.keys(newData).forEach((key) => {
set(form, key, newData[key])
const open = (model_id: string, model_setting_data?: any) => {
modelAPi.getModelParamsForm(model_id, loading).then((ok) => {
model_form_field.value = ok.data
model_setting_data =
model_setting_data && keys(model_setting_data).length > 0
? model_setting_data
: ok.data
.map((item) => ({ [item.field]: item.default_value }))
.reduce((x, y) => ({ ...x, ...y }), {})
//
dynamicsFormRef.value?.render(model_form_field.value, model_setting_data)
})
dialogVisible.value = true
}
const submit = async () => {
if (paramFormRef.value) {
await paramFormRef.value.validate((valid, fields) => {
if (valid) {
const data = Object.keys(form).reduce(
(acc, key) => {
acc[key] = form[key].value
return acc
},
{} as Record<string, any>
)
if (props.nodeId) {
data.node_id = props.nodeId
}
application.asyncPostModelConfig(props.id, data, loading).then(() => {
emit('refresh', data)
dialogVisible.value = false
})
}
})
}
emit('refresh', form_data.value)
dialogVisible.value = false
}
watch(dialogVisible, (bool) => {
if (!bool) {
resetForm()
}
})
defineExpose({ open })
</script>

View File

@ -196,13 +196,7 @@
@change="openCreateModel($event)"
></CreateModelDialog>
<SelectProviderDialog ref="selectProviderRef" @change="openCreateModel($event)" />
<AIModeParamSettingDialog
ref="AIModeParamSettingDialogRef"
:id="id"
:node-id="props.nodeModel.id"
@refresh="refreshParam"
/>
<AIModeParamSettingDialog ref="AIModeParamSettingDialogRef" @refresh="refreshParam" />
</NodeContainer>
</template>
<script setup lang="ts">
@ -319,15 +313,12 @@ const openCreateModel = (provider?: Provider) => {
const openAIParamSettingDialog = (modelId: string) => {
if (modelId) {
application.asyncGetModelConfig(id, modelId, props.nodeModel.id).then((res: any) => {
AIModeParamSettingDialogRef.value?.open(res.data)
})
AIModeParamSettingDialogRef.value?.open(modelId)
}
}
function refreshParam(data: any) {
chat_data.value.temperature = data.temperature
chat_data.value.max_tokens = data.max_tokens
set(props.nodeModel.properties.node_data, 'model_params_setting', data)
}
onMounted(() => {

View File

@ -178,6 +178,7 @@
@change="openCreateModel($event)"
></CreateModelDialog>
<SelectProviderDialog ref="selectProviderRef" @change="openCreateModel($event)" />
<AIModeParamSettingDialog ref="AIModeParamSettingDialogRef" @refresh="refreshParam" />
</NodeContainer>
</template>
<script setup lang="ts">
@ -186,6 +187,7 @@ import { app } from '@/main'
import NodeContainer from '@/workflow/common/NodeContainer.vue'
import CreateModelDialog from '@/views/template/component/CreateModelDialog.vue'
import SelectProviderDialog from '@/views/template/component/SelectProviderDialog.vue'
import AIModeParamSettingDialog from '@/views/application/component/AIModeParamSettingDialog.vue'
import type { FormInstance } from 'element-plus'
import { ref, computed, onMounted } from 'vue'
import applicationApi from '@/api/application'
@ -229,6 +231,9 @@ const form = {
dialogue_number: 1,
is_result: false
}
function refreshParam(data: any) {
set(props.nodeModel.properties.node_data, 'model_params_setting', data)
}
const form_data = computed({
get: () => {