mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 10:12:51 +00:00
refactor: tokens
This commit is contained in:
parent
3d1c43c020
commit
ff41b1ff6e
|
|
@ -20,5 +20,5 @@ class BaiLianChatModel(MaxKBBaseModel, BaseChatOpenAI):
|
|||
model=model_name,
|
||||
openai_api_base=model_credential.get('api_base'),
|
||||
openai_api_key=model_credential.get('api_key'),
|
||||
**optional_params
|
||||
extra_body=optional_params
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,10 +1,12 @@
|
|||
import os
|
||||
import re
|
||||
from typing import Dict
|
||||
from typing import Dict, List
|
||||
|
||||
from botocore.config import Config
|
||||
from langchain_community.chat_models import BedrockChat
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
|
||||
from common.config.tokenizer_manage_config import TokenizerManage
|
||||
from models_provider.base_model_provider import MaxKBBaseModel
|
||||
|
||||
|
||||
|
|
@ -72,6 +74,19 @@ class BedrockModel(MaxKBBaseModel, BedrockChat):
|
|||
config=config
|
||||
)
|
||||
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
try:
|
||||
return super().get_num_tokens_from_messages(messages)
|
||||
except Exception as e:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
try:
|
||||
return super().get_num_tokens(text)
|
||||
except Exception as e:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return len(tokenizer.encode(text))
|
||||
|
||||
def _update_aws_credentials(profile_name, access_key_id, secret_access_key):
|
||||
credentials_path = os.path.join(os.path.expanduser("~"), ".aws", "credentials")
|
||||
|
|
|
|||
|
|
@ -26,6 +26,6 @@ class DeepSeekChatModel(MaxKBBaseModel, BaseChatOpenAI):
|
|||
model=model_name,
|
||||
openai_api_base='https://api.deepseek.com',
|
||||
openai_api_key=model_credential.get('api_key'),
|
||||
**optional_params
|
||||
extra_body=optional_params
|
||||
)
|
||||
return deepseek_chat_open_ai
|
||||
|
|
|
|||
|
|
@ -26,6 +26,6 @@ class KimiChatModel(MaxKBBaseModel, BaseChatOpenAI):
|
|||
openai_api_base=model_credential['api_base'],
|
||||
openai_api_key=model_credential['api_key'],
|
||||
model_name=model_name,
|
||||
**optional_params
|
||||
extra_body=optional_params,
|
||||
)
|
||||
return kimi_chat_open_ai
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ class OllamaLLMModelParams(BaseForm):
|
|||
_step=0.01,
|
||||
precision=2)
|
||||
|
||||
max_tokens = forms.SliderField(
|
||||
num_predict = forms.SliderField(
|
||||
TooltipLabel(_('Output the maximum Tokens'),
|
||||
_('Specify the maximum number of tokens that the model can generate')),
|
||||
required=True, default_value=1024,
|
||||
|
|
|
|||
|
|
@ -33,15 +33,15 @@ class OpenAIChatModel(MaxKBBaseModel, BaseChatOpenAI):
|
|||
streaming = model_kwargs.get('streaming', True)
|
||||
if 'o1' in model_name:
|
||||
streaming = False
|
||||
azure_chat_open_ai = OpenAIChatModel(
|
||||
chat_open_ai = OpenAIChatModel(
|
||||
model=model_name,
|
||||
openai_api_base=model_credential.get('api_base'),
|
||||
openai_api_key=model_credential.get('api_key'),
|
||||
**optional_params,
|
||||
extra_body=optional_params,
|
||||
streaming=streaming,
|
||||
custom_get_token_ids=custom_get_token_ids
|
||||
)
|
||||
return azure_chat_open_ai
|
||||
return chat_open_ai
|
||||
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -34,5 +34,5 @@ class SiliconCloudChatModel(MaxKBBaseModel, BaseChatOpenAI):
|
|||
model=model_name,
|
||||
openai_api_base=model_credential.get('api_base'),
|
||||
openai_api_key=model_credential.get('api_key'),
|
||||
**optional_params
|
||||
extra_body=optional_params
|
||||
)
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ class VllmChatModel(MaxKBBaseModel, BaseChatOpenAI):
|
|||
model=model_name,
|
||||
openai_api_base=model_credential.get('api_base'),
|
||||
openai_api_key=model_credential.get('api_key'),
|
||||
**optional_params,
|
||||
extra_body=optional_params,
|
||||
streaming=True,
|
||||
stream_usage=True,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -17,5 +17,5 @@ class VolcanicEngineChatModel(MaxKBBaseModel, BaseChatOpenAI):
|
|||
model=model_name,
|
||||
openai_api_base=model_credential.get('api_base'),
|
||||
openai_api_key=model_credential.get('api_key'),
|
||||
**optional_params
|
||||
extra_body=optional_params
|
||||
)
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ class XinferenceChatModel(MaxKBBaseModel, BaseChatOpenAI):
|
|||
model=model_name,
|
||||
openai_api_base=base_url,
|
||||
openai_api_key=model_credential.get('api_key'),
|
||||
**optional_params
|
||||
extra_body=optional_params
|
||||
)
|
||||
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
|
|
|
|||
Loading…
Reference in New Issue