From 4c28ff12f510c841c6ef8031dd0a30a27a195c16 Mon Sep 17 00:00:00 2001 From: wxg0103 <727495428@qq.com> Date: Thu, 19 Sep 2024 16:03:51 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E4=BC=98=E5=8C=96=E4=BB=A3?= =?UTF-8?q?=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../models_provider/impl/base_chat_open_ai.py | 67 +++---------------- .../impl/qwen_model_provider/model/llm.py | 12 ++-- .../tencent_model_provider/model/hunyuan.py | 30 ++++----- .../impl/tencent_model_provider/model/llm.py | 6 +- .../impl/wenxin_model_provider/model/llm.py | 10 +-- .../impl/xf_model_provider/model/llm.py | 10 +-- .../impl/zhipu_model_provider/model/llm.py | 10 +-- .../views/application/ApplicationAccess.vue | 6 +- .../component/AccessSettingDrawer.vue | 16 ++--- 9 files changed, 64 insertions(+), 103 deletions(-) diff --git a/apps/setting/models_provider/impl/base_chat_open_ai.py b/apps/setting/models_provider/impl/base_chat_open_ai.py index e7f4990f1..eb0c7558e 100644 --- a/apps/setting/models_provider/impl/base_chat_open_ai.py +++ b/apps/setting/models_provider/impl/base_chat_open_ai.py @@ -11,73 +11,26 @@ from langchain_openai.chat_models.base import _convert_delta_to_message_chunk class BaseChatOpenAI(ChatOpenAI): + usage_metadata: dict = {} def get_last_generation_info(self) -> Optional[Dict[str, Any]]: - return self.__dict__.get('_last_generation_info') + return self.usage_metadata def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: - return self.get_last_generation_info().get('prompt_tokens', 0) + return self.usage_metadata.get('input_tokens', 0) def get_num_tokens(self, text: str) -> int: - return self.get_last_generation_info().get('completion_tokens', 0) + return self.get_last_generation_info().get('output_tokens', 0) def _stream( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, + self, *args: Any, stream_usage: Optional[bool] = None, **kwargs: Any ) -> Iterator[ChatGenerationChunk]: kwargs["stream"] = True kwargs["stream_options"] = {"include_usage": True} - payload = self._get_request_payload(messages, stop=stop, **kwargs) - default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk - if self.include_response_headers: - raw_response = self.client.with_raw_response.create(**payload) - response = raw_response.parse() - base_generation_info = {"headers": dict(raw_response.headers)} - else: - response = self.client.create(**payload) - base_generation_info = {} - with response: - is_first_chunk = True - for chunk in response: - if not isinstance(chunk, dict): - chunk = chunk.model_dump() - if (len(chunk["choices"]) == 0 or chunk["choices"][0]["finish_reason"] == "length" or - chunk["choices"][0]["finish_reason"] == "stop") and chunk.get("usage") is not None: - if token_usage := chunk.get("usage"): - self.__dict__.setdefault('_last_generation_info', {}).update(token_usage) - logprobs = None - continue - else: - choice = chunk["choices"][0] - if choice["delta"] is None: - continue - message_chunk = _convert_delta_to_message_chunk( - choice["delta"], default_chunk_class - ) - generation_info = {**base_generation_info} if is_first_chunk else {} - if finish_reason := choice.get("finish_reason"): - generation_info["finish_reason"] = finish_reason - if model_name := chunk.get("model"): - generation_info["model_name"] = model_name - if system_fingerprint := chunk.get("system_fingerprint"): - generation_info["system_fingerprint"] = system_fingerprint - - logprobs = choice.get("logprobs") - if logprobs: - generation_info["logprobs"] = logprobs - default_chunk_class = message_chunk.__class__ - generation_chunk = ChatGenerationChunk( - message=message_chunk, generation_info=generation_info or None - ) - if run_manager: - run_manager.on_llm_new_token( - generation_chunk.text, chunk=generation_chunk, logprobs=logprobs - ) - is_first_chunk = False - yield generation_chunk + for chunk in super()._stream(*args, stream_usage=stream_usage, **kwargs): + if chunk.message.usage_metadata is not None: + self.usage_metadata = chunk.message.usage_metadata + yield chunk def invoke( self, @@ -101,5 +54,5 @@ class BaseChatOpenAI(ChatOpenAI): **kwargs, ).generations[0][0], ).message - self.__dict__.setdefault('_last_generation_info', {}).update(chat_result.response_metadata['token_usage']) + self.usage_metadata = chat_result.response_metadata['token_usage'] return chat_result diff --git a/apps/setting/models_provider/impl/qwen_model_provider/model/llm.py b/apps/setting/models_provider/impl/qwen_model_provider/model/llm.py index 71e5ed8f7..8ac8347aa 100644 --- a/apps/setting/models_provider/impl/qwen_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/qwen_model_provider/model/llm.py @@ -39,14 +39,16 @@ class QwenChatModel(MaxKBBaseModel, ChatTongyi): ) return chat_tong_yi + usage_metadata: dict = {} + def get_last_generation_info(self) -> Optional[Dict[str, Any]]: - return self.__dict__.get('_last_generation_info') + return self.usage_metadata def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: - return self.get_last_generation_info().get('input_tokens', 0) + return self.usage_metadata.get('input_tokens', 0) def get_num_tokens(self, text: str) -> int: - return self.get_last_generation_info().get('output_tokens', 0) + return self.usage_metadata.get('output_tokens', 0) def _stream( self, @@ -69,7 +71,7 @@ class QwenChatModel(MaxKBBaseModel, ChatTongyi): and message["content"] == "" ) or (choice["finish_reason"] == "length"): token_usage = stream_resp["usage"] - self.__dict__.setdefault('_last_generation_info', {}).update(token_usage) + self.usage_metadata = token_usage if ( choice["finish_reason"] == "null" and message["content"] == "" @@ -108,5 +110,5 @@ class QwenChatModel(MaxKBBaseModel, ChatTongyi): **kwargs, ).generations[0][0], ).message - self.__dict__.setdefault('_last_generation_info', {}).update(chat_result.response_metadata['token_usage']) + self.usage_metadata = chat_result.response_metadata['token_usage'] return chat_result diff --git a/apps/setting/models_provider/impl/tencent_model_provider/model/hunyuan.py b/apps/setting/models_provider/impl/tencent_model_provider/model/hunyuan.py index 219920125..38983a72f 100644 --- a/apps/setting/models_provider/impl/tencent_model_provider/model/hunyuan.py +++ b/apps/setting/models_provider/impl/tencent_model_provider/model/hunyuan.py @@ -54,7 +54,7 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: def _convert_delta_to_message_chunk( - _dict: Mapping[str, Any], default_class: Type[BaseMessageChunk] + _dict: Mapping[str, Any], default_class: Type[BaseMessageChunk] ) -> BaseMessageChunk: role = _dict.get("Role") content = _dict.get("Content") or "" @@ -198,11 +198,11 @@ class ChatHunyuan(BaseChatModel): return {**normal_params, **self.model_kwargs} def _generate( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> ChatResult: if self.streaming: stream_iter = self._stream( @@ -213,12 +213,14 @@ class ChatHunyuan(BaseChatModel): res = self._chat(messages, **kwargs) return _create_chat_result(json.loads(res.to_json_string())) + usage_metadata: dict = {} + def _stream( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: res = self._chat(messages, **kwargs) @@ -238,9 +240,7 @@ class ChatHunyuan(BaseChatModel): default_chunk_class = chunk.__class__ # FinishReason === stop if choice.get("FinishReason") == "stop": - self.__dict__.setdefault("_last_generation_info", {}).update( - response.get("Usage", {}) - ) + self.usage_metadata = response.get("Usage", {}) cg_chunk = ChatGenerationChunk(message=chunk) if run_manager: run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk) @@ -275,4 +275,4 @@ class ChatHunyuan(BaseChatModel): @property def _llm_type(self) -> str: - return "hunyuan-chat" \ No newline at end of file + return "hunyuan-chat" diff --git a/apps/setting/models_provider/impl/tencent_model_provider/model/llm.py b/apps/setting/models_provider/impl/tencent_model_provider/model/llm.py index 0f879f73b..cfe673b50 100644 --- a/apps/setting/models_provider/impl/tencent_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/tencent_model_provider/model/llm.py @@ -38,10 +38,10 @@ class TencentModel(MaxKBBaseModel, ChatHunyuan): return TencentModel(model_name=model_name, credentials=model_credential, streaming=streaming, **model_kwargs) def get_last_generation_info(self) -> Optional[Dict[str, Any]]: - return self.__dict__.get('_last_generation_info') + return self.usage_metadata def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: - return self.get_last_generation_info().get('PromptTokens', 0) + return self.usage_metadata.get('PromptTokens', 0) def get_num_tokens(self, text: str) -> int: - return self.get_last_generation_info().get('CompletionTokens', 0) + return self.usage_metadata.get('CompletionTokens', 0) diff --git a/apps/setting/models_provider/impl/wenxin_model_provider/model/llm.py b/apps/setting/models_provider/impl/wenxin_model_provider/model/llm.py index fc5b5335d..8c634c6d0 100644 --- a/apps/setting/models_provider/impl/wenxin_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/wenxin_model_provider/model/llm.py @@ -37,14 +37,16 @@ class QianfanChatModel(MaxKBBaseModel, QianfanChatEndpoint): streaming=model_kwargs.get('streaming', False), init_kwargs=optional_params) + usage_metadata: dict = {} + def get_last_generation_info(self) -> Optional[Dict[str, Any]]: - return self.__dict__.get('_last_generation_info') + return self.usage_metadata def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: - return self.get_last_generation_info().get('prompt_tokens', 0) + return self.usage_metadata.get('prompt_tokens', 0) def get_num_tokens(self, text: str) -> int: - return self.get_last_generation_info().get('completion_tokens', 0) + return self.usage_metadata.get('completion_tokens', 0) def _stream( self, @@ -63,7 +65,7 @@ class QianfanChatModel(MaxKBBaseModel, QianfanChatEndpoint): additional_kwargs = msg.additional_kwargs.get("function_call", {}) if msg.content == "" or res.get("body").get("is_end"): token_usage = res.get("body").get("usage") - self.__dict__.setdefault('_last_generation_info', {}).update(token_usage) + self.usage_metadata = token_usage chunk = ChatGenerationChunk( text=res["result"], message=AIMessageChunk( # type: ignore[call-arg] diff --git a/apps/setting/models_provider/impl/xf_model_provider/model/llm.py b/apps/setting/models_provider/impl/xf_model_provider/model/llm.py index f8d55acb5..598af07ac 100644 --- a/apps/setting/models_provider/impl/xf_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/xf_model_provider/model/llm.py @@ -40,14 +40,16 @@ class XFChatSparkLLM(MaxKBBaseModel, ChatSparkLLM): **optional_params ) + usage_metadata: dict = {} + def get_last_generation_info(self) -> Optional[Dict[str, Any]]: - return self.__dict__.get('_last_generation_info') + return self.usage_metadata def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: - return self.get_last_generation_info().get('prompt_tokens', 0) + return self.usage_metadata.get('prompt_tokens', 0) def get_num_tokens(self, text: str) -> int: - return self.get_last_generation_info().get('completion_tokens', 0) + return self.usage_metadata.get('completion_tokens', 0) def _stream( self, @@ -71,7 +73,7 @@ class XFChatSparkLLM(MaxKBBaseModel, ChatSparkLLM): cg_chunk = ChatGenerationChunk(message=chunk) elif "usage" in content: generation_info = content["usage"] - self.__dict__.setdefault('_last_generation_info', {}).update(generation_info) + self.usage_metadata = generation_info continue else: continue diff --git a/apps/setting/models_provider/impl/zhipu_model_provider/model/llm.py b/apps/setting/models_provider/impl/zhipu_model_provider/model/llm.py index 5722aead3..c86c2e3a3 100644 --- a/apps/setting/models_provider/impl/zhipu_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/zhipu_model_provider/model/llm.py @@ -47,14 +47,16 @@ class ZhipuChatModel(MaxKBBaseModel, ChatZhipuAI): ) return zhipuai_chat + usage_metadata: dict = {} + def get_last_generation_info(self) -> Optional[Dict[str, Any]]: - return self.__dict__.get('_last_generation_info') + return self.usage_metadata def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: - return self.get_last_generation_info().get('prompt_tokens', 0) + return self.usage_metadata.get('prompt_tokens', 0) def get_num_tokens(self, text: str) -> int: - return self.get_last_generation_info().get('completion_tokens', 0) + return self.usage_metadata.get('completion_tokens', 0) def _stream( self, @@ -91,7 +93,7 @@ class ZhipuChatModel(MaxKBBaseModel, ChatZhipuAI): generation_info = {} if "usage" in chunk: generation_info = chunk["usage"] - self.__dict__.setdefault('_last_generation_info', {}).update(generation_info) + self.usage_metadata = generation_info chunk = _convert_delta_to_message_chunk( choice["delta"], default_chunk_class ) diff --git a/ui/src/views/application/ApplicationAccess.vue b/ui/src/views/application/ApplicationAccess.vue index e7adf5005..0461d6b10 100644 --- a/ui/src/views/application/ApplicationAccess.vue +++ b/ui/src/views/application/ApplicationAccess.vue @@ -52,7 +52,7 @@ const platforms = reactive([ { key: 'wecom', logoSrc: new URL(`../../assets/logo_wechat-work.svg`, import.meta.url).href, - name: '企业微信', + name: '企业微信应用', description: '打造企业微信智能应用', isActive: false, exists: false @@ -60,7 +60,7 @@ const platforms = reactive([ { key: 'dingtalk', logoSrc: new URL(`../../assets/logo_dingtalk.svg`, import.meta.url).href, - name: '钉钉', + name: '钉钉应用', description: '打造钉钉智能应用', isActive: false, exists: false @@ -76,7 +76,7 @@ const platforms = reactive([ { key: 'feishu', logoSrc: new URL(`../../assets/logo_lark.svg`, import.meta.url).href, - name: '飞书', + name: '飞书应用', description: '打造飞书智能应用', isActive: false, exists: false diff --git a/ui/src/views/application/component/AccessSettingDrawer.vue b/ui/src/views/application/component/AccessSettingDrawer.vue index 3e30a0903..e57c3a5b7 100644 --- a/ui/src/views/application/component/AccessSettingDrawer.vue +++ b/ui/src/views/application/component/AccessSettingDrawer.vue @@ -29,8 +29,8 @@