mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-25 17:22:55 +00:00
feat: enhance image and video handling by supporting URLs and file IDs
This commit is contained in:
parent
30494cf6de
commit
9d4b2bf010
|
|
@ -74,7 +74,7 @@ class BaseImageToVideoNode(IImageToVideoNode):
|
|||
def get_file_base64(self, image_url):
|
||||
try:
|
||||
if isinstance(image_url, list):
|
||||
image_url = image_url[0].get('file_id')
|
||||
image_url = image_url[0].get('file_id') if 'file_id' in image_url[0] else image_url[0].get('url')
|
||||
if isinstance(image_url, str) and not image_url.startswith('http'):
|
||||
file = QuerySet(File).filter(id=image_url).first()
|
||||
file_bytes = file.get_bytes()
|
||||
|
|
|
|||
|
|
@ -131,11 +131,18 @@ class BaseImageUnderstandNode(IImageUnderstandNode):
|
|||
image_list = data['image_list']
|
||||
if len(image_list) == 0 or data['dialogue_type'] == 'WORKFLOW':
|
||||
return HumanMessage(content=chat_record.problem_text)
|
||||
file_id_list = [image.get('file_id') for image in image_list]
|
||||
|
||||
file_id_list = []
|
||||
url_list = []
|
||||
for image in image_list:
|
||||
if 'file_id' in image:
|
||||
file_id_list.append(image.get('file_id'))
|
||||
elif 'url' in image:
|
||||
url_list.append(image.get('url'))
|
||||
return HumanMessage(content=[
|
||||
{'type': 'text', 'text': data['question']},
|
||||
*[{'type': 'image_url', 'image_url': {'url': f'./oss/file/{file_id}'}} for file_id in file_id_list]
|
||||
|
||||
*[{'type': 'image_url', 'image_url': {'url': f'./oss/file/{file_id}'}} for file_id in file_id_list],
|
||||
*[{'type': 'image_url', 'image_url': {'url': url}} for url in url_list]
|
||||
])
|
||||
return HumanMessage(content=chat_record.problem_text)
|
||||
|
||||
|
|
@ -155,13 +162,22 @@ class BaseImageUnderstandNode(IImageUnderstandNode):
|
|||
image_list = data['image_list']
|
||||
if len(image_list) == 0 or data['dialogue_type'] == 'WORKFLOW':
|
||||
return HumanMessage(content=chat_record.problem_text)
|
||||
image_base64_list = [file_id_to_base64(image.get('file_id')) for image in image_list]
|
||||
file_id_list = []
|
||||
url_list = []
|
||||
for image in image_list:
|
||||
if 'file_id' in image:
|
||||
file_id_list.append(image.get('file_id'))
|
||||
elif 'url' in image:
|
||||
url_list.append(image.get('url'))
|
||||
image_base64_list = [file_id_to_base64(file_id) for file_id in file_id_list]
|
||||
|
||||
return HumanMessage(
|
||||
content=[
|
||||
{'type': 'text', 'text': data['question']},
|
||||
*[{'type': 'image_url',
|
||||
'image_url': {'url': f'data:image/{base64_image[1]};base64,{base64_image[0]}'}} for
|
||||
base64_image in image_base64_list]
|
||||
base64_image in image_base64_list],
|
||||
*[{'type': 'image_url', 'image_url': url} for url in url_list]
|
||||
])
|
||||
return HumanMessage(content=chat_record.problem_text)
|
||||
|
||||
|
|
@ -177,13 +193,17 @@ class BaseImageUnderstandNode(IImageUnderstandNode):
|
|||
images.append({'type': 'image_url', 'image_url': {'url': image}})
|
||||
elif image is not None and len(image) > 0:
|
||||
for img in image:
|
||||
file_id = img['file_id']
|
||||
file = QuerySet(File).filter(id=file_id).first()
|
||||
image_bytes = file.get_bytes()
|
||||
base64_image = base64.b64encode(image_bytes).decode("utf-8")
|
||||
image_format = what(None, image_bytes)
|
||||
images.append(
|
||||
{'type': 'image_url', 'image_url': {'url': f'data:image/{image_format};base64,{base64_image}'}})
|
||||
if 'file_id' in img:
|
||||
file_id = img['file_id']
|
||||
file = QuerySet(File).filter(id=file_id).first()
|
||||
image_bytes = file.get_bytes()
|
||||
base64_image = base64.b64encode(image_bytes).decode("utf-8")
|
||||
image_format = what(None, image_bytes)
|
||||
images.append(
|
||||
{'type': 'image_url', 'image_url': {'url': f'data:image/{image_format};base64,{base64_image}'}})
|
||||
elif 'url' in img and img['url'].startswith('http'):
|
||||
images.append(
|
||||
{'type': 'image_url', 'image_url': {'url': img["url"]}})
|
||||
return images
|
||||
|
||||
def generate_message_list(self, image_model, system: str, prompt: str, history_message, image):
|
||||
|
|
|
|||
|
|
@ -131,11 +131,17 @@ class BaseVideoUnderstandNode(IVideoUnderstandNode):
|
|||
# 增加对 None 和空列表的检查
|
||||
if not video_list or len(video_list) == 0 or data['dialogue_type'] == 'WORKFLOW':
|
||||
return HumanMessage(content=chat_record.problem_text)
|
||||
file_id_list = [video.get('file_id') for video in video_list]
|
||||
file_id_list = []
|
||||
url_list = []
|
||||
for image in video_list:
|
||||
if 'file_id' in image:
|
||||
file_id_list.append(image.get('file_id'))
|
||||
elif 'url' in image:
|
||||
url_list.append(image.get('url'))
|
||||
return HumanMessage(content=[
|
||||
{'type': 'text', 'text': data['question']},
|
||||
*[{'type': 'video_url', 'video_url': {'url': f'./oss/file/{file_id}'}} for file_id in file_id_list]
|
||||
|
||||
*[{'type': 'video_url', 'video_url': {'url': f'./oss/file/{file_id}'}} for file_id in file_id_list],
|
||||
*[{'type': 'video_url', 'video_url': {'url': url}} for url in url_list],
|
||||
])
|
||||
return HumanMessage(content=chat_record.problem_text)
|
||||
|
||||
|
|
@ -155,6 +161,13 @@ class BaseVideoUnderstandNode(IVideoUnderstandNode):
|
|||
video_list = data['video_list']
|
||||
if len(video_list) == 0 or data['dialogue_type'] == 'WORKFLOW':
|
||||
return HumanMessage(content=chat_record.problem_text)
|
||||
file_id_list = []
|
||||
url_list = []
|
||||
for image in video_list:
|
||||
if 'file_id' in image:
|
||||
file_id_list.append(image.get('file_id'))
|
||||
elif 'url' in image:
|
||||
url_list.append(image.get('url'))
|
||||
video_base64_list = [file_id_to_base64(video.get('file_id'), video_model) for video in video_list]
|
||||
return HumanMessage(
|
||||
content=[
|
||||
|
|
@ -174,11 +187,15 @@ class BaseVideoUnderstandNode(IVideoUnderstandNode):
|
|||
videos.append({'type': 'video_url', 'video_url': {'url': image}})
|
||||
elif image is not None and len(image) > 0:
|
||||
for img in image:
|
||||
file_id = img['file_id']
|
||||
file = QuerySet(File).filter(id=file_id).first()
|
||||
url = video_model.upload_file_and_get_url(file.get_bytes(), file.file_name)
|
||||
videos.append(
|
||||
{'type': 'video_url', 'video_url': {'url': url}})
|
||||
if 'file_id' in img:
|
||||
file_id = img['file_id']
|
||||
file = QuerySet(File).filter(id=file_id).first()
|
||||
url = video_model.upload_file_and_get_url(file.get_bytes(), file.file_name)
|
||||
videos.append(
|
||||
{'type': 'video_url', 'video_url': {'url': url}})
|
||||
elif 'url' in img and img['url'].startswith('http'):
|
||||
videos.append(
|
||||
{'type': 'video_url', 'video_url': {'url': img['url']}})
|
||||
return videos
|
||||
|
||||
def generate_message_list(self, video_model, system: str, prompt: str, history_message, video):
|
||||
|
|
|
|||
|
|
@ -17,5 +17,7 @@ urlpatterns = [
|
|||
views.FileRetrievalView.as_view()),
|
||||
re_path(rf'oss/file/(?P<file_id>[\w-]+)/?$',
|
||||
views.FileRetrievalView.as_view()),
|
||||
re_path(rf'^/oss/get_url/(?P<url>[\w-]+)?$',
|
||||
views.GetUrlView.as_view()),
|
||||
|
||||
]
|
||||
|
|
|
|||
|
|
@ -6,4 +6,5 @@ app_name = 'oss'
|
|||
|
||||
urlpatterns = [
|
||||
path('oss/file', views.FileView.as_view()),
|
||||
path('oss/get_url', views.GetUrlView.as_view()),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,4 +1,7 @@
|
|||
# coding=utf-8
|
||||
import base64
|
||||
|
||||
import requests
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from drf_spectacular.utils import extend_schema
|
||||
from rest_framework.parsers import MultiPartParser
|
||||
|
|
@ -66,3 +69,33 @@ class FileView(APIView):
|
|||
@log(menu='file', operate='Delete file')
|
||||
def delete(self, request: Request, file_id: str):
|
||||
return result.success(FileSerializer.Operate(data={'id': file_id}).delete())
|
||||
|
||||
|
||||
class GetUrlView(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
@extend_schema(
|
||||
methods=['GET'],
|
||||
summary=_('Get url'),
|
||||
description=_('Get url'),
|
||||
operation_id=_('Get url'), # type: ignore
|
||||
tags=[_('Chat')] # type: ignore
|
||||
)
|
||||
def get(self, request: Request):
|
||||
url = request.query_params.get('url')
|
||||
response = requests.get(url)
|
||||
# 返回状态码 响应内容大小 响应的contenttype 还有字节流
|
||||
content_type = response.headers.get('Content-Type', '')
|
||||
# 根据内容类型决定如何处理
|
||||
if 'text' in content_type or 'json' in content_type:
|
||||
content = response.text
|
||||
else:
|
||||
# 二进制内容使用Base64编码
|
||||
content = base64.b64encode(response.content).decode('utf-8')
|
||||
|
||||
return result.success({
|
||||
'status_code': response.status_code,
|
||||
'Content-Length': response.headers.get('Content-Length', 0),
|
||||
'Content-Type': content_type,
|
||||
'content': content,
|
||||
})
|
||||
|
|
|
|||
Loading…
Reference in New Issue