mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-27 20:42:52 +00:00
101 lines
4.0 KiB
Python
101 lines
4.0 KiB
Python
# coding=utf-8
|
|
import base64
|
|
import os
|
|
from typing import Dict, Any, List, Optional, Iterator
|
|
|
|
#from docutils.utils import SystemMessage
|
|
from langchain_community.chat_models.sparkllm import ChatSparkLLM, _convert_delta_to_message_chunk
|
|
from langchain_core.callbacks import CallbackManagerForLLMRun
|
|
from langchain_core.messages import BaseMessage, ChatMessage, HumanMessage, AIMessage, AIMessageChunk
|
|
from langchain_core.outputs import ChatGenerationChunk
|
|
|
|
from models_provider.base_model_provider import MaxKBBaseModel
|
|
|
|
|
|
class ImageMessage(HumanMessage):
|
|
content: str
|
|
|
|
|
|
def convert_message_to_dict(message: BaseMessage) -> dict:
|
|
message_dict: Dict[str, Any]
|
|
if isinstance(message, ChatMessage):
|
|
message_dict = {"role": "user", "content": message.content}
|
|
elif isinstance(message, ImageMessage):
|
|
message_dict = {"role": "user", "content": message.content, "content_type": "image"}
|
|
elif isinstance(message, HumanMessage):
|
|
message_dict = {"role": "user", "content": message.content}
|
|
elif isinstance(message, AIMessage):
|
|
message_dict = {"role": "assistant", "content": message.content}
|
|
if "function_call" in message.additional_kwargs:
|
|
message_dict["function_call"] = message.additional_kwargs["function_call"]
|
|
# If function call only, content is None not empty string
|
|
if message_dict["content"] == "":
|
|
message_dict["content"] = None
|
|
if "tool_calls" in message.additional_kwargs:
|
|
message_dict["tool_calls"] = message.additional_kwargs["tool_calls"]
|
|
# If tool calls only, content is None not empty string
|
|
if message_dict["content"] == "":
|
|
message_dict["content"] = None
|
|
# elif isinstance(message, SystemMessage):
|
|
# message_dict = {"role": "system", "content": message.content}
|
|
else:
|
|
raise ValueError(f"Got unknown type {message}")
|
|
|
|
return message_dict
|
|
|
|
|
|
class XFSparkImage(MaxKBBaseModel, ChatSparkLLM):
|
|
spark_app_id: str
|
|
spark_api_key: str
|
|
spark_api_secret: str
|
|
spark_api_url: str
|
|
|
|
@staticmethod
|
|
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
|
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
|
|
return XFSparkImage(
|
|
spark_app_id=model_credential.get('spark_app_id'),
|
|
spark_api_key=model_credential.get('spark_api_key'),
|
|
spark_api_secret=model_credential.get('spark_api_secret'),
|
|
spark_api_url=model_credential.get('spark_api_url'),
|
|
**optional_params
|
|
)
|
|
|
|
@staticmethod
|
|
def generate_message(prompt: str, image) -> list[BaseMessage]:
|
|
if image is None:
|
|
cwd = os.path.dirname(os.path.abspath(__file__))
|
|
with open(f'{cwd}/img_1.png', 'rb') as f:
|
|
base64_image = base64.b64encode(f.read()).decode("utf-8")
|
|
return [ImageMessage(f'data:image/jpeg;base64,{base64_image}'), HumanMessage(prompt)]
|
|
return [HumanMessage(prompt)]
|
|
|
|
def _stream(
|
|
self,
|
|
messages: List[BaseMessage],
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
**kwargs: Any,
|
|
) -> Iterator[ChatGenerationChunk]:
|
|
default_chunk_class = AIMessageChunk
|
|
|
|
self.client.arun(
|
|
[convert_message_to_dict(m) for m in messages],
|
|
self.spark_user_id,
|
|
self.model_kwargs,
|
|
streaming=True,
|
|
)
|
|
for content in self.client.subscribe(timeout=self.request_timeout):
|
|
if "data" not in content:
|
|
continue
|
|
delta = content["data"]
|
|
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
|
|
cg_chunk = ChatGenerationChunk(message=chunk)
|
|
if run_manager:
|
|
run_manager.on_llm_new_token(str(chunk.content), chunk=cg_chunk)
|
|
yield cg_chunk
|
|
|
|
@staticmethod
|
|
def is_cache_model():
|
|
return False
|