diff --git a/apps/application/flow/step_node/video_understand_step_node/impl/base_video_understand_node.py b/apps/application/flow/step_node/video_understand_step_node/impl/base_video_understand_node.py index 644e9ea3e..2fb9a21e0 100644 --- a/apps/application/flow/step_node/video_understand_step_node/impl/base_video_understand_node.py +++ b/apps/application/flow/step_node/video_understand_step_node/impl/base_video_understand_node.py @@ -168,7 +168,7 @@ class BaseVideoUnderstandNode(IVideoUnderstandNode): def generate_prompt_question(self, prompt): return HumanMessage(self.workflow_manage.generate_prompt(prompt)) - def _process_videos(self, image): + def _process_videos(self, image, video_model): videos = [] if isinstance(image, str) and image.startswith('http'): videos.append({'type': 'video_url', 'video_url': {'url': image}}) @@ -176,16 +176,14 @@ class BaseVideoUnderstandNode(IVideoUnderstandNode): for img in image: file_id = img['file_id'] file = QuerySet(File).filter(id=file_id).first() - video_bytes = file.get_bytes() - base64_video = base64.b64encode(video_bytes).decode("utf-8") - video_format = mimetypes.guess_type(file.file_name)[0] # 获取MIME类型 + url = video_model.upload_file_and_get_url(file.get_bytes(), file.file_name) videos.append( - {'type': 'video_url', 'video_url': {'url': f'data:{video_format};base64,{base64_video}'}}) + {'type': 'video_url', 'video_url': {'url': url}}) return videos def generate_message_list(self, video_model, system: str, prompt: str, history_message, video): prompt_text = self.workflow_manage.generate_prompt(prompt) - videos = self._process_videos(video) + videos = self._process_videos(video, video_model) if videos: messages = [HumanMessage(content=[{'type': 'text', 'text': prompt_text}, *videos])] diff --git a/apps/chat/serializers/chat.py b/apps/chat/serializers/chat.py index d064a0c9e..28e361292 100644 --- a/apps/chat/serializers/chat.py +++ b/apps/chat/serializers/chat.py @@ -143,8 +143,10 @@ class DebugChatSerializers(serializers.Serializer): "application_id": chat_info.application.id, "debug": True }).chat(instance, base_to_response) + SYSTEM_ROLE = get_file_content(os.path.join(PROJECT_DIR, "apps", "chat", 'template', 'generate_prompt_system')) + class PromptGenerateSerializer(serializers.Serializer): workspace_id = serializers.CharField(required=False, label=_('Workspace ID')) model_id = serializers.CharField(required=False, allow_blank=True, allow_null=True, label=_("Model")) @@ -156,13 +158,13 @@ class PromptGenerateSerializer(serializers.Serializer): query_set = QuerySet(Application).filter(id=self.data.get('application_id')) if workspace_id: query_set = query_set.filter(workspace_id=workspace_id) - application=query_set.first() + application = query_set.first() if application is None: raise AppApiException(500, _('Application id does not exist')) return application def generate_prompt(self, instance: dict): - application=self.is_valid(raise_exception=True) + application = self.is_valid(raise_exception=True) GeneratePromptSerializers(data=instance).is_valid(raise_exception=True) workspace_id = self.data.get('workspace_id') model_id = self.data.get('model_id') @@ -171,14 +173,14 @@ class PromptGenerateSerializer(serializers.Serializer): message = messages[-1]['content'] q = prompt.replace("{userInput}", message) - q = q.replace("{application_name}",application.name) - q = q.replace("{detail}",application.desc) + q = q.replace("{application_name}", application.name) + q = q.replace("{detail}", application.desc) messages[-1]['content'] = q - + SUPPORTED_MODEL_TYPES = ["LLM", "IMAGE"] model_exist = QuerySet(Model).filter( id=model_id, - model_type="LLM" + model_type__in=SUPPORTED_MODEL_TYPES ).exists() if not model_exist: raise Exception(_("Model does not exists or is not an LLM model")) @@ -186,14 +188,17 @@ class PromptGenerateSerializer(serializers.Serializer): system_content = SYSTEM_ROLE.format(application_name=application.name, detail=application.desc) def process(): - model = get_model_instance_by_model_workspace_id(model_id=model_id, workspace_id=workspace_id,**application.model_params_setting) + model = get_model_instance_by_model_workspace_id(model_id=model_id, workspace_id=workspace_id, + **application.model_params_setting) try: for r in model.stream([SystemMessage(content=system_content), - *[HumanMessage(content=m.get('content')) if m.get('role') == 'user' else AIMessage( + *[HumanMessage(content=m.get('content')) if m.get( + 'role') == 'user' else AIMessage( content=m.get('content')) for m in messages]]): yield 'data: ' + json.dumps({'content': r.content}) + '\n\n' except Exception as e: yield 'data: ' + json.dumps({'error': str(e)}) + '\n\n' + return to_stream_response_simple(process()) diff --git a/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/image.py b/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/image.py index a03cccfa3..89bf7b4dd 100644 --- a/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/image.py +++ b/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/image.py @@ -1,13 +1,17 @@ # coding=utf-8 -from typing import Dict +from typing import Dict, Optional, Any, Iterator +import requests from langchain_community.chat_models import ChatTongyi -from langchain_core.messages import HumanMessage +from langchain_core.language_models import LanguageModelInput +from langchain_core.messages import HumanMessage, BaseMessageChunk, AIMessage from django.utils.translation import gettext +from langchain_core.runnables import RunnableConfig + from models_provider.base_model_provider import MaxKBBaseModel from models_provider.impl.base_chat_open_ai import BaseChatOpenAI - +import json class QwenVLChatModel(MaxKBBaseModel, BaseChatOpenAI): @@ -32,3 +36,113 @@ class QwenVLChatModel(MaxKBBaseModel, BaseChatOpenAI): def check_auth(self, api_key): chat = ChatTongyi(api_key=api_key, model_name='qwen-max') chat.invoke([HumanMessage([{"type": "text", "text": gettext('Hello')}])]) + + def get_upload_policy(self, api_key, model_name): + """获取文件上传凭证""" + url = "https://dashscope.aliyuncs.com/api/v1/uploads" + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json" + } + params = { + "action": "getPolicy", + "model": model_name + } + + response = requests.get(url, headers=headers, params=params) + if response.status_code != 200: + raise Exception(f"Failed to get upload policy: {response.text}") + + return response.json()['data'] + + def upload_file_to_oss(self, policy_data, file_stream, file_name): + """将文件流上传到临时存储OSS""" + # 构建OSS上传的目标路径 + key = f"{policy_data['upload_dir']}/{file_name}" + + # 构建上传数据 + files = { + 'OSSAccessKeyId': (None, policy_data['oss_access_key_id']), + 'Signature': (None, policy_data['signature']), + 'policy': (None, policy_data['policy']), + 'x-oss-object-acl': (None, policy_data['x_oss_object_acl']), + 'x-oss-forbid-overwrite': (None, policy_data['x_oss_forbid_overwrite']), + 'key': (None, key), + 'success_action_status': (None, '200'), + 'file': (file_name, file_stream) + } + + # 执行上传请求 + response = requests.post(policy_data['upload_host'], files=files) + if response.status_code != 200: + raise Exception(f"Failed to upload file: {response.text}") + + return f"oss://{key}" + + def upload_file_and_get_url(self, file_stream, file_name): + """上传文件并获取URL""" + # 1. 获取上传凭证,上传凭证接口有限流,超出限流将导致请求失败 + policy_data = self.get_upload_policy(self.openai_api_key.get_secret_value(), self.model_name) + # 2. 上传文件到OSS + oss_url = self.upload_file_to_oss(policy_data, file_stream, file_name) + print(oss_url) + + return oss_url + + def stream( + self, + input: LanguageModelInput, + config: Optional[RunnableConfig] = None, + *, + stop: Optional[list[str]] = None, + **kwargs: Any, + ) -> Iterator[BaseMessageChunk]: + url = "https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions" + + headers = { + "Authorization": f"Bearer {self.openai_api_key.get_secret_value()}", + "Content-Type": "application/json", + "X-DashScope-OssResourceResolve": "enable" + } + + data = { + "model": self.model_name, + "messages": [ + { + "role": "user", + "content": input[0].content + } + ], + **self.extra_body, + "stream": True, + } + response = requests.post(url, headers=headers, json=data) + if response.status_code != 200: + raise Exception(f"Failed to get response: {response.text}") + for line in response.iter_lines(): + if line: + try: + decoded_line = line.decode('utf-8') + # 检查是否是有效的SSE数据行 + if decoded_line.startswith('data: '): + # 提取JSON部分 + json_str = decoded_line[6:] # 移除 'data: ' 前缀 + # 检查是否是结束标记 + if json_str.strip() == '[DONE]': + continue + + # 尝试解析JSON + chunk_data = json.loads(json_str) + + if 'choices' in chunk_data and chunk_data['choices']: + delta = chunk_data['choices'][0].get('delta', {}) + content = delta.get('content', '') + if content: + print(content) + yield AIMessage(content=content) + except json.JSONDecodeError: + # 忽略无法解析的行 + continue + except Exception as e: + # 处理其他可能的异常 + continue