From 36efb58a055cedd205cde269f4945108adf39026 Mon Sep 17 00:00:00 2001 From: shaohuzhang1 Date: Fri, 25 Oct 2024 16:40:22 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=A4=A7=E8=AF=AD=E8=A8=80=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B=E6=94=AF=E6=8C=81=E8=87=AA=E5=AE=9A=E4=B9=89=E5=8F=82?= =?UTF-8?q?=E6=95=B0=E5=85=A5=E5=8F=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../models_provider/base_model_provider.py | 8 +++++++ .../aws_bedrock_model_provider/model/llm.py | 7 +----- .../model/azure_chat_model.py | 6 +---- .../impl/deepseek_model_provider/model/llm.py | 6 +---- .../impl/gemini_model_provider/model/llm.py | 6 +---- .../impl/kimi_model_provider/model/llm.py | 6 +---- .../impl/ollama_model_provider/model/llm.py | 6 +---- .../impl/openai_model_provider/model/llm.py | 6 +---- .../impl/qwen_model_provider/model/llm.py | 6 +---- .../impl/tencent_model_provider/model/llm.py | 4 +--- .../impl/vllm_model_provider/model/llm.py | 6 +---- .../model/llm.py | 6 +---- .../impl/wenxin_model_provider/model/llm.py | 6 +---- .../impl/xf_model_provider/model/llm.py | 11 +++------ .../xinference_model_provider/model/llm.py | 12 +++++----- .../impl/zhipu_model_provider/model/llm.py | 24 +++++++++---------- 16 files changed, 40 insertions(+), 86 deletions(-) diff --git a/apps/setting/models_provider/base_model_provider.py b/apps/setting/models_provider/base_model_provider.py index 9171e036c..c4722c9f5 100644 --- a/apps/setting/models_provider/base_model_provider.py +++ b/apps/setting/models_provider/base_model_provider.py @@ -93,6 +93,14 @@ class MaxKBBaseModel(ABC): def is_cache_model(): return True + @staticmethod + def filter_optional_params(model_kwargs): + optional_params = {} + for key, value in model_kwargs.items(): + if key not in ['model_id', 'use_local', 'streaming']: + optional_params[key] = value + return optional_params + class BaseModelCredential(ABC): diff --git a/apps/setting/models_provider/impl/aws_bedrock_model_provider/model/llm.py b/apps/setting/models_provider/impl/aws_bedrock_model_provider/model/llm.py index bfaf9f17b..950cd2b3f 100644 --- a/apps/setting/models_provider/impl/aws_bedrock_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/aws_bedrock_model_provider/model/llm.py @@ -40,12 +40,7 @@ class BedrockModel(MaxKBBaseModel, BedrockChat): @classmethod def new_instance(cls, model_type: str, model_name: str, model_credential: Dict[str, str], **model_kwargs) -> 'BedrockModel': - optional_params = {} - if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: - keyword = get_max_tokens_keyword(model_name) - optional_params[keyword] = model_kwargs['max_tokens'] - if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: - optional_params['temperature'] = model_kwargs['temperature'] + optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs) return cls( model_id=model_name, diff --git a/apps/setting/models_provider/impl/azure_model_provider/model/azure_chat_model.py b/apps/setting/models_provider/impl/azure_model_provider/model/azure_chat_model.py index 110d72da8..0996c3289 100644 --- a/apps/setting/models_provider/impl/azure_model_provider/model/azure_chat_model.py +++ b/apps/setting/models_provider/impl/azure_model_provider/model/azure_chat_model.py @@ -26,11 +26,7 @@ class AzureChatModel(MaxKBBaseModel, AzureChatOpenAI): @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): - optional_params = {} - if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: - optional_params['max_tokens'] = model_kwargs['max_tokens'] - if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: - optional_params['temperature'] = model_kwargs['temperature'] + optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs) return AzureChatModel( azure_endpoint=model_credential.get('api_base'), diff --git a/apps/setting/models_provider/impl/deepseek_model_provider/model/llm.py b/apps/setting/models_provider/impl/deepseek_model_provider/model/llm.py index 94c7d4899..ac8dff4bd 100644 --- a/apps/setting/models_provider/impl/deepseek_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/deepseek_model_provider/model/llm.py @@ -20,11 +20,7 @@ class DeepSeekChatModel(MaxKBBaseModel, BaseChatOpenAI): @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): - optional_params = {} - if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: - optional_params['max_tokens'] = model_kwargs['max_tokens'] - if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: - optional_params['temperature'] = model_kwargs['temperature'] + optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs) deepseek_chat_open_ai = DeepSeekChatModel( model=model_name, diff --git a/apps/setting/models_provider/impl/gemini_model_provider/model/llm.py b/apps/setting/models_provider/impl/gemini_model_provider/model/llm.py index e9174d3a6..68d5e1128 100644 --- a/apps/setting/models_provider/impl/gemini_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/gemini_model_provider/model/llm.py @@ -30,11 +30,7 @@ class GeminiChatModel(MaxKBBaseModel, ChatGoogleGenerativeAI): @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): - optional_params = {} - if 'temperature' in model_kwargs: - optional_params['temperature'] = model_kwargs['temperature'] - if 'max_tokens' in model_kwargs: - optional_params['max_output_tokens'] = model_kwargs['max_tokens'] + optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs) gemini_chat = GeminiChatModel( model=model_name, diff --git a/apps/setting/models_provider/impl/kimi_model_provider/model/llm.py b/apps/setting/models_provider/impl/kimi_model_provider/model/llm.py index 2dd117fa5..c5f7b62b6 100644 --- a/apps/setting/models_provider/impl/kimi_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/kimi_model_provider/model/llm.py @@ -20,11 +20,7 @@ class KimiChatModel(MaxKBBaseModel, BaseChatOpenAI): @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): - optional_params = {} - if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: - optional_params['max_tokens'] = model_kwargs['max_tokens'] - if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: - optional_params['temperature'] = model_kwargs['temperature'] + optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs) kimi_chat_open_ai = KimiChatModel( openai_api_base=model_credential['api_base'], diff --git a/apps/setting/models_provider/impl/ollama_model_provider/model/llm.py b/apps/setting/models_provider/impl/ollama_model_provider/model/llm.py index abb5fcb80..7c98f7e5c 100644 --- a/apps/setting/models_provider/impl/ollama_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/ollama_model_provider/model/llm.py @@ -34,11 +34,7 @@ class OllamaChatModel(MaxKBBaseModel, ChatOpenAI): api_base = model_credential.get('api_base', '') base_url = get_base_url(api_base) base_url = base_url if base_url.endswith('/v1') else (base_url + '/v1') - optional_params = {} - if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: - optional_params['max_tokens'] = model_kwargs['max_tokens'] - if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: - optional_params['temperature'] = model_kwargs['temperature'] + optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs) return OllamaChatModel(model=model_name, openai_api_base=base_url, openai_api_key=model_credential.get('api_key'), diff --git a/apps/setting/models_provider/impl/openai_model_provider/model/llm.py b/apps/setting/models_provider/impl/openai_model_provider/model/llm.py index a78ecc0c0..c5b5694e2 100644 --- a/apps/setting/models_provider/impl/openai_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/openai_model_provider/model/llm.py @@ -30,11 +30,7 @@ class OpenAIChatModel(MaxKBBaseModel, ChatOpenAI): @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): - optional_params = {} - if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: - optional_params['max_tokens'] = model_kwargs['max_tokens'] - if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: - optional_params['temperature'] = model_kwargs['temperature'] + optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs) azure_chat_open_ai = OpenAIChatModel( model=model_name, openai_api_base=model_credential.get('api_base'), diff --git a/apps/setting/models_provider/impl/qwen_model_provider/model/llm.py b/apps/setting/models_provider/impl/qwen_model_provider/model/llm.py index 8ac8347aa..1336cb05b 100644 --- a/apps/setting/models_provider/impl/qwen_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/qwen_model_provider/model/llm.py @@ -27,11 +27,7 @@ class QwenChatModel(MaxKBBaseModel, ChatTongyi): @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): - optional_params = {} - if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: - optional_params['max_tokens'] = model_kwargs['max_tokens'] - if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: - optional_params['temperature'] = model_kwargs['temperature'] + optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs) chat_tong_yi = QwenChatModel( model_name=model_name, dashscope_api_key=model_credential.get('api_key'), diff --git a/apps/setting/models_provider/impl/tencent_model_provider/model/llm.py b/apps/setting/models_provider/impl/tencent_model_provider/model/llm.py index cfe673b50..17023f32e 100644 --- a/apps/setting/models_provider/impl/tencent_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/tencent_model_provider/model/llm.py @@ -18,9 +18,7 @@ class TencentModel(MaxKBBaseModel, ChatHunyuan): hunyuan_secret_id = credentials.get('hunyuan_secret_id') hunyuan_secret_key = credentials.get('hunyuan_secret_key') - optional_params = {} - if 'temperature' in kwargs: - optional_params['temperature'] = kwargs['temperature'] + optional_params = MaxKBBaseModel.filter_optional_params(kwargs) if not all([hunyuan_app_id, hunyuan_secret_id, hunyuan_secret_key]): raise ValueError( diff --git a/apps/setting/models_provider/impl/vllm_model_provider/model/llm.py b/apps/setting/models_provider/impl/vllm_model_provider/model/llm.py index 5c498f8e0..d03eb7229 100644 --- a/apps/setting/models_provider/impl/vllm_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/vllm_model_provider/model/llm.py @@ -22,11 +22,7 @@ class VllmChatModel(MaxKBBaseModel, BaseChatOpenAI): @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): - optional_params = {} - if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: - optional_params['max_tokens'] = model_kwargs['max_tokens'] - if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: - optional_params['temperature'] = model_kwargs['temperature'] + optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs) vllm_chat_open_ai = VllmChatModel( model=model_name, openai_api_base=model_credential.get('api_base'), diff --git a/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/llm.py b/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/llm.py index c549710e5..181ad2971 100644 --- a/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/llm.py @@ -12,11 +12,7 @@ class VolcanicEngineChatModel(MaxKBBaseModel, BaseChatOpenAI): @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): - optional_params = {} - if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: - optional_params['max_tokens'] = model_kwargs['max_tokens'] - if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: - optional_params['temperature'] = model_kwargs['temperature'] + optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs) return VolcanicEngineChatModel( model=model_name, openai_api_base=model_credential.get('api_base'), diff --git a/apps/setting/models_provider/impl/wenxin_model_provider/model/llm.py b/apps/setting/models_provider/impl/wenxin_model_provider/model/llm.py index 8c634c6d0..e9b69d781 100644 --- a/apps/setting/models_provider/impl/wenxin_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/wenxin_model_provider/model/llm.py @@ -26,11 +26,7 @@ class QianfanChatModel(MaxKBBaseModel, QianfanChatEndpoint): @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): - optional_params = {} - if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: - optional_params['max_output_tokens'] = model_kwargs['max_tokens'] - if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: - optional_params['temperature'] = model_kwargs['temperature'] + optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs) return QianfanChatModel(model=model_name, qianfan_ak=model_credential.get('api_key'), qianfan_sk=model_credential.get('secret_key'), diff --git a/apps/setting/models_provider/impl/xf_model_provider/model/llm.py b/apps/setting/models_provider/impl/xf_model_provider/model/llm.py index 598af07ac..7c3b39d31 100644 --- a/apps/setting/models_provider/impl/xf_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/xf_model_provider/model/llm.py @@ -6,11 +6,10 @@ @date:2024/04/19 15:55 @desc: """ -import json from typing import List, Optional, Any, Iterator, Dict -from langchain_community.chat_models.sparkllm import _convert_message_to_dict, _convert_delta_to_message_chunk, \ - ChatSparkLLM +from langchain_community.chat_models.sparkllm import \ + ChatSparkLLM, _convert_message_to_dict, _convert_delta_to_message_chunk from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.messages import BaseMessage, AIMessageChunk from langchain_core.outputs import ChatGenerationChunk @@ -25,11 +24,7 @@ class XFChatSparkLLM(MaxKBBaseModel, ChatSparkLLM): @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): - optional_params = {} - if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: - optional_params['max_tokens'] = model_kwargs['max_tokens'] - if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: - optional_params['temperature'] = model_kwargs['temperature'] + optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs) return XFChatSparkLLM( spark_app_id=model_credential.get('spark_app_id'), spark_api_key=model_credential.get('spark_api_key'), diff --git a/apps/setting/models_provider/impl/xinference_model_provider/model/llm.py b/apps/setting/models_provider/impl/xinference_model_provider/model/llm.py index b0fb4d16e..16996b907 100644 --- a/apps/setting/models_provider/impl/xinference_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/xinference_model_provider/model/llm.py @@ -1,8 +1,12 @@ # coding=utf-8 -from typing import Dict +from typing import Dict, Optional, List, Any, Iterator from urllib.parse import urlparse, ParseResult +from langchain_core.language_models import LanguageModelInput +from langchain_core.messages import BaseMessageChunk +from langchain_core.runnables import RunnableConfig + from setting.models_provider.base_model_provider import MaxKBBaseModel from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI @@ -26,11 +30,7 @@ class XinferenceChatModel(MaxKBBaseModel, BaseChatOpenAI): api_base = model_credential.get('api_base', '') base_url = get_base_url(api_base) base_url = base_url if base_url.endswith('/v1') else (base_url + '/v1') - optional_params = {} - if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: - optional_params['max_tokens'] = model_kwargs['max_tokens'] - if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: - optional_params['temperature'] = model_kwargs['temperature'] + optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs) return XinferenceChatModel( model=model_name, openai_api_base=base_url, diff --git a/apps/setting/models_provider/impl/zhipu_model_provider/model/llm.py b/apps/setting/models_provider/impl/zhipu_model_provider/model/llm.py index c86c2e3a3..03699321c 100644 --- a/apps/setting/models_provider/impl/zhipu_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/zhipu_model_provider/model/llm.py @@ -7,43 +7,41 @@ @desc: """ -from langchain_community.chat_models import ChatZhipuAI -from langchain_community.chat_models.zhipuai import _truncate_params, _get_jwt_token, connect_sse, \ - _convert_delta_to_message_chunk -from setting.models_provider.base_model_provider import MaxKBBaseModel import json from collections.abc import Iterator from typing import Any, Dict, List, Optional +from langchain_community.chat_models import ChatZhipuAI +from langchain_community.chat_models.zhipuai import _truncate_params, _get_jwt_token, connect_sse, \ + _convert_delta_to_message_chunk from langchain_core.callbacks import ( CallbackManagerForLLMRun, ) - from langchain_core.messages import ( AIMessageChunk, BaseMessage ) from langchain_core.outputs import ChatGenerationChunk +from setting.models_provider.base_model_provider import MaxKBBaseModel + class ZhipuChatModel(MaxKBBaseModel, ChatZhipuAI): + optional_params: dict + @staticmethod def is_cache_model(): return False @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): - optional_params = {} - if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: - optional_params['max_tokens'] = model_kwargs['max_tokens'] - if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: - optional_params['temperature'] = model_kwargs['temperature'] - + optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs) zhipuai_chat = ZhipuChatModel( api_key=model_credential.get('api_key'), model=model_name, streaming=model_kwargs.get('streaming', False), - **optional_params + optional_params=optional_params, + **optional_params, ) return zhipuai_chat @@ -71,7 +69,7 @@ class ZhipuChatModel(MaxKBBaseModel, ChatZhipuAI): if self.zhipuai_api_base is None: raise ValueError("Did not find zhipu_api_base.") message_dicts, params = self._create_message_dicts(messages, stop) - payload = {**params, **kwargs, "messages": message_dicts, "stream": True} + payload = {**params, **kwargs, **self.optional_params, "messages": message_dicts, "stream": True} _truncate_params(payload) headers = { "Authorization": _get_jwt_token(self.zhipuai_api_key),