MaxKB/apps/models_provider/impl/wenxin_model_provider/model/llm.py

105 lines
4.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# coding=utf-8
"""
@project: maxkb
@Author
@file llm.py
@date2023/11/10 17:45
@desc:
"""
from typing import List, Dict, Optional, Any, Iterator
from langchain_community.chat_models.baidu_qianfan_endpoint import _convert_dict_to_message, QianfanChatEndpoint
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.messages import (
AIMessageChunk,
BaseMessage,
)
from langchain_core.outputs import ChatGenerationChunk
from models_provider.base_model_provider import MaxKBBaseModel
from models_provider.impl.base_chat_open_ai import BaseChatOpenAI
class QianfanChatModelQianfan(MaxKBBaseModel, QianfanChatEndpoint):
@staticmethod
def is_cache_model():
return False
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
return QianfanChatModelQianfan(model=model_name,
qianfan_ak=model_credential.get('api_key'),
qianfan_sk=model_credential.get('secret_key'),
streaming=model_kwargs.get('streaming', False),
init_kwargs=optional_params)
usage_metadata: dict = {}
def get_last_generation_info(self) -> Optional[Dict[str, Any]]:
return self.usage_metadata
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
return self.usage_metadata.get('prompt_tokens', 0)
def get_num_tokens(self, text: str) -> int:
return self.usage_metadata.get('completion_tokens', 0)
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
kwargs = {**self.init_kwargs, **kwargs}
params = self._convert_prompt_msg_params(messages, **kwargs)
params["stop"] = stop
params["stream"] = True
for res in self.client.do(**params):
if res:
msg = _convert_dict_to_message(res)
additional_kwargs = msg.additional_kwargs.get("function_call", {})
if msg.content == "" or res.get("body").get("is_end"):
token_usage = res.get("body").get("usage")
self.usage_metadata = token_usage
chunk = ChatGenerationChunk(
text=res["result"],
message=AIMessageChunk( # type: ignore[call-arg]
content=msg.content,
role="assistant",
additional_kwargs=additional_kwargs,
),
generation_info=msg.additional_kwargs,
)
if run_manager:
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
yield chunk
class QianfanChatModelOpenai(MaxKBBaseModel, BaseChatOpenAI):
@staticmethod
def is_cache_model():
return False
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
return QianfanChatModelOpenai(
model=model_name,
openai_api_base=model_credential.get('api_base'),
openai_api_key=model_credential.get('api_key'),
extra_body=optional_params
)
class QianfanChatModel(MaxKBBaseModel):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
api_version = model_credential.get('api_version', 'v1')
if api_version == "v1":
return QianfanChatModelQianfan.new_instance(model_type, model_name, model_credential, **model_kwargs)
elif api_version == "v2":
return QianfanChatModelOpenai.new_instance(model_type, model_name, model_credential, **model_kwargs)