diff --git a/apps/application/flow/step_node/__init__.py b/apps/application/flow/step_node/__init__.py index 23537f289..c1ebd2224 100644 --- a/apps/application/flow/step_node/__init__.py +++ b/apps/application/flow/step_node/__init__.py @@ -16,9 +16,12 @@ from .direct_reply_node import * from .function_lib_node import * from .function_node import * from .reranker_node import * +from .document_extract_node import * +from .image_understand_step_node import * node_list = [BaseStartStepNode, BaseChatNode, BaseSearchDatasetNode, BaseQuestionNode, BaseConditionNode, BaseReplyNode, - BaseFunctionNodeNode, BaseFunctionLibNodeNode, BaseRerankerNode, BaseApplicationNode] + BaseFunctionNodeNode, BaseFunctionLibNodeNode, BaseRerankerNode, BaseApplicationNode, BaseDocumentExtractNode, + BaseImageUnderstandNode] def get_node(node_type): diff --git a/apps/application/flow/step_node/document_extract_node/__init__.py b/apps/application/flow/step_node/document_extract_node/__init__.py new file mode 100644 index 000000000..ce8f10f3e --- /dev/null +++ b/apps/application/flow/step_node/document_extract_node/__init__.py @@ -0,0 +1 @@ +from .impl import * \ No newline at end of file diff --git a/apps/application/flow/step_node/document_extract_node/i_document_extract_node.py b/apps/application/flow/step_node/document_extract_node/i_document_extract_node.py new file mode 100644 index 000000000..d030e0078 --- /dev/null +++ b/apps/application/flow/step_node/document_extract_node/i_document_extract_node.py @@ -0,0 +1,30 @@ +# coding=utf-8 + +from typing import Type + +from rest_framework import serializers + +from application.flow.i_step_node import INode, NodeResult +from common.util.field_message import ErrMessage + + +class DocumentExtractNodeSerializer(serializers.Serializer): + # 需要查询的数据集id列表 + file_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True), + error_messages=ErrMessage.list("数据集id列表")) + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + + +class IDocumentExtractNode(INode): + type = 'document-extract-node' + + def get_node_params_serializer_class(self) -> Type[serializers.Serializer]: + return DocumentExtractNodeSerializer + + def _run(self): + return self.execute(**self.flow_params_serializer.data) + + def execute(self, file_list, **kwargs) -> NodeResult: + pass diff --git a/apps/application/flow/step_node/document_extract_node/impl/__init__.py b/apps/application/flow/step_node/document_extract_node/impl/__init__.py new file mode 100644 index 000000000..cf9d55ecd --- /dev/null +++ b/apps/application/flow/step_node/document_extract_node/impl/__init__.py @@ -0,0 +1 @@ +from .base_document_extract_node import BaseDocumentExtractNode diff --git a/apps/application/flow/step_node/document_extract_node/impl/base_document_extract_node.py b/apps/application/flow/step_node/document_extract_node/impl/base_document_extract_node.py new file mode 100644 index 000000000..bf900ffe2 --- /dev/null +++ b/apps/application/flow/step_node/document_extract_node/impl/base_document_extract_node.py @@ -0,0 +1,11 @@ +# coding=utf-8 + +from application.flow.step_node.document_extract_node.i_document_extract_node import IDocumentExtractNode + + +class BaseDocumentExtractNode(IDocumentExtractNode): + def execute(self, file_list, **kwargs): + pass + + def get_details(self, index: int, **kwargs): + pass diff --git a/apps/application/flow/step_node/image_understand_step_node/__init__.py b/apps/application/flow/step_node/image_understand_step_node/__init__.py new file mode 100644 index 000000000..f3feecc9c --- /dev/null +++ b/apps/application/flow/step_node/image_understand_step_node/__init__.py @@ -0,0 +1,3 @@ +# coding=utf-8 + +from .impl import * diff --git a/apps/application/flow/step_node/image_understand_step_node/i_image_understand_node.py b/apps/application/flow/step_node/image_understand_step_node/i_image_understand_node.py new file mode 100644 index 000000000..4c15ad8cd --- /dev/null +++ b/apps/application/flow/step_node/image_understand_step_node/i_image_understand_node.py @@ -0,0 +1,39 @@ +# coding=utf-8 + +from typing import Type + +from rest_framework import serializers + +from application.flow.i_step_node import INode, NodeResult +from common.util.field_message import ErrMessage + + +class ImageUnderstandNodeSerializer(serializers.Serializer): + model_id = serializers.CharField(required=True, error_messages=ErrMessage.char("模型id")) + system = serializers.CharField(required=False, allow_blank=True, allow_null=True, + error_messages=ErrMessage.char("角色设定")) + prompt = serializers.CharField(required=True, error_messages=ErrMessage.char("提示词")) + # 多轮对话数量 + dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer("多轮对话数量")) + + is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容')) + + image_list = serializers.ListField(required=False, error_messages=ErrMessage.list("图片仅1张")) + + +class IImageUnderstandNode(INode): + type = 'image-understand-node' + + def get_node_params_serializer_class(self) -> Type[serializers.Serializer]: + return ImageUnderstandNodeSerializer + + def _run(self): + res = self.workflow_manage.get_reference_field(self.node_params_serializer.data.get('image_list')[0], + self.node_params_serializer.data.get('image_list')[1:]) + return self.execute(image=res, **self.node_params_serializer.data, **self.flow_params_serializer.data) + + def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, + chat_record_id, + image, + **kwargs) -> NodeResult: + pass diff --git a/apps/application/flow/step_node/image_understand_step_node/impl/__init__.py b/apps/application/flow/step_node/image_understand_step_node/impl/__init__.py new file mode 100644 index 000000000..ba2512839 --- /dev/null +++ b/apps/application/flow/step_node/image_understand_step_node/impl/__init__.py @@ -0,0 +1,3 @@ +# coding=utf-8 + +from .base_image_understand_node import BaseImageUnderstandNode diff --git a/apps/application/flow/step_node/image_understand_step_node/impl/base_image_understand_node.py b/apps/application/flow/step_node/image_understand_step_node/impl/base_image_understand_node.py new file mode 100644 index 000000000..9731fbd80 --- /dev/null +++ b/apps/application/flow/step_node/image_understand_step_node/impl/base_image_understand_node.py @@ -0,0 +1,147 @@ +# coding=utf-8 +import base64 +import os +import time +from functools import reduce +from typing import List, Dict + +from django.db.models import QuerySet +from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage + +from application.flow.i_step_node import NodeResult, INode +from application.flow.step_node.image_understand_step_node.i_image_understand_node import IImageUnderstandNode +from dataset.models import File +from setting.models_provider.tools import get_model_instance_by_model_user_id + + +def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str): + chat_model = node_variable.get('chat_model') + message_tokens = node_variable['usage_metadata']['output_tokens'] if 'usage_metadata' in node_variable else 0 + answer_tokens = chat_model.get_num_tokens(answer) + node.context['message_tokens'] = message_tokens + node.context['answer_tokens'] = answer_tokens + node.context['answer'] = answer + node.context['history_message'] = node_variable['history_message'] + node.context['question'] = node_variable['question'] + node.context['run_time'] = time.time() - node.context['start_time'] + if workflow.is_result(node, NodeResult(node_variable, workflow_variable)): + workflow.answer += answer + + +def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow): + """ + 写入上下文数据 (流式) + @param node_variable: 节点数据 + @param workflow_variable: 全局数据 + @param node: 节点 + @param workflow: 工作流管理器 + """ + response = node_variable.get('result') + answer = '' + for chunk in response: + answer += chunk.content + yield chunk.content + _write_context(node_variable, workflow_variable, node, workflow, answer) + + +def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow): + """ + 写入上下文数据 + @param node_variable: 节点数据 + @param workflow_variable: 全局数据 + @param node: 节点实例对象 + @param workflow: 工作流管理器 + """ + response = node_variable.get('result') + answer = response.content + _write_context(node_variable, workflow_variable, node, workflow, answer) + + +class BaseImageUnderstandNode(IImageUnderstandNode): + def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id, + image, + **kwargs) -> NodeResult: + image_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id')) + history_message = self.get_history_message(history_chat_record, dialogue_number) + self.context['history_message'] = history_message + question = self.generate_prompt_question(prompt) + self.context['question'] = question.content + # todo 处理上传图片 + message_list = self.generate_message_list(image_model, system, prompt, history_message, image) + self.context['message_list'] = message_list + self.context['image_list'] = image + if stream: + r = image_model.stream(message_list) + return NodeResult({'result': r, 'chat_model': image_model, 'message_list': message_list, + 'history_message': history_message, 'question': question.content}, {}, + _write_context=write_context_stream) + else: + r = image_model.invoke(message_list) + return NodeResult({'result': r, 'chat_model': image_model, 'message_list': message_list, + 'history_message': history_message, 'question': question.content}, {}, + _write_context=write_context) + + @staticmethod + def get_history_message(history_chat_record, dialogue_number): + start_index = len(history_chat_record) - dialogue_number + history_message = reduce(lambda x, y: [*x, *y], [ + [history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()] + for index in + range(start_index if start_index > 0 else 0, len(history_chat_record))], []) + return history_message + + def generate_prompt_question(self, prompt): + return HumanMessage(self.workflow_manage.generate_prompt(prompt)) + + def generate_message_list(self, image_model, system: str, prompt: str, history_message, image): + if image is not None and len(image) > 0: + file_id = image[0]['file_id'] + file = QuerySet(File).filter(id=file_id).first() + base64_image = base64.b64encode(file.get_byte()).decode("utf-8") + messages = [HumanMessage( + content=[ + {'type': 'text', 'text': self.workflow_manage.generate_prompt(prompt)}, + {'type': 'image_url', 'image_url': {'url': f'data:image/jpeg;base64,{base64_image}'}}, + ])] + else: + messages = [HumanMessage(self.workflow_manage.generate_prompt(prompt))] + + if system is not None and len(system) > 0: + return [ + SystemMessage(self.workflow_manage.generate_prompt(system)), + *history_message, + *messages + ] + else: + return [ + *history_message, + *messages + ] + + @staticmethod + def reset_message_list(message_list: List[BaseMessage], answer_text): + result = [{'role': 'user' if isinstance(message, HumanMessage) else 'ai', 'content': message.content} for + message + in + message_list] + result.append({'role': 'ai', 'content': answer_text}) + return result + + def get_details(self, index: int, **kwargs): + return { + 'name': self.node.properties.get('stepName'), + "index": index, + 'run_time': self.context.get('run_time'), + 'system': self.node_params.get('system'), + 'history_message': [{'content': message.content, 'role': message.type} for message in + (self.context.get('history_message') if self.context.get( + 'history_message') is not None else [])], + 'question': self.context.get('question'), + 'answer': self.context.get('answer'), + 'type': self.node.type, + 'message_tokens': self.context.get('message_tokens'), + 'answer_tokens': self.context.get('answer_tokens'), + 'status': self.status, + 'err_message': self.err_message, + 'image_list': self.context.get('image_list') + } diff --git a/apps/application/flow/step_node/start_node/impl/base_start_node.py b/apps/application/flow/step_node/start_node/impl/base_start_node.py index 39fbfe76a..186623123 100644 --- a/apps/application/flow/step_node/start_node/impl/base_start_node.py +++ b/apps/application/flow/step_node/start_node/impl/base_start_node.py @@ -41,7 +41,7 @@ class BaseStartStepNode(IStarNode): """ 开始节点 初始化全局变量 """ - return NodeResult({'question': question}, + return NodeResult({'question': question, 'image': self.workflow_manage.image_list}, workflow_variable) def get_details(self, index: int, **kwargs): @@ -61,5 +61,6 @@ class BaseStartStepNode(IStarNode): 'type': self.node.type, 'status': self.status, 'err_message': self.err_message, + 'image_list': self.context.get('image'), 'global_fields': global_fields } diff --git a/apps/application/flow/workflow_manage.py b/apps/application/flow/workflow_manage.py index 0e5cb15fb..5de8e0ddc 100644 --- a/apps/application/flow/workflow_manage.py +++ b/apps/application/flow/workflow_manage.py @@ -240,10 +240,13 @@ class NodeChunk: class WorkflowManage: def __init__(self, flow: Flow, params, work_flow_post_handler: WorkFlowPostHandler, - base_to_response: BaseToResponse = SystemToResponse(), form_data=None): + base_to_response: BaseToResponse = SystemToResponse(), form_data=None, image_list=None): if form_data is None: form_data = {} + if image_list is None: + image_list = [] self.form_data = form_data + self.image_list = image_list self.params = params self.flow = flow self.lock = threading.Lock() diff --git a/apps/application/migrations/0019_application_file_upload_enable_and_more.py b/apps/application/migrations/0019_application_file_upload_enable_and_more.py new file mode 100644 index 000000000..0f934d6bf --- /dev/null +++ b/apps/application/migrations/0019_application_file_upload_enable_and_more.py @@ -0,0 +1,23 @@ +# Generated by Django 4.2.15 on 2024-11-07 11:22 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('application', '0018_workflowversion_name'), + ] + + operations = [ + migrations.AddField( + model_name='application', + name='file_upload_enable', + field=models.BooleanField(default=False, verbose_name='文件上传是否启用'), + ), + migrations.AddField( + model_name='application', + name='file_upload_setting', + field=models.JSONField(default={}, verbose_name='文件上传相关设置'), + ), + ] diff --git a/apps/application/models/application.py b/apps/application/models/application.py index e65278d4a..8df9ac190 100644 --- a/apps/application/models/application.py +++ b/apps/application/models/application.py @@ -66,6 +66,9 @@ class Application(AppModelMixin): stt_model_enable = models.BooleanField(verbose_name="语音识别模型是否启用", default=False) tts_type = models.CharField(verbose_name="语音播放类型", max_length=20, default="BROWSER") clean_time = models.IntegerField(verbose_name="清理时间", default=180) + file_upload_enable = models.BooleanField(verbose_name="文件上传是否启用", default=False) + file_upload_setting = models.JSONField(verbose_name="文件上传相关设置", default={}) + @staticmethod def get_default_model_prompt(): diff --git a/apps/application/serializers/application_serializers.py b/apps/application/serializers/application_serializers.py index 86460f093..7e5cc8deb 100644 --- a/apps/application/serializers/application_serializers.py +++ b/apps/application/serializers/application_serializers.py @@ -823,6 +823,8 @@ class ApplicationSerializer(serializers.Serializer): 'stt_model_enable': application.stt_model_enable, 'tts_model_enable': application.tts_model_enable, 'tts_type': application.tts_type, + 'file_upload_enable': application.file_upload_enable, + 'file_upload_setting': application.file_upload_setting, 'work_flow': application.work_flow, 'show_source': application_access_token.show_source, **application_setting_dict}) @@ -876,6 +878,7 @@ class ApplicationSerializer(serializers.Serializer): update_keys = ['name', 'desc', 'model_id', 'multiple_rounds_dialogue', 'prologue', 'status', 'dataset_setting', 'model_setting', 'problem_optimization', 'dialogue_number', 'stt_model_id', 'tts_model_id', 'tts_model_enable', 'stt_model_enable', 'tts_type', + 'file_upload_enable', 'file_upload_setting', 'api_key_is_active', 'icon', 'work_flow', 'model_params_setting', 'tts_model_params_setting', 'problem_optimization_prompt', 'clean_time'] for update_key in update_keys: @@ -941,6 +944,10 @@ class ApplicationSerializer(serializers.Serializer): instance['tts_type'] = node_data['tts_type'] if 'tts_model_params_setting' in node_data: instance['tts_model_params_setting'] = node_data['tts_model_params_setting'] + if 'file_upload_enable' in node_data: + instance['file_upload_enable'] = node_data['file_upload_enable'] + if 'file_upload_setting' in node_data: + instance['file_upload_setting'] = node_data['file_upload_setting'] break def speech_to_text(self, file, with_valid=True): diff --git a/apps/application/serializers/chat_message_serializers.py b/apps/application/serializers/chat_message_serializers.py index 61051f96d..e0576b247 100644 --- a/apps/application/serializers/chat_message_serializers.py +++ b/apps/application/serializers/chat_message_serializers.py @@ -222,6 +222,7 @@ class ChatMessageSerializer(serializers.Serializer): client_id = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端id")) client_type = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端类型")) form_data = serializers.DictField(required=False, error_messages=ErrMessage.char("全局变量")) + image_list = serializers.ListField(required=False, error_messages=ErrMessage.list("图片仅1张")) def is_valid_application_workflow(self, *, raise_exception=False): self.is_valid_intraday_access_num() @@ -299,6 +300,7 @@ class ChatMessageSerializer(serializers.Serializer): client_id = self.data.get('client_id') client_type = self.data.get('client_type') form_data = self.data.get('form_data') + image_list = self.data.get('image_list') user_id = chat_info.application.user_id work_flow_manage = WorkflowManage(Flow.new_instance(chat_info.work_flow_version.work_flow), {'history_chat_record': chat_info.chat_record_list, 'question': message, @@ -308,7 +310,7 @@ class ChatMessageSerializer(serializers.Serializer): 'client_id': client_id, 'client_type': client_type, 'user_id': user_id}, WorkFlowPostHandler(chat_info, client_id, client_type), - base_to_response, form_data) + base_to_response, form_data, image_list) r = work_flow_manage.run() return r diff --git a/apps/application/urls.py b/apps/application/urls.py index 5bd551b7b..bca2725be 100644 --- a/apps/application/urls.py +++ b/apps/application/urls.py @@ -49,6 +49,7 @@ urlpatterns = [ path('application//chat//', views.ChatView.Page.as_view()), path('application//chat/', views.ChatView.Operate.as_view()), path('application//chat//chat_record/', views.ChatView.ChatRecord.as_view()), + path('application//chat//upload_file', views.ChatView.UploadFile.as_view()), path('application//chat//chat_record//', views.ChatView.ChatRecord.Page.as_view()), path('application//chat//chat_record/', diff --git a/apps/application/views/chat_views.py b/apps/application/views/chat_views.py index 78eded773..927670bb2 100644 --- a/apps/application/views/chat_views.py +++ b/apps/application/views/chat_views.py @@ -22,6 +22,7 @@ from common.constants.permission_constants import Permission, Group, Operate, \ RoleConstants, ViewPermission, CompareConstants from common.response import result from common.util.common import query_params_to_single_dict +from dataset.serializers.file_serializers import FileSerializer class Openai(APIView): @@ -128,6 +129,7 @@ class ChatView(APIView): 'client_id': request.auth.client_id, 'form_data': (request.data.get( 'form_data') if 'form_data' in request.data else {}), + 'image_list': request.data.get('image_list') if 'image_list' in request.data else [], 'client_type': request.auth.client_type}).chat() @action(methods=['GET'], detail=False) @@ -391,3 +393,28 @@ class ChatView(APIView): data={'chat_id': chat_id, 'chat_record_id': chat_record_id, 'dataset_id': dataset_id, 'document_id': document_id, 'paragraph_id': paragraph_id}).delete()) + + class UploadFile(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['POST'], detail=False) + @swagger_auto_schema(operation_summary="上传文件", + operation_id="上传文件", + manual_parameters=ChatRecordApi.get_request_params_api(), + tags=["应用/对话日志"] + ) + @has_permissions( + ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_KEY, + RoleConstants.APPLICATION_ACCESS_TOKEN], + [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, + dynamic_tag=keywords.get('application_id'))]) + ) + def post(self, request: Request, application_id: str, chat_id: str): + files = request.FILES.getlist('file') + file_ids = [] + meta = {'application_id': application_id, 'chat_id': chat_id} + for file in files: + file_url = FileSerializer(data={'file': file, 'meta': meta}).upload() + file_ids.append({'name': file.name, 'url': file_url, 'file_id': file_url.split('/')[-1]}) + return result.success(file_ids) + diff --git a/apps/common/job/__init__.py b/apps/common/job/__init__.py index 2f4ef2697..5d75a6c35 100644 --- a/apps/common/job/__init__.py +++ b/apps/common/job/__init__.py @@ -6,6 +6,7 @@ @date:2024/3/14 11:54 @desc: """ +from .clean_orphaned_file_job import * from .client_access_num_job import * from .clean_chat_job import * @@ -13,3 +14,4 @@ from .clean_chat_job import * def run(): client_access_num_job.run() clean_chat_job.run() + clean_orphaned_file_job.run() diff --git a/apps/common/job/clean_orphaned_file_job.py b/apps/common/job/clean_orphaned_file_job.py new file mode 100644 index 000000000..69185150b --- /dev/null +++ b/apps/common/job/clean_orphaned_file_job.py @@ -0,0 +1,40 @@ +# coding=utf-8 + +import logging + +from apscheduler.schedulers.background import BackgroundScheduler +from django.db.models import Q +from django_apscheduler.jobstores import DjangoJobStore + +from application.models import Chat +from common.lock.impl.file_lock import FileLock +from dataset.models import File + +scheduler = BackgroundScheduler() +scheduler.add_jobstore(DjangoJobStore(), "default") +lock = FileLock() + + +def clean_debug_file(): + logging.getLogger("max_kb").info('开始清理没有关联会话的上传文件') + existing_chat_ids = set(Chat.objects.values_list('id', flat=True)) + # UUID to str + existing_chat_ids = [str(chat_id) for chat_id in existing_chat_ids] + print(existing_chat_ids) + # 查找引用的不存在的 chat_id 并删除相关记录 + deleted_count, _ = File.objects.filter(~Q(meta__chat_id__in=existing_chat_ids)).delete() + + logging.getLogger("max_kb").info(f'结束清理没有关联会话的上传文件: {deleted_count}') + + +def run(): + if lock.try_lock('clean_orphaned_file_job', 30 * 30): + try: + scheduler.start() + clean_orphaned_file = scheduler.get_job(job_id='clean_orphaned_file') + if clean_orphaned_file is not None: + clean_orphaned_file.remove() + scheduler.add_job(clean_debug_file, 'cron', hour='2', minute='0', second='0', + id='clean_orphaned_file') + finally: + lock.un_lock('clean_orphaned_file_job') diff --git a/apps/dataset/migrations/0010_file_meta.py b/apps/dataset/migrations/0010_file_meta.py new file mode 100644 index 000000000..f227554ba --- /dev/null +++ b/apps/dataset/migrations/0010_file_meta.py @@ -0,0 +1,18 @@ +# Generated by Django 4.2.15 on 2024-11-07 15:32 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('dataset', '0009_alter_document_status_alter_paragraph_status'), + ] + + operations = [ + migrations.AddField( + model_name='file', + name='meta', + field=models.JSONField(default={}, verbose_name='文件关联数据'), + ), + ] diff --git a/apps/dataset/models/data_set.py b/apps/dataset/models/data_set.py index 9fcb0d643..fb86c30e3 100644 --- a/apps/dataset/models/data_set.py +++ b/apps/dataset/models/data_set.py @@ -141,6 +141,9 @@ class File(AppModelMixin): loid = models.IntegerField(verbose_name="loid") + meta = models.JSONField(verbose_name="文件关联数据", default=dict) + + class Meta: db_table = "file" @@ -149,7 +152,6 @@ class File(AppModelMixin): ): result = select_one("SELECT lo_from_bytea(%s, %s::bytea) as loid", [0, bytea]) self.loid = result['loid'] - self.file_name = 'speech.mp3' super().save() def get_byte(self): diff --git a/apps/dataset/serializers/file_serializers.py b/apps/dataset/serializers/file_serializers.py index 894f14904..1bbd9feb2 100644 --- a/apps/dataset/serializers/file_serializers.py +++ b/apps/dataset/serializers/file_serializers.py @@ -56,12 +56,13 @@ mime_types = {"html": "text/html", "htm": "text/html", "shtml": "text/html", "cs class FileSerializer(serializers.Serializer): file = UploadedFileField(required=True, error_messages=ErrMessage.image("文件")) + meta = serializers.JSONField(required=False) def upload(self, with_valid=True): if with_valid: self.is_valid(raise_exception=True) file_id = uuid.uuid1() - file = File(id=file_id, file_name=self.data.get('file').name) + file = File(id=file_id, file_name=self.data.get('file').name, meta=self.data.get('meta')) file.save(self.data.get('file').read()) return f'/api/file/{file_id}' diff --git a/apps/setting/models_provider/impl/base_image.py b/apps/setting/models_provider/impl/base_image.py deleted file mode 100644 index 70bc99595..000000000 --- a/apps/setting/models_provider/impl/base_image.py +++ /dev/null @@ -1,14 +0,0 @@ -# coding=utf-8 -from abc import abstractmethod - -from pydantic import BaseModel - - -class BaseImage(BaseModel): - @abstractmethod - def check_auth(self): - pass - - @abstractmethod - def image_understand(self, image_file, text): - pass diff --git a/apps/setting/models_provider/impl/openai_model_provider/credential/image.py b/apps/setting/models_provider/impl/openai_model_provider/credential/image.py index 304040222..c400fcb65 100644 --- a/apps/setting/models_provider/impl/openai_model_provider/credential/image.py +++ b/apps/setting/models_provider/impl/openai_model_provider/credential/image.py @@ -1,6 +1,10 @@ # coding=utf-8 +import base64 +import os from typing import Dict +from langchain_core.messages import HumanMessage + from common import forms from common.exception.app_exception import AppApiException from common.forms import BaseForm @@ -25,7 +29,7 @@ class OpenAIImageModelCredential(BaseForm, BaseModelCredential): return False try: model = provider.get_model(model_type, model_name, model_credential) - model.check_auth() + model.stream([HumanMessage(content=[{"type": "text", "text": "你好"}])]) except Exception as e: if isinstance(e, AppApiException): raise e diff --git a/apps/setting/models_provider/impl/openai_model_provider/model/image.py b/apps/setting/models_provider/impl/openai_model_provider/model/image.py index 556b4a0ee..2ccb04f69 100644 --- a/apps/setting/models_provider/impl/openai_model_provider/model/image.py +++ b/apps/setting/models_provider/impl/openai_model_provider/model/image.py @@ -1,12 +1,9 @@ -import base64 -import os from typing import Dict -from openai import OpenAI +from langchain_openai.chat_models import ChatOpenAI from common.config.tokenizer_manage_config import TokenizerManage from setting.models_provider.base_model_provider import MaxKBBaseModel -from setting.models_provider.impl.base_image import BaseImage def custom_get_token_ids(text: str): @@ -14,66 +11,15 @@ def custom_get_token_ids(text: str): return tokenizer.encode(text) -class OpenAIImage(MaxKBBaseModel, BaseImage): - api_base: str - api_key: str - model: str - - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.api_key = kwargs.get('api_key') - self.api_base = kwargs.get('api_base') +class OpenAIImage(MaxKBBaseModel, ChatOpenAI): @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): - optional_params = {} - if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: - optional_params['max_tokens'] = model_kwargs['max_tokens'] - if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: - optional_params['temperature'] = model_kwargs['temperature'] + optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs) return OpenAIImage( model=model_name, - api_base=model_credential.get('api_base'), - api_key=model_credential.get('api_key'), + openai_api_base=model_credential.get('api_base'), + openai_api_key=model_credential.get('api_key'), + stream_options={"include_usage": True}, **optional_params, ) - - def check_auth(self): - client = OpenAI( - base_url=self.api_base, - api_key=self.api_key - ) - response_list = client.models.with_raw_response.list() - # print(response_list) - # cwd = os.path.dirname(os.path.abspath(__file__)) - # with open(f'{cwd}/img_1.png', 'rb') as f: - # self.image_understand(f, "一句话概述这个图片") - - def image_understand(self, image_file, text): - client = OpenAI( - base_url=self.api_base, - api_key=self.api_key - ) - base64_image = base64.b64encode(image_file.read()).decode('utf-8') - - response = client.chat.completions.create( - model=self.model, - messages=[ - { - "role": "user", - "content": [ - { - "type": "text", - "text": text, - }, - { - "type": "image_url", - "image_url": { - "url": f"data:image/jpeg;base64,{base64_image}" - }, - }, - ], - } - ], - ) - return response.choices[0].message.content diff --git a/apps/setting/models_provider/impl/qwen_model_provider/credential/image.py b/apps/setting/models_provider/impl/qwen_model_provider/credential/image.py new file mode 100644 index 000000000..6f8b2f702 --- /dev/null +++ b/apps/setting/models_provider/impl/qwen_model_provider/credential/image.py @@ -0,0 +1,69 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: llm.py + @date:2024/7/11 18:41 + @desc: +""" +import base64 +import os +from typing import Dict + +from langchain_core.messages import HumanMessage + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm, TooltipLabel +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class QwenModelParams(BaseForm): + temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), + required=True, default_value=1.0, + _min=0.1, + _max=1.9, + _step=0.01, + precision=2) + + max_tokens = forms.SliderField( + TooltipLabel('输出最大Tokens', '指定模型可生成的最大token个数'), + required=True, default_value=800, + _min=1, + _max=100000, + _step=1, + precision=0) + + +class QwenVLModelCredential(BaseForm, BaseModelCredential): + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + for key in ['api_key']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model = provider.get_model(model_type, model_name, model_credential) + model.stream([HumanMessage(content=[{"type": "text", "text": "你好"}])]) + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'api_key': super().encryption(model.get('api_key', ''))} + + api_key = forms.PasswordInputField('API Key', required=True) + + def get_model_params_setting_form(self, model_name): + return QwenModelParams() diff --git a/apps/setting/models_provider/impl/qwen_model_provider/model/image.py b/apps/setting/models_provider/impl/qwen_model_provider/model/image.py new file mode 100644 index 000000000..60e209154 --- /dev/null +++ b/apps/setting/models_provider/impl/qwen_model_provider/model/image.py @@ -0,0 +1,22 @@ +# coding=utf-8 + +from typing import Dict + +from langchain_community.chat_models import ChatOpenAI + +from setting.models_provider.base_model_provider import MaxKBBaseModel + + +class QwenVLChatModel(MaxKBBaseModel, ChatOpenAI): + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs) + chat_tong_yi = QwenVLChatModel( + model=model_name, + openai_api_key=model_credential.get('api_key'), + openai_api_base='https://dashscope.aliyuncs.com/compatible-mode/v1', + stream_options={"include_usage": True}, + model_kwargs=optional_params, + ) + return chat_tong_yi diff --git a/apps/setting/models_provider/impl/qwen_model_provider/qwen_model_provider.py b/apps/setting/models_provider/impl/qwen_model_provider/qwen_model_provider.py index dd0a924b8..0a24ca35c 100644 --- a/apps/setting/models_provider/impl/qwen_model_provider/qwen_model_provider.py +++ b/apps/setting/models_provider/impl/qwen_model_provider/qwen_model_provider.py @@ -11,21 +11,33 @@ import os from common.util.file_util import get_file_content from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, ModelInfo, IModelProvider, \ ModelInfoManage +from setting.models_provider.impl.qwen_model_provider.credential.image import QwenVLModelCredential from setting.models_provider.impl.qwen_model_provider.credential.llm import OpenAILLMModelCredential +from setting.models_provider.impl.qwen_model_provider.model.image import QwenVLChatModel from setting.models_provider.impl.qwen_model_provider.model.llm import QwenChatModel from smartdoc.conf import PROJECT_DIR qwen_model_credential = OpenAILLMModelCredential() +qwenvl_model_credential = QwenVLModelCredential() module_info_list = [ ModelInfo('qwen-turbo', '', ModelTypeConst.LLM, qwen_model_credential, QwenChatModel), ModelInfo('qwen-plus', '', ModelTypeConst.LLM, qwen_model_credential, QwenChatModel), ModelInfo('qwen-max', '', ModelTypeConst.LLM, qwen_model_credential, QwenChatModel) ] +module_info_vl_list = [ + ModelInfo('qwen-vl-max', '', ModelTypeConst.IMAGE, qwenvl_model_credential, QwenVLChatModel), + ModelInfo('qwen-vl-max-0809', '', ModelTypeConst.IMAGE, qwenvl_model_credential, QwenVLChatModel), + ModelInfo('qwen-vl-plus-0809', '', ModelTypeConst.IMAGE, qwenvl_model_credential, QwenVLChatModel), +] -model_info_manage = ModelInfoManage.builder().append_model_info_list(module_info_list).append_default_model_info( - ModelInfo('qwen-turbo', '', ModelTypeConst.LLM, qwen_model_credential, QwenChatModel)).build() +model_info_manage = (ModelInfoManage.builder() + .append_model_info_list(module_info_list) + .append_default_model_info( + ModelInfo('qwen-turbo', '', ModelTypeConst.LLM, qwen_model_credential, QwenChatModel)) + .append_model_info_list(module_info_vl_list) + .build()) class QwenModelProvider(IModelProvider): diff --git a/apps/setting/models_provider/impl/tencent_model_provider/credential/image.py b/apps/setting/models_provider/impl/tencent_model_provider/credential/image.py new file mode 100644 index 000000000..f2224546d --- /dev/null +++ b/apps/setting/models_provider/impl/tencent_model_provider/credential/image.py @@ -0,0 +1,69 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: llm.py + @date:2024/7/11 18:41 + @desc: +""" +import base64 +import os +from typing import Dict + +from langchain_core.messages import HumanMessage + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm, TooltipLabel +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class QwenModelParams(BaseForm): + temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), + required=True, default_value=1.0, + _min=0.1, + _max=1.9, + _step=0.01, + precision=2) + + max_tokens = forms.SliderField( + TooltipLabel('输出最大Tokens', '指定模型可生成的最大token个数'), + required=True, default_value=800, + _min=1, + _max=100000, + _step=1, + precision=0) + + +class TencentVisionModelCredential(BaseForm, BaseModelCredential): + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + for key in ['api_key']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model = provider.get_model(model_type, model_name, model_credential) + model.stream([HumanMessage(content=[{"type": "text", "text": "你好"}])]) + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'api_key': super().encryption(model.get('api_key', ''))} + + api_key = forms.PasswordInputField('API Key', required=True) + + def get_model_params_setting_form(self, model_name): + return QwenModelParams() diff --git a/apps/setting/models_provider/impl/tencent_model_provider/model/image.py b/apps/setting/models_provider/impl/tencent_model_provider/model/image.py new file mode 100644 index 000000000..eb7a00f61 --- /dev/null +++ b/apps/setting/models_provider/impl/tencent_model_provider/model/image.py @@ -0,0 +1,25 @@ +from typing import Dict + +from langchain_openai.chat_models import ChatOpenAI + +from common.config.tokenizer_manage_config import TokenizerManage +from setting.models_provider.base_model_provider import MaxKBBaseModel + + +def custom_get_token_ids(text: str): + tokenizer = TokenizerManage.get_tokenizer() + return tokenizer.encode(text) + + +class TencentVision(MaxKBBaseModel, ChatOpenAI): + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs) + return TencentVision( + model=model_name, + openai_api_base='https://api.hunyuan.cloud.tencent.com/v1', + openai_api_key=model_credential.get('api_key'), + stream_options={"include_usage": True}, + **optional_params, + ) diff --git a/apps/setting/models_provider/impl/tencent_model_provider/tencent_model_provider.py b/apps/setting/models_provider/impl/tencent_model_provider/tencent_model_provider.py index 47841a032..b37809eb5 100644 --- a/apps/setting/models_provider/impl/tencent_model_provider/tencent_model_provider.py +++ b/apps/setting/models_provider/impl/tencent_model_provider/tencent_model_provider.py @@ -7,8 +7,10 @@ from setting.models_provider.base_model_provider import ( IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, ModelInfoManage ) from setting.models_provider.impl.tencent_model_provider.credential.embedding import TencentEmbeddingCredential +from setting.models_provider.impl.tencent_model_provider.credential.image import TencentVisionModelCredential from setting.models_provider.impl.tencent_model_provider.credential.llm import TencentLLMModelCredential from setting.models_provider.impl.tencent_model_provider.model.embedding import TencentEmbeddingModel +from setting.models_provider.impl.tencent_model_provider.model.image import TencentVision from setting.models_provider.impl.tencent_model_provider.model.llm import TencentModel from smartdoc.conf import PROJECT_DIR @@ -78,9 +80,18 @@ def _initialize_model_info(): model_info_embedding_list = [tencent_embedding_model_info] + model_info_vision_list = [_create_model_info( + 'hunyuan-vision', + '混元视觉模型', + ModelTypeConst.IMAGE, + TencentVisionModelCredential, + TencentVision)] + + model_info_manage = ModelInfoManage.builder() \ .append_model_info_list(model_info_list) \ .append_model_info_list(model_info_embedding_list) \ + .append_model_info_list(model_info_vision_list) \ .append_default_model_info(model_info_list[0]) \ .build() diff --git a/apps/setting/models_provider/impl/xf_model_provider/credential/image.py b/apps/setting/models_provider/impl/xf_model_provider/credential/image.py index 88345a545..449f32129 100644 --- a/apps/setting/models_provider/impl/xf_model_provider/credential/image.py +++ b/apps/setting/models_provider/impl/xf_model_provider/credential/image.py @@ -1,11 +1,15 @@ # coding=utf-8 - +import base64 +import os from typing import Dict +from langchain_core.messages import HumanMessage + from common import forms from common.exception.app_exception import AppApiException from common.forms import BaseForm from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode +from setting.models_provider.impl.xf_model_provider.model.image import ImageMessage class XunFeiImageModelCredential(BaseForm, BaseModelCredential): @@ -28,7 +32,10 @@ class XunFeiImageModelCredential(BaseForm, BaseModelCredential): return False try: model = provider.get_model(model_type, model_name, model_credential) - model.check_auth() + cwd = os.path.dirname(os.path.abspath(__file__)) + with open(f'{cwd}/img_1.png', 'rb') as f: + message_list = [ImageMessage(str(base64.b64encode(f.read()), 'utf-8')), HumanMessage('请概述这张图片')] + model.stream(message_list) except Exception as e: if isinstance(e, AppApiException): raise e diff --git a/apps/setting/models_provider/impl/openai_model_provider/model/img_1.png b/apps/setting/models_provider/impl/xf_model_provider/credential/img_1.png similarity index 100% rename from apps/setting/models_provider/impl/openai_model_provider/model/img_1.png rename to apps/setting/models_provider/impl/xf_model_provider/credential/img_1.png diff --git a/apps/setting/models_provider/impl/xf_model_provider/model/image.py b/apps/setting/models_provider/impl/xf_model_provider/model/image.py index 91e8fa0cc..b7813287d 100644 --- a/apps/setting/models_provider/impl/xf_model_provider/model/image.py +++ b/apps/setting/models_provider/impl/xf_model_provider/model/image.py @@ -1,50 +1,58 @@ # coding=utf-8 - -import asyncio import base64 -import datetime -import hashlib -import hmac -import json import os -import ssl -from datetime import datetime, UTC -from typing import Dict -from urllib.parse import urlencode -from urllib.parse import urlparse +from typing import Dict, Any, List, Optional, Iterator -import websockets +from docutils.utils import SystemMessage +from langchain_community.chat_models.sparkllm import ChatSparkLLM, _convert_delta_to_message_chunk +from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.messages import BaseMessage, ChatMessage, HumanMessage, AIMessage, AIMessageChunk +from langchain_core.outputs import ChatGenerationChunk from setting.models_provider.base_model_provider import MaxKBBaseModel -from setting.models_provider.impl.base_image import BaseImage - -ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) -ssl_context.check_hostname = False -ssl_context.verify_mode = ssl.CERT_NONE -class XFSparkImage(MaxKBBaseModel, BaseImage): +class ImageMessage(HumanMessage): + content: str + + +def convert_message_to_dict(message: BaseMessage) -> dict: + message_dict: Dict[str, Any] + if isinstance(message, ChatMessage): + message_dict = {"role": "user", "content": message.content} + elif isinstance(message, ImageMessage): + message_dict = {"role": "user", "content": message.content, "content_type": "image"} + elif isinstance(message, HumanMessage): + message_dict = {"role": "user", "content": message.content} + elif isinstance(message, AIMessage): + message_dict = {"role": "assistant", "content": message.content} + if "function_call" in message.additional_kwargs: + message_dict["function_call"] = message.additional_kwargs["function_call"] + # If function call only, content is None not empty string + if message_dict["content"] == "": + message_dict["content"] = None + if "tool_calls" in message.additional_kwargs: + message_dict["tool_calls"] = message.additional_kwargs["tool_calls"] + # If tool calls only, content is None not empty string + if message_dict["content"] == "": + message_dict["content"] = None + elif isinstance(message, SystemMessage): + message_dict = {"role": "system", "content": message.content} + else: + raise ValueError(f"Got unknown type {message}") + + return message_dict + + +class XFSparkImage(MaxKBBaseModel, ChatSparkLLM): spark_app_id: str spark_api_key: str spark_api_secret: str spark_api_url: str - params: dict - - # 初始化 - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.spark_api_url = kwargs.get('spark_api_url') - self.spark_app_id = kwargs.get('spark_app_id') - self.spark_api_key = kwargs.get('spark_api_key') - self.spark_api_secret = kwargs.get('spark_api_secret') - self.params = kwargs.get('params') @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): - optional_params = {'params': {}} - for key, value in model_kwargs.items(): - if key not in ['model_id', 'use_local', 'streaming']: - optional_params['params'][key] = value + optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs) return XFSparkImage( spark_app_id=model_credential.get('spark_app_id'), spark_api_key=model_credential.get('spark_api_key'), @@ -53,118 +61,36 @@ class XFSparkImage(MaxKBBaseModel, BaseImage): **optional_params ) - def create_url(self): - url = self.spark_api_url - host = urlparse(url).hostname - # 生成RFC1123格式的时间戳 - gmt_format = '%a, %d %b %Y %H:%M:%S GMT' - date = datetime.now(UTC).strftime(gmt_format) - - # 拼接字符串 - signature_origin = "host: " + host + "\n" - signature_origin += "date: " + date + "\n" - signature_origin += "GET " + "/v2.1/image " + "HTTP/1.1" - # 进行hmac-sha256进行加密 - signature_sha = hmac.new(self.spark_api_secret.encode('utf-8'), signature_origin.encode('utf-8'), - digestmod=hashlib.sha256).digest() - signature_sha = base64.b64encode(signature_sha).decode(encoding='utf-8') - - authorization_origin = "api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"" % ( - self.spark_api_key, "hmac-sha256", "host date request-line", signature_sha) - authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8') - # 将请求的鉴权参数组合为字典 - v = { - "authorization": authorization, - "date": date, - "host": host - } - # 拼接鉴权参数,生成url - url = url + '?' + urlencode(v) - # print("date: ",date) - # print("v: ",v) - # 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致 - # print('websocket url :', url) - return url - - def check_auth(self): - cwd = os.path.dirname(os.path.abspath(__file__)) - with open(f'{cwd}/img_1.png', 'rb') as f: - self.image_understand(f,"一句话概述这个图片") - - def image_understand(self, image_file, question): - async def handle(): - async with websockets.connect(self.create_url(), max_size=1000000000, ssl=ssl_context) as ws: - # 发送 full client request - await self.send(ws, image_file, question) - return await self.handle_message(ws) - - return asyncio.run(handle()) - - # 收到websocket消息的处理 @staticmethod - async def handle_message(ws): - # print(message) - answer = '' - while True: - res = await ws.recv() - data = json.loads(res) - code = data['header']['code'] - if code != 0: - return f'请求错误: {code}, {data}' - else: - choices = data["payload"]["choices"] - status = choices["status"] - content = choices["text"][0]["content"] - # print(content, end="") - answer += content - # print(1) - if status == 2: - break - return answer + def generate_message(prompt: str, image) -> list[BaseMessage]: + if image is None: + cwd = os.path.dirname(os.path.abspath(__file__)) + with open(f'{cwd}/img_1.png', 'rb') as f: + base64_image = base64.b64encode(f.read()).decode("utf-8") + return [ImageMessage(f'data:image/jpeg;base64,{base64_image}'), HumanMessage(prompt)] + return [HumanMessage(prompt)] - async def send(self, ws, image_file, question): - text = [ - {"role": "user", "content": str(base64.b64encode(image_file.read()), 'utf-8'), "content_type": "image"}, - {"role": "user", "content": question} - ] + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + default_chunk_class = AIMessageChunk - data = { - "header": { - "app_id": self.spark_app_id - }, - "parameter": { - "chat": { - "domain": "image", - "temperature": 0.5, - "top_k": 4, - "max_tokens": 2028, - "auditing": "default" - } - }, - "payload": { - "message": { - "text": text - } - } - } - - d = json.dumps(data) - await ws.send(d) - - def is_cache_model(self): - return False - - @staticmethod - def get_len(text): - length = 0 - for content in text: - temp = content["content"] - leng = len(temp) - length += leng - return length - - def check_len(self, text): - print("text-content-tokens:", self.get_len(text[1:])) - while (self.get_len(text[1:]) > 8000): - del text[1] - return text + self.client.arun( + [convert_message_to_dict(m) for m in messages], + self.spark_user_id, + self.model_kwargs, + streaming=True, + ) + for content in self.client.subscribe(timeout=self.request_timeout): + if "data" not in content: + continue + delta = content["data"] + chunk = _convert_delta_to_message_chunk(delta, default_chunk_class) + cg_chunk = ChatGenerationChunk(message=chunk) + if run_manager: + run_manager.on_llm_new_token(str(chunk.content), chunk=cg_chunk) + yield cg_chunk diff --git a/apps/setting/models_provider/impl/xf_model_provider/xf_model_provider.py b/apps/setting/models_provider/impl/xf_model_provider/xf_model_provider.py index 37bd3e435..dc8fde919 100644 --- a/apps/setting/models_provider/impl/xf_model_provider/xf_model_provider.py +++ b/apps/setting/models_provider/impl/xf_model_provider/xf_model_provider.py @@ -37,7 +37,6 @@ model_info_list = [ ModelInfo('generalv2', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM), ModelInfo('iat', '中英文识别', ModelTypeConst.STT, stt_model_credential, XFSparkSpeechToText), ModelInfo('tts', '', ModelTypeConst.TTS, tts_model_credential, XFSparkTextToSpeech), - ModelInfo('image', '', ModelTypeConst.IMAGE, image_model_credential, XFSparkImage), ModelInfo('embedding', '', ModelTypeConst.EMBEDDING, embedding_model_credential, XFEmbedding) ] diff --git a/ui/src/api/application.ts b/ui/src/api/application.ts index f41268e15..bf384ad49 100644 --- a/ui/src/api/application.ts +++ b/ui/src/api/application.ts @@ -286,6 +286,14 @@ const getApplicationTTSModel: ( return get(`${prefix}/${application_id}/model`, { model_type: 'TTS' }, loading) } +const getApplicationImageModel: ( + application_id: string, + loading?: Ref +) => Promise>> = (application_id, loading) => { + return get(`${prefix}/${application_id}/model`, { model_type: 'IMAGE' }, loading) +} + + /** * 发布应用 * @param 参数 @@ -350,6 +358,19 @@ const getModelParamsForm: ( return get(`${prefix}/${application_id}/model_params_form/${model_id}`, undefined, loading) } +/** + * 上传文档图片附件 + */ +const uploadFile: ( + application_id: String, + chat_id: String, + data: any, + loading?: Ref +) => Promise> = (application_id, chat_id, data, loading) => { + return post(`${prefix}/${application_id}/chat/${chat_id}/upload_file`, data, undefined, loading) +} + + /** * 语音转文本 */ @@ -501,6 +522,7 @@ export default { getApplicationRerankerModel, getApplicationSTTModel, getApplicationTTSModel, + getApplicationImageModel, postSpeechToText, postTextToSpeech, getPlatformStatus, @@ -513,5 +535,6 @@ export default { putWorkFlowVersion, playDemoText, getUserList, - getApplicationList + getApplicationList, + uploadFile } diff --git a/ui/src/components/ai-chat/ExecutionDetailDialog.vue b/ui/src/components/ai-chat/ExecutionDetailDialog.vue index 72dfbcb6c..f63c7392d 100644 --- a/ui/src/components/ai-chat/ExecutionDetailDialog.vue +++ b/ui/src/components/ai-chat/ExecutionDetailDialog.vue @@ -54,6 +54,18 @@
{{ f.label }}: {{ f.value }}
+
+ 上传的文档: +
+ {{ f.name }} +
+
+
+ 上传的图片: +
+ {{ f.name }} +
+
@@ -96,7 +108,8 @@ v-if=" item.type == WorkflowType.AiChat || item.type == WorkflowType.Question || - item.type == WorkflowType.Application + item.type == WorkflowType.Application || + item.type == WorkflowType.ImageUnderstandNode " >
+ + + + + + + + @@ -790,7 +803,7 @@ function chatMessage(chat?: any, problem?: string, re_chat?: boolean) { is_stop: false, record_id: '', vote_status: '-1', - status: undefined + status: undefined, }) chatList.value.push(chat) ChatManagement.addChatRecord(chat, 50, loading) @@ -809,7 +822,8 @@ function chatMessage(chat?: any, problem?: string, re_chat?: boolean) { const obj = { message: chat.problem_text, re_chat: re_chat || false, - form_data: { ...form_data.value, ...api_form_data.value } + form_data: { ...form_data.value, ...api_form_data.value }, + image_list: uploadFileList.value, } // 对话 applicationApi @@ -832,6 +846,7 @@ function chatMessage(chat?: any, problem?: string, re_chat?: boolean) { nextTick(() => { // 将滚动条滚动到最下面 scrollDiv.value.setScrollTop(getMaxHeight()) + uploadFileList.value = [] }) const reader = response.body.getReader() // 处理流数据 @@ -916,6 +931,42 @@ const handleScroll = () => { } } +// 保存上传文件列表 +const uploadFileList = ref([]) +const uploadFile = async (file: any, fileList: any) => { + const { maxFiles, fileLimit } = props.data.file_upload_setting + if (fileList.length > maxFiles) { + MsgWarning('最多上传' + maxFiles + '个文件') + return + } + if (fileList.filter((f: any) => f.size > fileLimit * 1024 * 1024).length > 0) { // MB + MsgWarning('单个文件大小不能超过' + fileLimit + 'MB') + fileList.splice(0, fileList.length) + return + } + const formData = new FormData() + for (const file of fileList) { + formData.append('file', file.raw, file.name) + uploadFileList.value.push(file) + } + + if (props.chatId === 'new' || !chartOpenId.value) { + const res = await applicationApi.getChatOpen(props.data.id as string) + chartOpenId.value = res.data + } + applicationApi.uploadFile(props.data.id as string, chartOpenId.value, formData, loading).then((response) => { + fileList.splice(0, fileList.length) + uploadFileList.value.forEach((file: any) => { + const f = response.data.filter((f: any) => f.name === file.name) + if (f.length > 0) { + file.url = f[0].url + file.file_id = f[0].file_id + } + }) + console.log(uploadFileList.value) + }) +} + // 定义响应式引用 const mediaRecorder = ref(null) diff --git a/ui/src/enums/workflow.ts b/ui/src/enums/workflow.ts index aaf636cc2..5c4584d01 100644 --- a/ui/src/enums/workflow.ts +++ b/ui/src/enums/workflow.ts @@ -9,5 +9,7 @@ export enum WorkflowType { FunctionLib = 'function-lib-node', FunctionLibCustom = 'function-node', RrerankerNode = 'reranker-node', - Application = 'application-node' + Application = 'application-node', + DocumentExtractNode = 'document-extract-node', + ImageUnderstandNode = 'image-understand-node', } diff --git a/ui/src/workflow/common/data.ts b/ui/src/workflow/common/data.ts index cfb9fe2a4..97495cbd1 100644 --- a/ui/src/workflow/common/data.ts +++ b/ui/src/workflow/common/data.ts @@ -168,13 +168,49 @@ export const rerankerNode = { } } } +export const documentExtractNode = { + type: WorkflowType.DocumentExtractNode, + text: '提取文档中的内容', + label: '文档内容提取', + height: 252, + properties: { + stepName: '文档内容提取', + config: { + fields: [ + { + label: '文件内容', + value: 'content' + }, + ] + } + } +} +export const imageUnderstandNode = { + type: WorkflowType.ImageUnderstandNode, + text: '识别出图片中的对象、场景等信息回答用户问题', + label: '图片理解', + height: 252, + properties: { + stepName: '图片理解', + config: { + fields: [ + { + label: 'AI 回答内容', + value: 'content' + }, + ] + } + } +} export const menuNodes = [ aiChatNode, searchDatasetNode, questionNode, conditionNode, replyNode, - rerankerNode + rerankerNode, + documentExtractNode, + imageUnderstandNode ] /** @@ -261,7 +297,9 @@ export const nodeDict: any = { [WorkflowType.FunctionLib]: functionLibNode, [WorkflowType.FunctionLibCustom]: functionNode, [WorkflowType.RrerankerNode]: rerankerNode, - [WorkflowType.Application]: applicationNode + [WorkflowType.Application]: applicationNode, + [WorkflowType.DocumentExtractNode]: documentExtractNode, + [WorkflowType.ImageUnderstandNode]: imageUnderstandNode, } export function isWorkFlow(type: string | undefined) { return type === 'WORK_FLOW' diff --git a/ui/src/workflow/icons/document-extract-node-icon.vue b/ui/src/workflow/icons/document-extract-node-icon.vue new file mode 100644 index 000000000..86152c376 --- /dev/null +++ b/ui/src/workflow/icons/document-extract-node-icon.vue @@ -0,0 +1,6 @@ + + diff --git a/ui/src/workflow/icons/image-understand-node-icon.vue b/ui/src/workflow/icons/image-understand-node-icon.vue new file mode 100644 index 000000000..a0b8a8f41 --- /dev/null +++ b/ui/src/workflow/icons/image-understand-node-icon.vue @@ -0,0 +1,6 @@ + + diff --git a/ui/src/workflow/nodes/base-node/component/FileUploadSettingDialog.vue b/ui/src/workflow/nodes/base-node/component/FileUploadSettingDialog.vue new file mode 100644 index 000000000..685f24635 --- /dev/null +++ b/ui/src/workflow/nodes/base-node/component/FileUploadSettingDialog.vue @@ -0,0 +1,101 @@ + + + + + \ No newline at end of file diff --git a/ui/src/workflow/nodes/base-node/index.vue b/ui/src/workflow/nodes/base-node/index.vue index eb5df113a..93ded67c4 100644 --- a/ui/src/workflow/nodes/base-node/index.vue +++ b/ui/src/workflow/nodes/base-node/index.vue @@ -45,6 +45,36 @@ @submitDialog="submitDialog" /> + + + @@ -139,7 +169,6 @@ - 设置
@@ -212,6 +241,7 @@ + + + \ No newline at end of file diff --git a/ui/src/workflow/nodes/image-understand/index.ts b/ui/src/workflow/nodes/image-understand/index.ts new file mode 100644 index 000000000..06695d3c9 --- /dev/null +++ b/ui/src/workflow/nodes/image-understand/index.ts @@ -0,0 +1,14 @@ +import ImageUnderstandNodeVue from './index.vue' +import { AppNode, AppNodeModel } from '@/workflow/common/app-node' + +class RerankerNode extends AppNode { + constructor(props: any) { + super(props, ImageUnderstandNodeVue) + } +} + +export default { + type: 'image-understand-node', + model: AppNodeModel, + view: RerankerNode +} diff --git a/ui/src/workflow/nodes/image-understand/index.vue b/ui/src/workflow/nodes/image-understand/index.vue new file mode 100644 index 000000000..4d9fff21b --- /dev/null +++ b/ui/src/workflow/nodes/image-understand/index.vue @@ -0,0 +1,277 @@ + + + + + + \ No newline at end of file diff --git a/ui/src/workflow/nodes/start-node/index.vue b/ui/src/workflow/nodes/start-node/index.vue index 883495bba..b5a6a93f1 100644 --- a/ui/src/workflow/nodes/start-node/index.vue +++ b/ui/src/workflow/nodes/start-node/index.vue @@ -61,8 +61,31 @@ const refreshFieldList = () => { } props.nodeModel.graphModel.eventCenter.on('refreshFieldList', refreshFieldList) +const refreshFileUploadConfig = () => { + let fields = cloneDeep(props.nodeModel.properties.config.fields) + const form_data = props.nodeModel.graphModel.nodes + .filter((v: any) => v.id === 'base-node') + .map((v: any) => cloneDeep(v.properties.node_data.file_upload_setting)) + .filter((v: any) => v) + if (form_data.length === 0) { + return + } + fields = fields.filter((item: any) => item.value !== 'image' && item.value !== 'document') + let fileUploadFields = [] + if (form_data[0].document) { + fileUploadFields.push({ label: '文档', value: 'document' }) + } + if (form_data[0].image) { + fileUploadFields.push({ label: '图片', value: 'image' }) + } + + set(props.nodeModel.properties.config, 'fields', [...fields, ...fileUploadFields]) +} +props.nodeModel.graphModel.eventCenter.on('refreshFileUploadConfig', refreshFileUploadConfig) + onMounted(() => { refreshFieldList() + refreshFileUploadConfig() })