From 210f99bf705d5f4457fc85d21527636dadddd210 Mon Sep 17 00:00:00 2001 From: wxg0103 <727495428@qq.com> Date: Wed, 14 Aug 2024 18:54:32 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E4=BC=98=E5=8C=96=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B=E5=8F=82=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../step/chat_step/i_chat_step.py | 3 + .../step/chat_step/impl/base_chat_step.py | 2 +- .../serializers/application_serializers.py | 35 +++ .../serializers/chat_message_serializers.py | 2 + apps/application/urls.py | 2 + apps/application/views/application_views.py | 22 ++ apps/setting/models_provider/__init__.py | 8 +- .../models_provider/base_model_provider.py | 7 + .../credential/llm.py | 13 + .../aws_bedrock_model_provider/model/llm.py | 55 +++- .../azure_model_provider/credential/llm.py | 22 ++ .../model/azure_chat_model.py | 94 ++++++- .../models_provider/impl/base_chat_open_ai.py | 2 +- .../deepseek_model_provider/credential/llm.py | 22 ++ .../impl/deepseek_model_provider/model/llm.py | 14 +- .../gemini_model_provider/credential/llm.py | 25 +- .../impl/gemini_model_provider/model/llm.py | 14 +- .../kimi_model_provider/credential/llm.py | 22 ++ .../impl/kimi_model_provider/model/llm.py | 16 +- .../ollama_model_provider/credential/llm.py | 22 ++ .../impl/ollama_model_provider/model/llm.py | 9 +- .../openai_model_provider/credential/llm.py | 22 ++ .../impl/openai_model_provider/model/llm.py | 23 +- .../qwen_model_provider/credential/llm.py | 13 + .../impl/qwen_model_provider/model/llm.py | 60 +++- .../tencent_model_provider/credential/llm.py | 15 +- .../tencent_model_provider/model/hunyuan.py | 5 + .../impl/tencent_model_provider/model/llm.py | 19 +- .../credential/llm.py | 22 ++ .../model/llm.py | 30 +- .../volcanic_engine_model_provider.py | 4 +- .../wenxin_model_provider/credential/llm.py | 13 + .../impl/wenxin_model_provider/model/llm.py | 59 +++- .../impl/xf_model_provider/credential/llm.py | 25 ++ .../impl/xf_model_provider/model/llm.py | 9 +- .../credential/llm.py | 22 ++ .../xinference_model_provider/model/llm.py | 5 +- .../zhipu_model_provider/credential/llm.py | 23 ++ .../impl/zhipu_model_provider/model/llm.py | 10 +- apps/setting/models_provider/tools.py | 4 +- ui/src/api/application.ts | 248 +++++++++------- ui/src/stores/modules/application.ts | 265 ++++++++++-------- .../views/application/ApplicationSetting.vue | 76 +++-- .../component/AIModeParamSettingDialog.vue | 145 +++++----- ui/src/views/authentication/component/CAS.vue | 67 +++-- .../views/authentication/component/LDAP.vue | 86 ++++-- .../views/authentication/component/OIDC.vue | 109 ++++--- ui/src/views/template/index.vue | 1 + 48 files changed, 1281 insertions(+), 510 deletions(-) diff --git a/apps/application/chat_pipeline/step/chat_step/i_chat_step.py b/apps/application/chat_pipeline/step/chat_step/i_chat_step.py index 55ba166cf..a2fc2af3b 100644 --- a/apps/application/chat_pipeline/step/chat_step/i_chat_step.py +++ b/apps/application/chat_pipeline/step/chat_step/i_chat_step.py @@ -73,6 +73,9 @@ class IChatStep(IBaseChatPipelineStep): no_references_setting = NoReferencesSetting(required=True, error_messages=ErrMessage.base("无引用分段设置")) user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id")) + temperature = serializers.FloatField(required=False, allow_null=True, error_messages=ErrMessage.float("温度")) + max_tokens = serializers.IntegerField(required=False, allow_null=True, + error_messages=ErrMessage.integer("最大token数")) def is_valid(self, *, raise_exception=False): super().is_valid(raise_exception=True) diff --git a/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py b/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py index 2667d34a2..b68577409 100644 --- a/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py +++ b/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py @@ -111,7 +111,7 @@ class BaseChatStep(IChatStep): client_id=None, client_type=None, no_references_setting=None, **kwargs): - chat_model = get_model_instance_by_model_user_id(model_id, user_id) if model_id is not None else None + chat_model = get_model_instance_by_model_user_id(model_id, user_id, **kwargs) if model_id is not None else None if stream: return self.execute_stream(message_list, chat_id, problem_text, post_response_handler, chat_model, paragraph_list, diff --git a/apps/application/serializers/application_serializers.py b/apps/application/serializers/application_serializers.py index 040788a8f..de90d61a9 100644 --- a/apps/application/serializers/application_serializers.py +++ b/apps/application/serializers/application_serializers.py @@ -43,6 +43,7 @@ from dataset.serializers.common_serializers import list_paragraph, get_embedding from embedding.models import SearchMode from setting.models import AuthOperate from setting.models.model_management import Model +from setting.models_provider.constants.model_provider_constants import ModelProvideConstants from setting.serializers.provider_serializers import ModelSerializer from smartdoc.conf import PROJECT_DIR from django.conf import settings @@ -109,6 +110,10 @@ class DatasetSettingSerializer(serializers.Serializer): class ModelSettingSerializer(serializers.Serializer): prompt = serializers.CharField(required=True, max_length=2048, error_messages=ErrMessage.char("提示词")) + temperature = serializers.FloatField(required=False, allow_null=True, + error_messages=ErrMessage.char("温度")) + max_tokens = serializers.IntegerField(required=False, allow_null=True, + error_messages=ErrMessage.integer("最大token数")) class ApplicationWorkflowSerializer(serializers.Serializer): @@ -541,6 +546,7 @@ class ApplicationSerializer(serializers.Serializer): class Operate(serializers.Serializer): application_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("应用id")) user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id")) + model_id = serializers.UUIDField(required=False, error_messages=ErrMessage.uuid("模型id")) def is_valid(self, *, raise_exception=False): super().is_valid(raise_exception=True) @@ -722,6 +728,35 @@ class ApplicationSerializer(serializers.Serializer): [self.data.get('user_id') if self.data.get('user_id') == str(application.user_id) else None, application.user_id, self.data.get('user_id')]) + def get_other_file_list(self): + temperature = None + max_tokens = None + application = Application.objects.filter(id=self.initial_data.get("application_id")).first() + if application: + setting_dict = application.model_setting + temperature = setting_dict.get("temperature") + max_tokens = setting_dict.get("max_tokens") + model = Model.objects.filter(id=self.initial_data.get("model_id")).first() + if model: + res = ModelProvideConstants[model.provider].value.get_model_credential(model.model_type, + model.model_name).get_other_fields( + model.model_name) + if temperature and res.get('temperature'): + res['temperature']['value'] = temperature + if max_tokens and res.get('max_tokens'): + res['max_tokens']['value'] = max_tokens + return res + + def save_other_config(self, data): + application = Application.objects.filter(id=self.initial_data.get("application_id")).first() + if application: + setting_dict = application.model_setting + for key in ['max_tokens', 'temperature']: + if key in data: + setting_dict[key] = data[key] + application.model_setting = setting_dict + application.save() + class ApplicationKeySerializerModel(serializers.ModelSerializer): class Meta: model = ApplicationApiKey diff --git a/apps/application/serializers/chat_message_serializers.py b/apps/application/serializers/chat_message_serializers.py index b19a8ed5b..d526e3540 100644 --- a/apps/application/serializers/chat_message_serializers.py +++ b/apps/application/serializers/chat_message_serializers.py @@ -80,6 +80,8 @@ class ChatInfo: 'model_id': self.application.model.id if self.application.model is not None else None, 'problem_optimization': self.application.problem_optimization, 'stream': True, + 'temperature': model_setting.get('temperature') if 'temperature' in model_setting else None, + 'max_tokens': model_setting.get('max_tokens') if 'max_tokens' in model_setting else None, 'search_mode': self.application.dataset_setting.get( 'search_mode') if 'search_mode' in self.application.dataset_setting else 'embedding', 'no_references_setting': self.application.dataset_setting.get( diff --git a/apps/application/urls.py b/apps/application/urls.py index 335205d37..95ec19ef4 100644 --- a/apps/application/urls.py +++ b/apps/application/urls.py @@ -19,6 +19,8 @@ urlpatterns = [ path('application//statistics/chat_record_aggregate_trend', views.ApplicationStatistics.ChatRecordAggregateTrend.as_view()), path('application//model', views.Application.Model.as_view()), + path('application//model/', views.Application.Model.Operate.as_view()), + path('application//other-config', views.Application.Model.OtherConfig.as_view()), path('application//hit_test', views.Application.HitTest.as_view()), path('application//api_key', views.Application.ApplicationKey.as_view()), path("application//api_key/", diff --git a/apps/application/views/application_views.py b/apps/application/views/application_views.py index e6fe7191e..14aa0f1e8 100644 --- a/apps/application/views/application_views.py +++ b/apps/application/views/application_views.py @@ -187,6 +187,28 @@ class Application(APIView): data={'application_id': application_id, 'user_id': request.user.id}).list_model(request.query_params.get('model_type'))) + class Operate(APIView): + authentication_classes = [TokenAuth] + + @swagger_auto_schema(operation_summary="获取应用参数设置其他字段", + operation_id="获取应用参数设置其他字段", + tags=["应用/会话"]) + def get(self, request: Request, application_id: str, model_id: str): + return result.success( + ApplicationSerializer.Operate( + data={'application_id': application_id, 'model_id': model_id}).get_other_file_list()) + + class OtherConfig(APIView): + authentication_classes = [TokenAuth] + + @swagger_auto_schema(operation_summary="获取应用参数设置其他字段", + operation_id="获取应用参数设置其他字段", + tags=["应用/会话"]) + def put(self, request: Request, application_id: str): + return result.success( + ApplicationSerializer.Operate( + data={'application_id': application_id}).save_other_config(request.data)) + class Profile(APIView): authentication_classes = [TokenAuth] diff --git a/apps/setting/models_provider/__init__.py b/apps/setting/models_provider/__init__.py index 285e90f32..4197eb042 100644 --- a/apps/setting/models_provider/__init__.py +++ b/apps/setting/models_provider/__init__.py @@ -13,7 +13,7 @@ from common.util.rsa_util import rsa_long_decrypt from setting.models_provider.constants.model_provider_constants import ModelProvideConstants -def get_model_(provider, model_type, model_name, credential): +def get_model_(provider, model_type, model_name, credential, **kwargs): """ 获取模型实例 @param provider: 供应商 @@ -25,17 +25,17 @@ def get_model_(provider, model_type, model_name, credential): model = get_provider(provider).get_model(model_type, model_name, json.loads( rsa_long_decrypt(credential)), - streaming=True) + streaming=True, **kwargs) return model -def get_model(model): +def get_model(model, **kwargs): """ 获取模型实例 @param model: model 数据库Model实例对象 @return: 模型实例 """ - return get_model_(model.provider, model.model_type, model.model_name, model.credential) + return get_model_(model.provider, model.model_type, model.model_name, model.credential, **kwargs) def get_provider(provider): diff --git a/apps/setting/models_provider/base_model_provider.py b/apps/setting/models_provider/base_model_provider.py index 022cf92e3..a6966247c 100644 --- a/apps/setting/models_provider/base_model_provider.py +++ b/apps/setting/models_provider/base_model_provider.py @@ -108,6 +108,13 @@ class BaseModelCredential(ABC): """ pass + def get_other_fields(self, model_name): + """ + 获取其他字段 + :return: + """ + pass + @staticmethod def encryption(message: str): """ diff --git a/apps/setting/models_provider/impl/aws_bedrock_model_provider/credential/llm.py b/apps/setting/models_provider/impl/aws_bedrock_model_provider/credential/llm.py index 848250b1e..27bb88f28 100644 --- a/apps/setting/models_provider/impl/aws_bedrock_model_provider/credential/llm.py +++ b/apps/setting/models_provider/impl/aws_bedrock_model_provider/credential/llm.py @@ -60,3 +60,16 @@ class BedrockLLMModelCredential(BaseForm, BaseModelCredential): region_name = forms.TextInputField('Region Name', required=True) access_key_id = forms.TextInputField('Access Key ID', required=True) secret_access_key = forms.PasswordInputField('Secret Access Key', required=True) + + def get_other_fields(self, model_name): + return { + 'max_tokens': { + 'value': 1024, + 'min': 1, + 'max': 8192, + 'step': 1, + 'label': '输出最大Tokens', + 'precision': 0, + 'tooltip': '指定模型可生成的最大token个数' + } + } diff --git a/apps/setting/models_provider/impl/aws_bedrock_model_provider/model/llm.py b/apps/setting/models_provider/impl/aws_bedrock_model_provider/model/llm.py index 43c3c00a0..41546efb6 100644 --- a/apps/setting/models_provider/impl/aws_bedrock_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/aws_bedrock_model_provider/model/llm.py @@ -1,12 +1,18 @@ -from typing import List, Dict, Any +from typing import List, Dict, Any, Optional, Iterator from langchain_community.chat_models import BedrockChat -from langchain_core.messages import BaseMessage, get_buffer_string -from common.config.tokenizer_manage_config import TokenizerManage +from langchain_community.chat_models.bedrock import ChatPromptAdapter +from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.messages import BaseMessage, get_buffer_string, AIMessageChunk +from langchain_core.outputs import ChatGenerationChunk from setting.models_provider.base_model_provider import MaxKBBaseModel class BedrockModel(MaxKBBaseModel, BedrockChat): + @staticmethod + def is_cache_model(): + return False + def __init__(self, model_id: str, region_name: str, credentials_profile_name: str, streaming: bool = False, **kwargs): super().__init__(model_id=model_id, region_name=region_name, @@ -15,21 +21,52 @@ class BedrockModel(MaxKBBaseModel, BedrockChat): @classmethod def new_instance(cls, model_type: str, model_name: str, model_credential: Dict[str, str], **model_kwargs) -> 'BedrockModel': + optional_params = {} + if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: + optional_params['max_tokens'] = model_kwargs['max_tokens'] + if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: + optional_params['temperature'] = model_kwargs['temperature'] + return cls( model_id=model_name, region_name=model_credential['region_name'], credentials_profile_name=model_credential['credentials_profile_name'], streaming=model_kwargs.pop('streaming', False), - **model_kwargs + **optional_params ) - def _get_num_tokens(self, content: str) -> int: - """Helper method to count tokens in a string.""" - tokenizer = TokenizerManage.get_tokenizer() - return len(tokenizer.encode(content)) - def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: return sum(self._get_num_tokens(get_buffer_string([message])) for message in messages) def get_num_tokens(self, text: str) -> int: return self._get_num_tokens(text) + + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + provider = self._get_provider() + prompt, system, formatted_messages = None, None, None + + if provider == "anthropic": + system, formatted_messages = ChatPromptAdapter.format_messages( + provider, messages + ) + else: + prompt = ChatPromptAdapter.convert_messages_to_prompt( + provider=provider, messages=messages + ) + + for chunk in self._prepare_input_and_invoke_stream( + prompt=prompt, + system=system, + messages=formatted_messages, + stop=stop, + run_manager=run_manager, + **kwargs, + ): + delta = chunk.text + yield ChatGenerationChunk(message=AIMessageChunk(content=delta)) diff --git a/apps/setting/models_provider/impl/azure_model_provider/credential/llm.py b/apps/setting/models_provider/impl/azure_model_provider/credential/llm.py index ddfb395a0..b20c63368 100644 --- a/apps/setting/models_provider/impl/azure_model_provider/credential/llm.py +++ b/apps/setting/models_provider/impl/azure_model_provider/credential/llm.py @@ -53,3 +53,25 @@ class AzureLLMModelCredential(BaseForm, BaseModelCredential): api_key = forms.PasswordInputField("API Key (api_key)", required=True) deployment_name = forms.TextInputField("部署名 (deployment_name)", required=True) + + def get_other_fields(self, model_name): + return { + 'temperature': { + 'value': 0.7, + 'min': 0.1, + 'max': 1, + 'step': 0.01, + 'label': '温度', + 'precision': 2, + 'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定' + }, + 'max_tokens': { + 'value': 800, + 'min': 1, + 'max': 4096, + 'step': 1, + 'label': '输出最大Tokens', + 'precision': 0, + 'tooltip': '指定模型可生成的最大token个数' + } + } diff --git a/apps/setting/models_provider/impl/azure_model_provider/model/azure_chat_model.py b/apps/setting/models_provider/impl/azure_model_provider/model/azure_chat_model.py index 45b388af6..9b93b437e 100644 --- a/apps/setting/models_provider/impl/azure_model_provider/model/azure_chat_model.py +++ b/apps/setting/models_provider/impl/azure_model_provider/model/azure_chat_model.py @@ -6,36 +6,102 @@ @date:2024/4/28 11:45 @desc: """ -from typing import List, Dict -from langchain_core.messages import BaseMessage, get_buffer_string +from typing import List, Dict, Optional, Any, Iterator, Type +from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.messages import BaseMessage, AIMessageChunk, BaseMessageChunk +from langchain_core.outputs import ChatGenerationChunk +from langchain_openai.chat_models.base import _convert_delta_to_message_chunk from langchain_openai import AzureChatOpenAI -from common.config.tokenizer_manage_config import TokenizerManage from setting.models_provider.base_model_provider import MaxKBBaseModel class AzureChatModel(MaxKBBaseModel, AzureChatOpenAI): + @staticmethod + def is_cache_model(): + return False + @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + optional_params = {} + if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: + optional_params['max_tokens'] = model_kwargs['max_tokens'] + if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: + optional_params['temperature'] = model_kwargs['temperature'] + return AzureChatModel( azure_endpoint=model_credential.get('api_base'), openai_api_version=model_credential.get('api_version', '2024-02-15-preview'), deployment_name=model_credential.get('deployment_name'), openai_api_key=model_credential.get('api_key'), - openai_api_type="azure" + openai_api_type="azure", + **optional_params ) + def get_last_generation_info(self) -> Optional[Dict[str, Any]]: + return self.__dict__.get('_last_generation_info') + def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: - try: - return super().get_num_tokens_from_messages(messages) - except Exception as e: - tokenizer = TokenizerManage.get_tokenizer() - return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages]) + return self.get_last_generation_info().get('prompt_tokens', 0) def get_num_tokens(self, text: str) -> int: - try: - return super().get_num_tokens(text) - except Exception as e: - tokenizer = TokenizerManage.get_tokenizer() - return len(tokenizer.encode(text)) + return self.get_last_generation_info().get('completion_tokens', 0) + + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + kwargs["stream"] = True + kwargs["stream_options"] = {"include_usage": True} + payload = self._get_request_payload(messages, stop=stop, **kwargs) + default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk + if self.include_response_headers: + raw_response = self.client.with_raw_response.create(**payload) + response = raw_response.parse() + base_generation_info = {"headers": dict(raw_response.headers)} + else: + response = self.client.create(**payload) + base_generation_info = {} + with response: + is_first_chunk = True + for chunk in response: + if not isinstance(chunk, dict): + chunk = chunk.model_dump() + if len(chunk["choices"]) == 0: + if token_usage := chunk.get("usage"): + self.__dict__.setdefault('_last_generation_info', {}).update(token_usage) + logprobs = None + else: + continue + else: + choice = chunk["choices"][0] + if choice["delta"] is None: + continue + message_chunk = _convert_delta_to_message_chunk( + choice["delta"], default_chunk_class + ) + generation_info = {**base_generation_info} if is_first_chunk else {} + if finish_reason := choice.get("finish_reason"): + generation_info["finish_reason"] = finish_reason + if model_name := chunk.get("model"): + generation_info["model_name"] = model_name + if system_fingerprint := chunk.get("system_fingerprint"): + generation_info["system_fingerprint"] = system_fingerprint + + logprobs = choice.get("logprobs") + if logprobs: + generation_info["logprobs"] = logprobs + default_chunk_class = message_chunk.__class__ + generation_chunk = ChatGenerationChunk( + message=message_chunk, generation_info=generation_info or None + ) + if run_manager: + run_manager.on_llm_new_token( + generation_chunk.text, chunk=generation_chunk, logprobs=logprobs + ) + is_first_chunk = False + yield generation_chunk diff --git a/apps/setting/models_provider/impl/base_chat_open_ai.py b/apps/setting/models_provider/impl/base_chat_open_ai.py index 1774c75b1..74afcf0f2 100644 --- a/apps/setting/models_provider/impl/base_chat_open_ai.py +++ b/apps/setting/models_provider/impl/base_chat_open_ai.py @@ -42,7 +42,7 @@ class BaseChatOpenAI(ChatOpenAI): for chunk in response: if not isinstance(chunk, dict): chunk = chunk.model_dump() - if len(chunk["choices"]) == 0: + if len(chunk["choices"]) == 0 or chunk["choices"][0]["finish_reason"] == "length" or chunk["choices"][0]["finish_reason"] == "stop": if token_usage := chunk.get("usage"): self.__dict__.setdefault('_last_generation_info', {}).update(token_usage) logprobs = None diff --git a/apps/setting/models_provider/impl/deepseek_model_provider/credential/llm.py b/apps/setting/models_provider/impl/deepseek_model_provider/credential/llm.py index 29d9e344c..d4bea15ea 100644 --- a/apps/setting/models_provider/impl/deepseek_model_provider/credential/llm.py +++ b/apps/setting/models_provider/impl/deepseek_model_provider/credential/llm.py @@ -46,3 +46,25 @@ class DeepSeekLLMModelCredential(BaseForm, BaseModelCredential): return {**model, 'api_key': super().encryption(model.get('api_key', ''))} api_key = forms.PasswordInputField('API Key', required=True) + + def get_other_fields(self, model_name): + return { + 'temperature': { + 'value': 0.7, + 'min': 0.1, + 'max': 1, + 'step': 0.01, + 'label': '温度', + 'precision': 2, + 'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定' + }, + 'max_tokens': { + 'value': 800, + 'min': 1, + 'max': 4096, + 'step': 1, + 'label': '输出最大Tokens', + 'precision': 0, + 'tooltip': '指定模型可生成的最大token个数' + } + } diff --git a/apps/setting/models_provider/impl/deepseek_model_provider/model/llm.py b/apps/setting/models_provider/impl/deepseek_model_provider/model/llm.py index 8bc6f30d2..94c7d4899 100644 --- a/apps/setting/models_provider/impl/deepseek_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/deepseek_model_provider/model/llm.py @@ -13,12 +13,24 @@ from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI class DeepSeekChatModel(MaxKBBaseModel, BaseChatOpenAI): + + @staticmethod + def is_cache_model(): + return False + @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + optional_params = {} + if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: + optional_params['max_tokens'] = model_kwargs['max_tokens'] + if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: + optional_params['temperature'] = model_kwargs['temperature'] + deepseek_chat_open_ai = DeepSeekChatModel( model=model_name, openai_api_base='https://api.deepseek.com', - openai_api_key=model_credential.get('api_key') + openai_api_key=model_credential.get('api_key'), + **optional_params ) return deepseek_chat_open_ai diff --git a/apps/setting/models_provider/impl/gemini_model_provider/credential/llm.py b/apps/setting/models_provider/impl/gemini_model_provider/credential/llm.py index 85ea1b23b..135e64fce 100644 --- a/apps/setting/models_provider/impl/gemini_model_provider/credential/llm.py +++ b/apps/setting/models_provider/impl/gemini_model_provider/credential/llm.py @@ -32,7 +32,8 @@ class GeminiLLMModelCredential(BaseForm, BaseModelCredential): return False try: model = provider.get_model(model_type, model_name, model_credential) - model.invoke([HumanMessage(content='你好')]) + res = model.invoke([HumanMessage(content='你好')]) + print(res) except Exception as e: if isinstance(e, AppApiException): raise e @@ -46,3 +47,25 @@ class GeminiLLMModelCredential(BaseForm, BaseModelCredential): return {**model, 'api_key': super().encryption(model.get('api_key', ''))} api_key = forms.PasswordInputField('API Key', required=True) + + def get_other_fields(self): + return { + 'temperature': { + 'value': 0.7, + 'min': 0.1, + 'max': 1.0, + 'step': 0.01, + 'label': '温度', + 'precision': 2, + 'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定' + }, + 'max_tokens': { + 'value': 800, + 'min': 1, + 'max': 4096, + 'step': 1, + 'label': '输出最大Tokens', + 'precision': 0, + 'tooltip': '指定模型可生成的最大token个数' + } + } diff --git a/apps/setting/models_provider/impl/gemini_model_provider/model/llm.py b/apps/setting/models_provider/impl/gemini_model_provider/model/llm.py index b3e0bbc96..cba5186a7 100644 --- a/apps/setting/models_provider/impl/gemini_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/gemini_model_provider/model/llm.py @@ -16,11 +16,23 @@ from setting.models_provider.base_model_provider import MaxKBBaseModel class GeminiChatModel(MaxKBBaseModel, ChatGoogleGenerativeAI): + + @staticmethod + def is_cache_model(): + return False + @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + optional_params = {} + if 'temperature' in model_kwargs: + optional_params['temperature'] = model_kwargs['temperature'] + if 'max_tokens' in model_kwargs: + optional_params['max_output_tokens'] = model_kwargs['max_tokens'] + gemini_chat = GeminiChatModel( model=model_name, - google_api_key=model_credential.get('api_key') + google_api_key=model_credential.get('api_key'), + **optional_params ) return gemini_chat diff --git a/apps/setting/models_provider/impl/kimi_model_provider/credential/llm.py b/apps/setting/models_provider/impl/kimi_model_provider/credential/llm.py index 9e5835cca..13245af36 100644 --- a/apps/setting/models_provider/impl/kimi_model_provider/credential/llm.py +++ b/apps/setting/models_provider/impl/kimi_model_provider/credential/llm.py @@ -47,3 +47,25 @@ class KimiLLMModelCredential(BaseForm, BaseModelCredential): api_base = forms.TextInputField('API 域名', required=True) api_key = forms.PasswordInputField('API Key', required=True) + + def get_other_fields(self, model_name): + return { + 'temperature': { + 'value': 0.3, + 'min': 0.1, + 'max': 1.0, + 'step': 0.01, + 'label': '温度', + 'precision': 2, + 'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定' + }, + 'max_tokens': { + 'value': 1024, + 'min': 1, + 'max': 4096, + 'step': 1, + 'label': '输出最大Tokens', + 'precision': 0, + 'tooltip': '指定模型可生成的最大token个数' + } + } diff --git a/apps/setting/models_provider/impl/kimi_model_provider/model/llm.py b/apps/setting/models_provider/impl/kimi_model_provider/model/llm.py index 652788cc5..2dd117fa5 100644 --- a/apps/setting/models_provider/impl/kimi_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/kimi_model_provider/model/llm.py @@ -8,20 +8,28 @@ """ from typing import List, Dict -from langchain_core.messages import BaseMessage, get_buffer_string - -from common.config.tokenizer_manage_config import TokenizerManage from setting.models_provider.base_model_provider import MaxKBBaseModel from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI class KimiChatModel(MaxKBBaseModel, BaseChatOpenAI): + + @staticmethod + def is_cache_model(): + return False + @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + optional_params = {} + if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: + optional_params['max_tokens'] = model_kwargs['max_tokens'] + if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: + optional_params['temperature'] = model_kwargs['temperature'] + kimi_chat_open_ai = KimiChatModel( openai_api_base=model_credential['api_base'], openai_api_key=model_credential['api_key'], model_name=model_name, + **optional_params ) return kimi_chat_open_ai - diff --git a/apps/setting/models_provider/impl/ollama_model_provider/credential/llm.py b/apps/setting/models_provider/impl/ollama_model_provider/credential/llm.py index 030924ca0..136f7d2a9 100644 --- a/apps/setting/models_provider/impl/ollama_model_provider/credential/llm.py +++ b/apps/setting/models_provider/impl/ollama_model_provider/credential/llm.py @@ -42,3 +42,25 @@ class OllamaLLMModelCredential(BaseForm, BaseModelCredential): api_base = forms.TextInputField('API 域名', required=True) api_key = forms.PasswordInputField('API Key', required=True) + + def get_other_fields(self, model_name): + return { + 'temperature': { + 'value': 0.3, + 'min': 0.1, + 'max': 1.0, + 'step': 0.01, + 'label': '温度', + 'precision': 2, + 'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定' + }, + 'max_tokens': { + 'value': 1024, + 'min': 1, + 'max': 4096, + 'step': 1, + 'label': '输出最大Tokens', + 'precision': 0, + 'tooltip': '指定模型可生成的最大token个数' + } + } diff --git a/apps/setting/models_provider/impl/ollama_model_provider/model/llm.py b/apps/setting/models_provider/impl/ollama_model_provider/model/llm.py index 2a21c31b9..08e8ee9e0 100644 --- a/apps/setting/models_provider/impl/ollama_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/ollama_model_provider/model/llm.py @@ -30,5 +30,12 @@ class OllamaChatModel(MaxKBBaseModel, BaseChatOpenAI): api_base = model_credential.get('api_base', '') base_url = get_base_url(api_base) base_url = base_url if base_url.endswith('/v1') else (base_url + '/v1') + optional_params = {} + if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: + optional_params['max_tokens'] = model_kwargs['max_tokens'] + if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: + optional_params['temperature'] = model_kwargs['temperature'] + return OllamaChatModel(model=model_name, openai_api_base=base_url, - openai_api_key=model_credential.get('api_key')) + openai_api_key=model_credential.get('api_key'), + stream_usage=True, **optional_params) diff --git a/apps/setting/models_provider/impl/openai_model_provider/credential/llm.py b/apps/setting/models_provider/impl/openai_model_provider/credential/llm.py index 60ff3b88f..42a713a3f 100644 --- a/apps/setting/models_provider/impl/openai_model_provider/credential/llm.py +++ b/apps/setting/models_provider/impl/openai_model_provider/credential/llm.py @@ -47,3 +47,25 @@ class OpenAILLMModelCredential(BaseForm, BaseModelCredential): api_base = forms.TextInputField('API 域名', required=True) api_key = forms.PasswordInputField('API Key', required=True) + + def get_other_fields(self, model_name): + return { + 'temperature': { + 'value': 0.7, + 'min': 0.1, + 'max': 1.0, + 'step': 0.01, + 'label': '温度', + 'precision': 2, + 'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定' + }, + 'max_tokens': { + 'value': 800, + 'min': 1, + 'max': 4096, + 'step': 1, + 'label': '输出最大Tokens', + 'precision': 0, + 'tooltip': '指定模型可生成的最大token个数' + } + } diff --git a/apps/setting/models_provider/impl/openai_model_provider/model/llm.py b/apps/setting/models_provider/impl/openai_model_provider/model/llm.py index 9c3a9c116..f7e48da45 100644 --- a/apps/setting/models_provider/impl/openai_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/openai_model_provider/model/llm.py @@ -6,21 +6,32 @@ @date:2024/4/18 15:28 @desc: """ -from typing import List, Dict +from typing import List, Dict, Optional, Iterator, Any, Type +from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.messages import BaseMessage, get_buffer_string, BaseMessageChunk, AIMessageChunk +from langchain_core.messages.ai import UsageMetadata +from langchain_core.outputs import ChatGenerationChunk +from langchain_openai import ChatOpenAI +from langchain_openai.chat_models.base import _convert_delta_to_message_chunk + +from common.config.tokenizer_manage_config import TokenizerManage from setting.models_provider.base_model_provider import MaxKBBaseModel -from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI -class OpenAIChatModel(MaxKBBaseModel, BaseChatOpenAI): +class OpenAIChatModel(MaxKBBaseModel, ChatOpenAI): @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + optional_params = {} + if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: + optional_params['max_tokens'] = model_kwargs['max_tokens'] + if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: + optional_params['temperature'] = model_kwargs['temperature'] azure_chat_open_ai = OpenAIChatModel( model=model_name, openai_api_base=model_credential.get('api_base'), openai_api_key=model_credential.get('api_key'), - streaming=model_kwargs.get('streaming', False), - max_tokens=model_kwargs.get('max_tokens', 5), - temperature=model_kwargs.get('temperature', 0.5), + **optional_params, + stream_usage=True ) return azure_chat_open_ai diff --git a/apps/setting/models_provider/impl/qwen_model_provider/credential/llm.py b/apps/setting/models_provider/impl/qwen_model_provider/credential/llm.py index e31624c95..714a8eaa4 100644 --- a/apps/setting/models_provider/impl/qwen_model_provider/credential/llm.py +++ b/apps/setting/models_provider/impl/qwen_model_provider/credential/llm.py @@ -45,3 +45,16 @@ class OpenAILLMModelCredential(BaseForm, BaseModelCredential): return {**model, 'api_key': super().encryption(model.get('api_key', ''))} api_key = forms.PasswordInputField('API Key', required=True) + + def get_other_fields(self, model_name): + return { + 'temperature': { + 'value': 1.0, + 'min': 0.1, + 'max': 1.9, + 'step': 0.01, + 'label': '温度', + 'precision': 2, + 'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定' + }, + } diff --git a/apps/setting/models_provider/impl/qwen_model_provider/model/llm.py b/apps/setting/models_provider/impl/qwen_model_provider/model/llm.py index fa4b1109a..2bf7f6ac9 100644 --- a/apps/setting/models_provider/impl/qwen_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/qwen_model_provider/model/llm.py @@ -6,10 +6,13 @@ @date:2024/4/28 11:44 @desc: """ -from typing import List, Dict +from typing import List, Dict, Optional, Iterator, Any from langchain_community.chat_models import ChatTongyi +from langchain_community.llms.tongyi import generate_with_last_element_mark +from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.messages import BaseMessage, get_buffer_string +from langchain_core.outputs import ChatGenerationChunk from common.config.tokenizer_manage_config import TokenizerManage from setting.models_provider.base_model_provider import MaxKBBaseModel @@ -18,16 +21,61 @@ from setting.models_provider.base_model_provider import MaxKBBaseModel class QwenChatModel(MaxKBBaseModel, ChatTongyi): @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + optional_params = {} + if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: + optional_params['max_tokens'] = model_kwargs['max_tokens'] + if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: + optional_params['temperature'] = model_kwargs['temperature'] chat_tong_yi = QwenChatModel( model_name=model_name, - dashscope_api_key=model_credential.get('api_key') + dashscope_api_key=model_credential.get('api_key'), + **optional_params, ) return chat_tong_yi + def get_last_generation_info(self) -> Optional[Dict[str, Any]]: + return self.__dict__.get('_last_generation_info') + def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: - tokenizer = TokenizerManage.get_tokenizer() - return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages]) + return self.get_last_generation_info().get('input_tokens', 0) def get_num_tokens(self, text: str) -> int: - tokenizer = TokenizerManage.get_tokenizer() - return len(tokenizer.encode(text)) + return self.get_last_generation_info().get('output_tokens', 0) + + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + params: Dict[str, Any] = self._invocation_params( + messages=messages, stop=stop, stream=True, **kwargs + ) + + for stream_resp, is_last_chunk in generate_with_last_element_mark( + self.stream_completion_with_retry(**params) + ): + choice = stream_resp["output"]["choices"][0] + message = choice["message"] + if ( + choice["finish_reason"] == "stop" + and message["content"] == "" + ): + token_usage = stream_resp["usage"] + self.__dict__.setdefault('_last_generation_info', {}).update(token_usage) + if ( + choice["finish_reason"] == "null" + and message["content"] == "" + and "tool_calls" not in message + ): + continue + + chunk = ChatGenerationChunk( + **self._chat_generation_from_qwen_resp( + stream_resp, is_chunk=True, is_last_chunk=is_last_chunk + ) + ) + if run_manager: + run_manager.on_llm_new_token(chunk.text, chunk=chunk) + yield chunk diff --git a/apps/setting/models_provider/impl/tencent_model_provider/credential/llm.py b/apps/setting/models_provider/impl/tencent_model_provider/credential/llm.py index ad9ab3a82..06c892675 100644 --- a/apps/setting/models_provider/impl/tencent_model_provider/credential/llm.py +++ b/apps/setting/models_provider/impl/tencent_model_provider/credential/llm.py @@ -44,4 +44,17 @@ class TencentLLMModelCredential(BaseForm, BaseModelCredential): hunyuan_app_id = forms.TextInputField('APP ID', required=True) hunyuan_secret_id = forms.PasswordInputField('SecretId', required=True) - hunyuan_secret_key = forms.PasswordInputField('SecretKey', required=True) \ No newline at end of file + hunyuan_secret_key = forms.PasswordInputField('SecretKey', required=True) + + def get_other_fields(self, model_name): + return { + 'temperature': { + 'value': 0.5, + 'min': 0.1, + 'max': 2.0, + 'step': 0.01, + 'label': '温度', + 'precision': 2, + 'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定' + }, + } diff --git a/apps/setting/models_provider/impl/tencent_model_provider/model/hunyuan.py b/apps/setting/models_provider/impl/tencent_model_provider/model/hunyuan.py index 9af6f983c..219920125 100644 --- a/apps/setting/models_provider/impl/tencent_model_provider/model/hunyuan.py +++ b/apps/setting/models_provider/impl/tencent_model_provider/model/hunyuan.py @@ -236,6 +236,11 @@ class ChatHunyuan(BaseChatModel): choice["Delta"], default_chunk_class ) default_chunk_class = chunk.__class__ + # FinishReason === stop + if choice.get("FinishReason") == "stop": + self.__dict__.setdefault("_last_generation_info", {}).update( + response.get("Usage", {}) + ) cg_chunk = ChatGenerationChunk(message=chunk) if run_manager: run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk) diff --git a/apps/setting/models_provider/impl/tencent_model_provider/model/llm.py b/apps/setting/models_provider/impl/tencent_model_provider/model/llm.py index 81116eb61..dd726b214 100644 --- a/apps/setting/models_provider/impl/tencent_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/tencent_model_provider/model/llm.py @@ -1,6 +1,6 @@ # coding=utf-8 -from typing import List, Dict +from typing import List, Dict, Optional, Any from langchain_core.messages import BaseMessage, get_buffer_string from common.config.tokenizer_manage_config import TokenizerManage @@ -15,12 +15,18 @@ class TencentModel(MaxKBBaseModel, ChatHunyuan): hunyuan_secret_id = credentials.get('hunyuan_secret_id') hunyuan_secret_key = credentials.get('hunyuan_secret_key') + optional_params = {} + if 'temperature' in kwargs: + optional_params['temperature'] = kwargs['temperature'] + if not all([hunyuan_app_id, hunyuan_secret_id, hunyuan_secret_key]): raise ValueError( "All of 'hunyuan_app_id', 'hunyuan_secret_id', and 'hunyuan_secret_key' must be provided in credentials.") super().__init__(model=model_name, hunyuan_app_id=hunyuan_app_id, hunyuan_secret_id=hunyuan_secret_id, - hunyuan_secret_key=hunyuan_secret_key, streaming=streaming, **kwargs) + hunyuan_secret_key=hunyuan_secret_key, streaming=streaming, + temperature=optional_params.get('temperature', None) + ) @staticmethod def new_instance(model_type: str, model_name: str, model_credential: Dict[str, object], @@ -28,10 +34,11 @@ class TencentModel(MaxKBBaseModel, ChatHunyuan): streaming = model_kwargs.pop('streaming', False) return TencentModel(model_name=model_name, credentials=model_credential, streaming=streaming, **model_kwargs) + def get_last_generation_info(self) -> Optional[Dict[str, Any]]: + return self.__dict__.get('_last_generation_info') + def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: - tokenizer = TokenizerManage.get_tokenizer() - return sum(len(tokenizer.encode(get_buffer_string([m]))) for m in messages) + return self.get_last_generation_info().get('PromptTokens', 0) def get_num_tokens(self, text: str) -> int: - tokenizer = TokenizerManage.get_tokenizer() - return len(tokenizer.encode(text)) + return self.get_last_generation_info().get('CompletionTokens', 0) diff --git a/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/llm.py b/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/llm.py index 5647e2a08..025c43409 100644 --- a/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/llm.py +++ b/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/llm.py @@ -48,3 +48,25 @@ class VolcanicEngineLLMModelCredential(BaseForm, BaseModelCredential): access_key_id = forms.PasswordInputField('Access Key ID', required=True) secret_access_key = forms.PasswordInputField('Secret Access Key', required=True) + + def get_other_fields(self): + return { + 'temperature': { + 'value': 0.3, + 'min': 0.1, + 'max': 1.0, + 'step': 0.01, + 'label': '温度', + 'precision': 2, + 'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定' + }, + 'max_tokens': { + 'value': 1024, + 'min': 1, + 'max': 4096, + 'step': 1, + 'label': '输出最大Tokens', + 'precision': 0, + 'tooltip': '指定模型可生成的最大token个数' + } + } diff --git a/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/llm.py b/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/llm.py index 3ab309863..e7ce56b6d 100644 --- a/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/llm.py @@ -1,27 +1,21 @@ from typing import List, Dict -from langchain_community.chat_models import VolcEngineMaasChat -from langchain_core.messages import BaseMessage, get_buffer_string - -from common.config.tokenizer_manage_config import TokenizerManage from setting.models_provider.base_model_provider import MaxKBBaseModel -from langchain_openai import ChatOpenAI + +from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI -class VolcanicEngineChatModel(MaxKBBaseModel, ChatOpenAI): +class VolcanicEngineChatModel(MaxKBBaseModel, BaseChatOpenAI): @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): - volcanic_engine_chat = VolcanicEngineChatModel( + optional_params = {} + if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: + optional_params['max_tokens'] = model_kwargs['max_tokens'] + if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: + optional_params['temperature'] = model_kwargs['temperature'] + return VolcanicEngineChatModel( model=model_name, - volc_engine_maas_ak=model_credential.get("access_key_id"), - volc_engine_maas_sk=model_credential.get("secret_access_key"), + openai_api_base=model_credential.get('api_base'), + openai_api_key=model_credential.get('api_key'), + **optional_params ) - return volcanic_engine_chat - - def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: - tokenizer = TokenizerManage.get_tokenizer() - return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages]) - - def get_num_tokens(self, text: str) -> int: - tokenizer = TokenizerManage.get_tokenizer() - return len(tokenizer.encode(text)) diff --git a/apps/setting/models_provider/impl/volcanic_engine_model_provider/volcanic_engine_model_provider.py b/apps/setting/models_provider/impl/volcanic_engine_model_provider/volcanic_engine_model_provider.py index af7bc76b3..48802f6b8 100644 --- a/apps/setting/models_provider/impl/volcanic_engine_model_provider/volcanic_engine_model_provider.py +++ b/apps/setting/models_provider/impl/volcanic_engine_model_provider/volcanic_engine_model_provider.py @@ -14,7 +14,7 @@ from setting.models_provider.base_model_provider import IModelProvider, ModelPro from setting.models_provider.impl.openai_model_provider.credential.embedding import OpenAIEmbeddingCredential from setting.models_provider.impl.openai_model_provider.credential.llm import OpenAILLMModelCredential from setting.models_provider.impl.openai_model_provider.model.embedding import OpenAIEmbeddingModel -from setting.models_provider.impl.openai_model_provider.model.llm import OpenAIChatModel +from setting.models_provider.impl.volcanic_engine_model_provider.model.llm import VolcanicEngineChatModel from smartdoc.conf import PROJECT_DIR @@ -24,7 +24,7 @@ model_info_list = [ ModelInfo('ep-xxxxxxxxxx-yyyy', '用户前往火山方舟的模型推理页面创建推理接入点,这里需要输入ep-xxxxxxxxxx-yyyy进行调用', ModelTypeConst.LLM, - volcanic_engine_llm_model_credential, OpenAIChatModel + volcanic_engine_llm_model_credential, VolcanicEngineChatModel ) ] diff --git a/apps/setting/models_provider/impl/wenxin_model_provider/credential/llm.py b/apps/setting/models_provider/impl/wenxin_model_provider/credential/llm.py index 55166ade9..ccdc04d59 100644 --- a/apps/setting/models_provider/impl/wenxin_model_provider/credential/llm.py +++ b/apps/setting/models_provider/impl/wenxin_model_provider/credential/llm.py @@ -53,3 +53,16 @@ class WenxinLLMModelCredential(BaseForm, BaseModelCredential): api_key = forms.PasswordInputField('API Key', required=True) secret_key = forms.PasswordInputField("Secret Key", required=True) + + def get_other_fields(self, model_name): + return { + 'temperature': { + 'value': 0.95, + 'min': 0.1, + 'max': 1.0, + 'step': 0.01, + 'label': '温度', + 'precision': 2, + 'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定' + }, + } diff --git a/apps/setting/models_provider/impl/wenxin_model_provider/model/llm.py b/apps/setting/models_provider/impl/wenxin_model_provider/model/llm.py index bddd1f6db..159f0b470 100644 --- a/apps/setting/models_provider/impl/wenxin_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/wenxin_model_provider/model/llm.py @@ -6,28 +6,73 @@ @date:2023/11/10 17:45 @desc: """ -from typing import List, Dict +import uuid +from typing import List, Dict, Optional, Any, Iterator -from langchain.schema.messages import BaseMessage, get_buffer_string +from langchain.schema.messages import get_buffer_string from langchain_community.chat_models import QianfanChatEndpoint +from langchain_community.chat_models.baidu_qianfan_endpoint import _convert_dict_to_message +from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.outputs import ChatGenerationChunk from common.config.tokenizer_manage_config import TokenizerManage from setting.models_provider.base_model_provider import MaxKBBaseModel +from langchain_core.messages import ( + AIMessageChunk, + BaseMessage, +) class QianfanChatModel(MaxKBBaseModel, QianfanChatEndpoint): @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + optional_params = {} + if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: + optional_params['max_output_tokens'] = model_kwargs['max_tokens'] + if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: + optional_params['temperature'] = model_kwargs['temperature'] return QianfanChatModel(model=model_name, qianfan_ak=model_credential.get('api_key'), qianfan_sk=model_credential.get('secret_key'), - streaming=model_kwargs.get('streaming', False)) + streaming=model_kwargs.get('streaming', False), + **optional_params) + + def get_last_generation_info(self) -> Optional[Dict[str, Any]]: + return self.__dict__.get('_last_generation_info') def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: - tokenizer = TokenizerManage.get_tokenizer() - return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages]) + return self.get_last_generation_info().get('prompt_tokens', 0) def get_num_tokens(self, text: str) -> int: - tokenizer = TokenizerManage.get_tokenizer() - return len(tokenizer.encode(text)) + return self.get_last_generation_info().get('completion_tokens', 0) + + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + params = self._convert_prompt_msg_params(messages, **kwargs) + params["stop"] = stop + params["stream"] = True + for res in self.client.do(**params): + if res: + msg = _convert_dict_to_message(res) + additional_kwargs = msg.additional_kwargs.get("function_call", {}) + if msg.content == "": + token_usage = res.get("body").get("usage") + self.__dict__.setdefault('_last_generation_info', {}).update(token_usage) + chunk = ChatGenerationChunk( + text=res["result"], + message=AIMessageChunk( # type: ignore[call-arg] + content=msg.content, + role="assistant", + additional_kwargs=additional_kwargs, + ), + generation_info=msg.additional_kwargs, + ) + if run_manager: + run_manager.on_llm_new_token(chunk.text, chunk=chunk) + yield chunk diff --git a/apps/setting/models_provider/impl/xf_model_provider/credential/llm.py b/apps/setting/models_provider/impl/xf_model_provider/credential/llm.py index 88d47266e..706aca2d6 100644 --- a/apps/setting/models_provider/impl/xf_model_provider/credential/llm.py +++ b/apps/setting/models_provider/impl/xf_model_provider/credential/llm.py @@ -49,3 +49,28 @@ class XunFeiLLMModelCredential(BaseForm, BaseModelCredential): spark_app_id = forms.TextInputField('APP ID', required=True) spark_api_key = forms.PasswordInputField("API Key", required=True) spark_api_secret = forms.PasswordInputField('API Secret', required=True) + + def get_other_fields(self, model_name): + max_value = 8192 + if model_name == 'general' or model_name == 'pro-128k': + max_value = 4096 + return { + 'temperature': { + 'value': 0.5, + 'min': 0.1, + 'max': 1, + 'step': 0.01, + 'label': '温度', + 'precision': 2, + 'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定' + }, + 'max_tokens': { + 'value': 4096, + 'min': 1, + 'max': max_value, + 'step': 1, + 'label': '输出最大Tokens', + 'precision': 0, + 'tooltip': '指定模型可生成的最大token个数' + } + } diff --git a/apps/setting/models_provider/impl/xf_model_provider/model/llm.py b/apps/setting/models_provider/impl/xf_model_provider/model/llm.py index 1dccd29e3..023fc259c 100644 --- a/apps/setting/models_provider/impl/xf_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/xf_model_provider/model/llm.py @@ -25,14 +25,19 @@ class XFChatSparkLLM(MaxKBBaseModel, ChatSparkLLM): @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + optional_params = {} + if 'max_tokens' in model_kwargs: + optional_params['max_tokens'] = model_kwargs['max_tokens'] + if 'temperature' in model_kwargs: + optional_params['temperature'] = model_kwargs['temperature'] return XFChatSparkLLM( spark_app_id=model_credential.get('spark_app_id'), spark_api_key=model_credential.get('spark_api_key'), spark_api_secret=model_credential.get('spark_api_secret'), spark_api_url=model_credential.get('spark_api_url'), spark_llm_domain=model_name, - temperature=model_kwargs.get('temperature', 0.5), - max_tokens=model_kwargs.get('max_tokens', 5), + streaming=model_kwargs.get('streaming', False), + **optional_params ) def get_last_generation_info(self) -> Optional[Dict[str, Any]]: diff --git a/apps/setting/models_provider/impl/xinference_model_provider/credential/llm.py b/apps/setting/models_provider/impl/xinference_model_provider/credential/llm.py index d6442de32..a5425ec81 100644 --- a/apps/setting/models_provider/impl/xinference_model_provider/credential/llm.py +++ b/apps/setting/models_provider/impl/xinference_model_provider/credential/llm.py @@ -39,3 +39,25 @@ class XinferenceLLMModelCredential(BaseForm, BaseModelCredential): api_base = forms.TextInputField('API 域名', required=True) api_key = forms.PasswordInputField('API Key', required=True) + + def get_other_fields(self, model_name): + return { + 'temperature': { + 'value': 0.7, + 'min': 0.1, + 'max': 1, + 'step': 0.01, + 'label': '温度', + 'precision': 2, + 'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定' + }, + 'max_tokens': { + 'value': 800, + 'min': 1, + 'max': 4096, + 'step': 1, + 'label': '输出最大Tokens', + 'precision': 0, + 'tooltip': '指定模型可生成的最大token个数' + } + } diff --git a/apps/setting/models_provider/impl/xinference_model_provider/model/llm.py b/apps/setting/models_provider/impl/xinference_model_provider/model/llm.py index ed9e4e3c6..398e032fc 100644 --- a/apps/setting/models_provider/impl/xinference_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/xinference_model_provider/model/llm.py @@ -26,14 +26,13 @@ class XinferenceChatModel(MaxKBBaseModel, BaseChatOpenAI): base_url = get_base_url(api_base) base_url = base_url if base_url.endswith('/v1') else (base_url + '/v1') optional_params = {} - if 'max_tokens' in model_kwargs: + if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: optional_params['max_tokens'] = model_kwargs['max_tokens'] - if 'temperature' in model_kwargs: + if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: optional_params['temperature'] = model_kwargs['temperature'] return XinferenceChatModel( model=model_name, openai_api_base=base_url, openai_api_key=model_credential.get('api_key'), - streaming=model_kwargs.get('streaming', False), **optional_params ) diff --git a/apps/setting/models_provider/impl/zhipu_model_provider/credential/llm.py b/apps/setting/models_provider/impl/zhipu_model_provider/credential/llm.py index a88f65ce0..b2c16dcd7 100644 --- a/apps/setting/models_provider/impl/zhipu_model_provider/credential/llm.py +++ b/apps/setting/models_provider/impl/zhipu_model_provider/credential/llm.py @@ -45,3 +45,26 @@ class ZhiPuLLMModelCredential(BaseForm, BaseModelCredential): return {**model, 'api_key': super().encryption(model.get('api_key', ''))} api_key = forms.PasswordInputField('API Key', required=True) + + def get_other_fields(self, model_name): + return { + 'temperature': { + 'value': 0.95, + 'min': 0.1, + 'max': 1.0, + 'step': 0.01, + 'label': '温度', + 'precision': 2, + 'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定' + }, + 'max_tokens': { + 'value': 1024, + 'min': 1, + 'max': 4095, + 'step': 1, + 'label': '输出最大Tokens', + 'precision': 0, + 'tooltip': '指定模型可生成的最大token个数' + } + } + diff --git a/apps/setting/models_provider/impl/zhipu_model_provider/model/llm.py b/apps/setting/models_provider/impl/zhipu_model_provider/model/llm.py index f12425ff4..d66bf17bd 100644 --- a/apps/setting/models_provider/impl/zhipu_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/zhipu_model_provider/model/llm.py @@ -33,11 +33,17 @@ class ZhipuChatModel(MaxKBBaseModel, ChatZhipuAI): @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + optional_params = {} + if 'max_tokens' in model_kwargs: + optional_params['max_tokens'] = model_kwargs['max_tokens'] + if 'temperature' in model_kwargs: + optional_params['temperature'] = model_kwargs['temperature'] + zhipuai_chat = ZhipuChatModel( - temperature=0.5, api_key=model_credential.get('api_key'), model=model_name, - max_tokens=model_kwargs.get('max_tokens', 5) + streaming=model_kwargs.get('streaming', False), + **optional_params ) return zhipuai_chat diff --git a/apps/setting/models_provider/tools.py b/apps/setting/models_provider/tools.py index 293b9f7ef..660604369 100644 --- a/apps/setting/models_provider/tools.py +++ b/apps/setting/models_provider/tools.py @@ -22,7 +22,7 @@ def get_model_by_id(_id, user_id): return model -def get_model_instance_by_model_user_id(model_id, user_id): +def get_model_instance_by_model_user_id(model_id, user_id, **kwargs): """ 获取模型实例,根据模型相关数据 @param model_id: 模型id @@ -30,4 +30,4 @@ def get_model_instance_by_model_user_id(model_id, user_id): @return: 模型实例 """ model = get_model_by_id(model_id, user_id) - return ModelManage.get_model(model_id, lambda _id: get_model(model)) + return ModelManage.get_model(model_id, lambda _id: get_model(model, **kwargs)) diff --git a/ui/src/api/application.ts b/ui/src/api/application.ts index 3aa7c29b1..d7e8de3af 100644 --- a/ui/src/api/application.ts +++ b/ui/src/api/application.ts @@ -1,8 +1,8 @@ -import { Result } from '@/request/Result' -import { get, post, postStream, del, put } from '@/request/index' -import type { pageRequest } from '@/api/type/common' -import type { ApplicationFormType } from '@/api/type/application' -import { type Ref } from 'vue' +import {Result} from '@/request/Result' +import {get, post, postStream, del, put} from '@/request/index' +import type {pageRequest} from '@/api/type/common' +import type {ApplicationFormType} from '@/api/type/application' +import {type Ref} from 'vue' const prefix = '/application' @@ -11,25 +11,25 @@ const prefix = '/application' * @param 参数 */ const getAllAppilcation: () => Promise> = () => { - return get(`${prefix}`) + return get(`${prefix}`) } /** * 获取分页应用 * page { - "current_page": "string", - "page_size": "string", - } + "current_page": "string", + "page_size": "string", + } * param { - "name": "string", - } + "name": "string", + } */ const getApplication: ( - page: pageRequest, - param: any, - loading?: Ref + page: pageRequest, + param: any, + loading?: Ref ) => Promise> = (page, param, loading) => { - return get(`${prefix}/${page.current_page}/${page.page_size}`, param, loading) + return get(`${prefix}/${page.current_page}/${page.page_size}`, param, loading) } /** @@ -37,10 +37,10 @@ const getApplication: ( * @param 参数 */ const postApplication: ( - data: ApplicationFormType, - loading?: Ref + data: ApplicationFormType, + loading?: Ref ) => Promise> = (data, loading) => { - return post(`${prefix}`, data, undefined, loading) + return post(`${prefix}`, data, undefined, loading) } /** @@ -48,11 +48,11 @@ const postApplication: ( * @param 参数 */ const putApplication: ( - application_id: String, - data: ApplicationFormType, - loading?: Ref + application_id: String, + data: ApplicationFormType, + loading?: Ref ) => Promise> = (application_id, data, loading) => { - return put(`${prefix}/${application_id}`, data, undefined, loading) + return put(`${prefix}/${application_id}`, data, undefined, loading) } /** @@ -60,10 +60,10 @@ const putApplication: ( * @param 参数 application_id */ const delApplication: ( - application_id: String, - loading?: Ref + application_id: String, + loading?: Ref ) => Promise> = (application_id, loading) => { - return del(`${prefix}/${application_id}`, undefined, {}, loading) + return del(`${prefix}/${application_id}`, undefined, {}, loading) } /** @@ -71,10 +71,10 @@ const delApplication: ( * @param 参数 application_id */ const getApplicationDetail: ( - application_id: string, - loading?: Ref + application_id: string, + loading?: Ref ) => Promise> = (application_id, loading) => { - return get(`${prefix}/${application_id}`, undefined, loading) + return get(`${prefix}/${application_id}`, undefined, loading) } /** @@ -82,10 +82,10 @@ const getApplicationDetail: ( * @param 参数 application_id */ const getApplicationDataset: ( - application_id: string, - loading?: Ref + application_id: string, + loading?: Ref ) => Promise> = (application_id, loading) => { - return get(`${prefix}/${application_id}/list_dataset`, undefined, loading) + return get(`${prefix}/${application_id}/list_dataset`, undefined, loading) } /** @@ -93,10 +93,10 @@ const getApplicationDataset: ( * @param 参数 application_id */ const getAccessToken: (application_id: string, loading?: Ref) => Promise> = ( - application_id, - loading + application_id, + loading ) => { - return get(`${prefix}/${application_id}/access_token`, undefined, loading) + return get(`${prefix}/${application_id}/access_token`, undefined, loading) } /** @@ -107,71 +107,71 @@ const getAccessToken: (application_id: string, loading?: Ref) => Promis * } */ const putAccessToken: ( - application_id: string, - data: any, - loading?: Ref + application_id: string, + data: any, + loading?: Ref ) => Promise> = (application_id, data, loading) => { - return put(`${prefix}/${application_id}/access_token`, data, undefined, loading) + return put(`${prefix}/${application_id}/access_token`, data, undefined, loading) } /** * 应用认证 - * @param 参数 + * @param 参数 { - "access_token": "string" -} + "access_token": "string" + } */ const postAppAuthentication: (access_token: string, loading?: Ref) => Promise = ( - access_token, - loading + access_token, + loading ) => { - return post(`${prefix}/authentication`, { access_token }, undefined, loading) + return post(`${prefix}/authentication`, {access_token}, undefined, loading) } /** * 对话获取应用相关信息 - * @param 参数 + * @param 参数 { - "access_token": "string" -} + "access_token": "string" + } */ const getAppProfile: (loading?: Ref) => Promise = (loading) => { - return get(`${prefix}/profile`, undefined, loading) + return get(`${prefix}/profile`, undefined, loading) } /** * 获得临时回话Id - * @param 参数 + * @param 参数 -} + } */ const postChatOpen: (data: ApplicationFormType) => Promise> = (data) => { - return post(`${prefix}/chat/open`, data) + return post(`${prefix}/chat/open`, data) } /** * 获得工作流临时回话Id - * @param 参数 + * @param 参数 -} + } */ const postWorkflowChatOpen: (data: ApplicationFormType) => Promise> = (data) => { - return post(`${prefix}/chat_workflow/open`, data) + return post(`${prefix}/chat_workflow/open`, data) } /** * 正式回话Id - * @param 参数 + * @param 参数 * { - "model_id": "string", - "multiple_rounds_dialogue": true, - "dataset_id_list": [ - "string" - ] -} + "model_id": "string", + "multiple_rounds_dialogue": true, + "dataset_id_list": [ + "string" + ] + } */ const getChatOpen: (application_id: String) => Promise> = (application_id) => { - return get(`${prefix}/${application_id}/chat/open`) + return get(`${prefix}/${application_id}/chat/open`) } /** * 对话 @@ -180,32 +180,32 @@ const getChatOpen: (application_id: String) => Promise> = (applicati * data */ const postChatMessage: (chat_id: string, data: any) => Promise = (chat_id, data) => { - return postStream(`/api${prefix}/chat_message/${chat_id}`, data) + return postStream(`/api${prefix}/chat_message/${chat_id}`, data) } /** * 点赞、点踩 - * @param 参数 + * @param 参数 * application_id : string; chat_id : string; chat_record_id : string * { - "vote_status": "string", // -1 0 1 - } + "vote_status": "string", // -1 0 1 + } */ const putChatVote: ( - application_id: string, - chat_id: string, - chat_record_id: string, - vote_status: string, - loading?: Ref + application_id: string, + chat_id: string, + chat_record_id: string, + vote_status: string, + loading?: Ref ) => Promise = (application_id, chat_id, chat_record_id, vote_status, loading) => { - return put( - `${prefix}/${application_id}/chat/${chat_id}/chat_record/${chat_record_id}/vote`, - { - vote_status - }, - undefined, - loading - ) + return put( + `${prefix}/${application_id}/chat/${chat_id}/chat_record/${chat_record_id}/vote`, + { + vote_status + }, + undefined, + loading + ) } /** @@ -216,11 +216,11 @@ const putChatVote: ( * @returns */ const getApplicationHitTest: ( - application_id: string, - data: any, - loading?: Ref + application_id: string, + data: any, + loading?: Ref ) => Promise>> = (application_id, data, loading) => { - return get(`${prefix}/${application_id}/hit_test`, data, loading) + return get(`${prefix}/${application_id}/hit_test`, data, loading) } /** @@ -231,10 +231,10 @@ const getApplicationHitTest: ( * @returns */ const getApplicationModel: ( - application_id: string, - loading?: Ref + application_id: string, + loading?: Ref ) => Promise>> = (application_id, loading) => { - return get(`${prefix}/${application_id}/model`, loading) + return get(`${prefix}/${application_id}/model`, loading) } /** @@ -242,31 +242,61 @@ const getApplicationModel: ( * @param 参数 */ const putPublishApplication: ( - application_id: String, - data: ApplicationFormType, - loading?: Ref + application_id: String, + data: ApplicationFormType, + loading?: Ref ) => Promise> = (application_id, data, loading) => { - return put(`${prefix}/${application_id}/publish`, data, undefined, loading) + return put(`${prefix}/${application_id}/publish`, data, undefined, loading) +} + +/** + * 获取模型其他配置 + * @param application_id + * @param model_id + * @param loading + * @returns + */ +const getModelOtherConfig: ( + application_id: string, + model_id: string, + loading?: Ref +) => Promise> = (application_id, model_id, loading) => { + return get(`${prefix}/${application_id}/model/${model_id}`, undefined, loading) +} + +/** + * 保存其他配置信息 + * @param application_id + * + */ +const putModelOtherConfig: ( + application_id: string, + data: any, + loading?: Ref +) => Promise> = (application_id, data, loading) => { + return put(`${prefix}/${application_id}/other-config`, data, undefined, loading) } export default { - getAllAppilcation, - getApplication, - postApplication, - putApplication, - postChatOpen, - getChatOpen, - postChatMessage, - delApplication, - getApplicationDetail, - getApplicationDataset, - getAccessToken, - putAccessToken, - postAppAuthentication, - getAppProfile, - putChatVote, - getApplicationHitTest, - getApplicationModel, - putPublishApplication, - postWorkflowChatOpen + getAllAppilcation, + getApplication, + postApplication, + putApplication, + postChatOpen, + getChatOpen, + postChatMessage, + delApplication, + getApplicationDetail, + getApplicationDataset, + getAccessToken, + putAccessToken, + postAppAuthentication, + getAppProfile, + putChatVote, + getApplicationHitTest, + getApplicationModel, + putPublishApplication, + postWorkflowChatOpen, + getModelOtherConfig, + putModelOtherConfig } diff --git a/ui/src/stores/modules/application.ts b/ui/src/stores/modules/application.ts index 6f651ab6d..936d4b64c 100644 --- a/ui/src/stores/modules/application.ts +++ b/ui/src/stores/modules/application.ts @@ -1,136 +1,163 @@ -import { defineStore } from 'pinia' +import {defineStore} from 'pinia' import applicationApi from '@/api/application' import applicationXpackApi from '@/api/application-xpack' -import { type Ref } from 'vue' +import {type Ref, type UnwrapRef} from 'vue' import useUserStore from './user' +import type {ApplicationFormType} from "@/api/type/application"; const useApplicationStore = defineStore({ - id: 'application', - state: () => ({ - location: `${window.location.origin}/ui/chat/` - }), - actions: { - async asyncGetAllApplication() { - return new Promise((resolve, reject) => { - applicationApi - .getAllAppilcation() - .then((data) => { - resolve(data) - }) - .catch((error) => { - reject(error) - }) - }) - }, + id: 'application', + state: () => ({ + location: `${window.location.origin}/ui/chat/` + }), + actions: { + async asyncGetAllApplication() { + return new Promise((resolve, reject) => { + applicationApi + .getAllAppilcation() + .then((data) => { + resolve(data) + }) + .catch((error) => { + reject(error) + }) + }) + }, - async asyncGetApplicationDetail(id: string, loading?: Ref) { - return new Promise((resolve, reject) => { - applicationApi - .getApplicationDetail(id, loading) - .then((data) => { - resolve(data) - }) - .catch((error) => { - reject(error) - }) - }) - }, + async asyncGetApplicationDetail(id: string, loading?: Ref) { + return new Promise((resolve, reject) => { + applicationApi + .getApplicationDetail(id, loading) + .then((data) => { + resolve(data) + }) + .catch((error) => { + reject(error) + }) + }) + }, - async asyncGetApplicationDataset(id: string, loading?: Ref) { - return new Promise((resolve, reject) => { - applicationApi - .getApplicationDataset(id, loading) - .then((data) => { - resolve(data) - }) - .catch((error) => { - reject(error) - }) - }) - }, + async asyncGetApplicationDataset(id: string, loading?: Ref) { + return new Promise((resolve, reject) => { + applicationApi + .getApplicationDataset(id, loading) + .then((data) => { + resolve(data) + }) + .catch((error) => { + reject(error) + }) + }) + }, - async asyncGetAccessToken(id: string, loading?: Ref) { - return new Promise((resolve, reject) => { - const user = useUserStore() - if (user.isEnterprise()) { - applicationXpackApi - .getAccessToken(id, loading) - .then((data) => { - resolve(data) + async asyncGetAccessToken(id: string, loading?: Ref) { + return new Promise((resolve, reject) => { + const user = useUserStore() + if (user.isEnterprise()) { + applicationXpackApi + .getAccessToken(id, loading) + .then((data) => { + resolve(data) + }) + .catch((error) => { + reject(error) + }) + } else { + applicationApi + .getAccessToken(id, loading) + .then((data) => { + resolve(data) + }) + .catch((error) => { + reject(error) + }) + } }) - .catch((error) => { - reject(error) - }) - } else { - applicationApi - .getAccessToken(id, loading) - .then((data) => { - resolve(data) - }) - .catch((error) => { - reject(error) - }) - } - }) - }, + }, - async asyncGetAppProfile(loading?: Ref) { - return new Promise((resolve, reject) => { - const user = useUserStore() - if (user.isEnterprise()) { - applicationXpackApi - .getAppXpackProfile(loading) - .then((data) => { - resolve(data) + async asyncGetAppProfile(loading?: Ref) { + return new Promise((resolve, reject) => { + const user = useUserStore() + if (user.isEnterprise()) { + applicationXpackApi + .getAppXpackProfile(loading) + .then((data) => { + resolve(data) + }) + .catch((error) => { + reject(error) + }) + } else { + applicationApi + .getAppProfile(loading) + .then((data) => { + resolve(data) + }) + .catch((error) => { + reject(error) + }) + } }) - .catch((error) => { - reject(error) - }) - } else { - applicationApi - .getAppProfile(loading) - .then((data) => { - resolve(data) - }) - .catch((error) => { - reject(error) - }) - } - }) - }, + }, - async asyncAppAuthentication(token: string, loading?: Ref) { - return new Promise((resolve, reject) => { - applicationApi - .postAppAuthentication(token, loading) - .then((res) => { - localStorage.setItem('accessToken', res.data) - sessionStorage.setItem('accessToken', res.data) - resolve(res) - }) - .catch((error) => { - reject(error) - }) - }) - }, - async refreshAccessToken(token: string) { - this.asyncAppAuthentication(token) - }, - // 修改应用 - async asyncPutApplication(id: string, data: any, loading?: Ref) { - return new Promise((resolve, reject) => { - applicationApi - .putApplication(id, data, loading) - .then((data) => { - resolve(data) - }) - .catch((error) => { - reject(error) - }) - }) + async asyncAppAuthentication(token: string, loading?: Ref) { + return new Promise((resolve, reject) => { + applicationApi + .postAppAuthentication(token, loading) + .then((res) => { + localStorage.setItem('accessToken', res.data) + sessionStorage.setItem('accessToken', res.data) + resolve(res) + }) + .catch((error) => { + reject(error) + }) + }) + }, + async refreshAccessToken(token: string) { + this.asyncAppAuthentication(token) + }, + // 修改应用 + async asyncPutApplication(id: string, data: any, loading?: Ref) { + return new Promise((resolve, reject) => { + applicationApi + .putApplication(id, data, loading) + .then((data) => { + resolve(data) + }) + .catch((error) => { + reject(error) + }) + }) + }, + // 获取模型的 温度/max_token字段设置 + async asyncGetModelConfig(id: string, model_id: string, loading?: Ref) { + return new Promise((resolve, reject) => { + applicationApi + .getModelOtherConfig(id, model_id, loading) + .then((data) => { + resolve(data) + }) + .catch((error) => { + reject(error) + }) + }) + }, + // 保存应用的 温度/max_token字段设置 + async asyncPostModelConfig(id: string, params: any, loading?: Ref) { + return new Promise((resolve, reject) => { + applicationApi + .putModelOtherConfig(id, params, loading) + .then((data) => { + resolve(data) + }) + .catch((error) => { + reject(error) + }) + }) + }, } - } }) export default useApplicationStore diff --git a/ui/src/views/application/ApplicationSetting.vue b/ui/src/views/application/ApplicationSetting.vue index a83a6fdbf..7828f2a35 100644 --- a/ui/src/views/application/ApplicationSetting.vue +++ b/ui/src/views/application/ApplicationSetting.vue @@ -67,9 +67,14 @@ 公用 - + + + - + + +
- {{ - $t('views.application.applicationForm.form.relatedKnowledgeBaseWhere') - }} + {{ $t('views.application.applicationForm.form.relatedKnowledgeBaseWhere') }} +
- + + + @@ -314,7 +327,7 @@ - + { } const openAIParamSettingDialog = () => { - AIModeParamSettingDialogRef.value?.open(applicationForm.value) + const model_id = applicationForm.value.model_id + if (!model_id) { + MsgSuccess(t('请选择AI 模型')) + return + } + application.asyncGetModelConfig(id, model_id, loading).then((res: any) => { + AIModeParamSettingDialogRef.value?.open(res.data) + }) } const openParamSettingDialog = () => { @@ -463,9 +484,11 @@ function removeDataset(id: any) { ) } } + function addDataset(val: Array) { applicationForm.value.dataset_id_list = val } + function openDatasetDialog() { AddDatasetDialogRef.value.open(applicationForm.value.dataset_id_list) } @@ -525,15 +548,18 @@ onMounted(() => { .relate-dataset-card { color: var(--app-text-color); } + .dialog-bg { border-radius: 8px; background: var(--dialog-bg-gradient-color); overflow: hidden; box-sizing: border-box; } + .scrollbar-height-left { height: calc(var(--app-main-height) - 64px); } + .scrollbar-height { height: calc(var(--app-main-height) - 160px); } diff --git a/ui/src/views/application/component/AIModeParamSettingDialog.vue b/ui/src/views/application/component/AIModeParamSettingDialog.vue index ea7852c49..33d0ff182 100644 --- a/ui/src/views/application/component/AIModeParamSettingDialog.vue +++ b/ui/src/views/application/component/AIModeParamSettingDialog.vue @@ -8,122 +8,135 @@ append-to-body > - + - - - - - + -