refactor: aws

This commit is contained in:
wxg0103 2024-08-16 10:24:54 +08:00 committed by wxg0103
parent e266dd9d99
commit c332a6cacc
9 changed files with 78 additions and 141 deletions

View File

@ -93,12 +93,6 @@ def _initialize_model_info():
ModelTypeConst.LLM,
BedrockLLMModelCredential,
BedrockModel),
_create_model_info(
'amazon.titan-embed-text-v2:0',
'Amazon Titan Text Embeddings V2 是一种轻量级、高效的模型非常适合在不同维度上执行高精度检索任务。该模型支持灵活的嵌入大小1024、512 和 256并优先考虑在较小维度上保持准确性从而可以在不影响准确性的情况下降低存储成本。Titan Text Embeddings V2 适用于各种任务,包括文档检索、推荐系统、搜索引擎和对话式系统。',
ModelTypeConst.LLM,
BedrockLLMModelCredential,
BedrockModel),
_create_model_info(
'mistral.mistral-7b-instruct-v0:2',
'7B 密集型转换器,可快速部署,易于定制。体积虽小,但功能强大,适用于各种用例。支持英语和代码,以及 32k 的上下文窗口。',

View File

@ -63,10 +63,19 @@ class BedrockLLMModelCredential(BaseForm, BaseModelCredential):
def get_other_fields(self, model_name):
return {
'temperature': {
'value': 0.7,
'min': 0.1,
'max': 1,
'step': 0.01,
'label': '温度',
'precision': 2,
'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'
},
'max_tokens': {
'value': 1024,
'min': 1,
'max': 8192,
'max': 4096,
'step': 1,
'label': '输出最大Tokens',
'precision': 0,

View File

@ -1,12 +1,29 @@
from typing import List, Dict, Any, Optional, Iterator
from typing import List, Dict
from langchain_community.chat_models import BedrockChat
from langchain_community.chat_models.bedrock import ChatPromptAdapter
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.messages import BaseMessage, get_buffer_string, AIMessageChunk
from langchain_core.outputs import ChatGenerationChunk
from setting.models_provider.base_model_provider import MaxKBBaseModel
def get_max_tokens_keyword(model_name):
"""
根据模型名称返回正确的 max_tokens 关键字
:param model_name: 模型名称字符串
:return: 对应的 max_tokens 关键字字符串
"""
if 'amazon' in model_name:
return 'maxTokenCount'
elif 'anthropic' in model_name:
return 'max_tokens_to_sample'
elif 'ai21' in model_name:
return 'maxTokens'
elif 'cohere' in model_name or 'mistral' in model_name:
return 'max_tokens'
elif 'meta' in model_name:
return 'max_gen_len'
else:
raise ValueError("Unsupported model supplier in model_name.")
class BedrockModel(MaxKBBaseModel, BedrockChat):
@staticmethod
@ -23,7 +40,8 @@ class BedrockModel(MaxKBBaseModel, BedrockChat):
**model_kwargs) -> 'BedrockModel':
optional_params = {}
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
optional_params['max_tokens'] = model_kwargs['max_tokens']
keyword = get_max_tokens_keyword(model_name)
optional_params[keyword] = model_kwargs['max_tokens']
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
optional_params['temperature'] = model_kwargs['temperature']
@ -31,42 +49,6 @@ class BedrockModel(MaxKBBaseModel, BedrockChat):
model_id=model_name,
region_name=model_credential['region_name'],
credentials_profile_name=model_credential['credentials_profile_name'],
streaming=model_kwargs.pop('streaming', False),
**optional_params
streaming=model_kwargs.pop('streaming', True),
model_kwargs=optional_params
)
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
return sum(self._get_num_tokens(get_buffer_string([message])) for message in messages)
def get_num_tokens(self, text: str) -> int:
return self._get_num_tokens(text)
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
provider = self._get_provider()
prompt, system, formatted_messages = None, None, None
if provider == "anthropic":
system, formatted_messages = ChatPromptAdapter.format_messages(
provider, messages
)
else:
prompt = ChatPromptAdapter.convert_messages_to_prompt(
provider=provider, messages=messages
)
for chunk in self._prepare_input_and_invoke_stream(
prompt=prompt,
system=system,
messages=formatted_messages,
stop=stop,
run_manager=run_manager,
**kwargs,
):
delta = chunk.text
yield ChatGenerationChunk(message=AIMessageChunk(content=delta))

View File

@ -8,12 +8,14 @@
"""
from typing import List, Dict, Optional, Any, Iterator, Type
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.messages import BaseMessage, AIMessageChunk, BaseMessageChunk
from langchain_core.outputs import ChatGenerationChunk
from langchain_openai.chat_models.base import _convert_delta_to_message_chunk
from langchain_openai import AzureChatOpenAI
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.messages import BaseMessage, get_buffer_string, BaseMessageChunk, AIMessageChunk
from langchain_core.outputs import ChatGenerationChunk
from langchain_openai import AzureChatOpenAI
from langchain_openai.chat_models.base import _convert_delta_to_message_chunk
from common.config.tokenizer_manage_config import TokenizerManage
from setting.models_provider.base_model_provider import MaxKBBaseModel
@ -36,72 +38,20 @@ class AzureChatModel(MaxKBBaseModel, AzureChatOpenAI):
deployment_name=model_credential.get('deployment_name'),
openai_api_key=model_credential.get('api_key'),
openai_api_type="azure",
**optional_params
**optional_params,
streaming=True,
)
def get_last_generation_info(self) -> Optional[Dict[str, Any]]:
return self.__dict__.get('_last_generation_info')
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
return self.get_last_generation_info().get('prompt_tokens', 0)
try:
return super().get_num_tokens_from_messages(messages)
except Exception as e:
tokenizer = TokenizerManage.get_tokenizer()
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
def get_num_tokens(self, text: str) -> int:
return self.get_last_generation_info().get('completion_tokens', 0)
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
kwargs["stream"] = True
kwargs["stream_options"] = {"include_usage": True}
payload = self._get_request_payload(messages, stop=stop, **kwargs)
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
if self.include_response_headers:
raw_response = self.client.with_raw_response.create(**payload)
response = raw_response.parse()
base_generation_info = {"headers": dict(raw_response.headers)}
else:
response = self.client.create(**payload)
base_generation_info = {}
with response:
is_first_chunk = True
for chunk in response:
if not isinstance(chunk, dict):
chunk = chunk.model_dump()
if len(chunk["choices"]) == 0:
if token_usage := chunk.get("usage"):
self.__dict__.setdefault('_last_generation_info', {}).update(token_usage)
logprobs = None
else:
continue
else:
choice = chunk["choices"][0]
if choice["delta"] is None:
continue
message_chunk = _convert_delta_to_message_chunk(
choice["delta"], default_chunk_class
)
generation_info = {**base_generation_info} if is_first_chunk else {}
if finish_reason := choice.get("finish_reason"):
generation_info["finish_reason"] = finish_reason
if model_name := chunk.get("model"):
generation_info["model_name"] = model_name
if system_fingerprint := chunk.get("system_fingerprint"):
generation_info["system_fingerprint"] = system_fingerprint
logprobs = choice.get("logprobs")
if logprobs:
generation_info["logprobs"] = logprobs
default_chunk_class = message_chunk.__class__
generation_chunk = ChatGenerationChunk(
message=message_chunk, generation_info=generation_info or None
)
if run_manager:
run_manager.on_llm_new_token(
generation_chunk.text, chunk=generation_chunk, logprobs=logprobs
)
is_first_chunk = False
yield generation_chunk
try:
return super().get_num_tokens(text)
except Exception as e:
tokenizer = TokenizerManage.get_tokenizer()
return len(tokenizer.encode(text))

View File

@ -10,10 +10,10 @@ from typing import List, Dict
from urllib.parse import urlparse, ParseResult
from langchain_core.messages import BaseMessage, get_buffer_string
from langchain_openai.chat_models import ChatOpenAI
from common.config.tokenizer_manage_config import TokenizerManage
from setting.models_provider.base_model_provider import MaxKBBaseModel
from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI
def get_base_url(url: str):
@ -24,7 +24,7 @@ def get_base_url(url: str):
return result_url[:-1] if result_url.endswith("/") else result_url
class OllamaChatModel(MaxKBBaseModel, BaseChatOpenAI):
class OllamaChatModel(MaxKBBaseModel, ChatOpenAI):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
api_base = model_credential.get('api_base', '')

View File

@ -6,16 +6,9 @@
@date2024/4/18 15:28
@desc:
"""
from typing import List, Dict, Optional, Iterator, Any, Type
from typing import List, Dict
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.messages import BaseMessage, get_buffer_string, BaseMessageChunk, AIMessageChunk
from langchain_core.messages.ai import UsageMetadata
from langchain_core.outputs import ChatGenerationChunk
from langchain_openai import ChatOpenAI
from langchain_openai.chat_models.base import _convert_delta_to_message_chunk
from common.config.tokenizer_manage_config import TokenizerManage
from langchain_openai.chat_models import ChatOpenAI
from setting.models_provider.base_model_provider import MaxKBBaseModel
@ -32,6 +25,7 @@ class OpenAIChatModel(MaxKBBaseModel, ChatOpenAI):
openai_api_base=model_credential.get('api_base'),
openai_api_key=model_credential.get('api_key'),
**optional_params,
stream_usage=True
streaming=True,
stream_usage=True,
)
return azure_chat_open_ai

