mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
refactor: 重构模型参数代码
This commit is contained in:
parent
228f913d9b
commit
9e0ac81f1d
|
|
@ -73,9 +73,8 @@ class IChatStep(IBaseChatPipelineStep):
|
|||
no_references_setting = NoReferencesSetting(required=True, error_messages=ErrMessage.base("无引用分段设置"))
|
||||
|
||||
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))
|
||||
temperature = serializers.FloatField(required=False, allow_null=True, error_messages=ErrMessage.float("温度"))
|
||||
max_tokens = serializers.IntegerField(required=False, allow_null=True,
|
||||
error_messages=ErrMessage.integer("最大token数"))
|
||||
|
||||
model_params_setting = serializers.DictField(required=True, error_messages=ErrMessage.dict("模型参数设置"))
|
||||
|
||||
def is_valid(self, *, raise_exception=False):
|
||||
super().is_valid(raise_exception=True)
|
||||
|
|
@ -100,5 +99,5 @@ class IChatStep(IBaseChatPipelineStep):
|
|||
paragraph_list=None,
|
||||
manage: PipelineManage = None,
|
||||
padding_problem_text: str = None, stream: bool = True, client_id=None, client_type=None,
|
||||
no_references_setting=None, **kwargs):
|
||||
no_references_setting=None, model_params_setting=None, **kwargs):
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -110,8 +110,10 @@ class BaseChatStep(IChatStep):
|
|||
stream: bool = True,
|
||||
client_id=None, client_type=None,
|
||||
no_references_setting=None,
|
||||
model_params_setting=None,
|
||||
**kwargs):
|
||||
chat_model = get_model_instance_by_model_user_id(model_id, user_id, **kwargs) if model_id is not None else None
|
||||
chat_model = get_model_instance_by_model_user_id(model_id, user_id,
|
||||
**model_params_setting) if model_id is not None else None
|
||||
if stream:
|
||||
return self.execute_stream(message_list, chat_id, problem_text, post_response_handler, chat_model,
|
||||
paragraph_list,
|
||||
|
|
|
|||
|
|
@ -23,9 +23,8 @@ class ChatNodeSerializer(serializers.Serializer):
|
|||
dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer("多轮对话数量"))
|
||||
|
||||
is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容'))
|
||||
temperature = serializers.FloatField(required=False, allow_null=True, error_messages=ErrMessage.float("温度"))
|
||||
max_tokens = serializers.IntegerField(required=False, allow_null=True,
|
||||
error_messages=ErrMessage.integer("最大token数"))
|
||||
|
||||
model_params_setting = serializers.DictField(required=True, error_messages=ErrMessage.integer("模型参数相关设置"))
|
||||
|
||||
|
||||
class IChatNode(INode):
|
||||
|
|
@ -39,5 +38,6 @@ class IChatNode(INode):
|
|||
|
||||
def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id,
|
||||
chat_record_id,
|
||||
model_params_setting,
|
||||
**kwargs) -> NodeResult:
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -65,12 +65,11 @@ def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wor
|
|||
|
||||
class BaseChatNode(IChatNode):
|
||||
def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id,
|
||||
model_params_setting,
|
||||
**kwargs) -> NodeResult:
|
||||
kwargs = {k: v for k, v in kwargs.items() if k in ['temperature', 'max_tokens']}
|
||||
|
||||
chat_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'),
|
||||
**kwargs
|
||||
)
|
||||
**model_params_setting)
|
||||
history_message = self.get_history_message(history_chat_record, dialogue_number)
|
||||
self.context['history_message'] = history_message
|
||||
question = self.generate_prompt_question(prompt)
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ class QuestionNodeSerializer(serializers.Serializer):
|
|||
dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer("多轮对话数量"))
|
||||
|
||||
is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容'))
|
||||
model_params_setting = serializers.DictField(required=True, error_messages=ErrMessage.integer("模型参数相关设置"))
|
||||
|
||||
|
||||
class IQuestionNode(INode):
|
||||
|
|
@ -35,5 +36,6 @@ class IQuestionNode(INode):
|
|||
return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data)
|
||||
|
||||
def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id,
|
||||
model_params_setting,
|
||||
**kwargs) -> NodeResult:
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -65,8 +65,10 @@ def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wor
|
|||
|
||||
class BaseQuestionNode(IQuestionNode):
|
||||
def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id,
|
||||
model_params_setting,
|
||||
**kwargs) -> NodeResult:
|
||||
chat_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'))
|
||||
chat_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'),
|
||||
**model_params_setting)
|
||||
history_message = self.get_history_message(history_chat_record, dialogue_number)
|
||||
self.context['history_message'] = history_message
|
||||
question = self.generate_prompt_question(prompt)
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ import json
|
|||
from functools import reduce
|
||||
from typing import List, Dict
|
||||
|
||||
from django.db.models import QuerySet
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
|
||||
|
|
@ -17,6 +18,8 @@ from application.flow import tools
|
|||
from application.flow.i_step_node import INode, WorkFlowPostHandler, NodeResult
|
||||
from application.flow.step_node import get_node
|
||||
from common.exception.app_exception import AppApiException
|
||||
from setting.models import Model
|
||||
from setting.models_provider import get_model_credential
|
||||
|
||||
|
||||
class Edge:
|
||||
|
|
@ -68,6 +71,7 @@ class Flow:
|
|||
"""
|
||||
校验工作流数据
|
||||
"""
|
||||
self.is_valid_model_params()
|
||||
self.is_valid_start_node()
|
||||
self.is_valid_base_node()
|
||||
self.is_valid_work_flow()
|
||||
|
|
@ -123,6 +127,19 @@ class Flow:
|
|||
if len(start_node_list) > 1:
|
||||
raise AppApiException(500, '开始节点只能有一个')
|
||||
|
||||
def is_valid_model_params(self):
|
||||
node_list = [node for node in self.nodes if (node.type == 'ai-chat-node' or node.type == 'question-node')]
|
||||
for node in node_list:
|
||||
model = QuerySet(Model).filter(id=node.properties.get('node_data', {}).get('model_id')).first()
|
||||
credential = get_model_credential(model.provider, model.model_type, model.model_name)
|
||||
model_params_setting = node.properties.get('node_data', {}).get('model_params_setting')
|
||||
model_params_setting_form = credential.get_model_params_setting_form(
|
||||
model.model_name)
|
||||
if model_params_setting is None:
|
||||
model_params_setting = model_params_setting_form.get_default_form_data()
|
||||
node.properties.get('node_data', {})['model_params_setting'] = model_params_setting
|
||||
model_params_setting_form.valid_form(model_params_setting)
|
||||
|
||||
def is_valid_base_node(self):
|
||||
base_node_list = [node for node in self.nodes if node.id == 'base-node']
|
||||
if len(base_node_list) == 0:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,18 @@
|
|||
# Generated by Django 4.2.15 on 2024-08-23 14:17
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('application', '0010_alter_chatrecord_details'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name='application',
|
||||
name='model_params_setting',
|
||||
field=models.JSONField(default={}, verbose_name='模型参数相关设置'),
|
||||
),
|
||||
]
|
||||
|
|
@ -48,6 +48,7 @@ class Application(AppModelMixin):
|
|||
model = models.ForeignKey(Model, on_delete=models.SET_NULL, db_constraint=False, blank=True, null=True)
|
||||
dataset_setting = models.JSONField(verbose_name="数据集参数设置", default=get_dataset_setting_dict)
|
||||
model_setting = models.JSONField(verbose_name="模型参数相关设置", default=get_model_setting_dict)
|
||||
model_params_setting = models.JSONField(verbose_name="模型参数相关设置", default={})
|
||||
problem_optimization = models.BooleanField(verbose_name="问题优化", default=False)
|
||||
icon = models.CharField(max_length=256, verbose_name="应用icon", default="/ui/favicon.ico")
|
||||
work_flow = models.JSONField(verbose_name="工作流数据", default=dict)
|
||||
|
|
|
|||
|
|
@ -44,6 +44,7 @@ from dataset.serializers.common_serializers import list_paragraph, get_embedding
|
|||
from embedding.models import SearchMode
|
||||
from setting.models import AuthOperate
|
||||
from setting.models.model_management import Model
|
||||
from setting.models_provider import get_model_credential
|
||||
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
|
||||
from setting.serializers.provider_serializers import ModelSerializer
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
|
@ -93,6 +94,14 @@ class NoReferencesSetting(serializers.Serializer):
|
|||
value = serializers.CharField(required=True, error_messages=ErrMessage.char("提示词"))
|
||||
|
||||
|
||||
def valid_model_params_setting(model_id, model_params_setting):
|
||||
if model_id is None:
|
||||
return
|
||||
model = QuerySet(Model).filter(id=model_id).first()
|
||||
credential = get_model_credential(model.provider, model.model_type, model.model_name)
|
||||
credential.get_model_params_setting_form(model.model_name).valid_form(model_params_setting)
|
||||
|
||||
|
||||
class DatasetSettingSerializer(serializers.Serializer):
|
||||
top_n = serializers.FloatField(required=True, max_value=100, min_value=1,
|
||||
error_messages=ErrMessage.float("引用分段数"))
|
||||
|
|
@ -110,10 +119,6 @@ class DatasetSettingSerializer(serializers.Serializer):
|
|||
|
||||
class ModelSettingSerializer(serializers.Serializer):
|
||||
prompt = serializers.CharField(required=True, max_length=2048, error_messages=ErrMessage.char("提示词"))
|
||||
temperature = serializers.FloatField(required=False, allow_null=True,
|
||||
error_messages=ErrMessage.char("温度"))
|
||||
max_tokens = serializers.IntegerField(required=False, allow_null=True,
|
||||
error_messages=ErrMessage.integer("最大token数"))
|
||||
|
||||
|
||||
class ApplicationWorkflowSerializer(serializers.Serializer):
|
||||
|
|
@ -185,6 +190,7 @@ class ApplicationSerializer(serializers.Serializer):
|
|||
message="应用类型只支持SIMPLE|WORK_FLOW", code=500)
|
||||
]
|
||||
)
|
||||
model_params_setting = serializers.DictField(required=False, error_messages=ErrMessage.dict('模型参数'))
|
||||
|
||||
def is_valid(self, *, user_id=None, raise_exception=False):
|
||||
super().is_valid(raise_exception=True)
|
||||
|
|
@ -355,6 +361,8 @@ class ApplicationSerializer(serializers.Serializer):
|
|||
error_messages=ErrMessage.boolean("问题补全"))
|
||||
icon = serializers.CharField(required=False, allow_null=True, error_messages=ErrMessage.char("icon图标"))
|
||||
|
||||
model_params_setting = serializers.DictField(required=False, error_messages=ErrMessage.dict('模型参数'))
|
||||
|
||||
class Create(serializers.Serializer):
|
||||
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))
|
||||
|
||||
|
|
@ -667,7 +675,7 @@ class ApplicationSerializer(serializers.Serializer):
|
|||
ApplicationSerializer.Edit(data=instance).is_valid(
|
||||
raise_exception=True)
|
||||
application_id = self.data.get("application_id")
|
||||
|
||||
valid_model_params_setting(instance.get('model_id'), instance.get('model_params_setting'))
|
||||
application = QuerySet(Application).get(id=application_id)
|
||||
if instance.get('model_id') is None or len(instance.get('model_id')) == 0:
|
||||
application.model_id = None
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ from common.util.field_message import ErrMessage
|
|||
from common.util.split_model import flat_map
|
||||
from dataset.models import Paragraph, Document
|
||||
from setting.models import Model, Status
|
||||
from setting.models_provider import get_model_credential
|
||||
|
||||
chat_cache = caches['chat_cache']
|
||||
|
||||
|
|
@ -63,6 +64,12 @@ class ChatInfo:
|
|||
def to_base_pipeline_manage_params(self):
|
||||
dataset_setting = self.application.dataset_setting
|
||||
model_setting = self.application.model_setting
|
||||
model_id = self.application.model.id if self.application.model is not None else None
|
||||
model_params_setting = None
|
||||
if model_id is not None:
|
||||
model = QuerySet(Model).filter(id=model_id).first()
|
||||
credential = get_model_credential(model.provider, model.model_type, model.model_name)
|
||||
model_params_setting = credential.get_model_params_setting_form(model.model_name).get_default_form_data()
|
||||
return {
|
||||
'dataset_id_list': self.dataset_id_list,
|
||||
'exclude_document_id_list': self.exclude_document_id_list,
|
||||
|
|
@ -77,11 +84,11 @@ class ChatInfo:
|
|||
'prompt': model_setting.get(
|
||||
'prompt') if 'prompt' in model_setting else Application.get_default_model_prompt(),
|
||||
'chat_model': self.chat_model,
|
||||
'model_id': self.application.model.id if self.application.model is not None else None,
|
||||
'model_id': model_id,
|
||||
'problem_optimization': self.application.problem_optimization,
|
||||
'stream': True,
|
||||
'temperature': model_setting.get('temperature') if 'temperature' in model_setting else None,
|
||||
'max_tokens': model_setting.get('max_tokens') if 'max_tokens' in model_setting else None,
|
||||
'model_params_setting': model_params_setting if self.application.model_params_setting is None or len(
|
||||
self.application.model_params_setting.keys()) == 0 else self.application.model_params_setting,
|
||||
'search_mode': self.application.dataset_setting.get(
|
||||
'search_mode') if 'search_mode' in self.application.dataset_setting else 'embedding',
|
||||
'no_references_setting': self.application.dataset_setting.get(
|
||||
|
|
@ -90,7 +97,6 @@ class ChatInfo:
|
|||
'value': '{question}',
|
||||
},
|
||||
'user_id': self.application.user_id
|
||||
|
||||
}
|
||||
|
||||
def to_pipeline_manage_params(self, problem_text: str, post_response_handler: PostResponseHandler,
|
||||
|
|
|
|||
|
|
@ -33,17 +33,17 @@ from application.serializers.application_serializers import ModelDatasetAssociat
|
|||
from application.serializers.chat_message_serializers import ChatInfo
|
||||
from common.constants.permission_constants import RoleConstants
|
||||
from common.db.search import native_search, native_page_search, page_search, get_dynamics_model
|
||||
from common.event import ListenerManagement
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.util.common import post
|
||||
from common.util.field_message import ErrMessage
|
||||
from common.util.file_util import get_file_content
|
||||
from common.util.lock import try_lock, un_lock
|
||||
from dataset.models import Document, Problem, Paragraph, ProblemParagraphMapping
|
||||
from dataset.serializers.common_serializers import get_embedding_model_by_dataset_id, \
|
||||
get_embedding_model_id_by_dataset_id
|
||||
from dataset.serializers.common_serializers import get_embedding_model_id_by_dataset_id
|
||||
from dataset.serializers.paragraph_serializers import ParagraphSerializers
|
||||
from embedding.task import embedding_by_paragraph
|
||||
from setting.models import Model
|
||||
from setting.models_provider import get_model_credential
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
chat_cache = caches['chat_cache']
|
||||
|
|
@ -54,6 +54,17 @@ class WorkFlowSerializers(serializers.Serializer):
|
|||
edges = serializers.ListSerializer(child=serializers.DictField(), error_messages=ErrMessage.uuid("连线"))
|
||||
|
||||
|
||||
def valid_model_params_setting(model_id, model_params_setting):
|
||||
if model_id is None:
|
||||
return
|
||||
model = QuerySet(Model).filter(id=model_id).first()
|
||||
credential = get_model_credential(model.provider, model.model_type, model.model_name)
|
||||
model_params_setting_form = credential.get_model_params_setting_form(model.model_name)
|
||||
if model_params_setting is None or len(model_params_setting.keys()) == 0:
|
||||
model_params_setting = model_params_setting_form.get_default_form_data()
|
||||
credential.get_model_params_setting_form(model.model_name).valid_form(model_params_setting)
|
||||
|
||||
|
||||
class ChatSerializers(serializers.Serializer):
|
||||
class Operate(serializers.Serializer):
|
||||
chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("对话id"))
|
||||
|
|
@ -253,6 +264,7 @@ class ChatSerializers(serializers.Serializer):
|
|||
return chat_id
|
||||
|
||||
def open_simple(self, application):
|
||||
valid_model_params_setting(application.model_id, application.model_params_setting)
|
||||
application_id = self.data.get('application_id')
|
||||
dataset_id_list = [str(row.dataset_id) for row in
|
||||
QuerySet(ApplicationDatasetMapping).filter(
|
||||
|
|
@ -309,6 +321,8 @@ class ChatSerializers(serializers.Serializer):
|
|||
model_setting = ModelSettingSerializer(required=True)
|
||||
# 问题补全
|
||||
problem_optimization = serializers.BooleanField(required=True, error_messages=ErrMessage.boolean("问题补全"))
|
||||
# 模型相关设置
|
||||
model_params_setting = serializers.JSONField(required=True)
|
||||
|
||||
def is_valid(self, *, raise_exception=False):
|
||||
super().is_valid(raise_exception=True)
|
||||
|
|
@ -332,10 +346,12 @@ class ChatSerializers(serializers.Serializer):
|
|||
model_id = self.data.get('model_id')
|
||||
dataset_id_list = self.data.get('dataset_id_list')
|
||||
dialogue_number = 3 if self.data.get('multiple_rounds_dialogue', False) else 0
|
||||
valid_model_params_setting(model_id, self.data.get('model_params_setting'))
|
||||
application = Application(id=None, dialogue_number=dialogue_number, model_id=model_id,
|
||||
dataset_setting=self.data.get('dataset_setting'),
|
||||
model_setting=self.data.get('model_setting'),
|
||||
problem_optimization=self.data.get('problem_optimization'),
|
||||
model_params_setting=self.data.get('model_params_setting'),
|
||||
user_id=user_id)
|
||||
chat_cache.set(chat_id,
|
||||
ChatInfo(chat_id, None, dataset_id_list,
|
||||
|
|
|
|||
|
|
@ -19,10 +19,6 @@ urlpatterns = [
|
|||
path('application/<str:application_id>/statistics/chat_record_aggregate_trend',
|
||||
views.ApplicationStatistics.ChatRecordAggregateTrend.as_view()),
|
||||
path('application/<str:application_id>/model', views.Application.Model.as_view()),
|
||||
path('application/<str:application_id>/model/<str:model_id>', views.Application.Model.Operate.as_view()),
|
||||
path('application/<str:application_id>/model/<str:model_id>/<str:ai_node_id>',
|
||||
views.Application.Model.Operate.as_view()),
|
||||
path('application/<str:application_id>/other-config', views.Application.Model.OtherConfig.as_view()),
|
||||
path('application/<str:application_id>/hit_test', views.Application.HitTest.as_view()),
|
||||
path('application/<str:application_id>/api_key', views.Application.ApplicationKey.as_view()),
|
||||
path("application/<str:application_id>/api_key/<str:api_key_id>",
|
||||
|
|
|
|||
|
|
@ -187,30 +187,6 @@ class Application(APIView):
|
|||
data={'application_id': application_id,
|
||||
'user_id': request.user.id}).list_model(request.query_params.get('model_type')))
|
||||
|
||||
class Operate(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
@swagger_auto_schema(operation_summary="获取应用参数设置其他字段",
|
||||
operation_id="获取应用参数设置其他字段",
|
||||
tags=["应用/会话"])
|
||||
def get(self, request: Request, application_id: str, model_id: str, ai_node_id=None):
|
||||
return result.success(
|
||||
ApplicationSerializer.Operate(
|
||||
data={'application_id': application_id, 'model_id': model_id,
|
||||
'ai_node_id': ai_node_id}).get_other_file_list())
|
||||
|
||||
class OtherConfig(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
@swagger_auto_schema(operation_summary="获取应用参数设置其他字段",
|
||||
operation_id="获取应用参数设置其他字段",
|
||||
tags=["应用/会话"])
|
||||
def put(self, request: Request, application_id: str):
|
||||
return result.success(
|
||||
ApplicationSerializer.Operate(
|
||||
data={'application_id': application_id}).save_other_config(
|
||||
request.data))
|
||||
|
||||
class Profile(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
|
|
|
|||
|
|
@ -20,3 +20,5 @@ from .text_input_field import *
|
|||
from .radio_button_field import *
|
||||
from .table_checkbox import *
|
||||
from .radio_card_field import *
|
||||
from .label import *
|
||||
from .slider_field import *
|
||||
|
|
|
|||
|
|
@ -9,6 +9,9 @@
|
|||
from enum import Enum
|
||||
from typing import List, Dict
|
||||
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms.label.base_label import BaseLabel
|
||||
|
||||
|
||||
class TriggerType(Enum):
|
||||
# 执行函数获取 OptionList数据
|
||||
|
|
@ -20,7 +23,7 @@ class TriggerType(Enum):
|
|||
class BaseField:
|
||||
def __init__(self,
|
||||
input_type: str,
|
||||
label: str,
|
||||
label: str or BaseLabel,
|
||||
required: bool = False,
|
||||
default_value: object = None,
|
||||
relation_show_field_dict: Dict = None,
|
||||
|
|
@ -53,10 +56,16 @@ class BaseField:
|
|||
self.required = required
|
||||
self.trigger_type = trigger_type
|
||||
|
||||
def to_dict(self):
|
||||
def is_valid(self, value):
|
||||
field_label = self.label.label if hasattr(self.label, 'to_dict') else self.label
|
||||
if self.required and value is None:
|
||||
raise AppApiException(500,
|
||||
f"{field_label} 为必填参数")
|
||||
|
||||
def to_dict(self, **kwargs):
|
||||
return {
|
||||
'input_type': self.input_type,
|
||||
'label': self.label,
|
||||
'label': self.label.to_dict(**kwargs) if hasattr(self.label, 'to_dict') else self.label,
|
||||
'required': self.required,
|
||||
'default_value': self.default_value,
|
||||
'relation_show_field_dict': self.relation_show_field_dict,
|
||||
|
|
@ -64,6 +73,7 @@ class BaseField:
|
|||
'trigger_type': self.trigger_type.value,
|
||||
'attrs': self.attrs,
|
||||
'props_info': self.props_info,
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -97,8 +107,8 @@ class BaseDefaultOptionField(BaseField):
|
|||
self.value_field = value_field
|
||||
self.option_list = option_list
|
||||
|
||||
def to_dict(self):
|
||||
return {**super().to_dict(), 'text_field': self.text_field, 'value_field': self.value_field,
|
||||
def to_dict(self, **kwargs):
|
||||
return {**super().to_dict(**kwargs), 'text_field': self.text_field, 'value_field': self.value_field,
|
||||
'option_list': self.option_list}
|
||||
|
||||
|
||||
|
|
@ -141,6 +151,6 @@ class BaseExecField(BaseField):
|
|||
self.provider = provider
|
||||
self.method = method
|
||||
|
||||
def to_dict(self):
|
||||
return {**super().to_dict(), 'text_field': self.text_field, 'value_field': self.value_field,
|
||||
def to_dict(self, **kwargs):
|
||||
return {**super().to_dict(**kwargs), 'text_field': self.text_field, 'value_field': self.value_field,
|
||||
'provider': self.provider, 'method': self.method}
|
||||
|
|
|
|||
|
|
@ -6,11 +6,25 @@
|
|||
@date:2023/11/1 16:04
|
||||
@desc:
|
||||
"""
|
||||
from typing import Dict
|
||||
|
||||
from common.forms import BaseField
|
||||
|
||||
|
||||
class BaseForm:
|
||||
def to_form_list(self):
|
||||
return [{**self.__getattribute__(key).to_dict(), 'field': key} for key in
|
||||
def to_form_list(self, **kwargs):
|
||||
return [{**self.__getattribute__(key).to_dict(**kwargs), 'field': key} for key in
|
||||
list(filter(lambda key: isinstance(self.__getattribute__(key), BaseField),
|
||||
[attr for attr in vars(self.__class__) if not attr.startswith("__")]))]
|
||||
|
||||
def valid_form(self, form_data):
|
||||
field_keys = list(filter(lambda key: isinstance(self.__getattribute__(key), BaseField),
|
||||
[attr for attr in vars(self.__class__) if not attr.startswith("__")]))
|
||||
for field_key in field_keys:
|
||||
self.__getattribute__(field_key).is_valid(form_data.get(field_key))
|
||||
|
||||
def get_default_form_data(self):
|
||||
return {key: self.__getattribute__(key).default_value for key in
|
||||
[attr for attr in vars(self.__class__) if not attr.startswith("__")] if
|
||||
isinstance(self.__getattribute__(key), BaseField) and self.__getattribute__(
|
||||
key).default_value is not None}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,10 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: __init__.py.py
|
||||
@date:2024/8/22 17:19
|
||||
@desc:
|
||||
"""
|
||||
from .base_label import *
|
||||
from .tooltip_label import *
|
||||
|
|
@ -0,0 +1,28 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: base_label.py
|
||||
@date:2024/8/22 17:11
|
||||
@desc:
|
||||
"""
|
||||
|
||||
|
||||
class BaseLabel:
|
||||
def __init__(self,
|
||||
input_type: str,
|
||||
label: str,
|
||||
attrs=None,
|
||||
props_info=None):
|
||||
self.input_type = input_type
|
||||
self.label = label
|
||||
self.attrs = attrs
|
||||
self.props_info = props_info
|
||||
|
||||
def to_dict(self, **kwargs):
|
||||
return {
|
||||
'input_type': self.input_type,
|
||||
'label': self.label,
|
||||
'attrs': {} if self.attrs is None else self.attrs,
|
||||
'props_info': {} if self.props_info is None else self.props_info,
|
||||
}
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: tooltip_label.py
|
||||
@date:2024/8/22 17:19
|
||||
@desc:
|
||||
"""
|
||||
from common.forms.label.base_label import BaseLabel
|
||||
|
||||
|
||||
class TooltipLabel(BaseLabel):
|
||||
def __init__(self, label, tooltip):
|
||||
super().__init__('TooltipLabel', label, attrs={'tooltip': tooltip}, props_info={})
|
||||
|
|
@ -0,0 +1,58 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: slider_field.py
|
||||
@date:2024/8/22 17:06
|
||||
@desc:
|
||||
"""
|
||||
from typing import Dict
|
||||
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseField, TriggerType, BaseLabel
|
||||
|
||||
|
||||
class SliderField(BaseField):
|
||||
"""
|
||||
滑块输入框
|
||||
"""
|
||||
|
||||
def __init__(self, label: str or BaseLabel,
|
||||
_min,
|
||||
_max,
|
||||
_step,
|
||||
precision,
|
||||
required: bool = False,
|
||||
default_value=None,
|
||||
relation_show_field_dict: Dict = None,
|
||||
attrs=None, props_info=None):
|
||||
"""
|
||||
@param label: 提示
|
||||
@param _min: 最小值
|
||||
@param _max: 最大值
|
||||
@param _step: 步长
|
||||
@param precision: 保留多少小数
|
||||
@param required: 是否必填
|
||||
@param default_value: 默认值
|
||||
@param relation_show_field_dict:
|
||||
@param attrs:
|
||||
@param props_info:
|
||||
"""
|
||||
_attrs = {'min': _min, 'max': _max, 'step': _step,
|
||||
'precision': precision, 'show-input-controls': False, 'show-input': True}
|
||||
if attrs is not None:
|
||||
_attrs.update(attrs)
|
||||
super().__init__('Slider', label, required, default_value, relation_show_field_dict,
|
||||
{},
|
||||
TriggerType.OPTION_LIST, _attrs, props_info)
|
||||
|
||||
def is_valid(self, value):
|
||||
super().is_valid(value)
|
||||
field_label = self.label.label if hasattr(self.label, 'to_dict') else self.label
|
||||
if value is not None:
|
||||
if value < self.attrs.get('min'):
|
||||
raise AppApiException(500,
|
||||
f"{field_label} 不能小于{self.attrs.get('min')}")
|
||||
if value > self.attrs.get('max'):
|
||||
raise AppApiException(500,
|
||||
f"{field_label} 不能大于{self.attrs.get('max')}")
|
||||
|
|
@ -108,10 +108,10 @@ class BaseModelCredential(ABC):
|
|||
"""
|
||||
pass
|
||||
|
||||
def get_other_fields(self, model_name):
|
||||
def get_model_params_setting_form(self, model_name):
|
||||
"""
|
||||
获取其他字段
|
||||
:return:
|
||||
模型参数设置表单
|
||||
:return:
|
||||
"""
|
||||
pass
|
||||
|
||||
|
|
|
|||
|
|
@ -1,11 +1,30 @@
|
|||
import os
|
||||
import re
|
||||
from typing import Dict
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from setting.models_provider.base_model_provider import ValidCode, BaseModelCredential
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm, TooltipLabel
|
||||
from setting.models_provider.base_model_provider import ValidCode, BaseModelCredential
|
||||
|
||||
|
||||
class BedrockLLMModelParams(BaseForm):
|
||||
temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
|
||||
required=True, default_value=0.7,
|
||||
_min=0.1,
|
||||
_max=1.0,
|
||||
_step=0.01,
|
||||
precision=2)
|
||||
|
||||
max_tokens = forms.SliderField(
|
||||
TooltipLabel('输出最大Tokens', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
|
||||
required=True, default_value=1024,
|
||||
_min=1,
|
||||
_max=4096,
|
||||
_step=1,
|
||||
precision=0)
|
||||
|
||||
|
||||
class BedrockLLMModelCredential(BaseForm, BaseModelCredential):
|
||||
|
|
@ -61,24 +80,5 @@ class BedrockLLMModelCredential(BaseForm, BaseModelCredential):
|
|||
access_key_id = forms.TextInputField('Access Key ID', required=True)
|
||||
secret_access_key = forms.PasswordInputField('Secret Access Key', required=True)
|
||||
|
||||
def get_other_fields(self, model_name):
|
||||
return {
|
||||
'temperature': {
|
||||
'value': 0.7,
|
||||
'min': 0.1,
|
||||
'max': 1,
|
||||
'step': 0.01,
|
||||
'label': '温度',
|
||||
'precision': 2,
|
||||
'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'
|
||||
},
|
||||
'max_tokens': {
|
||||
'value': 1024,
|
||||
'min': 1,
|
||||
'max': 4096,
|
||||
'step': 1,
|
||||
'label': '输出最大Tokens',
|
||||
'precision': 0,
|
||||
'tooltip': '指定模型可生成的最大token个数'
|
||||
}
|
||||
}
|
||||
def get_model_params_setting_form(self, model_name):
|
||||
return BedrockLLMModelParams()
|
||||
|
|
|
|||
|
|
@ -12,10 +12,27 @@ from langchain_core.messages import HumanMessage
|
|||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from common.forms import BaseForm, TooltipLabel
|
||||
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
|
||||
class AzureLLMModelParams(BaseForm):
|
||||
temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
|
||||
required=True, default_value=0.7,
|
||||
_min=0.1,
|
||||
_max=1.0,
|
||||
_step=0.01,
|
||||
precision=2)
|
||||
|
||||
max_tokens = forms.SliderField(
|
||||
TooltipLabel('输出最大Tokens', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
|
||||
required=True, default_value=800,
|
||||
_min=1,
|
||||
_max=4096,
|
||||
_step=1,
|
||||
precision=0)
|
||||
|
||||
|
||||
class AzureLLMModelCredential(BaseForm, BaseModelCredential):
|
||||
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||
|
|
@ -54,24 +71,5 @@ class AzureLLMModelCredential(BaseForm, BaseModelCredential):
|
|||
|
||||
deployment_name = forms.TextInputField("部署名 (deployment_name)", required=True)
|
||||
|
||||
def get_other_fields(self, model_name):
|
||||
return {
|
||||
'temperature': {
|
||||
'value': 0.7,
|
||||
'min': 0.1,
|
||||
'max': 1,
|
||||
'step': 0.01,
|
||||
'label': '温度',
|
||||
'precision': 2,
|
||||
'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'
|
||||
},
|
||||
'max_tokens': {
|
||||
'value': 800,
|
||||
'min': 1,
|
||||
'max': 4096,
|
||||
'step': 1,
|
||||
'label': '输出最大Tokens',
|
||||
'precision': 0,
|
||||
'tooltip': '指定模型可生成的最大token个数'
|
||||
}
|
||||
}
|
||||
def get_model_params_setting_form(self, model_name):
|
||||
return AzureLLMModelParams()
|
||||
|
|
|
|||
|
|
@ -12,10 +12,27 @@ from langchain_core.messages import HumanMessage
|
|||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from common.forms import BaseForm, TooltipLabel
|
||||
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
|
||||
class DeepSeekLLMModelParams(BaseForm):
|
||||
temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
|
||||
required=True, default_value=0.7,
|
||||
_min=0.1,
|
||||
_max=1.0,
|
||||
_step=0.01,
|
||||
precision=2)
|
||||
|
||||
max_tokens = forms.SliderField(
|
||||
TooltipLabel('输出最大Tokens', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
|
||||
required=True, default_value=800,
|
||||
_min=1,
|
||||
_max=4096,
|
||||
_step=1,
|
||||
precision=0)
|
||||
|
||||
|
||||
class DeepSeekLLMModelCredential(BaseForm, BaseModelCredential):
|
||||
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||
|
|
@ -47,24 +64,5 @@ class DeepSeekLLMModelCredential(BaseForm, BaseModelCredential):
|
|||
|
||||
api_key = forms.PasswordInputField('API Key', required=True)
|
||||
|
||||
def get_other_fields(self, model_name):
|
||||
return {
|
||||
'temperature': {
|
||||
'value': 0.7,
|
||||
'min': 0.1,
|
||||
'max': 1,
|
||||
'step': 0.01,
|
||||
'label': '温度',
|
||||
'precision': 2,
|
||||
'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'
|
||||
},
|
||||
'max_tokens': {
|
||||
'value': 800,
|
||||
'min': 1,
|
||||
'max': 4096,
|
||||
'step': 1,
|
||||
'label': '输出最大Tokens',
|
||||
'precision': 0,
|
||||
'tooltip': '指定模型可生成的最大token个数'
|
||||
}
|
||||
}
|
||||
def get_model_params_setting_form(self, model_name):
|
||||
return DeepSeekLLMModelParams()
|
||||
|
|
|
|||
|
|
@ -12,10 +12,27 @@ from langchain_core.messages import HumanMessage
|
|||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from common.forms import BaseForm, TooltipLabel
|
||||
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
|
||||
class GeminiLLMModelParams(BaseForm):
|
||||
temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
|
||||
required=True, default_value=0.7,
|
||||
_min=0.1,
|
||||
_max=1.0,
|
||||
_step=0.01,
|
||||
precision=2)
|
||||
|
||||
max_tokens = forms.SliderField(
|
||||
TooltipLabel('输出最大Tokens', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
|
||||
required=True, default_value=800,
|
||||
_min=1,
|
||||
_max=4096,
|
||||
_step=1,
|
||||
precision=0)
|
||||
|
||||
|
||||
class GeminiLLMModelCredential(BaseForm, BaseModelCredential):
|
||||
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||
|
|
@ -48,24 +65,5 @@ class GeminiLLMModelCredential(BaseForm, BaseModelCredential):
|
|||
|
||||
api_key = forms.PasswordInputField('API Key', required=True)
|
||||
|
||||
def get_other_fields(self, model_name):
|
||||
return {
|
||||
'temperature': {
|
||||
'value': 0.7,
|
||||
'min': 0.1,
|
||||
'max': 1.0,
|
||||
'step': 0.01,
|
||||
'label': '温度',
|
||||
'precision': 2,
|
||||
'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'
|
||||
},
|
||||
'max_tokens': {
|
||||
'value': 800,
|
||||
'min': 1,
|
||||
'max': 4096,
|
||||
'step': 1,
|
||||
'label': '输出最大Tokens',
|
||||
'precision': 0,
|
||||
'tooltip': '指定模型可生成的最大token个数'
|
||||
}
|
||||
}
|
||||
def get_model_params_setting_form(self, model_name):
|
||||
return GeminiLLMModelParams()
|
||||
|
|
|
|||
|
|
@ -12,10 +12,27 @@ from langchain_core.messages import HumanMessage
|
|||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from common.forms import BaseForm, TooltipLabel
|
||||
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
|
||||
class KimiLLMModelParams(BaseForm):
|
||||
temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
|
||||
required=True, default_value=0.3,
|
||||
_min=0.1,
|
||||
_max=1.0,
|
||||
_step=0.01,
|
||||
precision=2)
|
||||
|
||||
max_tokens = forms.SliderField(
|
||||
TooltipLabel('输出最大Tokens', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
|
||||
required=True, default_value=1024,
|
||||
_min=1,
|
||||
_max=4096,
|
||||
_step=1,
|
||||
precision=0)
|
||||
|
||||
|
||||
class KimiLLMModelCredential(BaseForm, BaseModelCredential):
|
||||
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||
|
|
@ -48,24 +65,5 @@ class KimiLLMModelCredential(BaseForm, BaseModelCredential):
|
|||
api_base = forms.TextInputField('API 域名', required=True)
|
||||
api_key = forms.PasswordInputField('API Key', required=True)
|
||||
|
||||
def get_other_fields(self, model_name):
|
||||
return {
|
||||
'temperature': {
|
||||
'value': 0.3,
|
||||
'min': 0.1,
|
||||
'max': 1.0,
|
||||
'step': 0.01,
|
||||
'label': '温度',
|
||||
'precision': 2,
|
||||
'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'
|
||||
},
|
||||
'max_tokens': {
|
||||
'value': 1024,
|
||||
'min': 1,
|
||||
'max': 4096,
|
||||
'step': 1,
|
||||
'label': '输出最大Tokens',
|
||||
'precision': 0,
|
||||
'tooltip': '指定模型可生成的最大token个数'
|
||||
}
|
||||
}
|
||||
def get_model_params_setting_form(self, model_name):
|
||||
return KimiLLMModelParams()
|
||||
|
|
|
|||
|
|
@ -10,10 +10,27 @@ from typing import Dict
|
|||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from common.forms import BaseForm, TooltipLabel
|
||||
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
|
||||
class OllamaLLMModelParams(BaseForm):
|
||||
temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
|
||||
required=True, default_value=0.3,
|
||||
_min=0.1,
|
||||
_max=1.0,
|
||||
_step=0.01,
|
||||
precision=2)
|
||||
|
||||
max_tokens = forms.SliderField(
|
||||
TooltipLabel('输出最大Tokens', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
|
||||
required=True, default_value=1024,
|
||||
_min=1,
|
||||
_max=4096,
|
||||
_step=1,
|
||||
precision=0)
|
||||
|
||||
|
||||
class OllamaLLMModelCredential(BaseForm, BaseModelCredential):
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||
raise_exception=False):
|
||||
|
|
@ -43,24 +60,5 @@ class OllamaLLMModelCredential(BaseForm, BaseModelCredential):
|
|||
api_base = forms.TextInputField('API 域名', required=True)
|
||||
api_key = forms.PasswordInputField('API Key', required=True)
|
||||
|
||||
def get_other_fields(self, model_name):
|
||||
return {
|
||||
'temperature': {
|
||||
'value': 0.3,
|
||||
'min': 0.1,
|
||||
'max': 1.0,
|
||||
'step': 0.01,
|
||||
'label': '温度',
|
||||
'precision': 2,
|
||||
'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'
|
||||
},
|
||||
'max_tokens': {
|
||||
'value': 1024,
|
||||
'min': 1,
|
||||
'max': 4096,
|
||||
'step': 1,
|
||||
'label': '输出最大Tokens',
|
||||
'precision': 0,
|
||||
'tooltip': '指定模型可生成的最大token个数'
|
||||
}
|
||||
}
|
||||
def get_model_params_setting_form(self, model_name):
|
||||
return OllamaLLMModelParams()
|
||||
|
|
|
|||
|
|
@ -12,10 +12,27 @@ from langchain_core.messages import HumanMessage
|
|||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from common.forms import BaseForm, TooltipLabel
|
||||
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
|
||||
class OpenAILLMModelParams(BaseForm):
|
||||
temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
|
||||
required=True, default_value=0.7,
|
||||
_min=0.1,
|
||||
_max=1.0,
|
||||
_step=0.01,
|
||||
precision=2)
|
||||
|
||||
max_tokens = forms.SliderField(
|
||||
TooltipLabel('输出最大Tokens', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
|
||||
required=True, default_value=800,
|
||||
_min=1,
|
||||
_max=4096,
|
||||
_step=1,
|
||||
precision=0)
|
||||
|
||||
|
||||
class OpenAILLMModelCredential(BaseForm, BaseModelCredential):
|
||||
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||
|
|
@ -48,24 +65,5 @@ class OpenAILLMModelCredential(BaseForm, BaseModelCredential):
|
|||
api_base = forms.TextInputField('API 域名', required=True)
|
||||
api_key = forms.PasswordInputField('API Key', required=True)
|
||||
|
||||
def get_other_fields(self, model_name):
|
||||
return {
|
||||
'temperature': {
|
||||
'value': 0.7,
|
||||
'min': 0.1,
|
||||
'max': 1.0,
|
||||
'step': 0.01,
|
||||
'label': '温度',
|
||||
'precision': 2,
|
||||
'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'
|
||||
},
|
||||
'max_tokens': {
|
||||
'value': 800,
|
||||
'min': 1,
|
||||
'max': 4096,
|
||||
'step': 1,
|
||||
'label': '输出最大Tokens',
|
||||
'precision': 0,
|
||||
'tooltip': '指定模型可生成的最大token个数'
|
||||
}
|
||||
}
|
||||
def get_model_params_setting_form(self, model_name):
|
||||
return OpenAILLMModelParams()
|
||||
|
|
|
|||
|
|
@ -12,10 +12,27 @@ from langchain_core.messages import HumanMessage
|
|||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from common.forms import BaseForm, TooltipLabel
|
||||
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
|
||||
class QwenModelParams(BaseForm):
|
||||
temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
|
||||
required=True, default_value=1.0,
|
||||
_min=0.1,
|
||||
_max=1.9,
|
||||
_step=0.01,
|
||||
precision=2)
|
||||
|
||||
max_tokens = forms.SliderField(
|
||||
TooltipLabel('输出最大Tokens', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
|
||||
required=True, default_value=800,
|
||||
_min=1,
|
||||
_max=2048,
|
||||
_step=1,
|
||||
precision=0)
|
||||
|
||||
|
||||
class OpenAILLMModelCredential(BaseForm, BaseModelCredential):
|
||||
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||
|
|
@ -46,24 +63,5 @@ class OpenAILLMModelCredential(BaseForm, BaseModelCredential):
|
|||
|
||||
api_key = forms.PasswordInputField('API Key', required=True)
|
||||
|
||||
def get_other_fields(self, model_name):
|
||||
return {
|
||||
'temperature': {
|
||||
'value': 1.0,
|
||||
'min': 0.1,
|
||||
'max': 1.9,
|
||||
'step': 0.01,
|
||||
'label': '温度',
|
||||
'precision': 2,
|
||||
'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'
|
||||
},
|
||||
'max_tokens': {
|
||||
'value': 800,
|
||||
'min': 1,
|
||||
'max': 2048,
|
||||
'step': 1,
|
||||
'label': '输出最大Tokens',
|
||||
'precision': 0,
|
||||
'tooltip': '指定模型可生成的最大token个数'
|
||||
}
|
||||
}
|
||||
def get_model_params_setting_form(self, model_name):
|
||||
return QwenModelParams()
|
||||
|
|
|
|||
|
|
@ -1,11 +1,21 @@
|
|||
# coding=utf-8
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from common.forms import BaseForm, TooltipLabel
|
||||
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
|
||||
class TencentLLMModelParams(BaseForm):
|
||||
temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
|
||||
required=True, default_value=0.5,
|
||||
_min=0.1,
|
||||
_max=2.0,
|
||||
_step=0.01,
|
||||
precision=2)
|
||||
|
||||
|
||||
class TencentLLMModelCredential(BaseForm, BaseModelCredential):
|
||||
REQUIRED_FIELDS = ['hunyuan_app_id', 'hunyuan_secret_id', 'hunyuan_secret_key']
|
||||
|
||||
|
|
@ -46,15 +56,5 @@ class TencentLLMModelCredential(BaseForm, BaseModelCredential):
|
|||
hunyuan_secret_id = forms.PasswordInputField('SecretId', required=True)
|
||||
hunyuan_secret_key = forms.PasswordInputField('SecretKey', required=True)
|
||||
|
||||
def get_other_fields(self, model_name):
|
||||
return {
|
||||
'temperature': {
|
||||
'value': 0.5,
|
||||
'min': 0.1,
|
||||
'max': 2.0,
|
||||
'step': 0.01,
|
||||
'label': '温度',
|
||||
'precision': 2,
|
||||
'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'
|
||||
},
|
||||
}
|
||||
def get_model_params_setting_form(self, model_name):
|
||||
return TencentLLMModelParams()
|
||||
|
|
|
|||
|
|
@ -6,10 +6,27 @@ from langchain_core.messages import HumanMessage
|
|||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from common.forms import BaseForm, TooltipLabel
|
||||
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
|
||||
class VLLMModelParams(BaseForm):
|
||||
temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
|
||||
required=True, default_value=0.7,
|
||||
_min=0.1,
|
||||
_max=1.0,
|
||||
_step=0.01,
|
||||
precision=2)
|
||||
|
||||
max_tokens = forms.SliderField(
|
||||
TooltipLabel('输出最大Tokens', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
|
||||
required=True, default_value=800,
|
||||
_min=1,
|
||||
_max=4096,
|
||||
_step=1,
|
||||
precision=0)
|
||||
|
||||
|
||||
class VLLMModelCredential(BaseForm, BaseModelCredential):
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||
raise_exception=False):
|
||||
|
|
@ -44,24 +61,5 @@ class VLLMModelCredential(BaseForm, BaseModelCredential):
|
|||
api_base = forms.TextInputField('API 域名', required=True)
|
||||
api_key = forms.PasswordInputField('API Key', required=True)
|
||||
|
||||
def get_other_fields(self, model_name):
|
||||
return {
|
||||
'temperature': {
|
||||
'value': 0.7,
|
||||
'min': 0.1,
|
||||
'max': 1,
|
||||
'step': 0.01,
|
||||
'label': '温度',
|
||||
'precision': 2,
|
||||
'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'
|
||||
},
|
||||
'max_tokens': {
|
||||
'value': 800,
|
||||
'min': 1,
|
||||
'max': 4096,
|
||||
'step': 1,
|
||||
'label': '输出最大Tokens',
|
||||
'precision': 0,
|
||||
'tooltip': '指定模型可生成的最大token个数'
|
||||
}
|
||||
}
|
||||
def get_model_params_setting_form(self, model_name):
|
||||
return VLLMModelParams()
|
||||
|
|
|
|||
|
|
@ -12,10 +12,27 @@ from langchain_core.messages import HumanMessage
|
|||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from common.forms import BaseForm, TooltipLabel
|
||||
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
|
||||
class VolcanicEngineLLMModelParams(BaseForm):
|
||||
temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
|
||||
required=True, default_value=0.3,
|
||||
_min=0.1,
|
||||
_max=1.0,
|
||||
_step=0.01,
|
||||
precision=2)
|
||||
|
||||
max_tokens = forms.SliderField(
|
||||
TooltipLabel('输出最大Tokens', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
|
||||
required=True, default_value=1024,
|
||||
_min=1,
|
||||
_max=4096,
|
||||
_step=1,
|
||||
precision=0)
|
||||
|
||||
|
||||
class VolcanicEngineLLMModelCredential(BaseForm, BaseModelCredential):
|
||||
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||
|
|
@ -49,24 +66,5 @@ class VolcanicEngineLLMModelCredential(BaseForm, BaseModelCredential):
|
|||
access_key_id = forms.PasswordInputField('Access Key ID', required=True)
|
||||
secret_access_key = forms.PasswordInputField('Secret Access Key', required=True)
|
||||
|
||||
def get_other_fields(self):
|
||||
return {
|
||||
'temperature': {
|
||||
'value': 0.3,
|
||||
'min': 0.1,
|
||||
'max': 1.0,
|
||||
'step': 0.01,
|
||||
'label': '温度',
|
||||
'precision': 2,
|
||||
'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'
|
||||
},
|
||||
'max_tokens': {
|
||||
'value': 1024,
|
||||
'min': 1,
|
||||
'max': 4096,
|
||||
'step': 1,
|
||||
'label': '输出最大Tokens',
|
||||
'precision': 0,
|
||||
'tooltip': '指定模型可生成的最大token个数'
|
||||
}
|
||||
}
|
||||
def get_model_params_setting_form(self, model_name):
|
||||
return VolcanicEngineLLMModelParams()
|
||||
|
|
|
|||
|
|
@ -12,10 +12,27 @@ from langchain_core.messages import HumanMessage
|
|||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from common.forms import BaseForm, TooltipLabel
|
||||
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
|
||||
class WenxinLLMModelParams(BaseForm):
|
||||
temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
|
||||
required=True, default_value=0.95,
|
||||
_min=0.1,
|
||||
_max=1.0,
|
||||
_step=0.01,
|
||||
precision=2)
|
||||
|
||||
max_tokens = forms.SliderField(
|
||||
TooltipLabel('输出最大Tokens', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
|
||||
required=True, default_value=1024,
|
||||
_min=2,
|
||||
_max=2048,
|
||||
_step=1,
|
||||
precision=0)
|
||||
|
||||
|
||||
class WenxinLLMModelCredential(BaseForm, BaseModelCredential):
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||
raise_exception=False):
|
||||
|
|
@ -54,24 +71,5 @@ class WenxinLLMModelCredential(BaseForm, BaseModelCredential):
|
|||
|
||||
secret_key = forms.PasswordInputField("Secret Key", required=True)
|
||||
|
||||
def get_other_fields(self, model_name):
|
||||
return {
|
||||
'temperature': {
|
||||
'value': 0.95,
|
||||
'min': 0.1,
|
||||
'max': 1.0,
|
||||
'step': 0.01,
|
||||
'label': '温度',
|
||||
'precision': 2,
|
||||
'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'
|
||||
},
|
||||
'max_tokens': {
|
||||
'value': 1024,
|
||||
'min': 2,
|
||||
'max': 2048,
|
||||
'step': 1,
|
||||
'label': '输出最大Tokens',
|
||||
'precision': 0,
|
||||
'tooltip': '指定模型可生成的最大token个数'
|
||||
}
|
||||
}
|
||||
def get_model_params_setting_form(self, model_name):
|
||||
return WenxinLLMModelParams()
|
||||
|
|
|
|||
|
|
@ -12,10 +12,44 @@ from langchain_core.messages import HumanMessage
|
|||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from common.forms import BaseForm, TooltipLabel
|
||||
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
|
||||
class XunFeiLLMModelGeneralParams(BaseForm):
|
||||
temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
|
||||
required=True, default_value=0.5,
|
||||
_min=0.1,
|
||||
_max=1.0,
|
||||
_step=0.01,
|
||||
precision=2)
|
||||
|
||||
max_tokens = forms.SliderField(
|
||||
TooltipLabel('输出最大Tokens', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
|
||||
required=True, default_value=4096,
|
||||
_min=1,
|
||||
_max=4096,
|
||||
_step=1,
|
||||
precision=0)
|
||||
|
||||
|
||||
class XunFeiLLMModelProParams(BaseForm):
|
||||
temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
|
||||
required=True, default_value=0.5,
|
||||
_min=0.1,
|
||||
_max=1.0,
|
||||
_step=0.01,
|
||||
precision=2)
|
||||
|
||||
max_tokens = forms.SliderField(
|
||||
TooltipLabel('输出最大Tokens', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
|
||||
required=True, default_value=4096,
|
||||
_min=1,
|
||||
_max=8192,
|
||||
_step=1,
|
||||
precision=0)
|
||||
|
||||
|
||||
class XunFeiLLMModelCredential(BaseForm, BaseModelCredential):
|
||||
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||
|
|
@ -50,27 +84,7 @@ class XunFeiLLMModelCredential(BaseForm, BaseModelCredential):
|
|||
spark_api_key = forms.PasswordInputField("API Key", required=True)
|
||||
spark_api_secret = forms.PasswordInputField('API Secret', required=True)
|
||||
|
||||
def get_other_fields(self, model_name):
|
||||
max_value = 8192
|
||||
def get_model_params_setting_form(self, model_name):
|
||||
if model_name == 'general' or model_name == 'pro-128k':
|
||||
max_value = 4096
|
||||
return {
|
||||
'temperature': {
|
||||
'value': 0.5,
|
||||
'min': 0.1,
|
||||
'max': 1,
|
||||
'step': 0.01,
|
||||
'label': '温度',
|
||||
'precision': 2,
|
||||
'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'
|
||||
},
|
||||
'max_tokens': {
|
||||
'value': 4096,
|
||||
'min': 1,
|
||||
'max': max_value,
|
||||
'step': 1,
|
||||
'label': '输出最大Tokens',
|
||||
'precision': 0,
|
||||
'tooltip': '指定模型可生成的最大token个数'
|
||||
}
|
||||
}
|
||||
return XunFeiLLMModelGeneralParams()
|
||||
return XunFeiLLMModelProParams()
|
||||
|
|
|
|||
|
|
@ -6,10 +6,27 @@ from langchain_core.messages import HumanMessage
|
|||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from common.forms import BaseForm, TooltipLabel
|
||||
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
|
||||
class XinferenceLLMModelParams(BaseForm):
|
||||
temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
|
||||
required=True, default_value=0.7,
|
||||
_min=0.1,
|
||||
_max=1.0,
|
||||
_step=0.01,
|
||||
precision=2)
|
||||
|
||||
max_tokens = forms.SliderField(
|
||||
TooltipLabel('输出最大Tokens', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
|
||||
required=True, default_value=800,
|
||||
_min=1,
|
||||
_max=4096,
|
||||
_step=1,
|
||||
precision=0)
|
||||
|
||||
|
||||
class XinferenceLLMModelCredential(BaseForm, BaseModelCredential):
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||
raise_exception=False):
|
||||
|
|
@ -40,24 +57,5 @@ class XinferenceLLMModelCredential(BaseForm, BaseModelCredential):
|
|||
api_base = forms.TextInputField('API 域名', required=True)
|
||||
api_key = forms.PasswordInputField('API Key', required=True)
|
||||
|
||||
def get_other_fields(self, model_name):
|
||||
return {
|
||||
'temperature': {
|
||||
'value': 0.7,
|
||||
'min': 0.1,
|
||||
'max': 1,
|
||||
'step': 0.01,
|
||||
'label': '温度',
|
||||
'precision': 2,
|
||||
'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'
|
||||
},
|
||||
'max_tokens': {
|
||||
'value': 800,
|
||||
'min': 1,
|
||||
'max': 4096,
|
||||
'step': 1,
|
||||
'label': '输出最大Tokens',
|
||||
'precision': 0,
|
||||
'tooltip': '指定模型可生成的最大token个数'
|
||||
}
|
||||
}
|
||||
def get_model_params_setting_form(self, model_name):
|
||||
return XinferenceLLMModelParams()
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
# coding=utf-8
|
||||
|
||||
from typing import List, Dict
|
||||
from typing import Dict
|
||||
from urllib.parse import urlparse, ParseResult
|
||||
|
||||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||
from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI
|
||||
|
||||
|
|
|
|||
|
|
@ -12,10 +12,27 @@ from langchain_core.messages import HumanMessage
|
|||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from common.forms import BaseForm, TooltipLabel
|
||||
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
|
||||
class ZhiPuLLMModelParams(BaseForm):
|
||||
temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
|
||||
required=True, default_value=0.95,
|
||||
_min=0.1,
|
||||
_max=1.0,
|
||||
_step=0.01,
|
||||
precision=2)
|
||||
|
||||
max_tokens = forms.SliderField(
|
||||
TooltipLabel('输出最大Tokens', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
|
||||
required=True, default_value=1024,
|
||||
_min=1,
|
||||
_max=4096,
|
||||
_step=1,
|
||||
precision=0)
|
||||
|
||||
|
||||
class ZhiPuLLMModelCredential(BaseForm, BaseModelCredential):
|
||||
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||
|
|
@ -46,25 +63,5 @@ class ZhiPuLLMModelCredential(BaseForm, BaseModelCredential):
|
|||
|
||||
api_key = forms.PasswordInputField('API Key', required=True)
|
||||
|
||||
def get_other_fields(self, model_name):
|
||||
return {
|
||||
'temperature': {
|
||||
'value': 0.95,
|
||||
'min': 0.1,
|
||||
'max': 1.0,
|
||||
'step': 0.01,
|
||||
'label': '温度',
|
||||
'precision': 2,
|
||||
'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'
|
||||
},
|
||||
'max_tokens': {
|
||||
'value': 1024,
|
||||
'min': 1,
|
||||
'max': 4095,
|
||||
'step': 1,
|
||||
'label': '输出最大Tokens',
|
||||
'precision': 0,
|
||||
'tooltip': '指定模型可生成的最大token个数'
|
||||
}
|
||||
}
|
||||
|
||||
def get_model_params_setting_form(self, model_name):
|
||||
return ZhiPuLLMModelParams()
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ from common.util.field_message import ErrMessage
|
|||
from common.util.rsa_util import rsa_long_decrypt, rsa_long_encrypt
|
||||
from dataset.models import DataSet
|
||||
from setting.models.model_management import Model, Status
|
||||
from setting.models_provider import get_model, get_model_credential
|
||||
from setting.models_provider.base_model_provider import ValidCode, DownModelChunkStatus
|
||||
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
|
||||
|
||||
|
|
@ -234,6 +235,14 @@ class ModelSerializer(serializers.Serializer):
|
|||
'meta': model.meta
|
||||
}
|
||||
|
||||
def get_model_params(self, with_valid=True):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
model_id = self.data.get('id')
|
||||
model = QuerySet(Model).filter(id=model_id).first()
|
||||
credential = get_model_credential(model.provider, model.model_type, model.model_name)
|
||||
return credential.get_model_params_setting_form(model.model_name).to_form_list()
|
||||
|
||||
def delete(self, with_valid=True):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ urlpatterns = [
|
|||
path('provider/model_form', views.Provide.ModelForm.as_view(),
|
||||
name="provider/model_form"),
|
||||
path('model', views.Model.as_view(), name='model'),
|
||||
path('model/<str:model_id>/model_params_form', views.Model.ModelParamsForm.as_view(), name='model/model_params_form'),
|
||||
path('model/<str:model_id>', views.Model.Operate.as_view(), name='model/operate'),
|
||||
path('model/<str:model_id>/pause_download', views.Model.PauseDownload.as_view(), name='model/operate'),
|
||||
path('model/<str:model_id>/meta', views.Model.ModelMeta.as_view(), name='model/operate/meta'),
|
||||
|
|
|
|||
|
|
@ -81,6 +81,19 @@ class Model(APIView):
|
|||
return result.success(
|
||||
ModelSerializer.Operate(data={'id': model_id, 'user_id': request.user.id}).pause_download())
|
||||
|
||||
class ModelParamsForm(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
@action(methods=['GET'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="获取模型参数表单",
|
||||
operation_id="获取模型参数表单",
|
||||
manual_parameters=ProvideApi.ModelForm.get_request_params_api(),
|
||||
tags=["模型"])
|
||||
@has_permissions(PermissionConstants.MODEL_READ)
|
||||
def get(self, request: Request, model_id: str):
|
||||
return result.success(
|
||||
ModelSerializer.Operate(data={'id': model_id, 'user_id': request.user.id}).get_model_params())
|
||||
|
||||
class Operate(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
|
|
|
|||
|
|
@ -249,41 +249,6 @@ const putPublishApplication: (
|
|||
return put(`${prefix}/${application_id}/publish`, data, undefined, loading)
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取模型其他配置
|
||||
* @param application_id
|
||||
* @param model_id
|
||||
* @param loading
|
||||
* @returns
|
||||
*/
|
||||
const getModelOtherConfig: (
|
||||
application_id: string,
|
||||
model_id: string,
|
||||
ai_node_id?: string,
|
||||
loading?: Ref<boolean>
|
||||
) => Promise<Result<any>> = (application_id, model_id, ai_node_id?, loading?) => {
|
||||
if (ai_node_id) {
|
||||
// 如果 ai_node_id 不为空,则调用带有 ai_node_id 的 API
|
||||
return get(`${prefix}/${application_id}/model/${model_id}/${ai_node_id}`, undefined, loading)
|
||||
} else {
|
||||
// 如果 ai_node_id 为空,则调用不带 ai_node_id 的 API
|
||||
return get(`${prefix}/${application_id}/model/${model_id}`, undefined, loading)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 保存其他配置信息
|
||||
* @param application_id
|
||||
*
|
||||
*/
|
||||
const putModelOtherConfig: (
|
||||
application_id: string,
|
||||
data: any,
|
||||
loading?: Ref<boolean>
|
||||
) => Promise<Result<any>> = (application_id, data, loading) => {
|
||||
return put(`${prefix}/${application_id}/other-config`, data, undefined, loading)
|
||||
}
|
||||
|
||||
export default {
|
||||
getAllAppilcation,
|
||||
getApplication,
|
||||
|
|
@ -303,7 +268,5 @@ export default {
|
|||
getApplicationHitTest,
|
||||
getApplicationModel,
|
||||
putPublishApplication,
|
||||
postWorkflowChatOpen,
|
||||
getModelOtherConfig,
|
||||
putModelOtherConfig
|
||||
postWorkflowChatOpen
|
||||
}
|
||||
|
|
|
|||
|
|
@ -51,6 +51,18 @@ const getModelCreateForm: (
|
|||
return get(`${prefix_provider}/model_form`, { provider, model_type, model_name }, loading)
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取模型参数表单
|
||||
* @param model_id 模型id
|
||||
* @param loading
|
||||
* @returns
|
||||
*/
|
||||
const getModelParamsForm: (
|
||||
model_id: string,
|
||||
loading?: Ref<boolean>
|
||||
) => Promise<Result<Array<FormField>>> = (model_id, loading) => {
|
||||
return get(`model/${model_id}/model_params_form`, {}, loading)
|
||||
}
|
||||
/**
|
||||
* 获取模型类型列表
|
||||
* @param provider 供应商
|
||||
|
|
@ -159,5 +171,6 @@ export default {
|
|||
deleteModel,
|
||||
getModelById,
|
||||
getModelMetaById,
|
||||
pauseDownload
|
||||
pauseDownload,
|
||||
getModelParamsForm
|
||||
}
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ interface ApplicationFormType {
|
|||
icon?: string | undefined
|
||||
type?: string
|
||||
work_flow?: any
|
||||
model_params_setting?: any
|
||||
}
|
||||
interface chatType {
|
||||
id: string
|
||||
|
|
|
|||
|
|
@ -1,6 +1,11 @@
|
|||
<template>
|
||||
<div style="width: 1024px">
|
||||
<DynamicsForm v-model="form_data" :render_data="damo_data" ref="dynamicsFormRef">
|
||||
<DynamicsForm
|
||||
v-model="form_data"
|
||||
:model="form_data"
|
||||
:render_data="damo_data"
|
||||
ref="dynamicsFormRef"
|
||||
>
|
||||
<template #default="scope">
|
||||
<el-form-item label="其他字段">
|
||||
<el-input v-model="scope.form_value['zha']" /> </el-form-item
|
||||
|
|
@ -14,6 +19,7 @@ import type { FormField } from '@/components/dynamics-form/type'
|
|||
import DynamicsForm from '@/components/dynamics-form/index.vue'
|
||||
import { ref } from 'vue'
|
||||
import type { Dict } from '@/api/type/common'
|
||||
|
||||
const damo_data: Array<FormField> = [
|
||||
{ field: 'name', input_type: 'PasswordInput', label: '用戶名', required: false },
|
||||
{
|
||||
|
|
@ -29,6 +35,20 @@ const damo_data: Array<FormField> = [
|
|||
{ field: 'name3', input_type: 'TextInput', label: '用戶名3' }
|
||||
]
|
||||
},
|
||||
{
|
||||
field: 'maxkb_tokens',
|
||||
input_type: 'Slider',
|
||||
default_value: 1,
|
||||
attrs: {
|
||||
min: 0,
|
||||
max: 10,
|
||||
step: 1,
|
||||
precision: 1,
|
||||
'show-input-controls': false,
|
||||
'show-input': true
|
||||
},
|
||||
label: { label: '温度', attrs: { tooltip: 'sss' }, input_type: 'TooltipLabel' }
|
||||
},
|
||||
{
|
||||
field: 'object_card_field',
|
||||
input_type: 'ObjectCard',
|
||||
|
|
|
|||
|
|
@ -2,13 +2,18 @@
|
|||
<el-form-item
|
||||
v-loading="loading"
|
||||
:style="formItemStyle"
|
||||
:label="formfield.label"
|
||||
:prop="formfield.field"
|
||||
:key="formfield.field"
|
||||
:rules="rules"
|
||||
>
|
||||
<template #label v-if="formfield.label">
|
||||
<FormItemLabel :form-field="formfield"></FormItemLabel>
|
||||
<FormItemLabel v-if="isString(formfield.label)" :form-field="formfield"></FormItemLabel>
|
||||
<component
|
||||
v-else
|
||||
:is="formfield.label.input_type"
|
||||
:label="formfield.label.label"
|
||||
v-bind="label_attrs"
|
||||
></component>
|
||||
</template>
|
||||
<component
|
||||
ref="componentFormRef"
|
||||
|
|
@ -58,6 +63,9 @@ const emit = defineEmits(['change'])
|
|||
|
||||
const loading = ref<boolean>(false)
|
||||
|
||||
const isString = (value: any) => {
|
||||
return typeof value === 'string'
|
||||
}
|
||||
const itemValue = computed({
|
||||
get: () => {
|
||||
return props.modelValue
|
||||
|
|
@ -72,7 +80,13 @@ const itemValue = computed({
|
|||
}
|
||||
})
|
||||
const componentFormRef = ref<any>()
|
||||
|
||||
const label_attrs = computed(() => {
|
||||
return props.formfield.label &&
|
||||
typeof props.formfield.label !== 'string' &&
|
||||
props.formfield.label.attrs
|
||||
? props.formfield.label.attrs
|
||||
: {}
|
||||
})
|
||||
const props_info = computed(() => {
|
||||
return props.formfield.props_info ? props.formfield.props_info : {}
|
||||
})
|
||||
|
|
@ -87,7 +101,11 @@ const formItemStyle = computed(() => {
|
|||
* 表单错误Msg
|
||||
*/
|
||||
const errMsg = computed(() => {
|
||||
return props_info.value.err_msg ? props_info.value.err_msg : props.formfield.label + '不能为空'
|
||||
return props_info.value.err_msg
|
||||
? props_info.value.err_msg
|
||||
: isString(props.formfield.label)
|
||||
? props.formfield.label
|
||||
: props.formfield.label.label + '不能为空'
|
||||
})
|
||||
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@
|
|||
import type { Dict } from '@/api/type/common'
|
||||
import FormItem from '@/components/dynamics-form/FormItem.vue'
|
||||
import type { FormField } from '@/components/dynamics-form/type'
|
||||
import { ref, onMounted, watch, type Ref } from 'vue'
|
||||
import { ref, onBeforeMount, watch, type Ref } from 'vue'
|
||||
import type { FormInstance } from 'element-plus'
|
||||
import triggerApi from '@/api/provider'
|
||||
import type Result from '@/request/Result'
|
||||
|
|
@ -152,7 +152,7 @@ const initDefaultData = (formField: FormField) => {
|
|||
}
|
||||
}
|
||||
|
||||
onMounted(() => {
|
||||
onBeforeMount(() => {
|
||||
render(props.render_data, {})
|
||||
})
|
||||
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@
|
|||
require-asterisk-position="right"
|
||||
ref="ceFormRef"
|
||||
v-model="_data[index]"
|
||||
:model="_data[index]"
|
||||
:other-params="other"
|
||||
:render_data="render_data()"
|
||||
v-bind="attr"
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@
|
|||
require-asterisk-position="right"
|
||||
ref="ceFormRef"
|
||||
v-model="_data[index]"
|
||||
:model="_data[index]"
|
||||
:other-params="other"
|
||||
:render_data="render_data()"
|
||||
v-bind="attr"
|
||||
|
|
|
|||
|
|
@ -0,0 +1,40 @@
|
|||
<template>
|
||||
<div class="flex align-center" style="display: inline-flex">
|
||||
<div class="flex-between mr-4">
|
||||
<span>{{ label }}</span>
|
||||
</div>
|
||||
<el-tooltip effect="dark" placement="right">
|
||||
<template #content>{{ tooltip }}</template>
|
||||
<AppIcon iconName="app-warning" class="app-warning-icon"></AppIcon>
|
||||
</el-tooltip>
|
||||
</div>
|
||||
</template>
|
||||
<script setup lang="ts">
|
||||
defineProps<{
|
||||
label: string
|
||||
tooltip: string
|
||||
}>()
|
||||
</script>
|
||||
<style lang="scss" scope>
|
||||
.aiMode-param-dialog {
|
||||
padding: 8px 8px 24px 8px;
|
||||
|
||||
.el-dialog__header {
|
||||
padding: 16px 16px 0 16px;
|
||||
}
|
||||
|
||||
.el-dialog__body {
|
||||
padding: 16px !important;
|
||||
}
|
||||
|
||||
.dialog-max-height {
|
||||
height: 550px;
|
||||
}
|
||||
|
||||
.custom-slider {
|
||||
.el-input-number.is-without-controls .el-input__wrapper {
|
||||
padding: 0 !important;
|
||||
}
|
||||
}
|
||||
}
|
||||
</style>
|
||||
|
|
@ -0,0 +1,11 @@
|
|||
<template>
|
||||
<el-slider v-bind="$attrs" show-input :show-input-controls="false" class="custom-slider" />
|
||||
</template>
|
||||
<script setup lang="ts"></script>
|
||||
<style lang="scss" scoped>
|
||||
.custom-slider {
|
||||
.el-input-number.is-without-controls .el-input__wrapper {
|
||||
padding: 0 !important;
|
||||
}
|
||||
}
|
||||
</style>
|
||||
|
|
@ -117,7 +117,7 @@ interface FormField {
|
|||
/**
|
||||
* 提示
|
||||
*/
|
||||
label?: string
|
||||
label?: string | any
|
||||
/**
|
||||
* 是否 必填
|
||||
*/
|
||||
|
|
|
|||
|
|
@ -130,37 +130,6 @@ const useApplicationStore = defineStore({
|
|||
reject(error)
|
||||
})
|
||||
})
|
||||
},
|
||||
// 获取模型的 温度/max_token字段设置
|
||||
async asyncGetModelConfig(
|
||||
id: string,
|
||||
model_id: string,
|
||||
ai_node_id?: string,
|
||||
loading?: Ref<boolean>
|
||||
) {
|
||||
return new Promise((resolve, reject) => {
|
||||
applicationApi
|
||||
.getModelOtherConfig(id, model_id, ai_node_id, loading)
|
||||
.then((data) => {
|
||||
resolve(data)
|
||||
})
|
||||
.catch((error) => {
|
||||
reject(error)
|
||||
})
|
||||
})
|
||||
},
|
||||
// 保存应用的 温度/max_token字段设置
|
||||
async asyncPostModelConfig(id: string, params?: any, loading?: Ref<boolean>) {
|
||||
return new Promise((resolve, reject) => {
|
||||
applicationApi
|
||||
.putModelOtherConfig(id, params, loading)
|
||||
.then((data) => {
|
||||
resolve(data)
|
||||
})
|
||||
.catch((error) => {
|
||||
reject(error)
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
})
|
||||
|
|
|
|||
|
|
@ -328,7 +328,7 @@
|
|||
</el-col>
|
||||
</el-row>
|
||||
|
||||
<AIModeParamSettingDialog ref="AIModeParamSettingDialogRef" :id="id" @refresh="refreshForm" />
|
||||
<AIModeParamSettingDialog ref="AIModeParamSettingDialogRef" @refresh="refreshForm" />
|
||||
<ParamSettingDialog ref="ParamSettingDialogRef" @refresh="refreshParam" />
|
||||
<AddDatasetDialog
|
||||
ref="AddDatasetDialogRef"
|
||||
|
|
@ -348,8 +348,8 @@
|
|||
</LayoutContainer>
|
||||
</template>
|
||||
<script setup lang="ts">
|
||||
import { reactive, ref, watch, onMounted } from 'vue'
|
||||
import { useRouter, useRoute } from 'vue-router'
|
||||
import { reactive, ref, onMounted } from 'vue'
|
||||
import { useRoute } from 'vue-router'
|
||||
import { groupBy } from 'lodash'
|
||||
import AIModeParamSettingDialog from './component/AIModeParamSettingDialog.vue'
|
||||
import ParamSettingDialog from './component/ParamSettingDialog.vue'
|
||||
|
|
@ -408,6 +408,7 @@ const applicationForm = ref<ApplicationFormType>({
|
|||
model_setting: {
|
||||
prompt: defaultPrompt
|
||||
},
|
||||
model_params_setting: {},
|
||||
problem_optimization: false,
|
||||
type: 'SIMPLE'
|
||||
})
|
||||
|
|
@ -456,9 +457,7 @@ const openAIParamSettingDialog = () => {
|
|||
MsgSuccess(t('请选择AI 模型'))
|
||||
return
|
||||
}
|
||||
application.asyncGetModelConfig(id, model_id, '', loading).then((res: any) => {
|
||||
AIModeParamSettingDialogRef.value?.open(res.data)
|
||||
})
|
||||
AIModeParamSettingDialogRef.value?.open(model_id, applicationForm.value.model_params_setting)
|
||||
}
|
||||
|
||||
const openParamSettingDialog = () => {
|
||||
|
|
@ -470,11 +469,7 @@ function refreshParam(data: any) {
|
|||
}
|
||||
|
||||
function refreshForm(data: any) {
|
||||
// data是一个对象 把data的值赋值给model_setting 但是是合并
|
||||
applicationForm.value.model_setting = {
|
||||
...applicationForm.value.model_setting,
|
||||
...data
|
||||
}
|
||||
applicationForm.value.model_params_setting = data
|
||||
}
|
||||
|
||||
const openCreateModel = (provider?: Provider) => {
|
||||
|
|
|
|||
|
|
@ -7,31 +7,16 @@
|
|||
style="width: 550px"
|
||||
append-to-body
|
||||
>
|
||||
<el-form label-position="top" ref="paramFormRef" :model="form">
|
||||
<el-form-item v-for="(item, key) in form" :key="key">
|
||||
<template #label>
|
||||
<div class="flex align-center">
|
||||
<div class="flex-between mr-4">
|
||||
<span>{{ item.label }}</span>
|
||||
</div>
|
||||
<el-tooltip effect="dark" placement="right">
|
||||
<template #content>{{ item.tooltip }}</template>
|
||||
<AppIcon iconName="app-warning" class="app-warning-icon"></AppIcon>
|
||||
</el-tooltip>
|
||||
</div>
|
||||
</template>
|
||||
<el-slider
|
||||
v-model="item.value"
|
||||
show-input
|
||||
:show-input-controls="false"
|
||||
:min="item.min"
|
||||
:max="item.max"
|
||||
:precision="item.precision || 0"
|
||||
:step="item.step || 1"
|
||||
class="custom-slider"
|
||||
/>
|
||||
</el-form-item>
|
||||
</el-form>
|
||||
<DynamicsForm
|
||||
v-model="form_data"
|
||||
:model="form_data"
|
||||
label-position="top"
|
||||
require-asterisk-position="right"
|
||||
:render_data="model_form_field"
|
||||
ref="dynamicsFormRef"
|
||||
>
|
||||
</DynamicsForm>
|
||||
|
||||
<template #footer>
|
||||
<span class="dialog-footer p-16">
|
||||
<el-button @click.prevent="dialogVisible = false">
|
||||
|
|
@ -46,82 +31,40 @@
|
|||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, reactive, watch } from 'vue'
|
||||
import { cloneDeep, set } from 'lodash'
|
||||
import { ref } from 'vue'
|
||||
|
||||
import type { FormInstance } from 'element-plus'
|
||||
import useStore from '@/stores'
|
||||
|
||||
const { application } = useStore()
|
||||
|
||||
import type { FormField } from '@/components/dynamics-form/type'
|
||||
import modelAPi from '@/api/model'
|
||||
import DynamicsForm from '@/components/dynamics-form/index.vue'
|
||||
import { keys } from 'lodash'
|
||||
const model_form_field = ref<Array<FormField>>([])
|
||||
const emit = defineEmits(['refresh'])
|
||||
|
||||
const paramFormRef = ref<FormInstance>()
|
||||
const form = reactive<Form>({})
|
||||
const dynamicsFormRef = ref<InstanceType<typeof DynamicsForm>>()
|
||||
const form_data = ref<any>({})
|
||||
const dialogVisible = ref(false)
|
||||
const loading = ref(false)
|
||||
const props = defineProps<{
|
||||
id: string
|
||||
nodeId?: string
|
||||
}>()
|
||||
const resetForm = () => {
|
||||
// 清空 form 对象,等待新的数据
|
||||
Object.keys(form).forEach((key) => delete form[key])
|
||||
}
|
||||
|
||||
interface Form {
|
||||
[key: string]: FormField
|
||||
}
|
||||
|
||||
interface FormField {
|
||||
value: any
|
||||
min?: number
|
||||
max?: number
|
||||
step?: number
|
||||
label?: string
|
||||
precision?: number
|
||||
tooltip?: string
|
||||
}
|
||||
|
||||
const open = (data: any) => {
|
||||
const newData = cloneDeep(data)
|
||||
Object.keys(form).forEach((key) => {
|
||||
delete form[key]
|
||||
})
|
||||
Object.keys(newData).forEach((key) => {
|
||||
set(form, key, newData[key])
|
||||
const open = (model_id: string, model_setting_data?: any) => {
|
||||
modelAPi.getModelParamsForm(model_id, loading).then((ok) => {
|
||||
model_form_field.value = ok.data
|
||||
model_setting_data =
|
||||
model_setting_data && keys(model_setting_data).length > 0
|
||||
? model_setting_data
|
||||
: ok.data
|
||||
.map((item) => ({ [item.field]: item.default_value }))
|
||||
.reduce((x, y) => ({ ...x, ...y }), {})
|
||||
// 渲染动态表单
|
||||
dynamicsFormRef.value?.render(model_form_field.value, model_setting_data)
|
||||
})
|
||||
dialogVisible.value = true
|
||||
}
|
||||
|
||||
const submit = async () => {
|
||||
if (paramFormRef.value) {
|
||||
await paramFormRef.value.validate((valid, fields) => {
|
||||
if (valid) {
|
||||
const data = Object.keys(form).reduce(
|
||||
(acc, key) => {
|
||||
acc[key] = form[key].value
|
||||
return acc
|
||||
},
|
||||
{} as Record<string, any>
|
||||
)
|
||||
if (props.nodeId) {
|
||||
data.node_id = props.nodeId
|
||||
}
|
||||
application.asyncPostModelConfig(props.id, data, loading).then(() => {
|
||||
emit('refresh', data)
|
||||
dialogVisible.value = false
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
emit('refresh', form_data.value)
|
||||
dialogVisible.value = false
|
||||
}
|
||||
|
||||
watch(dialogVisible, (bool) => {
|
||||
if (!bool) {
|
||||
resetForm()
|
||||
}
|
||||
})
|
||||
|
||||
defineExpose({ open })
|
||||
</script>
|
||||
|
||||
|
|
|
|||
|
|
@ -196,13 +196,7 @@
|
|||
@change="openCreateModel($event)"
|
||||
></CreateModelDialog>
|
||||
<SelectProviderDialog ref="selectProviderRef" @change="openCreateModel($event)" />
|
||||
|
||||
<AIModeParamSettingDialog
|
||||
ref="AIModeParamSettingDialogRef"
|
||||
:id="id"
|
||||
:node-id="props.nodeModel.id"
|
||||
@refresh="refreshParam"
|
||||
/>
|
||||
<AIModeParamSettingDialog ref="AIModeParamSettingDialogRef" @refresh="refreshParam" />
|
||||
</NodeContainer>
|
||||
</template>
|
||||
<script setup lang="ts">
|
||||
|
|
@ -319,15 +313,12 @@ const openCreateModel = (provider?: Provider) => {
|
|||
|
||||
const openAIParamSettingDialog = (modelId: string) => {
|
||||
if (modelId) {
|
||||
application.asyncGetModelConfig(id, modelId, props.nodeModel.id).then((res: any) => {
|
||||
AIModeParamSettingDialogRef.value?.open(res.data)
|
||||
})
|
||||
AIModeParamSettingDialogRef.value?.open(modelId)
|
||||
}
|
||||
}
|
||||
|
||||
function refreshParam(data: any) {
|
||||
chat_data.value.temperature = data.temperature
|
||||
chat_data.value.max_tokens = data.max_tokens
|
||||
set(props.nodeModel.properties.node_data, 'model_params_setting', data)
|
||||
}
|
||||
|
||||
onMounted(() => {
|
||||
|
|
|
|||
|
|
@ -178,6 +178,7 @@
|
|||
@change="openCreateModel($event)"
|
||||
></CreateModelDialog>
|
||||
<SelectProviderDialog ref="selectProviderRef" @change="openCreateModel($event)" />
|
||||
<AIModeParamSettingDialog ref="AIModeParamSettingDialogRef" @refresh="refreshParam" />
|
||||
</NodeContainer>
|
||||
</template>
|
||||
<script setup lang="ts">
|
||||
|
|
@ -186,6 +187,7 @@ import { app } from '@/main'
|
|||
import NodeContainer from '@/workflow/common/NodeContainer.vue'
|
||||
import CreateModelDialog from '@/views/template/component/CreateModelDialog.vue'
|
||||
import SelectProviderDialog from '@/views/template/component/SelectProviderDialog.vue'
|
||||
import AIModeParamSettingDialog from '@/views/application/component/AIModeParamSettingDialog.vue'
|
||||
import type { FormInstance } from 'element-plus'
|
||||
import { ref, computed, onMounted } from 'vue'
|
||||
import applicationApi from '@/api/application'
|
||||
|
|
@ -229,6 +231,9 @@ const form = {
|
|||
dialogue_number: 1,
|
||||
is_result: false
|
||||
}
|
||||
function refreshParam(data: any) {
|
||||
set(props.nodeModel.properties.node_data, 'model_params_setting', data)
|
||||
}
|
||||
|
||||
const form_data = computed({
|
||||
get: () => {
|
||||
|
|
|
|||
Loading…
Reference in New Issue