mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 10:12:51 +00:00
--bug=1045125 --user=王孝刚 【模型管理】本地 Xinference 模型回答问题最后几个字会重复一次 https://www.tapd.cn/57709429/s/1565685
79 lines
3.7 KiB
Python
79 lines
3.7 KiB
Python
# coding=utf-8
|
|
|
|
from typing import List, Dict, Optional, Any, Iterator, Type
|
|
from langchain_core.callbacks import CallbackManagerForLLMRun
|
|
from langchain_core.messages import BaseMessage, AIMessageChunk, BaseMessageChunk
|
|
from langchain_core.outputs import ChatGenerationChunk
|
|
from langchain_openai import ChatOpenAI
|
|
from langchain_openai.chat_models.base import _convert_delta_to_message_chunk
|
|
|
|
|
|
class BaseChatOpenAI(ChatOpenAI):
|
|
|
|
def get_last_generation_info(self) -> Optional[Dict[str, Any]]:
|
|
return self.__dict__.get('_last_generation_info')
|
|
|
|
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
|
return self.get_last_generation_info().get('prompt_tokens', 0)
|
|
|
|
def get_num_tokens(self, text: str) -> int:
|
|
return self.get_last_generation_info().get('completion_tokens', 0)
|
|
|
|
def _stream(
|
|
self,
|
|
messages: List[BaseMessage],
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
**kwargs: Any,
|
|
) -> Iterator[ChatGenerationChunk]:
|
|
kwargs["stream"] = True
|
|
kwargs["stream_options"] = {"include_usage": True}
|
|
payload = self._get_request_payload(messages, stop=stop, **kwargs)
|
|
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
|
if self.include_response_headers:
|
|
raw_response = self.client.with_raw_response.create(**payload)
|
|
response = raw_response.parse()
|
|
base_generation_info = {"headers": dict(raw_response.headers)}
|
|
else:
|
|
response = self.client.create(**payload)
|
|
base_generation_info = {}
|
|
with response:
|
|
is_first_chunk = True
|
|
for chunk in response:
|
|
if not isinstance(chunk, dict):
|
|
chunk = chunk.model_dump()
|
|
if (len(chunk["choices"]) == 0 or chunk["choices"][0]["finish_reason"] == "length" or
|
|
chunk["choices"][0]["finish_reason"] == "stop") and chunk.get("usage") is not None:
|
|
if token_usage := chunk.get("usage"):
|
|
self.__dict__.setdefault('_last_generation_info', {}).update(token_usage)
|
|
logprobs = None
|
|
continue
|
|
else:
|
|
choice = chunk["choices"][0]
|
|
if choice["delta"] is None:
|
|
continue
|
|
message_chunk = _convert_delta_to_message_chunk(
|
|
choice["delta"], default_chunk_class
|
|
)
|
|
generation_info = {**base_generation_info} if is_first_chunk else {}
|
|
if finish_reason := choice.get("finish_reason"):
|
|
generation_info["finish_reason"] = finish_reason
|
|
if model_name := chunk.get("model"):
|
|
generation_info["model_name"] = model_name
|
|
if system_fingerprint := chunk.get("system_fingerprint"):
|
|
generation_info["system_fingerprint"] = system_fingerprint
|
|
|
|
logprobs = choice.get("logprobs")
|
|
if logprobs:
|
|
generation_info["logprobs"] = logprobs
|
|
default_chunk_class = message_chunk.__class__
|
|
generation_chunk = ChatGenerationChunk(
|
|
message=message_chunk, generation_info=generation_info or None
|
|
)
|
|
if run_manager:
|
|
run_manager.on_llm_new_token(
|
|
generation_chunk.text, chunk=generation_chunk, logprobs=logprobs
|
|
)
|
|
is_first_chunk = False
|
|
yield generation_chunk
|