feat: 大语言模型支持自定义参数入参

This commit is contained in:
shaohuzhang1 2024-10-25 16:40:22 +08:00 committed by shaohuzhang1
parent cbdf4b7fd9
commit 36efb58a05
16 changed files with 40 additions and 86 deletions

View File

@ -93,6 +93,14 @@ class MaxKBBaseModel(ABC):
def is_cache_model():
return True
@staticmethod
def filter_optional_params(model_kwargs):
optional_params = {}
for key, value in model_kwargs.items():
if key not in ['model_id', 'use_local', 'streaming']:
optional_params[key] = value
return optional_params
class BaseModelCredential(ABC):

View File

@ -40,12 +40,7 @@ class BedrockModel(MaxKBBaseModel, BedrockChat):
@classmethod
def new_instance(cls, model_type: str, model_name: str, model_credential: Dict[str, str],
**model_kwargs) -> 'BedrockModel':
optional_params = {}
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
keyword = get_max_tokens_keyword(model_name)
optional_params[keyword] = model_kwargs['max_tokens']
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
optional_params['temperature'] = model_kwargs['temperature']
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
return cls(
model_id=model_name,

View File

@ -26,11 +26,7 @@ class AzureChatModel(MaxKBBaseModel, AzureChatOpenAI):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {}
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
optional_params['max_tokens'] = model_kwargs['max_tokens']
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
optional_params['temperature'] = model_kwargs['temperature']
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
return AzureChatModel(
azure_endpoint=model_credential.get('api_base'),

View File

@ -20,11 +20,7 @@ class DeepSeekChatModel(MaxKBBaseModel, BaseChatOpenAI):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {}
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
optional_params['max_tokens'] = model_kwargs['max_tokens']
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
optional_params['temperature'] = model_kwargs['temperature']
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
deepseek_chat_open_ai = DeepSeekChatModel(
model=model_name,

View File

@ -30,11 +30,7 @@ class GeminiChatModel(MaxKBBaseModel, ChatGoogleGenerativeAI):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {}
if 'temperature' in model_kwargs:
optional_params['temperature'] = model_kwargs['temperature']
if 'max_tokens' in model_kwargs:
optional_params['max_output_tokens'] = model_kwargs['max_tokens']
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
gemini_chat = GeminiChatModel(
model=model_name,

View File

@ -20,11 +20,7 @@ class KimiChatModel(MaxKBBaseModel, BaseChatOpenAI):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {}
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
optional_params['max_tokens'] = model_kwargs['max_tokens']
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
optional_params['temperature'] = model_kwargs['temperature']
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
kimi_chat_open_ai = KimiChatModel(
openai_api_base=model_credential['api_base'],

View File

@ -34,11 +34,7 @@ class OllamaChatModel(MaxKBBaseModel, ChatOpenAI):
api_base = model_credential.get('api_base', '')
base_url = get_base_url(api_base)
base_url = base_url if base_url.endswith('/v1') else (base_url + '/v1')
optional_params = {}
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
optional_params['max_tokens'] = model_kwargs['max_tokens']
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
optional_params['temperature'] = model_kwargs['temperature']
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
return OllamaChatModel(model=model_name, openai_api_base=base_url,
openai_api_key=model_credential.get('api_key'),

View File

@ -30,11 +30,7 @@ class OpenAIChatModel(MaxKBBaseModel, ChatOpenAI):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {}
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
optional_params['max_tokens'] = model_kwargs['max_tokens']
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
optional_params['temperature'] = model_kwargs['temperature']
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
azure_chat_open_ai = OpenAIChatModel(
model=model_name,
openai_api_base=model_credential.get('api_base'),

View File

@ -27,11 +27,7 @@ class QwenChatModel(MaxKBBaseModel, ChatTongyi):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {}
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
optional_params['max_tokens'] = model_kwargs['max_tokens']
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
optional_params['temperature'] = model_kwargs['temperature']
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
chat_tong_yi = QwenChatModel(
model_name=model_name,
dashscope_api_key=model_credential.get('api_key'),

View File

@ -18,9 +18,7 @@ class TencentModel(MaxKBBaseModel, ChatHunyuan):
hunyuan_secret_id = credentials.get('hunyuan_secret_id')
hunyuan_secret_key = credentials.get('hunyuan_secret_key')
optional_params = {}
if 'temperature' in kwargs:
optional_params['temperature'] = kwargs['temperature']
optional_params = MaxKBBaseModel.filter_optional_params(kwargs)
if not all([hunyuan_app_id, hunyuan_secret_id, hunyuan_secret_key]):
raise ValueError(

View File

@ -22,11 +22,7 @@ class VllmChatModel(MaxKBBaseModel, BaseChatOpenAI):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {}
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
optional_params['max_tokens'] = model_kwargs['max_tokens']
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
optional_params['temperature'] = model_kwargs['temperature']
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
vllm_chat_open_ai = VllmChatModel(
model=model_name,
openai_api_base=model_credential.get('api_base'),

View File

@ -12,11 +12,7 @@ class VolcanicEngineChatModel(MaxKBBaseModel, BaseChatOpenAI):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {}
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
optional_params['max_tokens'] = model_kwargs['max_tokens']
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
optional_params['temperature'] = model_kwargs['temperature']
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
return VolcanicEngineChatModel(
model=model_name,
openai_api_base=model_credential.get('api_base'),

View File

@ -26,11 +26,7 @@ class QianfanChatModel(MaxKBBaseModel, QianfanChatEndpoint):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {}
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
optional_params['max_output_tokens'] = model_kwargs['max_tokens']
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
optional_params['temperature'] = model_kwargs['temperature']
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
return QianfanChatModel(model=model_name,
qianfan_ak=model_credential.get('api_key'),
qianfan_sk=model_credential.get('secret_key'),

View File

@ -6,11 +6,10 @@
@date2024/04/19 15:55
@desc:
"""
import json
from typing import List, Optional, Any, Iterator, Dict
from langchain_community.chat_models.sparkllm import _convert_message_to_dict, _convert_delta_to_message_chunk, \
ChatSparkLLM
from langchain_community.chat_models.sparkllm import \
ChatSparkLLM, _convert_message_to_dict, _convert_delta_to_message_chunk
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.messages import BaseMessage, AIMessageChunk
from langchain_core.outputs import ChatGenerationChunk
@ -25,11 +24,7 @@ class XFChatSparkLLM(MaxKBBaseModel, ChatSparkLLM):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {}
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
optional_params['max_tokens'] = model_kwargs['max_tokens']
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
optional_params['temperature'] = model_kwargs['temperature']
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
return XFChatSparkLLM(
spark_app_id=model_credential.get('spark_app_id'),
spark_api_key=model_credential.get('spark_api_key'),

View File

@ -1,8 +1,12 @@
# coding=utf-8
from typing import Dict
from typing import Dict, Optional, List, Any, Iterator
from urllib.parse import urlparse, ParseResult
from langchain_core.language_models import LanguageModelInput
from langchain_core.messages import BaseMessageChunk
from langchain_core.runnables import RunnableConfig
from setting.models_provider.base_model_provider import MaxKBBaseModel
from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI
@ -26,11 +30,7 @@ class XinferenceChatModel(MaxKBBaseModel, BaseChatOpenAI):
api_base = model_credential.get('api_base', '')
base_url = get_base_url(api_base)
base_url = base_url if base_url.endswith('/v1') else (base_url + '/v1')
optional_params = {}
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
optional_params['max_tokens'] = model_kwargs['max_tokens']
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
optional_params['temperature'] = model_kwargs['temperature']
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
return XinferenceChatModel(
model=model_name,
openai_api_base=base_url,

View File

@ -7,43 +7,41 @@
@desc:
"""
from langchain_community.chat_models import ChatZhipuAI
from langchain_community.chat_models.zhipuai import _truncate_params, _get_jwt_token, connect_sse, \
_convert_delta_to_message_chunk
from setting.models_provider.base_model_provider import MaxKBBaseModel
import json
from collections.abc import Iterator
from typing import Any, Dict, List, Optional
from langchain_community.chat_models import ChatZhipuAI
from langchain_community.chat_models.zhipuai import _truncate_params, _get_jwt_token, connect_sse, \
_convert_delta_to_message_chunk
from langchain_core.callbacks import (
CallbackManagerForLLMRun,
)
from langchain_core.messages import (
AIMessageChunk,
BaseMessage
)
from langchain_core.outputs import ChatGenerationChunk
from setting.models_provider.base_model_provider import MaxKBBaseModel
class ZhipuChatModel(MaxKBBaseModel, ChatZhipuAI):
optional_params: dict
@staticmethod
def is_cache_model():
return False
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {}
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
optional_params['max_tokens'] = model_kwargs['max_tokens']
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
optional_params['temperature'] = model_kwargs['temperature']
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
zhipuai_chat = ZhipuChatModel(
api_key=model_credential.get('api_key'),
model=model_name,
streaming=model_kwargs.get('streaming', False),
**optional_params
optional_params=optional_params,
**optional_params,
)
return zhipuai_chat
@ -71,7 +69,7 @@ class ZhipuChatModel(MaxKBBaseModel, ChatZhipuAI):
if self.zhipuai_api_base is None:
raise ValueError("Did not find zhipu_api_base.")
message_dicts, params = self._create_message_dicts(messages, stop)
payload = {**params, **kwargs, "messages": message_dicts, "stream": True}
payload = {**params, **kwargs, **self.optional_params, "messages": message_dicts, "stream": True}
_truncate_params(payload)
headers = {
"Authorization": _get_jwt_token(self.zhipuai_api_key),