From c332a6cacc95f27acdde2d7172ed411606526351 Mon Sep 17 00:00:00 2001 From: wxg0103 <727495428@qq.com> Date: Fri, 16 Aug 2024 10:24:54 +0800 Subject: [PATCH] refactor: aws --- .../aws_bedrock_model_provider.py | 6 -- .../credential/llm.py | 11 ++- .../aws_bedrock_model_provider/model/llm.py | 70 ++++++--------- .../model/azure_chat_model.py | 88 ++++--------------- .../impl/ollama_model_provider/model/llm.py | 4 +- .../impl/openai_model_provider/model/llm.py | 14 +-- .../tencent_model_provider/credential/llm.py | 3 +- .../wenxin_model_provider/credential/llm.py | 9 ++ .../impl/wenxin_model_provider/model/llm.py | 14 +-- 9 files changed, 78 insertions(+), 141 deletions(-) diff --git a/apps/setting/models_provider/impl/aws_bedrock_model_provider/aws_bedrock_model_provider.py b/apps/setting/models_provider/impl/aws_bedrock_model_provider/aws_bedrock_model_provider.py index a3a969564..c15bdf7e5 100644 --- a/apps/setting/models_provider/impl/aws_bedrock_model_provider/aws_bedrock_model_provider.py +++ b/apps/setting/models_provider/impl/aws_bedrock_model_provider/aws_bedrock_model_provider.py @@ -93,12 +93,6 @@ def _initialize_model_info(): ModelTypeConst.LLM, BedrockLLMModelCredential, BedrockModel), - _create_model_info( - 'amazon.titan-embed-text-v2:0', - 'Amazon Titan Text Embeddings V2 是一种轻量级、高效的模型,非常适合在不同维度上执行高精度检索任务。该模型支持灵活的嵌入大小(1024、512 和 256),并优先考虑在较小维度上保持准确性,从而可以在不影响准确性的情况下降低存储成本。Titan Text Embeddings V2 适用于各种任务,包括文档检索、推荐系统、搜索引擎和对话式系统。', - ModelTypeConst.LLM, - BedrockLLMModelCredential, - BedrockModel), _create_model_info( 'mistral.mistral-7b-instruct-v0:2', '7B 密集型转换器,可快速部署,易于定制。体积虽小,但功能强大,适用于各种用例。支持英语和代码,以及 32k 的上下文窗口。', diff --git a/apps/setting/models_provider/impl/aws_bedrock_model_provider/credential/llm.py b/apps/setting/models_provider/impl/aws_bedrock_model_provider/credential/llm.py index 27bb88f28..08829834b 100644 --- a/apps/setting/models_provider/impl/aws_bedrock_model_provider/credential/llm.py +++ b/apps/setting/models_provider/impl/aws_bedrock_model_provider/credential/llm.py @@ -63,10 +63,19 @@ class BedrockLLMModelCredential(BaseForm, BaseModelCredential): def get_other_fields(self, model_name): return { + 'temperature': { + 'value': 0.7, + 'min': 0.1, + 'max': 1, + 'step': 0.01, + 'label': '温度', + 'precision': 2, + 'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定' + }, 'max_tokens': { 'value': 1024, 'min': 1, - 'max': 8192, + 'max': 4096, 'step': 1, 'label': '输出最大Tokens', 'precision': 0, diff --git a/apps/setting/models_provider/impl/aws_bedrock_model_provider/model/llm.py b/apps/setting/models_provider/impl/aws_bedrock_model_provider/model/llm.py index 41546efb6..510ef1339 100644 --- a/apps/setting/models_provider/impl/aws_bedrock_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/aws_bedrock_model_provider/model/llm.py @@ -1,12 +1,29 @@ -from typing import List, Dict, Any, Optional, Iterator +from typing import List, Dict from langchain_community.chat_models import BedrockChat -from langchain_community.chat_models.bedrock import ChatPromptAdapter -from langchain_core.callbacks import CallbackManagerForLLMRun -from langchain_core.messages import BaseMessage, get_buffer_string, AIMessageChunk -from langchain_core.outputs import ChatGenerationChunk from setting.models_provider.base_model_provider import MaxKBBaseModel +def get_max_tokens_keyword(model_name): + """ + 根据模型名称返回正确的 max_tokens 关键字。 + + :param model_name: 模型名称字符串 + :return: 对应的 max_tokens 关键字字符串 + """ + if 'amazon' in model_name: + return 'maxTokenCount' + elif 'anthropic' in model_name: + return 'max_tokens_to_sample' + elif 'ai21' in model_name: + return 'maxTokens' + elif 'cohere' in model_name or 'mistral' in model_name: + return 'max_tokens' + elif 'meta' in model_name: + return 'max_gen_len' + else: + raise ValueError("Unsupported model supplier in model_name.") + + class BedrockModel(MaxKBBaseModel, BedrockChat): @staticmethod @@ -23,7 +40,8 @@ class BedrockModel(MaxKBBaseModel, BedrockChat): **model_kwargs) -> 'BedrockModel': optional_params = {} if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: - optional_params['max_tokens'] = model_kwargs['max_tokens'] + keyword = get_max_tokens_keyword(model_name) + optional_params[keyword] = model_kwargs['max_tokens'] if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: optional_params['temperature'] = model_kwargs['temperature'] @@ -31,42 +49,6 @@ class BedrockModel(MaxKBBaseModel, BedrockChat): model_id=model_name, region_name=model_credential['region_name'], credentials_profile_name=model_credential['credentials_profile_name'], - streaming=model_kwargs.pop('streaming', False), - **optional_params + streaming=model_kwargs.pop('streaming', True), + model_kwargs=optional_params ) - - def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: - return sum(self._get_num_tokens(get_buffer_string([message])) for message in messages) - - def get_num_tokens(self, text: str) -> int: - return self._get_num_tokens(text) - - def _stream( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> Iterator[ChatGenerationChunk]: - provider = self._get_provider() - prompt, system, formatted_messages = None, None, None - - if provider == "anthropic": - system, formatted_messages = ChatPromptAdapter.format_messages( - provider, messages - ) - else: - prompt = ChatPromptAdapter.convert_messages_to_prompt( - provider=provider, messages=messages - ) - - for chunk in self._prepare_input_and_invoke_stream( - prompt=prompt, - system=system, - messages=formatted_messages, - stop=stop, - run_manager=run_manager, - **kwargs, - ): - delta = chunk.text - yield ChatGenerationChunk(message=AIMessageChunk(content=delta)) diff --git a/apps/setting/models_provider/impl/azure_model_provider/model/azure_chat_model.py b/apps/setting/models_provider/impl/azure_model_provider/model/azure_chat_model.py index 9b93b437e..110d72da8 100644 --- a/apps/setting/models_provider/impl/azure_model_provider/model/azure_chat_model.py +++ b/apps/setting/models_provider/impl/azure_model_provider/model/azure_chat_model.py @@ -8,12 +8,14 @@ """ from typing import List, Dict, Optional, Any, Iterator, Type -from langchain_core.callbacks import CallbackManagerForLLMRun -from langchain_core.messages import BaseMessage, AIMessageChunk, BaseMessageChunk -from langchain_core.outputs import ChatGenerationChunk -from langchain_openai.chat_models.base import _convert_delta_to_message_chunk -from langchain_openai import AzureChatOpenAI +from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.messages import BaseMessage, get_buffer_string, BaseMessageChunk, AIMessageChunk +from langchain_core.outputs import ChatGenerationChunk +from langchain_openai import AzureChatOpenAI +from langchain_openai.chat_models.base import _convert_delta_to_message_chunk + +from common.config.tokenizer_manage_config import TokenizerManage from setting.models_provider.base_model_provider import MaxKBBaseModel @@ -36,72 +38,20 @@ class AzureChatModel(MaxKBBaseModel, AzureChatOpenAI): deployment_name=model_credential.get('deployment_name'), openai_api_key=model_credential.get('api_key'), openai_api_type="azure", - **optional_params + **optional_params, + streaming=True, ) - def get_last_generation_info(self) -> Optional[Dict[str, Any]]: - return self.__dict__.get('_last_generation_info') - def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: - return self.get_last_generation_info().get('prompt_tokens', 0) + try: + return super().get_num_tokens_from_messages(messages) + except Exception as e: + tokenizer = TokenizerManage.get_tokenizer() + return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages]) def get_num_tokens(self, text: str) -> int: - return self.get_last_generation_info().get('completion_tokens', 0) - - def _stream( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> Iterator[ChatGenerationChunk]: - kwargs["stream"] = True - kwargs["stream_options"] = {"include_usage": True} - payload = self._get_request_payload(messages, stop=stop, **kwargs) - default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk - if self.include_response_headers: - raw_response = self.client.with_raw_response.create(**payload) - response = raw_response.parse() - base_generation_info = {"headers": dict(raw_response.headers)} - else: - response = self.client.create(**payload) - base_generation_info = {} - with response: - is_first_chunk = True - for chunk in response: - if not isinstance(chunk, dict): - chunk = chunk.model_dump() - if len(chunk["choices"]) == 0: - if token_usage := chunk.get("usage"): - self.__dict__.setdefault('_last_generation_info', {}).update(token_usage) - logprobs = None - else: - continue - else: - choice = chunk["choices"][0] - if choice["delta"] is None: - continue - message_chunk = _convert_delta_to_message_chunk( - choice["delta"], default_chunk_class - ) - generation_info = {**base_generation_info} if is_first_chunk else {} - if finish_reason := choice.get("finish_reason"): - generation_info["finish_reason"] = finish_reason - if model_name := chunk.get("model"): - generation_info["model_name"] = model_name - if system_fingerprint := chunk.get("system_fingerprint"): - generation_info["system_fingerprint"] = system_fingerprint - - logprobs = choice.get("logprobs") - if logprobs: - generation_info["logprobs"] = logprobs - default_chunk_class = message_chunk.__class__ - generation_chunk = ChatGenerationChunk( - message=message_chunk, generation_info=generation_info or None - ) - if run_manager: - run_manager.on_llm_new_token( - generation_chunk.text, chunk=generation_chunk, logprobs=logprobs - ) - is_first_chunk = False - yield generation_chunk + try: + return super().get_num_tokens(text) + except Exception as e: + tokenizer = TokenizerManage.get_tokenizer() + return len(tokenizer.encode(text)) diff --git a/apps/setting/models_provider/impl/ollama_model_provider/model/llm.py b/apps/setting/models_provider/impl/ollama_model_provider/model/llm.py index 08e8ee9e0..9ae88558b 100644 --- a/apps/setting/models_provider/impl/ollama_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/ollama_model_provider/model/llm.py @@ -10,10 +10,10 @@ from typing import List, Dict from urllib.parse import urlparse, ParseResult from langchain_core.messages import BaseMessage, get_buffer_string +from langchain_openai.chat_models import ChatOpenAI from common.config.tokenizer_manage_config import TokenizerManage from setting.models_provider.base_model_provider import MaxKBBaseModel -from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI def get_base_url(url: str): @@ -24,7 +24,7 @@ def get_base_url(url: str): return result_url[:-1] if result_url.endswith("/") else result_url -class OllamaChatModel(MaxKBBaseModel, BaseChatOpenAI): +class OllamaChatModel(MaxKBBaseModel, ChatOpenAI): @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): api_base = model_credential.get('api_base', '') diff --git a/apps/setting/models_provider/impl/openai_model_provider/model/llm.py b/apps/setting/models_provider/impl/openai_model_provider/model/llm.py index f7e48da45..0708803b9 100644 --- a/apps/setting/models_provider/impl/openai_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/openai_model_provider/model/llm.py @@ -6,16 +6,9 @@ @date:2024/4/18 15:28 @desc: """ -from typing import List, Dict, Optional, Iterator, Any, Type +from typing import List, Dict -from langchain_core.callbacks import CallbackManagerForLLMRun -from langchain_core.messages import BaseMessage, get_buffer_string, BaseMessageChunk, AIMessageChunk -from langchain_core.messages.ai import UsageMetadata -from langchain_core.outputs import ChatGenerationChunk -from langchain_openai import ChatOpenAI -from langchain_openai.chat_models.base import _convert_delta_to_message_chunk - -from common.config.tokenizer_manage_config import TokenizerManage +from langchain_openai.chat_models import ChatOpenAI from setting.models_provider.base_model_provider import MaxKBBaseModel @@ -32,6 +25,7 @@ class OpenAIChatModel(MaxKBBaseModel, ChatOpenAI): openai_api_base=model_credential.get('api_base'), openai_api_key=model_credential.get('api_key'), **optional_params, - stream_usage=True + streaming=True, + stream_usage=True, ) return azure_chat_open_ai diff --git a/apps/setting/models_provider/impl/tencent_model_provider/credential/llm.py b/apps/setting/models_provider/impl/tencent_model_provider/credential/llm.py index 06c892675..3a5677cff 100644 --- a/apps/setting/models_provider/impl/tencent_model_provider/credential/llm.py +++ b/apps/setting/models_provider/impl/tencent_model_provider/credential/llm.py @@ -7,7 +7,7 @@ from setting.models_provider.base_model_provider import BaseModelCredential, Val class TencentLLMModelCredential(BaseForm, BaseModelCredential): - REQUIRED_FIELDS = ['hunyuan_app_id', 'hunyuan_secret_id', 'hunyuan_secret_key'] + REQUIRED_FIELDS = ['hunyuan_secret_id', 'hunyuan_secret_key'] @classmethod def _validate_model_type(cls, model_type, provider, raise_exception=False): @@ -42,7 +42,6 @@ class TencentLLMModelCredential(BaseForm, BaseModelCredential): def encryption_dict(self, model): return {**model, 'hunyuan_secret_key': super().encryption(model.get('hunyuan_secret_key', ''))} - hunyuan_app_id = forms.TextInputField('APP ID', required=True) hunyuan_secret_id = forms.PasswordInputField('SecretId', required=True) hunyuan_secret_key = forms.PasswordInputField('SecretKey', required=True) diff --git a/apps/setting/models_provider/impl/wenxin_model_provider/credential/llm.py b/apps/setting/models_provider/impl/wenxin_model_provider/credential/llm.py index ccdc04d59..f2b8b7b59 100644 --- a/apps/setting/models_provider/impl/wenxin_model_provider/credential/llm.py +++ b/apps/setting/models_provider/impl/wenxin_model_provider/credential/llm.py @@ -65,4 +65,13 @@ class WenxinLLMModelCredential(BaseForm, BaseModelCredential): 'precision': 2, 'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定' }, + 'max_tokens': { + 'value': 2048, + 'min': 2, + 'max': 1024, + 'step': 1, + 'label': '输出最大Tokens', + 'precision': 0, + 'tooltip': '指定模型可生成的最大token个数' + } } diff --git a/apps/setting/models_provider/impl/wenxin_model_provider/model/llm.py b/apps/setting/models_provider/impl/wenxin_model_provider/model/llm.py index 159f0b470..fc5b5335d 100644 --- a/apps/setting/models_provider/impl/wenxin_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/wenxin_model_provider/model/llm.py @@ -9,13 +9,9 @@ import uuid from typing import List, Dict, Optional, Any, Iterator -from langchain.schema.messages import get_buffer_string -from langchain_community.chat_models import QianfanChatEndpoint -from langchain_community.chat_models.baidu_qianfan_endpoint import _convert_dict_to_message +from langchain_community.chat_models.baidu_qianfan_endpoint import _convert_dict_to_message, QianfanChatEndpoint from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.outputs import ChatGenerationChunk - -from common.config.tokenizer_manage_config import TokenizerManage from setting.models_provider.base_model_provider import MaxKBBaseModel from langchain_core.messages import ( AIMessageChunk, @@ -24,6 +20,9 @@ from langchain_core.messages import ( class QianfanChatModel(MaxKBBaseModel, QianfanChatEndpoint): + @staticmethod + def is_cache_model(): + return False @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): @@ -36,7 +35,7 @@ class QianfanChatModel(MaxKBBaseModel, QianfanChatEndpoint): qianfan_ak=model_credential.get('api_key'), qianfan_sk=model_credential.get('secret_key'), streaming=model_kwargs.get('streaming', False), - **optional_params) + init_kwargs=optional_params) def get_last_generation_info(self) -> Optional[Dict[str, Any]]: return self.__dict__.get('_last_generation_info') @@ -54,6 +53,7 @@ class QianfanChatModel(MaxKBBaseModel, QianfanChatEndpoint): run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: + kwargs = {**self.init_kwargs, **kwargs} params = self._convert_prompt_msg_params(messages, **kwargs) params["stop"] = stop params["stream"] = True @@ -61,7 +61,7 @@ class QianfanChatModel(MaxKBBaseModel, QianfanChatEndpoint): if res: msg = _convert_dict_to_message(res) additional_kwargs = msg.additional_kwargs.get("function_call", {}) - if msg.content == "": + if msg.content == "" or res.get("body").get("is_end"): token_usage = res.get("body").get("usage") self.__dict__.setdefault('_last_generation_info', {}).update(token_usage) chunk = ChatGenerationChunk(