refactor: tokens

This commit is contained in:
wxg0103 2025-04-25 10:17:50 +08:00
parent 3d1c43c020
commit ff41b1ff6e
10 changed files with 27 additions and 12 deletions

View File

@ -20,5 +20,5 @@ class BaiLianChatModel(MaxKBBaseModel, BaseChatOpenAI):
model=model_name,
openai_api_base=model_credential.get('api_base'),
openai_api_key=model_credential.get('api_key'),
**optional_params
extra_body=optional_params
)

View File

@ -1,10 +1,12 @@
import os
import re
from typing import Dict
from typing import Dict, List
from botocore.config import Config
from langchain_community.chat_models import BedrockChat
from langchain_core.messages import BaseMessage, get_buffer_string
from common.config.tokenizer_manage_config import TokenizerManage
from models_provider.base_model_provider import MaxKBBaseModel
@ -72,6 +74,19 @@ class BedrockModel(MaxKBBaseModel, BedrockChat):
config=config
)
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
try:
return super().get_num_tokens_from_messages(messages)
except Exception as e:
tokenizer = TokenizerManage.get_tokenizer()
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
def get_num_tokens(self, text: str) -> int:
try:
return super().get_num_tokens(text)
except Exception as e:
tokenizer = TokenizerManage.get_tokenizer()
return len(tokenizer.encode(text))
def _update_aws_credentials(profile_name, access_key_id, secret_access_key):
credentials_path = os.path.join(os.path.expanduser("~"), ".aws", "credentials")

View File

@ -26,6 +26,6 @@ class DeepSeekChatModel(MaxKBBaseModel, BaseChatOpenAI):
model=model_name,
openai_api_base='https://api.deepseek.com',
openai_api_key=model_credential.get('api_key'),
**optional_params
extra_body=optional_params
)
return deepseek_chat_open_ai

View File

@ -26,6 +26,6 @@ class KimiChatModel(MaxKBBaseModel, BaseChatOpenAI):
openai_api_base=model_credential['api_base'],
openai_api_key=model_credential['api_key'],
model_name=model_name,
**optional_params
extra_body=optional_params,
)
return kimi_chat_open_ai

View File

@ -25,7 +25,7 @@ class OllamaLLMModelParams(BaseForm):
_step=0.01,
precision=2)
max_tokens = forms.SliderField(
num_predict = forms.SliderField(
TooltipLabel(_('Output the maximum Tokens'),
_('Specify the maximum number of tokens that the model can generate')),
required=True, default_value=1024,

View File

@ -33,15 +33,15 @@ class OpenAIChatModel(MaxKBBaseModel, BaseChatOpenAI):
streaming = model_kwargs.get('streaming', True)
if 'o1' in model_name:
streaming = False
azure_chat_open_ai = OpenAIChatModel(
chat_open_ai = OpenAIChatModel(
model=model_name,
openai_api_base=model_credential.get('api_base'),
openai_api_key=model_credential.get('api_key'),
**optional_params,
extra_body=optional_params,
streaming=streaming,
custom_get_token_ids=custom_get_token_ids
)
return azure_chat_open_ai
return chat_open_ai
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
try:

View File

@ -34,5 +34,5 @@ class SiliconCloudChatModel(MaxKBBaseModel, BaseChatOpenAI):
model=model_name,
openai_api_base=model_credential.get('api_base'),
openai_api_key=model_credential.get('api_key'),
**optional_params
extra_body=optional_params
)

View File

@ -31,7 +31,7 @@ class VllmChatModel(MaxKBBaseModel, BaseChatOpenAI):
model=model_name,
openai_api_base=model_credential.get('api_base'),
openai_api_key=model_credential.get('api_key'),
**optional_params,
extra_body=optional_params,
streaming=True,
stream_usage=True,
)

View File

@ -17,5 +17,5 @@ class VolcanicEngineChatModel(MaxKBBaseModel, BaseChatOpenAI):
model=model_name,
openai_api_base=model_credential.get('api_base'),
openai_api_key=model_credential.get('api_key'),
**optional_params
extra_body=optional_params
)

View File

@ -34,7 +34,7 @@ class XinferenceChatModel(MaxKBBaseModel, BaseChatOpenAI):
model=model_name,
openai_api_base=base_url,
openai_api_key=model_credential.get('api_key'),
**optional_params
extra_body=optional_params
)
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: