MaxKB/apps/setting/models_provider/impl/base_chat_open_ai.py
wxg0103 f9784dcbd1
Some checks are pending
sync2gitee / repo-sync (push) Waiting to run
Typos Check / Spell Check with Typos (push) Waiting to run
fix: 修复Xinference 模型回答问题最后几个字会重复一次
--bug=1045125 --user=王孝刚 【模型管理】本地 Xinference 模型回答问题最后几个字会重复一次 https://www.tapd.cn/57709429/s/1565685
2024-08-20 11:51:36 +08:00

79 lines
3.7 KiB
Python

# coding=utf-8
from typing import List, Dict, Optional, Any, Iterator, Type
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.messages import BaseMessage, AIMessageChunk, BaseMessageChunk
from langchain_core.outputs import ChatGenerationChunk
from langchain_openai import ChatOpenAI
from langchain_openai.chat_models.base import _convert_delta_to_message_chunk
class BaseChatOpenAI(ChatOpenAI):
def get_last_generation_info(self) -> Optional[Dict[str, Any]]:
return self.__dict__.get('_last_generation_info')
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
return self.get_last_generation_info().get('prompt_tokens', 0)
def get_num_tokens(self, text: str) -> int:
return self.get_last_generation_info().get('completion_tokens', 0)
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
kwargs["stream"] = True
kwargs["stream_options"] = {"include_usage": True}
payload = self._get_request_payload(messages, stop=stop, **kwargs)
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
if self.include_response_headers:
raw_response = self.client.with_raw_response.create(**payload)
response = raw_response.parse()
base_generation_info = {"headers": dict(raw_response.headers)}
else:
response = self.client.create(**payload)
base_generation_info = {}
with response:
is_first_chunk = True
for chunk in response:
if not isinstance(chunk, dict):
chunk = chunk.model_dump()
if (len(chunk["choices"]) == 0 or chunk["choices"][0]["finish_reason"] == "length" or
chunk["choices"][0]["finish_reason"] == "stop") and chunk.get("usage") is not None:
if token_usage := chunk.get("usage"):
self.__dict__.setdefault('_last_generation_info', {}).update(token_usage)
logprobs = None
continue
else:
choice = chunk["choices"][0]
if choice["delta"] is None:
continue
message_chunk = _convert_delta_to_message_chunk(
choice["delta"], default_chunk_class
)
generation_info = {**base_generation_info} if is_first_chunk else {}
if finish_reason := choice.get("finish_reason"):
generation_info["finish_reason"] = finish_reason
if model_name := chunk.get("model"):
generation_info["model_name"] = model_name
if system_fingerprint := chunk.get("system_fingerprint"):
generation_info["system_fingerprint"] = system_fingerprint
logprobs = choice.get("logprobs")
if logprobs:
generation_info["logprobs"] = logprobs
default_chunk_class = message_chunk.__class__
generation_chunk = ChatGenerationChunk(
message=message_chunk, generation_info=generation_info or None
)
if run_manager:
run_manager.on_llm_new_token(
generation_chunk.text, chunk=generation_chunk, logprobs=logprobs
)
is_first_chunk = False
yield generation_chunk