mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 10:02:46 +00:00
feat: 大语言模型支持自定义参数入参
This commit is contained in:
parent
cbdf4b7fd9
commit
36efb58a05
|
|
@ -93,6 +93,14 @@ class MaxKBBaseModel(ABC):
|
|||
def is_cache_model():
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def filter_optional_params(model_kwargs):
|
||||
optional_params = {}
|
||||
for key, value in model_kwargs.items():
|
||||
if key not in ['model_id', 'use_local', 'streaming']:
|
||||
optional_params[key] = value
|
||||
return optional_params
|
||||
|
||||
|
||||
class BaseModelCredential(ABC):
|
||||
|
||||
|
|
|
|||
|
|
@ -40,12 +40,7 @@ class BedrockModel(MaxKBBaseModel, BedrockChat):
|
|||
@classmethod
|
||||
def new_instance(cls, model_type: str, model_name: str, model_credential: Dict[str, str],
|
||||
**model_kwargs) -> 'BedrockModel':
|
||||
optional_params = {}
|
||||
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
|
||||
keyword = get_max_tokens_keyword(model_name)
|
||||
optional_params[keyword] = model_kwargs['max_tokens']
|
||||
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
|
||||
optional_params['temperature'] = model_kwargs['temperature']
|
||||
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
|
||||
|
||||
return cls(
|
||||
model_id=model_name,
|
||||
|
|
|
|||
|
|
@ -26,11 +26,7 @@ class AzureChatModel(MaxKBBaseModel, AzureChatOpenAI):
|
|||
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
optional_params = {}
|
||||
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
|
||||
optional_params['max_tokens'] = model_kwargs['max_tokens']
|
||||
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
|
||||
optional_params['temperature'] = model_kwargs['temperature']
|
||||
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
|
||||
|
||||
return AzureChatModel(
|
||||
azure_endpoint=model_credential.get('api_base'),
|
||||
|
|
|
|||
|
|
@ -20,11 +20,7 @@ class DeepSeekChatModel(MaxKBBaseModel, BaseChatOpenAI):
|
|||
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
optional_params = {}
|
||||
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
|
||||
optional_params['max_tokens'] = model_kwargs['max_tokens']
|
||||
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
|
||||
optional_params['temperature'] = model_kwargs['temperature']
|
||||
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
|
||||
|
||||
deepseek_chat_open_ai = DeepSeekChatModel(
|
||||
model=model_name,
|
||||
|
|
|
|||
|
|
@ -30,11 +30,7 @@ class GeminiChatModel(MaxKBBaseModel, ChatGoogleGenerativeAI):
|
|||
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
optional_params = {}
|
||||
if 'temperature' in model_kwargs:
|
||||
optional_params['temperature'] = model_kwargs['temperature']
|
||||
if 'max_tokens' in model_kwargs:
|
||||
optional_params['max_output_tokens'] = model_kwargs['max_tokens']
|
||||
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
|
||||
|
||||
gemini_chat = GeminiChatModel(
|
||||
model=model_name,
|
||||
|
|
|
|||
|
|
@ -20,11 +20,7 @@ class KimiChatModel(MaxKBBaseModel, BaseChatOpenAI):
|
|||
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
optional_params = {}
|
||||
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
|
||||
optional_params['max_tokens'] = model_kwargs['max_tokens']
|
||||
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
|
||||
optional_params['temperature'] = model_kwargs['temperature']
|
||||
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
|
||||
|
||||
kimi_chat_open_ai = KimiChatModel(
|
||||
openai_api_base=model_credential['api_base'],
|
||||
|
|
|
|||
|
|
@ -34,11 +34,7 @@ class OllamaChatModel(MaxKBBaseModel, ChatOpenAI):
|
|||
api_base = model_credential.get('api_base', '')
|
||||
base_url = get_base_url(api_base)
|
||||
base_url = base_url if base_url.endswith('/v1') else (base_url + '/v1')
|
||||
optional_params = {}
|
||||
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
|
||||
optional_params['max_tokens'] = model_kwargs['max_tokens']
|
||||
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
|
||||
optional_params['temperature'] = model_kwargs['temperature']
|
||||
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
|
||||
|
||||
return OllamaChatModel(model=model_name, openai_api_base=base_url,
|
||||
openai_api_key=model_credential.get('api_key'),
|
||||
|
|
|
|||
|
|
@ -30,11 +30,7 @@ class OpenAIChatModel(MaxKBBaseModel, ChatOpenAI):
|
|||
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
optional_params = {}
|
||||
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
|
||||
optional_params['max_tokens'] = model_kwargs['max_tokens']
|
||||
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
|
||||
optional_params['temperature'] = model_kwargs['temperature']
|
||||
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
|
||||
azure_chat_open_ai = OpenAIChatModel(
|
||||
model=model_name,
|
||||
openai_api_base=model_credential.get('api_base'),
|
||||
|
|
|
|||
|
|
@ -27,11 +27,7 @@ class QwenChatModel(MaxKBBaseModel, ChatTongyi):
|
|||
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
optional_params = {}
|
||||
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
|
||||
optional_params['max_tokens'] = model_kwargs['max_tokens']
|
||||
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
|
||||
optional_params['temperature'] = model_kwargs['temperature']
|
||||
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
|
||||
chat_tong_yi = QwenChatModel(
|
||||
model_name=model_name,
|
||||
dashscope_api_key=model_credential.get('api_key'),
|
||||
|
|
|
|||
|
|
@ -18,9 +18,7 @@ class TencentModel(MaxKBBaseModel, ChatHunyuan):
|
|||
hunyuan_secret_id = credentials.get('hunyuan_secret_id')
|
||||
hunyuan_secret_key = credentials.get('hunyuan_secret_key')
|
||||
|
||||
optional_params = {}
|
||||
if 'temperature' in kwargs:
|
||||
optional_params['temperature'] = kwargs['temperature']
|
||||
optional_params = MaxKBBaseModel.filter_optional_params(kwargs)
|
||||
|
||||
if not all([hunyuan_app_id, hunyuan_secret_id, hunyuan_secret_key]):
|
||||
raise ValueError(
|
||||
|
|
|
|||
|
|
@ -22,11 +22,7 @@ class VllmChatModel(MaxKBBaseModel, BaseChatOpenAI):
|
|||
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
optional_params = {}
|
||||
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
|
||||
optional_params['max_tokens'] = model_kwargs['max_tokens']
|
||||
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
|
||||
optional_params['temperature'] = model_kwargs['temperature']
|
||||
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
|
||||
vllm_chat_open_ai = VllmChatModel(
|
||||
model=model_name,
|
||||
openai_api_base=model_credential.get('api_base'),
|
||||
|
|
|
|||
|
|
@ -12,11 +12,7 @@ class VolcanicEngineChatModel(MaxKBBaseModel, BaseChatOpenAI):
|
|||
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
optional_params = {}
|
||||
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
|
||||
optional_params['max_tokens'] = model_kwargs['max_tokens']
|
||||
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
|
||||
optional_params['temperature'] = model_kwargs['temperature']
|
||||
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
|
||||
return VolcanicEngineChatModel(
|
||||
model=model_name,
|
||||
openai_api_base=model_credential.get('api_base'),
|
||||
|
|
|
|||
|
|
@ -26,11 +26,7 @@ class QianfanChatModel(MaxKBBaseModel, QianfanChatEndpoint):
|
|||
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
optional_params = {}
|
||||
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
|
||||
optional_params['max_output_tokens'] = model_kwargs['max_tokens']
|
||||
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
|
||||
optional_params['temperature'] = model_kwargs['temperature']
|
||||
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
|
||||
return QianfanChatModel(model=model_name,
|
||||
qianfan_ak=model_credential.get('api_key'),
|
||||
qianfan_sk=model_credential.get('secret_key'),
|
||||
|
|
|
|||
|
|
@ -6,11 +6,10 @@
|
|||
@date:2024/04/19 15:55
|
||||
@desc:
|
||||
"""
|
||||
import json
|
||||
from typing import List, Optional, Any, Iterator, Dict
|
||||
|
||||
from langchain_community.chat_models.sparkllm import _convert_message_to_dict, _convert_delta_to_message_chunk, \
|
||||
ChatSparkLLM
|
||||
from langchain_community.chat_models.sparkllm import \
|
||||
ChatSparkLLM, _convert_message_to_dict, _convert_delta_to_message_chunk
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.messages import BaseMessage, AIMessageChunk
|
||||
from langchain_core.outputs import ChatGenerationChunk
|
||||
|
|
@ -25,11 +24,7 @@ class XFChatSparkLLM(MaxKBBaseModel, ChatSparkLLM):
|
|||
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
optional_params = {}
|
||||
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
|
||||
optional_params['max_tokens'] = model_kwargs['max_tokens']
|
||||
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
|
||||
optional_params['temperature'] = model_kwargs['temperature']
|
||||
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
|
||||
return XFChatSparkLLM(
|
||||
spark_app_id=model_credential.get('spark_app_id'),
|
||||
spark_api_key=model_credential.get('spark_api_key'),
|
||||
|
|
|
|||
|
|
@ -1,8 +1,12 @@
|
|||
# coding=utf-8
|
||||
|
||||
from typing import Dict
|
||||
from typing import Dict, Optional, List, Any, Iterator
|
||||
from urllib.parse import urlparse, ParseResult
|
||||
|
||||
from langchain_core.language_models import LanguageModelInput
|
||||
from langchain_core.messages import BaseMessageChunk
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||
from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI
|
||||
|
||||
|
|
@ -26,11 +30,7 @@ class XinferenceChatModel(MaxKBBaseModel, BaseChatOpenAI):
|
|||
api_base = model_credential.get('api_base', '')
|
||||
base_url = get_base_url(api_base)
|
||||
base_url = base_url if base_url.endswith('/v1') else (base_url + '/v1')
|
||||
optional_params = {}
|
||||
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
|
||||
optional_params['max_tokens'] = model_kwargs['max_tokens']
|
||||
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
|
||||
optional_params['temperature'] = model_kwargs['temperature']
|
||||
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
|
||||
return XinferenceChatModel(
|
||||
model=model_name,
|
||||
openai_api_base=base_url,
|
||||
|
|
|
|||
|
|
@ -7,43 +7,41 @@
|
|||
@desc:
|
||||
"""
|
||||
|
||||
from langchain_community.chat_models import ChatZhipuAI
|
||||
from langchain_community.chat_models.zhipuai import _truncate_params, _get_jwt_token, connect_sse, \
|
||||
_convert_delta_to_message_chunk
|
||||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||
import json
|
||||
from collections.abc import Iterator
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_community.chat_models import ChatZhipuAI
|
||||
from langchain_community.chat_models.zhipuai import _truncate_params, _get_jwt_token, connect_sse, \
|
||||
_convert_delta_to_message_chunk
|
||||
from langchain_core.callbacks import (
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
|
||||
from langchain_core.messages import (
|
||||
AIMessageChunk,
|
||||
BaseMessage
|
||||
)
|
||||
from langchain_core.outputs import ChatGenerationChunk
|
||||
|
||||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||
|
||||
|
||||
class ZhipuChatModel(MaxKBBaseModel, ChatZhipuAI):
|
||||
optional_params: dict
|
||||
|
||||
@staticmethod
|
||||
def is_cache_model():
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
optional_params = {}
|
||||
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
|
||||
optional_params['max_tokens'] = model_kwargs['max_tokens']
|
||||
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
|
||||
optional_params['temperature'] = model_kwargs['temperature']
|
||||
|
||||
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
|
||||
zhipuai_chat = ZhipuChatModel(
|
||||
api_key=model_credential.get('api_key'),
|
||||
model=model_name,
|
||||
streaming=model_kwargs.get('streaming', False),
|
||||
**optional_params
|
||||
optional_params=optional_params,
|
||||
**optional_params,
|
||||
)
|
||||
return zhipuai_chat
|
||||
|
||||
|
|
@ -71,7 +69,7 @@ class ZhipuChatModel(MaxKBBaseModel, ChatZhipuAI):
|
|||
if self.zhipuai_api_base is None:
|
||||
raise ValueError("Did not find zhipu_api_base.")
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
payload = {**params, **kwargs, "messages": message_dicts, "stream": True}
|
||||
payload = {**params, **kwargs, **self.optional_params, "messages": message_dicts, "stream": True}
|
||||
_truncate_params(payload)
|
||||
headers = {
|
||||
"Authorization": _get_jwt_token(self.zhipuai_api_key),
|
||||
|
|
|
|||
Loading…
Reference in New Issue