fix: 解决非流式返回报错的缺陷
Some checks are pending
sync2gitee / repo-sync (push) Waiting to run
Typos Check / Spell Check with Typos (push) Waiting to run

This commit is contained in:
wxg0103 2024-09-19 11:41:29 +08:00 committed by wxg0103
parent 90a7a9d085
commit 88f6e336e7
2 changed files with 58 additions and 4 deletions

View File

@ -1,9 +1,11 @@
# coding=utf-8
from typing import List, Dict, Optional, Any, Iterator, Type
from typing import List, Dict, Optional, Any, Iterator, Type, cast
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models import LanguageModelInput
from langchain_core.messages import BaseMessage, AIMessageChunk, BaseMessageChunk
from langchain_core.outputs import ChatGenerationChunk
from langchain_core.outputs import ChatGenerationChunk, ChatGeneration
from langchain_core.runnables import RunnableConfig, ensure_config
from langchain_openai import ChatOpenAI
from langchain_openai.chat_models.base import _convert_delta_to_message_chunk
@ -76,3 +78,28 @@ class BaseChatOpenAI(ChatOpenAI):
)
is_first_chunk = False
yield generation_chunk
def invoke(
self,
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
*,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> BaseMessage:
config = ensure_config(config)
chat_result = cast(
ChatGeneration,
self.generate_prompt(
[self._convert_input(input)],
stop=stop,
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
run_name=config.get("run_name"),
run_id=config.pop("run_id", None),
**kwargs,
).generations[0][0],
).message
self.__dict__.setdefault('_last_generation_info', {}).update(chat_result.response_metadata['token_usage'])
return chat_result

View File

@ -6,13 +6,15 @@
@date2024/4/28 11:44
@desc:
"""
from typing import List, Dict, Optional, Iterator, Any
from typing import List, Dict, Optional, Iterator, Any, cast
from langchain_community.chat_models import ChatTongyi
from langchain_community.llms.tongyi import generate_with_last_element_mark
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models import LanguageModelInput
from langchain_core.messages import BaseMessage, get_buffer_string
from langchain_core.outputs import ChatGenerationChunk
from langchain_core.outputs import ChatGenerationChunk, ChatGeneration
from langchain_core.runnables import RunnableConfig, ensure_config
from common.config.tokenizer_manage_config import TokenizerManage
from setting.models_provider.base_model_provider import MaxKBBaseModel
@ -83,3 +85,28 @@ class QwenChatModel(MaxKBBaseModel, ChatTongyi):
if run_manager:
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
yield chunk
def invoke(
self,
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
*,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> BaseMessage:
config = ensure_config(config)
chat_result = cast(
ChatGeneration,
self.generate_prompt(
[self._convert_input(input)],
stop=stop,
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
run_name=config.get("run_name"),
run_id=config.pop("run_id", None),
**kwargs,
).generations[0][0],
).message
self.__dict__.setdefault('_last_generation_info', {}).update(chat_result.response_metadata['token_usage'])
return chat_result