View File

@ -7,7 +7,7 @@ from setting.models_provider.base_model_provider import BaseModelCredential, Val
class TencentLLMModelCredential(BaseForm, BaseModelCredential):
REQUIRED_FIELDS = ['hunyuan_app_id', 'hunyuan_secret_id', 'hunyuan_secret_key']
REQUIRED_FIELDS = ['hunyuan_secret_id', 'hunyuan_secret_key']
@classmethod
def _validate_model_type(cls, model_type, provider, raise_exception=False):
@ -42,7 +42,6 @@ class TencentLLMModelCredential(BaseForm, BaseModelCredential):
def encryption_dict(self, model):
return {**model, 'hunyuan_secret_key': super().encryption(model.get('hunyuan_secret_key', ''))}
hunyuan_app_id = forms.TextInputField('APP ID', required=True)
hunyuan_secret_id = forms.PasswordInputField('SecretId', required=True)
hunyuan_secret_key = forms.PasswordInputField('SecretKey', required=True)

View File

@ -65,4 +65,13 @@ class WenxinLLMModelCredential(BaseForm, BaseModelCredential):
'precision': 2,
'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'
},
'max_tokens': {
'value': 2048,
'min': 2,
'max': 1024,
'step': 1,
'label': '输出最大Tokens',
'precision': 0,
'tooltip': '指定模型可生成的最大token个数'
}
}

View File

@ -9,13 +9,9 @@
import uuid
from typing import List, Dict, Optional, Any, Iterator
from langchain.schema.messages import get_buffer_string
from langchain_community.chat_models import QianfanChatEndpoint
from langchain_community.chat_models.baidu_qianfan_endpoint import _convert_dict_to_message
from langchain_community.chat_models.baidu_qianfan_endpoint import _convert_dict_to_message, QianfanChatEndpoint
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.outputs import ChatGenerationChunk
from common.config.tokenizer_manage_config import TokenizerManage
from setting.models_provider.base_model_provider import MaxKBBaseModel
from langchain_core.messages import (
AIMessageChunk,
@ -24,6 +20,9 @@ from langchain_core.messages import (
class QianfanChatModel(MaxKBBaseModel, QianfanChatEndpoint):
@staticmethod
def is_cache_model():
return False
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
@ -36,7 +35,7 @@ class QianfanChatModel(MaxKBBaseModel, QianfanChatEndpoint):
qianfan_ak=model_credential.get('api_key'),
qianfan_sk=model_credential.get('secret_key'),
streaming=model_kwargs.get('streaming', False),
**optional_params)
init_kwargs=optional_params)
def get_last_generation_info(self) -> Optional[Dict[str, Any]]:
return self.__dict__.get('_last_generation_info')
@ -54,6 +53,7 @@ class QianfanChatModel(MaxKBBaseModel, QianfanChatEndpoint):
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
kwargs = {**self.init_kwargs, **kwargs}
params = self._convert_prompt_msg_params(messages, **kwargs)
params["stop"] = stop
params["stream"] = True
@ -61,7 +61,7 @@ class QianfanChatModel(MaxKBBaseModel, QianfanChatEndpoint):
if res:
msg = _convert_dict_to_message(res)
additional_kwargs = msg.additional_kwargs.get("function_call", {})
if msg.content == "":
if msg.content == "" or res.get("body").get("is_end"):
token_usage = res.get("body").get("usage")
self.__dict__.setdefault('_last_generation_info', {}).update(token_usage)
chunk = ChatGenerationChunk(