refactor: update model params

This commit is contained in:
wxg0103 2024-08-19 10:38:55 +08:00
parent 9d2747cc4b
commit 35f3ff779b
19 changed files with 568 additions and 276 deletions

View File

@ -23,6 +23,9 @@ class ChatNodeSerializer(serializers.Serializer):
dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer("多轮对话数量"))
is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容'))
temperature = serializers.FloatField(required=False, allow_null=True, error_messages=ErrMessage.float("温度"))
max_tokens = serializers.IntegerField(required=False, allow_null=True,
error_messages=ErrMessage.integer("最大token数"))
class IChatNode(INode):

View File

@ -66,7 +66,11 @@ def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wor
class BaseChatNode(IChatNode):
def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id,
**kwargs) -> NodeResult:
chat_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'))
kwargs = {k: v for k, v in kwargs.items() if k in ['temperature', 'max_tokens']}
chat_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'),
**kwargs
)
history_message = self.get_history_message(history_chat_record, dialogue_number)
self.context['history_message'] = history_message
question = self.generate_prompt_question(prompt)

View File

@ -547,6 +547,7 @@ class ApplicationSerializer(serializers.Serializer):
application_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("应用id"))
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))
model_id = serializers.UUIDField(required=False, error_messages=ErrMessage.uuid("模型id"))
ai_node_id = serializers.UUIDField(required=False, error_messages=ErrMessage.uuid("AI节点id"))
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
@ -731,31 +732,61 @@ class ApplicationSerializer(serializers.Serializer):
def get_other_file_list(self):
temperature = None
max_tokens = None
application = Application.objects.filter(id=self.initial_data.get("application_id")).first()
application_id = self.initial_data.get("application_id")
ai_node_id = self.initial_data.get("ai_node_id")
model_id = self.initial_data.get("model_id")
application = Application.objects.filter(id=application_id).first()
if application:
setting_dict = application.model_setting
temperature = setting_dict.get("temperature")
max_tokens = setting_dict.get("max_tokens")
model = Model.objects.filter(id=self.initial_data.get("model_id")).first()
if application.type == 'SIMPLE':
setting_dict = application.model_setting
temperature = setting_dict.get("temperature")
max_tokens = setting_dict.get("max_tokens")
elif application.type == 'WORK_FLOW':
work_flow = application.work_flow
api_node = next((node for node in work_flow.get('nodes', []) if node.get('id') == ai_node_id), None)
if api_node:
node_data = api_node.get('properties', {}).get('node_data', {})
temperature = node_data.get("temperature")
max_tokens = node_data.get("max_tokens")
model = Model.objects.filter(id=model_id).first()
if model:
res = ModelProvideConstants[model.provider].value.get_model_credential(model.model_type,
model.model_name).get_other_fields(
model.model_name)
if temperature and res.get('temperature'):
if temperature is not None and 'temperature' in res:
res['temperature']['value'] = temperature
if max_tokens and res.get('max_tokens'):
if max_tokens is not None and 'max_tokens' in res:
res['max_tokens']['value'] = max_tokens
return res
def save_other_config(self, data):
application = Application.objects.filter(id=self.initial_data.get("application_id")).first()
if application:
if not application:
return
if application.type == 'SIMPLE':
setting_dict = application.model_setting
for key in ['max_tokens', 'temperature']:
if key in data:
setting_dict[key] = data[key]
application.model_setting = setting_dict
application.save()
elif application.type == 'WORK_FLOW':
work_flow = application.work_flow
ai_node_id = data.get("node_id")
for api_node in work_flow.get('nodes', []):
if api_node.get('id') == ai_node_id:
node_data = api_node.get('properties', {}).get('node_data', {})
for key in ['max_tokens', 'temperature']:
if key in data:
node_data[key] = data[key]
api_node['properties']['node_data'] = node_data
break
application.work_flow = work_flow
application.save()
class ApplicationKeySerializerModel(serializers.ModelSerializer):
class Meta:

View File

@ -20,6 +20,8 @@ urlpatterns = [
views.ApplicationStatistics.ChatRecordAggregateTrend.as_view()),
path('application/<str:application_id>/model', views.Application.Model.as_view()),
path('application/<str:application_id>/model/<str:model_id>', views.Application.Model.Operate.as_view()),
path('application/<str:application_id>/model/<str:model_id>/<str:ai_node_id>',
views.Application.Model.Operate.as_view()),
path('application/<str:application_id>/other-config', views.Application.Model.OtherConfig.as_view()),
path('application/<str:application_id>/hit_test', views.Application.HitTest.as_view()),
path('application/<str:application_id>/api_key', views.Application.ApplicationKey.as_view()),

View File

@ -193,10 +193,11 @@ class Application(APIView):
@swagger_auto_schema(operation_summary="获取应用参数设置其他字段",
operation_id="获取应用参数设置其他字段",
tags=["应用/会话"])
def get(self, request: Request, application_id: str, model_id: str):
def get(self, request: Request, application_id: str, model_id: str, ai_node_id=None):
return result.success(
ApplicationSerializer.Operate(
data={'application_id': application_id, 'model_id': model_id}).get_other_file_list())
data={'application_id': application_id, 'model_id': model_id,
'ai_node_id': ai_node_id}).get_other_file_list())
class OtherConfig(APIView):
authentication_classes = [TokenAuth]
@ -207,7 +208,8 @@ class Application(APIView):
def put(self, request: Request, application_id: str):
return result.success(
ApplicationSerializer.Operate(
data={'application_id': application_id}).save_other_config(request.data))
data={'application_id': application_id}).save_other_config(
request.data))
class Profile(APIView):
authentication_classes = [TokenAuth]

View File

@ -17,6 +17,7 @@ from setting.models_provider.impl.ollama_model_provider.ollama_model_provider im
from setting.models_provider.impl.openai_model_provider.openai_model_provider import OpenAIModelProvider
from setting.models_provider.impl.qwen_model_provider.qwen_model_provider import QwenModelProvider
from setting.models_provider.impl.tencent_model_provider.tencent_model_provider import TencentModelProvider
from setting.models_provider.impl.vllm_model_provider.vllm_model_provider import VllmModelProvider
from setting.models_provider.impl.volcanic_engine_model_provider.volcanic_engine_model_provider import \
VolcanicEngineModelProvider
from setting.models_provider.impl.wenxin_model_provider.wenxin_model_provider import WenxinModelProvider
@ -42,3 +43,4 @@ class ModelProvideConstants(Enum):
model_aws_bedrock_provider = BedrockModelProvider()
model_local_provider = LocalModelProvider()
model_xinference_provider = XinferenceModelProvider()
model_vllm_provider = VllmModelProvider()

View File

@ -0,0 +1 @@
# coding=utf-8

View File

@ -0,0 +1,67 @@
# coding=utf-8
from typing import Dict
from langchain_core.messages import HumanMessage
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
class VLLMModelCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
try:
model_list = provider.get_base_model_list(model_credential.get('api_base'))
except Exception as e:
raise AppApiException(ValidCode.valid_error.value, "API 域名无效")
exist = provider.get_model_info_by_name(model_list, model_name)
if len(exist) == 0:
raise AppApiException(ValidCode.valid_error.value, "模型不存在,请先下载模型")
model = provider.get_model(model_type, model_name, model_credential)
try:
res = model.invoke([HumanMessage(content='你好')])
print(res)
except Exception as e:
print(e)
return True
def encryption_dict(self, model_info: Dict[str, object]):
return {**model_info, 'api_key': super().encryption(model_info.get('api_key', ''))}
def build_model(self, model_info: Dict[str, object]):
for key in ['api_key', 'model']:
if key not in model_info:
raise AppApiException(500, f'{key} 字段为必填字段')
self.api_key = model_info.get('api_key')
return self
api_base = forms.TextInputField('API 域名', required=True)
api_key = forms.PasswordInputField('API Key', required=True)
def get_other_fields(self, model_name):
return {
'temperature': {
'value': 0.7,
'min': 0.1,
'max': 1,
'step': 0.01,
'label': '温度',
'precision': 2,
'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'
},
'max_tokens': {
'value': 800,
'min': 1,
'max': 4096,
'step': 1,
'label': '输出最大Tokens',
'precision': 0,
'tooltip': '指定模型可生成的最大token个数'
}
}

View File

@ -0,0 +1,22 @@
<?xml version="1.0" standalone="no"?>
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 20010904//EN"
"http://www.w3.org/TR/2001/REC-SVG-20010904/DTD/svg10.dtd">
<svg version="1.0" xmlns="http://www.w3.org/2000/svg"
width="100%" height="100%" viewBox="0 0 3000.000000 860.000000"
preserveAspectRatio="xMidYMid meet">
<g transform="translate(0.000000,860.000000) scale(0.100000,-0.100000)"
fill="#000000" stroke="none">
<path d="M10110 4275 l0 -2905 858 0 c471 0 1299 -3 1840 -7 l982 -6 0 346 0
347 -1448 0 -1449 0 -6 428 c-4 235 -7 1389 -7 2565 l0 2137 -385 0 -385 0 0
-2905z"/>
<path d="M14550 4275 l0 -2905 858 0 c471 0 1299 -3 1840 -7 l982 -6 0 346 0
347 -1448 0 -1449 0 -6 428 c-4 235 -7 1389 -7 2565 l0 2137 -385 0 -385 0 0
-2905z"/>
<path d="M19000 4275 l0 -2905 235 -2 c129 -2 292 -3 363 -3 l127 0 5 2491 5
2491 796 -1691 795 -1691 234 0 233 0 799 1685 c439 926 803 1689 808 1694 7
7 10 -849 10 -2487 l0 -2497 385 0 385 0 0 2910 0 2910 -538 0 -537 0 -744
-1607 c-409 -885 -747 -1611 -750 -1615 -4 -4 -354 718 -779 1605 l-771 1612
-531 3 -530 2 0 -2905z"/>
</g>
</svg>

After

Width:  |  Height:  |  Size: 1.0 KiB

View File

@ -0,0 +1,38 @@
# coding=utf-8
from typing import List, Dict
from urllib.parse import urlparse, ParseResult
from setting.models_provider.base_model_provider import MaxKBBaseModel
from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI
def get_base_url(url: str):
parse = urlparse(url)
result_url = ParseResult(scheme=parse.scheme, netloc=parse.netloc, path=parse.path, params='',
query='',
fragment='').geturl()
return result_url[:-1] if result_url.endswith("/") else result_url
class VllmChatModel(MaxKBBaseModel, BaseChatOpenAI):
@staticmethod
def is_cache_model():
return False
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {}
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
optional_params['max_tokens'] = model_kwargs['max_tokens']
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
optional_params['temperature'] = model_kwargs['temperature']
vllm_chat_open_ai = VllmChatModel(
model=model_name,
openai_api_base=model_credential.get('api_base'),
openai_api_key=model_credential.get('api_key'),
**optional_params,
streaming=True,
stream_usage=True,
)
return vllm_chat_open_ai

View File

@ -0,0 +1,59 @@
# coding=utf-8
import os
from urllib.parse import urlparse, ParseResult
import requests
from common.util.file_util import get_file_content
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, \
ModelInfoManage
from setting.models_provider.impl.vllm_model_provider.credential.llm import VLLMModelCredential
from setting.models_provider.impl.vllm_model_provider.model.llm import VllmChatModel
from smartdoc.conf import PROJECT_DIR
v_llm_model_credential = VLLMModelCredential()
model_info_list = [
ModelInfo('facebook/opt-125m', 'Facebook的125M参数模型', ModelTypeConst.LLM, v_llm_model_credential, VllmChatModel),
ModelInfo('BAAI/Aquila-7B', 'BAAI的7B参数模型', ModelTypeConst.LLM, v_llm_model_credential, VllmChatModel),
ModelInfo('BAAI/AquilaChat-7B', 'BAAI的13B参数模型', ModelTypeConst.LLM, v_llm_model_credential, VllmChatModel),
]
model_info_manage = (ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info(
ModelInfo(
'facebook/opt-125m',
'Facebook的125M参数模型',
ModelTypeConst.LLM, v_llm_model_credential, VllmChatModel))
.build())
def get_base_url(url: str):
parse = urlparse(url)
result_url = ParseResult(scheme=parse.scheme, netloc=parse.netloc, path=parse.path, params='',
query='',
fragment='').geturl()
return result_url[:-1] if result_url.endswith("/") else result_url
class VllmModelProvider(IModelProvider):
def get_model_info_manage(self):
return model_info_manage
def get_model_provide_info(self):
return ModelProvideInfo(provider='model_vllm_provider', name='Vllm', icon=get_file_content(
os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'vllm_model_provider', 'icon',
'vllm_icon.svg')))
@staticmethod
def get_base_model_list(api_base):
base_url = get_base_url(api_base)
base_url = base_url if base_url.endswith('/v1') else (base_url + '/v1')
r = requests.request(method="GET", url=f"{base_url}/models", timeout=5)
r.raise_for_status()
return r.json().get('data')
@staticmethod
def get_model_info_by_name(model_list, model_name):
if model_list is None:
return []
return [model for model in model_list if model.get('id') == model_name]

View File

@ -66,9 +66,9 @@ class WenxinLLMModelCredential(BaseForm, BaseModelCredential):
'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'
},
'max_tokens': {
'value': 2048,
'value': 1024,
'min': 2,
'max': 1024,
'max': 2048,
'step': 1,
'label': '输出最大Tokens',
'precision': 0,

View File

@ -26,9 +26,9 @@ class XFChatSparkLLM(MaxKBBaseModel, ChatSparkLLM):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {}
if 'max_tokens' in model_kwargs:
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
optional_params['max_tokens'] = model_kwargs['max_tokens']
if 'temperature' in model_kwargs:
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
optional_params['temperature'] = model_kwargs['temperature']
return XFChatSparkLLM(
spark_app_id=model_credential.get('spark_app_id'),

View File

@ -34,9 +34,9 @@ class ZhipuChatModel(MaxKBBaseModel, ChatZhipuAI):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {}
if 'max_tokens' in model_kwargs:
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
optional_params['max_tokens'] = model_kwargs['max_tokens']
if 'temperature' in model_kwargs:
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
optional_params['temperature'] = model_kwargs['temperature']
zhipuai_chat = ZhipuChatModel(

View File

@ -1,8 +1,8 @@
import {Result} from '@/request/Result'
import {get, post, postStream, del, put} from '@/request/index'
import type {pageRequest} from '@/api/type/common'
import type {ApplicationFormType} from '@/api/type/application'
import {type Ref} from 'vue'
import { Result } from '@/request/Result'
import { get, post, postStream, del, put } from '@/request/index'
import type { pageRequest } from '@/api/type/common'
import type { ApplicationFormType } from '@/api/type/application'
import { type Ref } from 'vue'
const prefix = '/application'
@ -11,7 +11,7 @@ const prefix = '/application'
* @param
*/
const getAllAppilcation: () => Promise<Result<any[]>> = () => {
return get(`${prefix}`)
return get(`${prefix}`)
}
/**
@ -25,11 +25,11 @@ const getAllAppilcation: () => Promise<Result<any[]>> = () => {
}
*/
const getApplication: (
page: pageRequest,
param: any,
loading?: Ref<boolean>
page: pageRequest,
param: any,
loading?: Ref<boolean>
) => Promise<Result<any>> = (page, param, loading) => {
return get(`${prefix}/${page.current_page}/${page.page_size}`, param, loading)
return get(`${prefix}/${page.current_page}/${page.page_size}`, param, loading)
}
/**
@ -37,10 +37,10 @@ const getApplication: (
* @param
*/
const postApplication: (
data: ApplicationFormType,
loading?: Ref<boolean>
data: ApplicationFormType,
loading?: Ref<boolean>
) => Promise<Result<any>> = (data, loading) => {
return post(`${prefix}`, data, undefined, loading)
return post(`${prefix}`, data, undefined, loading)
}
/**
@ -48,11 +48,11 @@ const postApplication: (
* @param
*/
const putApplication: (
application_id: String,
data: ApplicationFormType,
loading?: Ref<boolean>
application_id: String,
data: ApplicationFormType,
loading?: Ref<boolean>
) => Promise<Result<any>> = (application_id, data, loading) => {
return put(`${prefix}/${application_id}`, data, undefined, loading)
return put(`${prefix}/${application_id}`, data, undefined, loading)
}
/**
@ -60,10 +60,10 @@ const putApplication: (
* @param application_id
*/
const delApplication: (
application_id: String,
loading?: Ref<boolean>
application_id: String,
loading?: Ref<boolean>
) => Promise<Result<boolean>> = (application_id, loading) => {
return del(`${prefix}/${application_id}`, undefined, {}, loading)
return del(`${prefix}/${application_id}`, undefined, {}, loading)
}
/**
@ -71,10 +71,10 @@ const delApplication: (
* @param application_id
*/
const getApplicationDetail: (
application_id: string,
loading?: Ref<boolean>
application_id: string,
loading?: Ref<boolean>
) => Promise<Result<any>> = (application_id, loading) => {
return get(`${prefix}/${application_id}`, undefined, loading)
return get(`${prefix}/${application_id}`, undefined, loading)
}
/**
@ -82,10 +82,10 @@ const getApplicationDetail: (
* @param application_id
*/
const getApplicationDataset: (
application_id: string,
loading?: Ref<boolean>
application_id: string,
loading?: Ref<boolean>
) => Promise<Result<any>> = (application_id, loading) => {
return get(`${prefix}/${application_id}/list_dataset`, undefined, loading)
return get(`${prefix}/${application_id}/list_dataset`, undefined, loading)
}
/**
@ -93,10 +93,10 @@ const getApplicationDataset: (
* @param application_id
*/
const getAccessToken: (application_id: string, loading?: Ref<boolean>) => Promise<Result<any>> = (
application_id,
loading
application_id,
loading
) => {
return get(`${prefix}/${application_id}/access_token`, undefined, loading)
return get(`${prefix}/${application_id}/access_token`, undefined, loading)
}
/**
@ -107,11 +107,11 @@ const getAccessToken: (application_id: string, loading?: Ref<boolean>) => Promis
* }
*/
const putAccessToken: (
application_id: string,
data: any,
loading?: Ref<boolean>
application_id: string,
data: any,
loading?: Ref<boolean>
) => Promise<Result<any>> = (application_id, data, loading) => {
return put(`${prefix}/${application_id}/access_token`, data, undefined, loading)
return put(`${prefix}/${application_id}/access_token`, data, undefined, loading)
}
/**
@ -122,10 +122,10 @@ const putAccessToken: (
}
*/
const postAppAuthentication: (access_token: string, loading?: Ref<boolean>) => Promise<any> = (
access_token,
loading
access_token,
loading
) => {
return post(`${prefix}/authentication`, {access_token}, undefined, loading)
return post(`${prefix}/authentication`, { access_token }, undefined, loading)
}
/**
@ -136,7 +136,7 @@ const postAppAuthentication: (access_token: string, loading?: Ref<boolean>) => P
}
*/
const getAppProfile: (loading?: Ref<boolean>) => Promise<any> = (loading) => {
return get(`${prefix}/profile`, undefined, loading)
return get(`${prefix}/profile`, undefined, loading)
}
/**
@ -146,7 +146,7 @@ const getAppProfile: (loading?: Ref<boolean>) => Promise<any> = (loading) => {
}
*/
const postChatOpen: (data: ApplicationFormType) => Promise<Result<any>> = (data) => {
return post(`${prefix}/chat/open`, data)
return post(`${prefix}/chat/open`, data)
}
/**
@ -156,7 +156,7 @@ const postChatOpen: (data: ApplicationFormType) => Promise<Result<any>> = (data)
}
*/
const postWorkflowChatOpen: (data: ApplicationFormType) => Promise<Result<any>> = (data) => {
return post(`${prefix}/chat_workflow/open`, data)
return post(`${prefix}/chat_workflow/open`, data)
}
/**
@ -171,7 +171,7 @@ const postWorkflowChatOpen: (data: ApplicationFormType) => Promise<Result<any>>
}
*/
const getChatOpen: (application_id: String) => Promise<Result<any>> = (application_id) => {
return get(`${prefix}/${application_id}/chat/open`)
return get(`${prefix}/${application_id}/chat/open`)
}
/**
*
@ -180,7 +180,7 @@ const getChatOpen: (application_id: String) => Promise<Result<any>> = (applicati
* data
*/
const postChatMessage: (chat_id: string, data: any) => Promise<any> = (chat_id, data) => {
return postStream(`/api${prefix}/chat_message/${chat_id}`, data)
return postStream(`/api${prefix}/chat_message/${chat_id}`, data)
}
/**
@ -192,20 +192,20 @@ const postChatMessage: (chat_id: string, data: any) => Promise<any> = (chat_id,
}
*/
const putChatVote: (
application_id: string,
chat_id: string,
chat_record_id: string,
vote_status: string,
loading?: Ref<boolean>
application_id: string,
chat_id: string,
chat_record_id: string,
vote_status: string,
loading?: Ref<boolean>
) => Promise<any> = (application_id, chat_id, chat_record_id, vote_status, loading) => {
return put(
`${prefix}/${application_id}/chat/${chat_id}/chat_record/${chat_record_id}/vote`,
{
vote_status
},
undefined,
loading
)
return put(
`${prefix}/${application_id}/chat/${chat_id}/chat_record/${chat_record_id}/vote`,
{
vote_status
},
undefined,
loading
)
}
/**
@ -216,11 +216,11 @@ const putChatVote: (
* @returns
*/
const getApplicationHitTest: (
application_id: string,
data: any,
loading?: Ref<boolean>
application_id: string,
data: any,
loading?: Ref<boolean>
) => Promise<Result<Array<any>>> = (application_id, data, loading) => {
return get(`${prefix}/${application_id}/hit_test`, data, loading)
return get(`${prefix}/${application_id}/hit_test`, data, loading)
}
/**
@ -231,10 +231,10 @@ const getApplicationHitTest: (
* @returns
*/
const getApplicationModel: (
application_id: string,
loading?: Ref<boolean>
application_id: string,
loading?: Ref<boolean>
) => Promise<Result<Array<any>>> = (application_id, loading) => {
return get(`${prefix}/${application_id}/model`, loading)
return get(`${prefix}/${application_id}/model`, loading)
}
/**
@ -242,11 +242,11 @@ const getApplicationModel: (
* @param
*/
const putPublishApplication: (
application_id: String,
data: ApplicationFormType,
loading?: Ref<boolean>
application_id: String,
data: ApplicationFormType,
loading?: Ref<boolean>
) => Promise<Result<any>> = (application_id, data, loading) => {
return put(`${prefix}/${application_id}/publish`, data, undefined, loading)
return put(`${prefix}/${application_id}/publish`, data, undefined, loading)
}
/**
@ -257,11 +257,18 @@ const putPublishApplication: (
* @returns
*/
const getModelOtherConfig: (
application_id: string,
model_id: string,
loading?: Ref<boolean>
) => Promise<Result<any>> = (application_id, model_id, loading) => {
application_id: string,
model_id: string,
ai_node_id?: string,
loading?: Ref<boolean>
) => Promise<Result<any>> = (application_id, model_id, ai_node_id?, loading?) => {
if (ai_node_id) {
// 如果 ai_node_id 不为空,则调用带有 ai_node_id 的 API
return get(`${prefix}/${application_id}/model/${model_id}/${ai_node_id}`, undefined, loading)
} else {
// 如果 ai_node_id 为空,则调用不带 ai_node_id 的 API
return get(`${prefix}/${application_id}/model/${model_id}`, undefined, loading)
}
}
/**
@ -270,33 +277,33 @@ const getModelOtherConfig: (
*
*/
const putModelOtherConfig: (
application_id: string,
data: any,
loading?: Ref<boolean>
application_id: string,
data: any,
loading?: Ref<boolean>
) => Promise<Result<any>> = (application_id, data, loading) => {
return put(`${prefix}/${application_id}/other-config`, data, undefined, loading)
return put(`${prefix}/${application_id}/other-config`, data, undefined, loading)
}
export default {
getAllAppilcation,
getApplication,
postApplication,
putApplication,
postChatOpen,
getChatOpen,
postChatMessage,
delApplication,
getApplicationDetail,
getApplicationDataset,
getAccessToken,
putAccessToken,
postAppAuthentication,
getAppProfile,
putChatVote,
getApplicationHitTest,
getApplicationModel,
putPublishApplication,
postWorkflowChatOpen,
getModelOtherConfig,
putModelOtherConfig
getAllAppilcation,
getApplication,
postApplication,
putApplication,
postChatOpen,
getChatOpen,
postChatMessage,
delApplication,
getApplicationDetail,
getApplicationDataset,
getAccessToken,
putAccessToken,
postAppAuthentication,
getAppProfile,
putChatVote,
getApplicationHitTest,
getApplicationModel,
putPublishApplication,
postWorkflowChatOpen,
getModelOtherConfig,
putModelOtherConfig
}

View File

@ -1,163 +1,168 @@
import {defineStore} from 'pinia'
import { defineStore } from 'pinia'
import applicationApi from '@/api/application'
import applicationXpackApi from '@/api/application-xpack'
import {type Ref, type UnwrapRef} from 'vue'
import { type Ref, type UnwrapRef } from 'vue'
import useUserStore from './user'
import type {ApplicationFormType} from "@/api/type/application";
import type { ApplicationFormType } from '@/api/type/application'
const useApplicationStore = defineStore({
id: 'application',
state: () => ({
location: `${window.location.origin}/ui/chat/`
}),
actions: {
async asyncGetAllApplication() {
return new Promise((resolve, reject) => {
applicationApi
.getAllAppilcation()
.then((data) => {
resolve(data)
})
.catch((error) => {
reject(error)
})
})
},
id: 'application',
state: () => ({
location: `${window.location.origin}/ui/chat/`
}),
actions: {
async asyncGetAllApplication() {
return new Promise((resolve, reject) => {
applicationApi
.getAllAppilcation()
.then((data) => {
resolve(data)
})
.catch((error) => {
reject(error)
})
})
},
async asyncGetApplicationDetail(id: string, loading?: Ref<boolean>) {
return new Promise((resolve, reject) => {
applicationApi
.getApplicationDetail(id, loading)
.then((data) => {
resolve(data)
})
.catch((error) => {
reject(error)
})
})
},
async asyncGetApplicationDetail(id: string, loading?: Ref<boolean>) {
return new Promise((resolve, reject) => {
applicationApi
.getApplicationDetail(id, loading)
.then((data) => {
resolve(data)
})
.catch((error) => {
reject(error)
})
})
},
async asyncGetApplicationDataset(id: string, loading?: Ref<boolean>) {
return new Promise((resolve, reject) => {
applicationApi
.getApplicationDataset(id, loading)
.then((data) => {
resolve(data)
})
.catch((error) => {
reject(error)
})
})
},
async asyncGetApplicationDataset(id: string, loading?: Ref<boolean>) {
return new Promise((resolve, reject) => {
applicationApi
.getApplicationDataset(id, loading)
.then((data) => {
resolve(data)
})
.catch((error) => {
reject(error)
})
})
},
async asyncGetAccessToken(id: string, loading?: Ref<boolean>) {
return new Promise((resolve, reject) => {
const user = useUserStore()
if (user.isEnterprise()) {
applicationXpackApi
.getAccessToken(id, loading)
.then((data) => {
resolve(data)
})
.catch((error) => {
reject(error)
})
} else {
applicationApi
.getAccessToken(id, loading)
.then((data) => {
resolve(data)
})
.catch((error) => {
reject(error)
})
}
async asyncGetAccessToken(id: string, loading?: Ref<boolean>) {
return new Promise((resolve, reject) => {
const user = useUserStore()
if (user.isEnterprise()) {
applicationXpackApi
.getAccessToken(id, loading)
.then((data) => {
resolve(data)
})
},
.catch((error) => {
reject(error)
})
} else {
applicationApi
.getAccessToken(id, loading)
.then((data) => {
resolve(data)
})
.catch((error) => {
reject(error)
})
}
})
},
async asyncGetAppProfile(loading?: Ref<boolean>) {
return new Promise((resolve, reject) => {
const user = useUserStore()
if (user.isEnterprise()) {
applicationXpackApi
.getAppXpackProfile(loading)
.then((data) => {
resolve(data)
})
.catch((error) => {
reject(error)
})
} else {
applicationApi
.getAppProfile(loading)
.then((data) => {
resolve(data)
})
.catch((error) => {
reject(error)
})
}
async asyncGetAppProfile(loading?: Ref<boolean>) {
return new Promise((resolve, reject) => {
const user = useUserStore()
if (user.isEnterprise()) {
applicationXpackApi
.getAppXpackProfile(loading)
.then((data) => {
resolve(data)
})
},
.catch((error) => {
reject(error)
})
} else {
applicationApi
.getAppProfile(loading)
.then((data) => {
resolve(data)
})
.catch((error) => {
reject(error)
})
}
})
},
async asyncAppAuthentication(token: string, loading?: Ref<boolean>) {
return new Promise((resolve, reject) => {
applicationApi
.postAppAuthentication(token, loading)
.then((res) => {
localStorage.setItem('accessToken', res.data)
sessionStorage.setItem('accessToken', res.data)
resolve(res)
})
.catch((error) => {
reject(error)
})
})
},
async refreshAccessToken(token: string) {
this.asyncAppAuthentication(token)
},
// 修改应用
async asyncPutApplication(id: string, data: any, loading?: Ref<boolean>) {
return new Promise((resolve, reject) => {
applicationApi
.putApplication(id, data, loading)
.then((data) => {
resolve(data)
})
.catch((error) => {
reject(error)
})
})
},
// 获取模型的 温度/max_token字段设置
async asyncGetModelConfig(id: string, model_id: string, loading?: Ref<boolean>) {
return new Promise((resolve, reject) => {
applicationApi
.getModelOtherConfig(id, model_id, loading)
.then((data) => {
resolve(data)
})
.catch((error) => {
reject(error)
})
})
},
// 保存应用的 温度/max_token字段设置
async asyncPostModelConfig(id: string, params: any, loading?: Ref<boolean>) {
return new Promise((resolve, reject) => {
applicationApi
.putModelOtherConfig(id, params, loading)
.then((data) => {
resolve(data)
})
.catch((error) => {
reject(error)
})
})
},
async asyncAppAuthentication(token: string, loading?: Ref<boolean>) {
return new Promise((resolve, reject) => {
applicationApi
.postAppAuthentication(token, loading)
.then((res) => {
localStorage.setItem('accessToken', res.data)
sessionStorage.setItem('accessToken', res.data)
resolve(res)
})
.catch((error) => {
reject(error)
})
})
},
async refreshAccessToken(token: string) {
this.asyncAppAuthentication(token)
},
// 修改应用
async asyncPutApplication(id: string, data: any, loading?: Ref<boolean>) {
return new Promise((resolve, reject) => {
applicationApi
.putApplication(id, data, loading)
.then((data) => {
resolve(data)
})
.catch((error) => {
reject(error)
})
})
},
// 获取模型的 温度/max_token字段设置
async asyncGetModelConfig(
id: string,
model_id: string,
ai_node_id?: string,
loading?: Ref<boolean>
) {
return new Promise((resolve, reject) => {
applicationApi
.getModelOtherConfig(id, model_id, ai_node_id, loading)
.then((data) => {
resolve(data)
})
.catch((error) => {
reject(error)
})
})
},
// 保存应用的 温度/max_token字段设置
async asyncPostModelConfig(id: string, params?: any, loading?: Ref<boolean>) {
return new Promise((resolve, reject) => {
applicationApi
.putModelOtherConfig(id, params, loading)
.then((data) => {
resolve(data)
})
.catch((error) => {
reject(error)
})
})
}
}
})
export default useApplicationStore

View File

@ -455,7 +455,7 @@ const openAIParamSettingDialog = () => {
MsgSuccess(t('请选择AI 模型'))
return
}
application.asyncGetModelConfig(id, model_id, loading).then((res: any) => {
application.asyncGetModelConfig(id, model_id, '', loading).then((res: any) => {
AIModeParamSettingDialogRef.value?.open(res.data)
})
}

View File

@ -60,7 +60,8 @@ const form = reactive<Form>({})
const dialogVisible = ref(false)
const loading = ref(false)
const props = defineProps<{
id: any
id: string
nodeId?: string
}>()
const resetForm = () => {
// form
@ -103,8 +104,11 @@ const submit = async () => {
},
{} as Record<string, any>
)
if (props.nodeId) {
data.node_id = props.nodeId
}
application.asyncPostModelConfig(props.id, data, loading).then(() => {
emit('refresh', form)
emit('refresh', data)
dialogVisible.value = false
})
}

View File

@ -14,6 +14,7 @@
class="mb-24"
label-width="auto"
ref="aiChatNodeFormRef"
hide-required-asterisk
>
<el-form-item
label="AI 模型"
@ -24,6 +25,21 @@
trigger: 'change'
}"
>
<template #label>
<div class="flex-between">
<div>
<span>AI 模型<span class="danger">*</span></span>
</div>
<el-button
type="primary"
link
@click="openAIParamSettingDialog(chat_data.model_id)"
@refresh="refreshParam"
>
{{ $t('views.application.applicationForm.form.paramSetting') }}
</el-button>
</div>
</template>
<el-select
@wheel="wheel"
@keydown="isKeyDown = true"
@ -57,9 +73,9 @@
>公用
</el-tag>
</div>
<el-icon class="check-icon" v-if="item.id === chat_data.model_id"
><Check
/></el-icon>
<el-icon class="check-icon" v-if="item.id === chat_data.model_id">
<Check />
</el-icon>
</el-option>
<!-- 不可用 -->
<el-option
@ -78,21 +94,24 @@
<span>{{ item.name }}</span>
<span class="danger">不可用</span>
</div>
<el-icon class="check-icon" v-if="item.id === chat_data.model_id"
><Check
/></el-icon>
<el-icon class="check-icon" v-if="item.id === chat_data.model_id">
<Check />
</el-icon>
</el-option>
</el-option-group>
<template #footer>
<div class="w-full text-left cursor" @click="openCreateModel()">
<el-button type="primary" link>
<el-icon class="mr-4"><Plus /></el-icon>
<el-icon class="mr-4">
<Plus />
</el-icon>
添加模型
</el-button>
</div>
</template>
</el-select>
</el-form-item>
<el-form-item label="角色设定">
<el-input
v-model="chat_data.system"
@ -109,8 +128,8 @@
</div>
<el-tooltip effect="dark" placement="right" popper-class="max-w-200">
<template #content
>通过调整提示词内容可以引导大模型聊天方向该提示词会被固定在上下文的开头可以使用变量</template
>
>通过调整提示词内容可以引导大模型聊天方向该提示词会被固定在上下文的开头可以使用变量
</template>
<AppIcon iconName="app-warning" class="app-warning-icon"></AppIcon>
</el-tooltip>
</div>
@ -163,10 +182,10 @@
</el-card>
<!-- 回复内容弹出层 -->
<el-dialog v-model="dialogVisible" title="提示词" append-to-body>
<MdEditor v-model="cloneContent" :preview="false" :toolbars="[]" :footers="[]"> </MdEditor>
<MdEditor v-model="cloneContent" :preview="false" :toolbars="[]" :footers="[]"></MdEditor>
<template #footer>
<div class="dialog-footer mt-24">
<el-button type="primary" @click="submitDialog"> 确认 </el-button>
<el-button type="primary" @click="submitDialog"> 确认</el-button>
</div>
</template>
</el-dialog>
@ -177,6 +196,13 @@
@change="openCreateModel($event)"
></CreateModelDialog>
<SelectProviderDialog ref="selectProviderRef" @change="openCreateModel($event)" />
<AIModeParamSettingDialog
ref="AIModeParamSettingDialogRef"
:id="id"
:node-id="props.nodeModel.id"
@refresh="refreshParam"
/>
</NodeContainer>
</template>
<script setup lang="ts">
@ -192,8 +218,9 @@ import useStore from '@/stores'
import { relatedObject } from '@/utils/utils'
import type { Provider } from '@/api/type/model'
import { isLastNode } from '@/workflow/common/data'
import AIModeParamSettingDialog from '@/views/application/component/AIModeParamSettingDialog.vue'
const { model } = useStore()
const { model, application } = useStore()
const isKeyDown = ref(false)
const wheel = (e: any) => {
if (isKeyDown.value) {
@ -206,14 +233,17 @@ const wheel = (e: any) => {
const dialogVisible = ref(false)
const cloneContent = ref('')
const footers: any = [null, '=', 0]
function openDialog() {
cloneContent.value = chat_data.value.prompt
dialogVisible.value = true
}
function submitDialog() {
set(props.nodeModel.properties.node_data, 'prompt', cloneContent.value)
dialogVisible.value = false
}
const {
params: { id }
} = app.config.globalProperties.$route as any
@ -228,7 +258,9 @@ const form = {
system: '',
prompt: defaultPrompt,
dialogue_number: 1,
is_result: false
is_result: false,
temperature: null,
max_tokens: null
}
const chat_data = computed({
@ -252,7 +284,7 @@ const selectProviderRef = ref<InstanceType<typeof SelectProviderDialog>>()
const modelOptions = ref<any>(null)
const providerOptions = ref<Array<Provider>>([])
const AIModeParamSettingDialogRef = ref<InstanceType<typeof AIModeParamSettingDialog>>()
const validate = () => {
return aiChatNodeFormRef.value?.validate().catch((err) => {
return Promise.reject({ node: props.nodeModel, errMessage: err })
@ -285,6 +317,19 @@ const openCreateModel = (provider?: Provider) => {
}
}
const openAIParamSettingDialog = (modelId: string) => {
if (modelId) {
application.asyncGetModelConfig(id, modelId, props.nodeModel.id).then((res: any) => {
AIModeParamSettingDialogRef.value?.open(res.data)
})
}
}
function refreshParam(data: any) {
chat_data.value.temperature = data.temperature
chat_data.value.max_tokens = data.max_tokens
}
onMounted(() => {
getProvider()
getModel()