mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 18:32:48 +00:00
feat: enhance video processing by adding video model parameter and implement file upload to OSS
This commit is contained in:
parent
8bfbac671c
commit
267bdae924
|
|
@ -168,7 +168,7 @@ class BaseVideoUnderstandNode(IVideoUnderstandNode):
|
|||
def generate_prompt_question(self, prompt):
|
||||
return HumanMessage(self.workflow_manage.generate_prompt(prompt))
|
||||
|
||||
def _process_videos(self, image):
|
||||
def _process_videos(self, image, video_model):
|
||||
videos = []
|
||||
if isinstance(image, str) and image.startswith('http'):
|
||||
videos.append({'type': 'video_url', 'video_url': {'url': image}})
|
||||
|
|
@ -176,16 +176,14 @@ class BaseVideoUnderstandNode(IVideoUnderstandNode):
|
|||
for img in image:
|
||||
file_id = img['file_id']
|
||||
file = QuerySet(File).filter(id=file_id).first()
|
||||
video_bytes = file.get_bytes()
|
||||
base64_video = base64.b64encode(video_bytes).decode("utf-8")
|
||||
video_format = mimetypes.guess_type(file.file_name)[0] # 获取MIME类型
|
||||
url = video_model.upload_file_and_get_url(file.get_bytes(), file.file_name)
|
||||
videos.append(
|
||||
{'type': 'video_url', 'video_url': {'url': f'data:{video_format};base64,{base64_video}'}})
|
||||
{'type': 'video_url', 'video_url': {'url': url}})
|
||||
return videos
|
||||
|
||||
def generate_message_list(self, video_model, system: str, prompt: str, history_message, video):
|
||||
prompt_text = self.workflow_manage.generate_prompt(prompt)
|
||||
videos = self._process_videos(video)
|
||||
videos = self._process_videos(video, video_model)
|
||||
|
||||
if videos:
|
||||
messages = [HumanMessage(content=[{'type': 'text', 'text': prompt_text}, *videos])]
|
||||
|
|
|
|||
|
|
@ -143,8 +143,10 @@ class DebugChatSerializers(serializers.Serializer):
|
|||
"application_id": chat_info.application.id, "debug": True
|
||||
}).chat(instance, base_to_response)
|
||||
|
||||
|
||||
SYSTEM_ROLE = get_file_content(os.path.join(PROJECT_DIR, "apps", "chat", 'template', 'generate_prompt_system'))
|
||||
|
||||
|
||||
class PromptGenerateSerializer(serializers.Serializer):
|
||||
workspace_id = serializers.CharField(required=False, label=_('Workspace ID'))
|
||||
model_id = serializers.CharField(required=False, allow_blank=True, allow_null=True, label=_("Model"))
|
||||
|
|
@ -156,13 +158,13 @@ class PromptGenerateSerializer(serializers.Serializer):
|
|||
query_set = QuerySet(Application).filter(id=self.data.get('application_id'))
|
||||
if workspace_id:
|
||||
query_set = query_set.filter(workspace_id=workspace_id)
|
||||
application=query_set.first()
|
||||
application = query_set.first()
|
||||
if application is None:
|
||||
raise AppApiException(500, _('Application id does not exist'))
|
||||
return application
|
||||
|
||||
def generate_prompt(self, instance: dict):
|
||||
application=self.is_valid(raise_exception=True)
|
||||
application = self.is_valid(raise_exception=True)
|
||||
GeneratePromptSerializers(data=instance).is_valid(raise_exception=True)
|
||||
workspace_id = self.data.get('workspace_id')
|
||||
model_id = self.data.get('model_id')
|
||||
|
|
@ -171,14 +173,14 @@ class PromptGenerateSerializer(serializers.Serializer):
|
|||
|
||||
message = messages[-1]['content']
|
||||
q = prompt.replace("{userInput}", message)
|
||||
q = q.replace("{application_name}",application.name)
|
||||
q = q.replace("{detail}",application.desc)
|
||||
q = q.replace("{application_name}", application.name)
|
||||
q = q.replace("{detail}", application.desc)
|
||||
|
||||
messages[-1]['content'] = q
|
||||
|
||||
SUPPORTED_MODEL_TYPES = ["LLM", "IMAGE"]
|
||||
model_exist = QuerySet(Model).filter(
|
||||
id=model_id,
|
||||
model_type="LLM"
|
||||
model_type__in=SUPPORTED_MODEL_TYPES
|
||||
).exists()
|
||||
if not model_exist:
|
||||
raise Exception(_("Model does not exists or is not an LLM model"))
|
||||
|
|
@ -186,14 +188,17 @@ class PromptGenerateSerializer(serializers.Serializer):
|
|||
system_content = SYSTEM_ROLE.format(application_name=application.name, detail=application.desc)
|
||||
|
||||
def process():
|
||||
model = get_model_instance_by_model_workspace_id(model_id=model_id, workspace_id=workspace_id,**application.model_params_setting)
|
||||
model = get_model_instance_by_model_workspace_id(model_id=model_id, workspace_id=workspace_id,
|
||||
**application.model_params_setting)
|
||||
try:
|
||||
for r in model.stream([SystemMessage(content=system_content),
|
||||
*[HumanMessage(content=m.get('content')) if m.get('role') == 'user' else AIMessage(
|
||||
*[HumanMessage(content=m.get('content')) if m.get(
|
||||
'role') == 'user' else AIMessage(
|
||||
content=m.get('content')) for m in messages]]):
|
||||
yield 'data: ' + json.dumps({'content': r.content}) + '\n\n'
|
||||
except Exception as e:
|
||||
yield 'data: ' + json.dumps({'error': str(e)}) + '\n\n'
|
||||
|
||||
return to_stream_response_simple(process())
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,13 +1,17 @@
|
|||
# coding=utf-8
|
||||
|
||||
from typing import Dict
|
||||
from typing import Dict, Optional, Any, Iterator
|
||||
|
||||
import requests
|
||||
from langchain_community.chat_models import ChatTongyi
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.language_models import LanguageModelInput
|
||||
from langchain_core.messages import HumanMessage, BaseMessageChunk, AIMessage
|
||||
from django.utils.translation import gettext
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from models_provider.base_model_provider import MaxKBBaseModel
|
||||
from models_provider.impl.base_chat_open_ai import BaseChatOpenAI
|
||||
|
||||
import json
|
||||
|
||||
class QwenVLChatModel(MaxKBBaseModel, BaseChatOpenAI):
|
||||
|
||||
|
|
@ -32,3 +36,113 @@ class QwenVLChatModel(MaxKBBaseModel, BaseChatOpenAI):
|
|||
def check_auth(self, api_key):
|
||||
chat = ChatTongyi(api_key=api_key, model_name='qwen-max')
|
||||
chat.invoke([HumanMessage([{"type": "text", "text": gettext('Hello')}])])
|
||||
|
||||
def get_upload_policy(self, api_key, model_name):
|
||||
"""获取文件上传凭证"""
|
||||
url = "https://dashscope.aliyuncs.com/api/v1/uploads"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
params = {
|
||||
"action": "getPolicy",
|
||||
"model": model_name
|
||||
}
|
||||
|
||||
response = requests.get(url, headers=headers, params=params)
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Failed to get upload policy: {response.text}")
|
||||
|
||||
return response.json()['data']
|
||||
|
||||
def upload_file_to_oss(self, policy_data, file_stream, file_name):
|
||||
"""将文件流上传到临时存储OSS"""
|
||||
# 构建OSS上传的目标路径
|
||||
key = f"{policy_data['upload_dir']}/{file_name}"
|
||||
|
||||
# 构建上传数据
|
||||
files = {
|
||||
'OSSAccessKeyId': (None, policy_data['oss_access_key_id']),
|
||||
'Signature': (None, policy_data['signature']),
|
||||
'policy': (None, policy_data['policy']),
|
||||
'x-oss-object-acl': (None, policy_data['x_oss_object_acl']),
|
||||
'x-oss-forbid-overwrite': (None, policy_data['x_oss_forbid_overwrite']),
|
||||
'key': (None, key),
|
||||
'success_action_status': (None, '200'),
|
||||
'file': (file_name, file_stream)
|
||||
}
|
||||
|
||||
# 执行上传请求
|
||||
response = requests.post(policy_data['upload_host'], files=files)
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Failed to upload file: {response.text}")
|
||||
|
||||
return f"oss://{key}"
|
||||
|
||||
def upload_file_and_get_url(self, file_stream, file_name):
|
||||
"""上传文件并获取URL"""
|
||||
# 1. 获取上传凭证,上传凭证接口有限流,超出限流将导致请求失败
|
||||
policy_data = self.get_upload_policy(self.openai_api_key.get_secret_value(), self.model_name)
|
||||
# 2. 上传文件到OSS
|
||||
oss_url = self.upload_file_to_oss(policy_data, file_stream, file_name)
|
||||
print(oss_url)
|
||||
|
||||
return oss_url
|
||||
|
||||
def stream(
|
||||
self,
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
stop: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[BaseMessageChunk]:
|
||||
url = "https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions"
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.openai_api_key.get_secret_value()}",
|
||||
"Content-Type": "application/json",
|
||||
"X-DashScope-OssResourceResolve": "enable"
|
||||
}
|
||||
|
||||
data = {
|
||||
"model": self.model_name,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": input[0].content
|
||||
}
|
||||
],
|
||||
**self.extra_body,
|
||||
"stream": True,
|
||||
}
|
||||
response = requests.post(url, headers=headers, json=data)
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Failed to get response: {response.text}")
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
try:
|
||||
decoded_line = line.decode('utf-8')
|
||||
# 检查是否是有效的SSE数据行
|
||||
if decoded_line.startswith('data: '):
|
||||
# 提取JSON部分
|
||||
json_str = decoded_line[6:] # 移除 'data: ' 前缀
|
||||
# 检查是否是结束标记
|
||||
if json_str.strip() == '[DONE]':
|
||||
continue
|
||||
|
||||
# 尝试解析JSON
|
||||
chunk_data = json.loads(json_str)
|
||||
|
||||
if 'choices' in chunk_data and chunk_data['choices']:
|
||||
delta = chunk_data['choices'][0].get('delta', {})
|
||||
content = delta.get('content', '')
|
||||
if content:
|
||||
print(content)
|
||||
yield AIMessage(content=content)
|
||||
except json.JSONDecodeError:
|
||||
# 忽略无法解析的行
|
||||
continue
|
||||
except Exception as e:
|
||||
# 处理其他可能的异常
|
||||
continue
|
||||
|
|
|
|||
Loading…
Reference in New Issue