From 35f3ff779baf4e0ccd9b4a84582bc58a93377432 Mon Sep 17 00:00:00 2001 From: wxg0103 <727495428@qq.com> Date: Mon, 19 Aug 2024 10:38:55 +0800 Subject: [PATCH] refactor: update model params --- .../ai_chat_step_node/i_chat_node.py | 3 + .../ai_chat_step_node/impl/base_chat_node.py | 6 +- .../serializers/application_serializers.py | 49 ++- apps/application/urls.py | 2 + apps/application/views/application_views.py | 8 +- .../constants/model_provider_constants.py | 2 + .../impl/vllm_model_provider/__init__.py | 1 + .../vllm_model_provider/credential/llm.py | 67 ++++ .../vllm_model_provider/icon/vllm_icon.svg | 22 ++ .../impl/vllm_model_provider/model/llm.py | 38 +++ .../vllm_model_provider.py | 59 ++++ .../wenxin_model_provider/credential/llm.py | 4 +- .../impl/xf_model_provider/model/llm.py | 4 +- .../impl/zhipu_model_provider/model/llm.py | 4 +- ui/src/api/application.ts | 195 ++++++------ ui/src/stores/modules/application.ts | 297 +++++++++--------- .../views/application/ApplicationSetting.vue | 2 +- .../component/AIModeParamSettingDialog.vue | 8 +- ui/src/workflow/nodes/ai-chat-node/index.vue | 73 ++++- 19 files changed, 568 insertions(+), 276 deletions(-) create mode 100644 apps/setting/models_provider/impl/vllm_model_provider/__init__.py create mode 100644 apps/setting/models_provider/impl/vllm_model_provider/credential/llm.py create mode 100644 apps/setting/models_provider/impl/vllm_model_provider/icon/vllm_icon.svg create mode 100644 apps/setting/models_provider/impl/vllm_model_provider/model/llm.py create mode 100644 apps/setting/models_provider/impl/vllm_model_provider/vllm_model_provider.py diff --git a/apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py b/apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py index 78cbc462c..60c70268d 100644 --- a/apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py +++ b/apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py @@ -23,6 +23,9 @@ class ChatNodeSerializer(serializers.Serializer): dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer("多轮对话数量")) is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容')) + temperature = serializers.FloatField(required=False, allow_null=True, error_messages=ErrMessage.float("温度")) + max_tokens = serializers.IntegerField(required=False, allow_null=True, + error_messages=ErrMessage.integer("最大token数")) class IChatNode(INode): diff --git a/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py b/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py index fdc5e99e5..d611df32e 100644 --- a/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py +++ b/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py @@ -66,7 +66,11 @@ def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wor class BaseChatNode(IChatNode): def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id, **kwargs) -> NodeResult: - chat_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id')) + kwargs = {k: v for k, v in kwargs.items() if k in ['temperature', 'max_tokens']} + + chat_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'), + **kwargs + ) history_message = self.get_history_message(history_chat_record, dialogue_number) self.context['history_message'] = history_message question = self.generate_prompt_question(prompt) diff --git a/apps/application/serializers/application_serializers.py b/apps/application/serializers/application_serializers.py index de90d61a9..909c4a627 100644 --- a/apps/application/serializers/application_serializers.py +++ b/apps/application/serializers/application_serializers.py @@ -547,6 +547,7 @@ class ApplicationSerializer(serializers.Serializer): application_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("应用id")) user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id")) model_id = serializers.UUIDField(required=False, error_messages=ErrMessage.uuid("模型id")) + ai_node_id = serializers.UUIDField(required=False, error_messages=ErrMessage.uuid("AI节点id")) def is_valid(self, *, raise_exception=False): super().is_valid(raise_exception=True) @@ -731,31 +732,61 @@ class ApplicationSerializer(serializers.Serializer): def get_other_file_list(self): temperature = None max_tokens = None - application = Application.objects.filter(id=self.initial_data.get("application_id")).first() + application_id = self.initial_data.get("application_id") + ai_node_id = self.initial_data.get("ai_node_id") + model_id = self.initial_data.get("model_id") + + application = Application.objects.filter(id=application_id).first() if application: - setting_dict = application.model_setting - temperature = setting_dict.get("temperature") - max_tokens = setting_dict.get("max_tokens") - model = Model.objects.filter(id=self.initial_data.get("model_id")).first() + if application.type == 'SIMPLE': + setting_dict = application.model_setting + temperature = setting_dict.get("temperature") + max_tokens = setting_dict.get("max_tokens") + elif application.type == 'WORK_FLOW': + work_flow = application.work_flow + api_node = next((node for node in work_flow.get('nodes', []) if node.get('id') == ai_node_id), None) + if api_node: + node_data = api_node.get('properties', {}).get('node_data', {}) + temperature = node_data.get("temperature") + max_tokens = node_data.get("max_tokens") + + model = Model.objects.filter(id=model_id).first() if model: res = ModelProvideConstants[model.provider].value.get_model_credential(model.model_type, model.model_name).get_other_fields( model.model_name) - if temperature and res.get('temperature'): + if temperature is not None and 'temperature' in res: res['temperature']['value'] = temperature - if max_tokens and res.get('max_tokens'): + if max_tokens is not None and 'max_tokens' in res: res['max_tokens']['value'] = max_tokens return res def save_other_config(self, data): application = Application.objects.filter(id=self.initial_data.get("application_id")).first() - if application: + if not application: + return + + if application.type == 'SIMPLE': setting_dict = application.model_setting for key in ['max_tokens', 'temperature']: if key in data: setting_dict[key] = data[key] application.model_setting = setting_dict - application.save() + + elif application.type == 'WORK_FLOW': + work_flow = application.work_flow + ai_node_id = data.get("node_id") + for api_node in work_flow.get('nodes', []): + if api_node.get('id') == ai_node_id: + node_data = api_node.get('properties', {}).get('node_data', {}) + for key in ['max_tokens', 'temperature']: + if key in data: + node_data[key] = data[key] + api_node['properties']['node_data'] = node_data + break + application.work_flow = work_flow + + application.save() class ApplicationKeySerializerModel(serializers.ModelSerializer): class Meta: diff --git a/apps/application/urls.py b/apps/application/urls.py index 95ec19ef4..1409452de 100644 --- a/apps/application/urls.py +++ b/apps/application/urls.py @@ -20,6 +20,8 @@ urlpatterns = [ views.ApplicationStatistics.ChatRecordAggregateTrend.as_view()), path('application//model', views.Application.Model.as_view()), path('application//model/', views.Application.Model.Operate.as_view()), + path('application//model//', + views.Application.Model.Operate.as_view()), path('application//other-config', views.Application.Model.OtherConfig.as_view()), path('application//hit_test', views.Application.HitTest.as_view()), path('application//api_key', views.Application.ApplicationKey.as_view()), diff --git a/apps/application/views/application_views.py b/apps/application/views/application_views.py index 14aa0f1e8..a468d0b65 100644 --- a/apps/application/views/application_views.py +++ b/apps/application/views/application_views.py @@ -193,10 +193,11 @@ class Application(APIView): @swagger_auto_schema(operation_summary="获取应用参数设置其他字段", operation_id="获取应用参数设置其他字段", tags=["应用/会话"]) - def get(self, request: Request, application_id: str, model_id: str): + def get(self, request: Request, application_id: str, model_id: str, ai_node_id=None): return result.success( ApplicationSerializer.Operate( - data={'application_id': application_id, 'model_id': model_id}).get_other_file_list()) + data={'application_id': application_id, 'model_id': model_id, + 'ai_node_id': ai_node_id}).get_other_file_list()) class OtherConfig(APIView): authentication_classes = [TokenAuth] @@ -207,7 +208,8 @@ class Application(APIView): def put(self, request: Request, application_id: str): return result.success( ApplicationSerializer.Operate( - data={'application_id': application_id}).save_other_config(request.data)) + data={'application_id': application_id}).save_other_config( + request.data)) class Profile(APIView): authentication_classes = [TokenAuth] diff --git a/apps/setting/models_provider/constants/model_provider_constants.py b/apps/setting/models_provider/constants/model_provider_constants.py index 6f691d276..13e43a278 100644 --- a/apps/setting/models_provider/constants/model_provider_constants.py +++ b/apps/setting/models_provider/constants/model_provider_constants.py @@ -17,6 +17,7 @@ from setting.models_provider.impl.ollama_model_provider.ollama_model_provider im from setting.models_provider.impl.openai_model_provider.openai_model_provider import OpenAIModelProvider from setting.models_provider.impl.qwen_model_provider.qwen_model_provider import QwenModelProvider from setting.models_provider.impl.tencent_model_provider.tencent_model_provider import TencentModelProvider +from setting.models_provider.impl.vllm_model_provider.vllm_model_provider import VllmModelProvider from setting.models_provider.impl.volcanic_engine_model_provider.volcanic_engine_model_provider import \ VolcanicEngineModelProvider from setting.models_provider.impl.wenxin_model_provider.wenxin_model_provider import WenxinModelProvider @@ -42,3 +43,4 @@ class ModelProvideConstants(Enum): model_aws_bedrock_provider = BedrockModelProvider() model_local_provider = LocalModelProvider() model_xinference_provider = XinferenceModelProvider() + model_vllm_provider = VllmModelProvider() diff --git a/apps/setting/models_provider/impl/vllm_model_provider/__init__.py b/apps/setting/models_provider/impl/vllm_model_provider/__init__.py new file mode 100644 index 000000000..9bad5790a --- /dev/null +++ b/apps/setting/models_provider/impl/vllm_model_provider/__init__.py @@ -0,0 +1 @@ +# coding=utf-8 diff --git a/apps/setting/models_provider/impl/vllm_model_provider/credential/llm.py b/apps/setting/models_provider/impl/vllm_model_provider/credential/llm.py new file mode 100644 index 000000000..87ef597c8 --- /dev/null +++ b/apps/setting/models_provider/impl/vllm_model_provider/credential/llm.py @@ -0,0 +1,67 @@ +# coding=utf-8 + +from typing import Dict + +from langchain_core.messages import HumanMessage + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class VLLMModelCredential(BaseForm, BaseModelCredential): + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + try: + model_list = provider.get_base_model_list(model_credential.get('api_base')) + except Exception as e: + raise AppApiException(ValidCode.valid_error.value, "API 域名无效") + exist = provider.get_model_info_by_name(model_list, model_name) + if len(exist) == 0: + raise AppApiException(ValidCode.valid_error.value, "模型不存在,请先下载模型") + model = provider.get_model(model_type, model_name, model_credential) + try: + res = model.invoke([HumanMessage(content='你好')]) + print(res) + except Exception as e: + print(e) + return True + + def encryption_dict(self, model_info: Dict[str, object]): + return {**model_info, 'api_key': super().encryption(model_info.get('api_key', ''))} + + def build_model(self, model_info: Dict[str, object]): + for key in ['api_key', 'model']: + if key not in model_info: + raise AppApiException(500, f'{key} 字段为必填字段') + self.api_key = model_info.get('api_key') + return self + + api_base = forms.TextInputField('API 域名', required=True) + api_key = forms.PasswordInputField('API Key', required=True) + + def get_other_fields(self, model_name): + return { + 'temperature': { + 'value': 0.7, + 'min': 0.1, + 'max': 1, + 'step': 0.01, + 'label': '温度', + 'precision': 2, + 'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定' + }, + 'max_tokens': { + 'value': 800, + 'min': 1, + 'max': 4096, + 'step': 1, + 'label': '输出最大Tokens', + 'precision': 0, + 'tooltip': '指定模型可生成的最大token个数' + } + } diff --git a/apps/setting/models_provider/impl/vllm_model_provider/icon/vllm_icon.svg b/apps/setting/models_provider/impl/vllm_model_provider/icon/vllm_icon.svg new file mode 100644 index 000000000..776ae8ccb --- /dev/null +++ b/apps/setting/models_provider/impl/vllm_model_provider/icon/vllm_icon.svg @@ -0,0 +1,22 @@ + + + + + + + + + + diff --git a/apps/setting/models_provider/impl/vllm_model_provider/model/llm.py b/apps/setting/models_provider/impl/vllm_model_provider/model/llm.py new file mode 100644 index 000000000..5c498f8e0 --- /dev/null +++ b/apps/setting/models_provider/impl/vllm_model_provider/model/llm.py @@ -0,0 +1,38 @@ +# coding=utf-8 + +from typing import List, Dict +from urllib.parse import urlparse, ParseResult +from setting.models_provider.base_model_provider import MaxKBBaseModel +from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI + + +def get_base_url(url: str): + parse = urlparse(url) + result_url = ParseResult(scheme=parse.scheme, netloc=parse.netloc, path=parse.path, params='', + query='', + fragment='').geturl() + return result_url[:-1] if result_url.endswith("/") else result_url + + +class VllmChatModel(MaxKBBaseModel, BaseChatOpenAI): + + @staticmethod + def is_cache_model(): + return False + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + optional_params = {} + if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: + optional_params['max_tokens'] = model_kwargs['max_tokens'] + if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: + optional_params['temperature'] = model_kwargs['temperature'] + vllm_chat_open_ai = VllmChatModel( + model=model_name, + openai_api_base=model_credential.get('api_base'), + openai_api_key=model_credential.get('api_key'), + **optional_params, + streaming=True, + stream_usage=True, + ) + return vllm_chat_open_ai diff --git a/apps/setting/models_provider/impl/vllm_model_provider/vllm_model_provider.py b/apps/setting/models_provider/impl/vllm_model_provider/vllm_model_provider.py new file mode 100644 index 000000000..c15162662 --- /dev/null +++ b/apps/setting/models_provider/impl/vllm_model_provider/vllm_model_provider.py @@ -0,0 +1,59 @@ +# coding=utf-8 +import os +from urllib.parse import urlparse, ParseResult + +import requests + +from common.util.file_util import get_file_content +from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, \ + ModelInfoManage +from setting.models_provider.impl.vllm_model_provider.credential.llm import VLLMModelCredential +from setting.models_provider.impl.vllm_model_provider.model.llm import VllmChatModel +from smartdoc.conf import PROJECT_DIR + +v_llm_model_credential = VLLMModelCredential() +model_info_list = [ + ModelInfo('facebook/opt-125m', 'Facebook的125M参数模型', ModelTypeConst.LLM, v_llm_model_credential, VllmChatModel), + ModelInfo('BAAI/Aquila-7B', 'BAAI的7B参数模型', ModelTypeConst.LLM, v_llm_model_credential, VllmChatModel), + ModelInfo('BAAI/AquilaChat-7B', 'BAAI的13B参数模型', ModelTypeConst.LLM, v_llm_model_credential, VllmChatModel), + +] + +model_info_manage = (ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info( + ModelInfo( + 'facebook/opt-125m', + 'Facebook的125M参数模型', + ModelTypeConst.LLM, v_llm_model_credential, VllmChatModel)) + .build()) + + +def get_base_url(url: str): + parse = urlparse(url) + result_url = ParseResult(scheme=parse.scheme, netloc=parse.netloc, path=parse.path, params='', + query='', + fragment='').geturl() + return result_url[:-1] if result_url.endswith("/") else result_url + + +class VllmModelProvider(IModelProvider): + def get_model_info_manage(self): + return model_info_manage + + def get_model_provide_info(self): + return ModelProvideInfo(provider='model_vllm_provider', name='Vllm', icon=get_file_content( + os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'vllm_model_provider', 'icon', + 'vllm_icon.svg'))) + + @staticmethod + def get_base_model_list(api_base): + base_url = get_base_url(api_base) + base_url = base_url if base_url.endswith('/v1') else (base_url + '/v1') + r = requests.request(method="GET", url=f"{base_url}/models", timeout=5) + r.raise_for_status() + return r.json().get('data') + + @staticmethod + def get_model_info_by_name(model_list, model_name): + if model_list is None: + return [] + return [model for model in model_list if model.get('id') == model_name] diff --git a/apps/setting/models_provider/impl/wenxin_model_provider/credential/llm.py b/apps/setting/models_provider/impl/wenxin_model_provider/credential/llm.py index f2b8b7b59..81ad4e404 100644 --- a/apps/setting/models_provider/impl/wenxin_model_provider/credential/llm.py +++ b/apps/setting/models_provider/impl/wenxin_model_provider/credential/llm.py @@ -66,9 +66,9 @@ class WenxinLLMModelCredential(BaseForm, BaseModelCredential): 'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定' }, 'max_tokens': { - 'value': 2048, + 'value': 1024, 'min': 2, - 'max': 1024, + 'max': 2048, 'step': 1, 'label': '输出最大Tokens', 'precision': 0, diff --git a/apps/setting/models_provider/impl/xf_model_provider/model/llm.py b/apps/setting/models_provider/impl/xf_model_provider/model/llm.py index 023fc259c..f8d55acb5 100644 --- a/apps/setting/models_provider/impl/xf_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/xf_model_provider/model/llm.py @@ -26,9 +26,9 @@ class XFChatSparkLLM(MaxKBBaseModel, ChatSparkLLM): @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): optional_params = {} - if 'max_tokens' in model_kwargs: + if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: optional_params['max_tokens'] = model_kwargs['max_tokens'] - if 'temperature' in model_kwargs: + if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: optional_params['temperature'] = model_kwargs['temperature'] return XFChatSparkLLM( spark_app_id=model_credential.get('spark_app_id'), diff --git a/apps/setting/models_provider/impl/zhipu_model_provider/model/llm.py b/apps/setting/models_provider/impl/zhipu_model_provider/model/llm.py index d66bf17bd..5722aead3 100644 --- a/apps/setting/models_provider/impl/zhipu_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/zhipu_model_provider/model/llm.py @@ -34,9 +34,9 @@ class ZhipuChatModel(MaxKBBaseModel, ChatZhipuAI): @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): optional_params = {} - if 'max_tokens' in model_kwargs: + if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: optional_params['max_tokens'] = model_kwargs['max_tokens'] - if 'temperature' in model_kwargs: + if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: optional_params['temperature'] = model_kwargs['temperature'] zhipuai_chat = ZhipuChatModel( diff --git a/ui/src/api/application.ts b/ui/src/api/application.ts index d7e8de3af..cb6858968 100644 --- a/ui/src/api/application.ts +++ b/ui/src/api/application.ts @@ -1,8 +1,8 @@ -import {Result} from '@/request/Result' -import {get, post, postStream, del, put} from '@/request/index' -import type {pageRequest} from '@/api/type/common' -import type {ApplicationFormType} from '@/api/type/application' -import {type Ref} from 'vue' +import { Result } from '@/request/Result' +import { get, post, postStream, del, put } from '@/request/index' +import type { pageRequest } from '@/api/type/common' +import type { ApplicationFormType } from '@/api/type/application' +import { type Ref } from 'vue' const prefix = '/application' @@ -11,7 +11,7 @@ const prefix = '/application' * @param 参数 */ const getAllAppilcation: () => Promise> = () => { - return get(`${prefix}`) + return get(`${prefix}`) } /** @@ -25,11 +25,11 @@ const getAllAppilcation: () => Promise> = () => { } */ const getApplication: ( - page: pageRequest, - param: any, - loading?: Ref + page: pageRequest, + param: any, + loading?: Ref ) => Promise> = (page, param, loading) => { - return get(`${prefix}/${page.current_page}/${page.page_size}`, param, loading) + return get(`${prefix}/${page.current_page}/${page.page_size}`, param, loading) } /** @@ -37,10 +37,10 @@ const getApplication: ( * @param 参数 */ const postApplication: ( - data: ApplicationFormType, - loading?: Ref + data: ApplicationFormType, + loading?: Ref ) => Promise> = (data, loading) => { - return post(`${prefix}`, data, undefined, loading) + return post(`${prefix}`, data, undefined, loading) } /** @@ -48,11 +48,11 @@ const postApplication: ( * @param 参数 */ const putApplication: ( - application_id: String, - data: ApplicationFormType, - loading?: Ref + application_id: String, + data: ApplicationFormType, + loading?: Ref ) => Promise> = (application_id, data, loading) => { - return put(`${prefix}/${application_id}`, data, undefined, loading) + return put(`${prefix}/${application_id}`, data, undefined, loading) } /** @@ -60,10 +60,10 @@ const putApplication: ( * @param 参数 application_id */ const delApplication: ( - application_id: String, - loading?: Ref + application_id: String, + loading?: Ref ) => Promise> = (application_id, loading) => { - return del(`${prefix}/${application_id}`, undefined, {}, loading) + return del(`${prefix}/${application_id}`, undefined, {}, loading) } /** @@ -71,10 +71,10 @@ const delApplication: ( * @param 参数 application_id */ const getApplicationDetail: ( - application_id: string, - loading?: Ref + application_id: string, + loading?: Ref ) => Promise> = (application_id, loading) => { - return get(`${prefix}/${application_id}`, undefined, loading) + return get(`${prefix}/${application_id}`, undefined, loading) } /** @@ -82,10 +82,10 @@ const getApplicationDetail: ( * @param 参数 application_id */ const getApplicationDataset: ( - application_id: string, - loading?: Ref + application_id: string, + loading?: Ref ) => Promise> = (application_id, loading) => { - return get(`${prefix}/${application_id}/list_dataset`, undefined, loading) + return get(`${prefix}/${application_id}/list_dataset`, undefined, loading) } /** @@ -93,10 +93,10 @@ const getApplicationDataset: ( * @param 参数 application_id */ const getAccessToken: (application_id: string, loading?: Ref) => Promise> = ( - application_id, - loading + application_id, + loading ) => { - return get(`${prefix}/${application_id}/access_token`, undefined, loading) + return get(`${prefix}/${application_id}/access_token`, undefined, loading) } /** @@ -107,11 +107,11 @@ const getAccessToken: (application_id: string, loading?: Ref) => Promis * } */ const putAccessToken: ( - application_id: string, - data: any, - loading?: Ref + application_id: string, + data: any, + loading?: Ref ) => Promise> = (application_id, data, loading) => { - return put(`${prefix}/${application_id}/access_token`, data, undefined, loading) + return put(`${prefix}/${application_id}/access_token`, data, undefined, loading) } /** @@ -122,10 +122,10 @@ const putAccessToken: ( } */ const postAppAuthentication: (access_token: string, loading?: Ref) => Promise = ( - access_token, - loading + access_token, + loading ) => { - return post(`${prefix}/authentication`, {access_token}, undefined, loading) + return post(`${prefix}/authentication`, { access_token }, undefined, loading) } /** @@ -136,7 +136,7 @@ const postAppAuthentication: (access_token: string, loading?: Ref) => P } */ const getAppProfile: (loading?: Ref) => Promise = (loading) => { - return get(`${prefix}/profile`, undefined, loading) + return get(`${prefix}/profile`, undefined, loading) } /** @@ -146,7 +146,7 @@ const getAppProfile: (loading?: Ref) => Promise = (loading) => { } */ const postChatOpen: (data: ApplicationFormType) => Promise> = (data) => { - return post(`${prefix}/chat/open`, data) + return post(`${prefix}/chat/open`, data) } /** @@ -156,7 +156,7 @@ const postChatOpen: (data: ApplicationFormType) => Promise> = (data) } */ const postWorkflowChatOpen: (data: ApplicationFormType) => Promise> = (data) => { - return post(`${prefix}/chat_workflow/open`, data) + return post(`${prefix}/chat_workflow/open`, data) } /** @@ -171,7 +171,7 @@ const postWorkflowChatOpen: (data: ApplicationFormType) => Promise> } */ const getChatOpen: (application_id: String) => Promise> = (application_id) => { - return get(`${prefix}/${application_id}/chat/open`) + return get(`${prefix}/${application_id}/chat/open`) } /** * 对话 @@ -180,7 +180,7 @@ const getChatOpen: (application_id: String) => Promise> = (applicati * data */ const postChatMessage: (chat_id: string, data: any) => Promise = (chat_id, data) => { - return postStream(`/api${prefix}/chat_message/${chat_id}`, data) + return postStream(`/api${prefix}/chat_message/${chat_id}`, data) } /** @@ -192,20 +192,20 @@ const postChatMessage: (chat_id: string, data: any) => Promise = (chat_id, } */ const putChatVote: ( - application_id: string, - chat_id: string, - chat_record_id: string, - vote_status: string, - loading?: Ref + application_id: string, + chat_id: string, + chat_record_id: string, + vote_status: string, + loading?: Ref ) => Promise = (application_id, chat_id, chat_record_id, vote_status, loading) => { - return put( - `${prefix}/${application_id}/chat/${chat_id}/chat_record/${chat_record_id}/vote`, - { - vote_status - }, - undefined, - loading - ) + return put( + `${prefix}/${application_id}/chat/${chat_id}/chat_record/${chat_record_id}/vote`, + { + vote_status + }, + undefined, + loading + ) } /** @@ -216,11 +216,11 @@ const putChatVote: ( * @returns */ const getApplicationHitTest: ( - application_id: string, - data: any, - loading?: Ref + application_id: string, + data: any, + loading?: Ref ) => Promise>> = (application_id, data, loading) => { - return get(`${prefix}/${application_id}/hit_test`, data, loading) + return get(`${prefix}/${application_id}/hit_test`, data, loading) } /** @@ -231,10 +231,10 @@ const getApplicationHitTest: ( * @returns */ const getApplicationModel: ( - application_id: string, - loading?: Ref + application_id: string, + loading?: Ref ) => Promise>> = (application_id, loading) => { - return get(`${prefix}/${application_id}/model`, loading) + return get(`${prefix}/${application_id}/model`, loading) } /** @@ -242,11 +242,11 @@ const getApplicationModel: ( * @param 参数 */ const putPublishApplication: ( - application_id: String, - data: ApplicationFormType, - loading?: Ref + application_id: String, + data: ApplicationFormType, + loading?: Ref ) => Promise> = (application_id, data, loading) => { - return put(`${prefix}/${application_id}/publish`, data, undefined, loading) + return put(`${prefix}/${application_id}/publish`, data, undefined, loading) } /** @@ -257,11 +257,18 @@ const putPublishApplication: ( * @returns */ const getModelOtherConfig: ( - application_id: string, - model_id: string, - loading?: Ref -) => Promise> = (application_id, model_id, loading) => { + application_id: string, + model_id: string, + ai_node_id?: string, + loading?: Ref +) => Promise> = (application_id, model_id, ai_node_id?, loading?) => { + if (ai_node_id) { + // 如果 ai_node_id 不为空,则调用带有 ai_node_id 的 API + return get(`${prefix}/${application_id}/model/${model_id}/${ai_node_id}`, undefined, loading) + } else { + // 如果 ai_node_id 为空,则调用不带 ai_node_id 的 API return get(`${prefix}/${application_id}/model/${model_id}`, undefined, loading) + } } /** @@ -270,33 +277,33 @@ const getModelOtherConfig: ( * */ const putModelOtherConfig: ( - application_id: string, - data: any, - loading?: Ref + application_id: string, + data: any, + loading?: Ref ) => Promise> = (application_id, data, loading) => { - return put(`${prefix}/${application_id}/other-config`, data, undefined, loading) + return put(`${prefix}/${application_id}/other-config`, data, undefined, loading) } export default { - getAllAppilcation, - getApplication, - postApplication, - putApplication, - postChatOpen, - getChatOpen, - postChatMessage, - delApplication, - getApplicationDetail, - getApplicationDataset, - getAccessToken, - putAccessToken, - postAppAuthentication, - getAppProfile, - putChatVote, - getApplicationHitTest, - getApplicationModel, - putPublishApplication, - postWorkflowChatOpen, - getModelOtherConfig, - putModelOtherConfig + getAllAppilcation, + getApplication, + postApplication, + putApplication, + postChatOpen, + getChatOpen, + postChatMessage, + delApplication, + getApplicationDetail, + getApplicationDataset, + getAccessToken, + putAccessToken, + postAppAuthentication, + getAppProfile, + putChatVote, + getApplicationHitTest, + getApplicationModel, + putPublishApplication, + postWorkflowChatOpen, + getModelOtherConfig, + putModelOtherConfig } diff --git a/ui/src/stores/modules/application.ts b/ui/src/stores/modules/application.ts index 936d4b64c..9cb706c53 100644 --- a/ui/src/stores/modules/application.ts +++ b/ui/src/stores/modules/application.ts @@ -1,163 +1,168 @@ -import {defineStore} from 'pinia' +import { defineStore } from 'pinia' import applicationApi from '@/api/application' import applicationXpackApi from '@/api/application-xpack' -import {type Ref, type UnwrapRef} from 'vue' +import { type Ref, type UnwrapRef } from 'vue' import useUserStore from './user' -import type {ApplicationFormType} from "@/api/type/application"; +import type { ApplicationFormType } from '@/api/type/application' const useApplicationStore = defineStore({ - id: 'application', - state: () => ({ - location: `${window.location.origin}/ui/chat/` - }), - actions: { - async asyncGetAllApplication() { - return new Promise((resolve, reject) => { - applicationApi - .getAllAppilcation() - .then((data) => { - resolve(data) - }) - .catch((error) => { - reject(error) - }) - }) - }, + id: 'application', + state: () => ({ + location: `${window.location.origin}/ui/chat/` + }), + actions: { + async asyncGetAllApplication() { + return new Promise((resolve, reject) => { + applicationApi + .getAllAppilcation() + .then((data) => { + resolve(data) + }) + .catch((error) => { + reject(error) + }) + }) + }, - async asyncGetApplicationDetail(id: string, loading?: Ref) { - return new Promise((resolve, reject) => { - applicationApi - .getApplicationDetail(id, loading) - .then((data) => { - resolve(data) - }) - .catch((error) => { - reject(error) - }) - }) - }, + async asyncGetApplicationDetail(id: string, loading?: Ref) { + return new Promise((resolve, reject) => { + applicationApi + .getApplicationDetail(id, loading) + .then((data) => { + resolve(data) + }) + .catch((error) => { + reject(error) + }) + }) + }, - async asyncGetApplicationDataset(id: string, loading?: Ref) { - return new Promise((resolve, reject) => { - applicationApi - .getApplicationDataset(id, loading) - .then((data) => { - resolve(data) - }) - .catch((error) => { - reject(error) - }) - }) - }, + async asyncGetApplicationDataset(id: string, loading?: Ref) { + return new Promise((resolve, reject) => { + applicationApi + .getApplicationDataset(id, loading) + .then((data) => { + resolve(data) + }) + .catch((error) => { + reject(error) + }) + }) + }, - async asyncGetAccessToken(id: string, loading?: Ref) { - return new Promise((resolve, reject) => { - const user = useUserStore() - if (user.isEnterprise()) { - applicationXpackApi - .getAccessToken(id, loading) - .then((data) => { - resolve(data) - }) - .catch((error) => { - reject(error) - }) - } else { - applicationApi - .getAccessToken(id, loading) - .then((data) => { - resolve(data) - }) - .catch((error) => { - reject(error) - }) - } + async asyncGetAccessToken(id: string, loading?: Ref) { + return new Promise((resolve, reject) => { + const user = useUserStore() + if (user.isEnterprise()) { + applicationXpackApi + .getAccessToken(id, loading) + .then((data) => { + resolve(data) }) - }, + .catch((error) => { + reject(error) + }) + } else { + applicationApi + .getAccessToken(id, loading) + .then((data) => { + resolve(data) + }) + .catch((error) => { + reject(error) + }) + } + }) + }, - async asyncGetAppProfile(loading?: Ref) { - return new Promise((resolve, reject) => { - const user = useUserStore() - if (user.isEnterprise()) { - applicationXpackApi - .getAppXpackProfile(loading) - .then((data) => { - resolve(data) - }) - .catch((error) => { - reject(error) - }) - } else { - applicationApi - .getAppProfile(loading) - .then((data) => { - resolve(data) - }) - .catch((error) => { - reject(error) - }) - } + async asyncGetAppProfile(loading?: Ref) { + return new Promise((resolve, reject) => { + const user = useUserStore() + if (user.isEnterprise()) { + applicationXpackApi + .getAppXpackProfile(loading) + .then((data) => { + resolve(data) }) - }, + .catch((error) => { + reject(error) + }) + } else { + applicationApi + .getAppProfile(loading) + .then((data) => { + resolve(data) + }) + .catch((error) => { + reject(error) + }) + } + }) + }, - async asyncAppAuthentication(token: string, loading?: Ref) { - return new Promise((resolve, reject) => { - applicationApi - .postAppAuthentication(token, loading) - .then((res) => { - localStorage.setItem('accessToken', res.data) - sessionStorage.setItem('accessToken', res.data) - resolve(res) - }) - .catch((error) => { - reject(error) - }) - }) - }, - async refreshAccessToken(token: string) { - this.asyncAppAuthentication(token) - }, - // 修改应用 - async asyncPutApplication(id: string, data: any, loading?: Ref) { - return new Promise((resolve, reject) => { - applicationApi - .putApplication(id, data, loading) - .then((data) => { - resolve(data) - }) - .catch((error) => { - reject(error) - }) - }) - }, - // 获取模型的 温度/max_token字段设置 - async asyncGetModelConfig(id: string, model_id: string, loading?: Ref) { - return new Promise((resolve, reject) => { - applicationApi - .getModelOtherConfig(id, model_id, loading) - .then((data) => { - resolve(data) - }) - .catch((error) => { - reject(error) - }) - }) - }, - // 保存应用的 温度/max_token字段设置 - async asyncPostModelConfig(id: string, params: any, loading?: Ref) { - return new Promise((resolve, reject) => { - applicationApi - .putModelOtherConfig(id, params, loading) - .then((data) => { - resolve(data) - }) - .catch((error) => { - reject(error) - }) - }) - }, + async asyncAppAuthentication(token: string, loading?: Ref) { + return new Promise((resolve, reject) => { + applicationApi + .postAppAuthentication(token, loading) + .then((res) => { + localStorage.setItem('accessToken', res.data) + sessionStorage.setItem('accessToken', res.data) + resolve(res) + }) + .catch((error) => { + reject(error) + }) + }) + }, + async refreshAccessToken(token: string) { + this.asyncAppAuthentication(token) + }, + // 修改应用 + async asyncPutApplication(id: string, data: any, loading?: Ref) { + return new Promise((resolve, reject) => { + applicationApi + .putApplication(id, data, loading) + .then((data) => { + resolve(data) + }) + .catch((error) => { + reject(error) + }) + }) + }, + // 获取模型的 温度/max_token字段设置 + async asyncGetModelConfig( + id: string, + model_id: string, + ai_node_id?: string, + loading?: Ref + ) { + return new Promise((resolve, reject) => { + applicationApi + .getModelOtherConfig(id, model_id, ai_node_id, loading) + .then((data) => { + resolve(data) + }) + .catch((error) => { + reject(error) + }) + }) + }, + // 保存应用的 温度/max_token字段设置 + async asyncPostModelConfig(id: string, params?: any, loading?: Ref) { + return new Promise((resolve, reject) => { + applicationApi + .putModelOtherConfig(id, params, loading) + .then((data) => { + resolve(data) + }) + .catch((error) => { + reject(error) + }) + }) } + } }) export default useApplicationStore diff --git a/ui/src/views/application/ApplicationSetting.vue b/ui/src/views/application/ApplicationSetting.vue index 7828f2a35..2e96a291d 100644 --- a/ui/src/views/application/ApplicationSetting.vue +++ b/ui/src/views/application/ApplicationSetting.vue @@ -455,7 +455,7 @@ const openAIParamSettingDialog = () => { MsgSuccess(t('请选择AI 模型')) return } - application.asyncGetModelConfig(id, model_id, loading).then((res: any) => { + application.asyncGetModelConfig(id, model_id, '', loading).then((res: any) => { AIModeParamSettingDialogRef.value?.open(res.data) }) } diff --git a/ui/src/views/application/component/AIModeParamSettingDialog.vue b/ui/src/views/application/component/AIModeParamSettingDialog.vue index 33d0ff182..0cf92ee76 100644 --- a/ui/src/views/application/component/AIModeParamSettingDialog.vue +++ b/ui/src/views/application/component/AIModeParamSettingDialog.vue @@ -60,7 +60,8 @@ const form = reactive
({}) const dialogVisible = ref(false) const loading = ref(false) const props = defineProps<{ - id: any + id: string + nodeId?: string }>() const resetForm = () => { // 清空 form 对象,等待新的数据 @@ -103,8 +104,11 @@ const submit = async () => { }, {} as Record ) + if (props.nodeId) { + data.node_id = props.nodeId + } application.asyncPostModelConfig(props.id, data, loading).then(() => { - emit('refresh', form) + emit('refresh', data) dialogVisible.value = false }) } diff --git a/ui/src/workflow/nodes/ai-chat-node/index.vue b/ui/src/workflow/nodes/ai-chat-node/index.vue index 92475e64d..8300a39d8 100644 --- a/ui/src/workflow/nodes/ai-chat-node/index.vue +++ b/ui/src/workflow/nodes/ai-chat-node/index.vue @@ -14,6 +14,7 @@ class="mb-24" label-width="auto" ref="aiChatNodeFormRef" + hide-required-asterisk > + 公用 - + + + {{ item.name }} (不可用) - + + + + + >通过调整提示词内容,可以引导大模型聊天方向,该提示词会被固定在上下文的开头,可以使用变量。 + @@ -163,10 +182,10 @@ - + @@ -177,6 +196,13 @@ @change="openCreateModel($event)" > + +