From ff41b1ff6ef4102cf4d014ce614bdb8a6d0e91f5 Mon Sep 17 00:00:00 2001 From: wxg0103 <727495428@qq.com> Date: Fri, 25 Apr 2025 10:17:50 +0800 Subject: [PATCH] refactor: tokens --- .../aliyun_bai_lian_model_provider/model/llm.py | 2 +- .../aws_bedrock_model_provider/model/llm.py | 17 ++++++++++++++++- .../impl/deepseek_model_provider/model/llm.py | 2 +- .../impl/kimi_model_provider/model/llm.py | 2 +- .../ollama_model_provider/credential/llm.py | 2 +- .../impl/openai_model_provider/model/llm.py | 6 +++--- .../siliconCloud_model_provider/model/llm.py | 2 +- .../impl/vllm_model_provider/model/llm.py | 2 +- .../volcanic_engine_model_provider/model/llm.py | 2 +- .../impl/xinference_model_provider/model/llm.py | 2 +- 10 files changed, 27 insertions(+), 12 deletions(-) diff --git a/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/llm.py b/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/llm.py index ced57cf57..50b439c49 100644 --- a/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/llm.py +++ b/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/llm.py @@ -20,5 +20,5 @@ class BaiLianChatModel(MaxKBBaseModel, BaseChatOpenAI): model=model_name, openai_api_base=model_credential.get('api_base'), openai_api_key=model_credential.get('api_key'), - **optional_params + extra_body=optional_params ) diff --git a/apps/models_provider/impl/aws_bedrock_model_provider/model/llm.py b/apps/models_provider/impl/aws_bedrock_model_provider/model/llm.py index 7b2b7668b..fcd5cad59 100644 --- a/apps/models_provider/impl/aws_bedrock_model_provider/model/llm.py +++ b/apps/models_provider/impl/aws_bedrock_model_provider/model/llm.py @@ -1,10 +1,12 @@ import os import re -from typing import Dict +from typing import Dict, List from botocore.config import Config from langchain_community.chat_models import BedrockChat +from langchain_core.messages import BaseMessage, get_buffer_string +from common.config.tokenizer_manage_config import TokenizerManage from models_provider.base_model_provider import MaxKBBaseModel @@ -72,6 +74,19 @@ class BedrockModel(MaxKBBaseModel, BedrockChat): config=config ) + def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: + try: + return super().get_num_tokens_from_messages(messages) + except Exception as e: + tokenizer = TokenizerManage.get_tokenizer() + return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages]) + + def get_num_tokens(self, text: str) -> int: + try: + return super().get_num_tokens(text) + except Exception as e: + tokenizer = TokenizerManage.get_tokenizer() + return len(tokenizer.encode(text)) def _update_aws_credentials(profile_name, access_key_id, secret_access_key): credentials_path = os.path.join(os.path.expanduser("~"), ".aws", "credentials") diff --git a/apps/models_provider/impl/deepseek_model_provider/model/llm.py b/apps/models_provider/impl/deepseek_model_provider/model/llm.py index 0d77b15f7..91f49d402 100644 --- a/apps/models_provider/impl/deepseek_model_provider/model/llm.py +++ b/apps/models_provider/impl/deepseek_model_provider/model/llm.py @@ -26,6 +26,6 @@ class DeepSeekChatModel(MaxKBBaseModel, BaseChatOpenAI): model=model_name, openai_api_base='https://api.deepseek.com', openai_api_key=model_credential.get('api_key'), - **optional_params + extra_body=optional_params ) return deepseek_chat_open_ai diff --git a/apps/models_provider/impl/kimi_model_provider/model/llm.py b/apps/models_provider/impl/kimi_model_provider/model/llm.py index 64fcabff2..e5ca5ed84 100644 --- a/apps/models_provider/impl/kimi_model_provider/model/llm.py +++ b/apps/models_provider/impl/kimi_model_provider/model/llm.py @@ -26,6 +26,6 @@ class KimiChatModel(MaxKBBaseModel, BaseChatOpenAI): openai_api_base=model_credential['api_base'], openai_api_key=model_credential['api_key'], model_name=model_name, - **optional_params + extra_body=optional_params, ) return kimi_chat_open_ai diff --git a/apps/models_provider/impl/ollama_model_provider/credential/llm.py b/apps/models_provider/impl/ollama_model_provider/credential/llm.py index 73af9b3e6..02558b0b9 100644 --- a/apps/models_provider/impl/ollama_model_provider/credential/llm.py +++ b/apps/models_provider/impl/ollama_model_provider/credential/llm.py @@ -25,7 +25,7 @@ class OllamaLLMModelParams(BaseForm): _step=0.01, precision=2) - max_tokens = forms.SliderField( + num_predict = forms.SliderField( TooltipLabel(_('Output the maximum Tokens'), _('Specify the maximum number of tokens that the model can generate')), required=True, default_value=1024, diff --git a/apps/models_provider/impl/openai_model_provider/model/llm.py b/apps/models_provider/impl/openai_model_provider/model/llm.py index a645fb7c8..2d5b76af7 100644 --- a/apps/models_provider/impl/openai_model_provider/model/llm.py +++ b/apps/models_provider/impl/openai_model_provider/model/llm.py @@ -33,15 +33,15 @@ class OpenAIChatModel(MaxKBBaseModel, BaseChatOpenAI): streaming = model_kwargs.get('streaming', True) if 'o1' in model_name: streaming = False - azure_chat_open_ai = OpenAIChatModel( + chat_open_ai = OpenAIChatModel( model=model_name, openai_api_base=model_credential.get('api_base'), openai_api_key=model_credential.get('api_key'), - **optional_params, + extra_body=optional_params, streaming=streaming, custom_get_token_ids=custom_get_token_ids ) - return azure_chat_open_ai + return chat_open_ai def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: try: diff --git a/apps/models_provider/impl/siliconCloud_model_provider/model/llm.py b/apps/models_provider/impl/siliconCloud_model_provider/model/llm.py index 256fed80a..336d09d30 100644 --- a/apps/models_provider/impl/siliconCloud_model_provider/model/llm.py +++ b/apps/models_provider/impl/siliconCloud_model_provider/model/llm.py @@ -34,5 +34,5 @@ class SiliconCloudChatModel(MaxKBBaseModel, BaseChatOpenAI): model=model_name, openai_api_base=model_credential.get('api_base'), openai_api_key=model_credential.get('api_key'), - **optional_params + extra_body=optional_params ) diff --git a/apps/models_provider/impl/vllm_model_provider/model/llm.py b/apps/models_provider/impl/vllm_model_provider/model/llm.py index 5367aa56c..6304c9150 100644 --- a/apps/models_provider/impl/vllm_model_provider/model/llm.py +++ b/apps/models_provider/impl/vllm_model_provider/model/llm.py @@ -31,7 +31,7 @@ class VllmChatModel(MaxKBBaseModel, BaseChatOpenAI): model=model_name, openai_api_base=model_credential.get('api_base'), openai_api_key=model_credential.get('api_key'), - **optional_params, + extra_body=optional_params, streaming=True, stream_usage=True, ) diff --git a/apps/models_provider/impl/volcanic_engine_model_provider/model/llm.py b/apps/models_provider/impl/volcanic_engine_model_provider/model/llm.py index 88ae15490..b5bc0be6e 100644 --- a/apps/models_provider/impl/volcanic_engine_model_provider/model/llm.py +++ b/apps/models_provider/impl/volcanic_engine_model_provider/model/llm.py @@ -17,5 +17,5 @@ class VolcanicEngineChatModel(MaxKBBaseModel, BaseChatOpenAI): model=model_name, openai_api_base=model_credential.get('api_base'), openai_api_key=model_credential.get('api_key'), - **optional_params + extra_body=optional_params ) diff --git a/apps/models_provider/impl/xinference_model_provider/model/llm.py b/apps/models_provider/impl/xinference_model_provider/model/llm.py index 9c9550920..225717414 100644 --- a/apps/models_provider/impl/xinference_model_provider/model/llm.py +++ b/apps/models_provider/impl/xinference_model_provider/model/llm.py @@ -34,7 +34,7 @@ class XinferenceChatModel(MaxKBBaseModel, BaseChatOpenAI): model=model_name, openai_api_base=base_url, openai_api_key=model_credential.get('api_key'), - **optional_params + extra_body=optional_params ) def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: