mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
feat: 高级编排支持文件上传(WIP)
This commit is contained in:
parent
72b91bee9a
commit
8a8305e75b
|
|
@ -16,9 +16,12 @@ from .direct_reply_node import *
|
|||
from .function_lib_node import *
|
||||
from .function_node import *
|
||||
from .reranker_node import *
|
||||
from .document_extract_node import *
|
||||
from .image_understand_step_node import *
|
||||
|
||||
node_list = [BaseStartStepNode, BaseChatNode, BaseSearchDatasetNode, BaseQuestionNode, BaseConditionNode, BaseReplyNode,
|
||||
BaseFunctionNodeNode, BaseFunctionLibNodeNode, BaseRerankerNode, BaseApplicationNode]
|
||||
BaseFunctionNodeNode, BaseFunctionLibNodeNode, BaseRerankerNode, BaseApplicationNode, BaseDocumentExtractNode,
|
||||
BaseImageUnderstandNode]
|
||||
|
||||
|
||||
def get_node(node_type):
|
||||
|
|
|
|||
|
|
@ -0,0 +1 @@
|
|||
from .impl import *
|
||||
|
|
@ -0,0 +1,30 @@
|
|||
# coding=utf-8
|
||||
|
||||
from typing import Type
|
||||
|
||||
from rest_framework import serializers
|
||||
|
||||
from application.flow.i_step_node import INode, NodeResult
|
||||
from common.util.field_message import ErrMessage
|
||||
|
||||
|
||||
class DocumentExtractNodeSerializer(serializers.Serializer):
|
||||
# 需要查询的数据集id列表
|
||||
file_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True),
|
||||
error_messages=ErrMessage.list("数据集id列表"))
|
||||
|
||||
def is_valid(self, *, raise_exception=False):
|
||||
super().is_valid(raise_exception=True)
|
||||
|
||||
|
||||
class IDocumentExtractNode(INode):
|
||||
type = 'document-extract-node'
|
||||
|
||||
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
|
||||
return DocumentExtractNodeSerializer
|
||||
|
||||
def _run(self):
|
||||
return self.execute(**self.flow_params_serializer.data)
|
||||
|
||||
def execute(self, file_list, **kwargs) -> NodeResult:
|
||||
pass
|
||||
|
|
@ -0,0 +1 @@
|
|||
from .base_document_extract_node import BaseDocumentExtractNode
|
||||
|
|
@ -0,0 +1,11 @@
|
|||
# coding=utf-8
|
||||
|
||||
from application.flow.step_node.document_extract_node.i_document_extract_node import IDocumentExtractNode
|
||||
|
||||
|
||||
class BaseDocumentExtractNode(IDocumentExtractNode):
|
||||
def execute(self, file_list, **kwargs):
|
||||
pass
|
||||
|
||||
def get_details(self, index: int, **kwargs):
|
||||
pass
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
# coding=utf-8
|
||||
|
||||
from .impl import *
|
||||
|
|
@ -0,0 +1,39 @@
|
|||
# coding=utf-8
|
||||
|
||||
from typing import Type
|
||||
|
||||
from rest_framework import serializers
|
||||
|
||||
from application.flow.i_step_node import INode, NodeResult
|
||||
from common.util.field_message import ErrMessage
|
||||
|
||||
|
||||
class ImageUnderstandNodeSerializer(serializers.Serializer):
|
||||
model_id = serializers.CharField(required=True, error_messages=ErrMessage.char("模型id"))
|
||||
system = serializers.CharField(required=False, allow_blank=True, allow_null=True,
|
||||
error_messages=ErrMessage.char("角色设定"))
|
||||
prompt = serializers.CharField(required=True, error_messages=ErrMessage.char("提示词"))
|
||||
# 多轮对话数量
|
||||
dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer("多轮对话数量"))
|
||||
|
||||
is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容'))
|
||||
|
||||
image_list = serializers.ListField(required=False, error_messages=ErrMessage.list("图片仅1张"))
|
||||
|
||||
|
||||
class IImageUnderstandNode(INode):
|
||||
type = 'image-understand-node'
|
||||
|
||||
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
|
||||
return ImageUnderstandNodeSerializer
|
||||
|
||||
def _run(self):
|
||||
res = self.workflow_manage.get_reference_field(self.node_params_serializer.data.get('image_list')[0],
|
||||
self.node_params_serializer.data.get('image_list')[1:])
|
||||
return self.execute(image=res, **self.node_params_serializer.data, **self.flow_params_serializer.data)
|
||||
|
||||
def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id,
|
||||
chat_record_id,
|
||||
image,
|
||||
**kwargs) -> NodeResult:
|
||||
pass
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
# coding=utf-8
|
||||
|
||||
from .base_image_understand_node import BaseImageUnderstandNode
|
||||
|
|
@ -0,0 +1,147 @@
|
|||
# coding=utf-8
|
||||
import base64
|
||||
import os
|
||||
import time
|
||||
from functools import reduce
|
||||
from typing import List, Dict
|
||||
|
||||
from django.db.models import QuerySet
|
||||
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
|
||||
|
||||
from application.flow.i_step_node import NodeResult, INode
|
||||
from application.flow.step_node.image_understand_step_node.i_image_understand_node import IImageUnderstandNode
|
||||
from dataset.models import File
|
||||
from setting.models_provider.tools import get_model_instance_by_model_user_id
|
||||
|
||||
|
||||
def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str):
|
||||
chat_model = node_variable.get('chat_model')
|
||||
message_tokens = node_variable['usage_metadata']['output_tokens'] if 'usage_metadata' in node_variable else 0
|
||||
answer_tokens = chat_model.get_num_tokens(answer)
|
||||
node.context['message_tokens'] = message_tokens
|
||||
node.context['answer_tokens'] = answer_tokens
|
||||
node.context['answer'] = answer
|
||||
node.context['history_message'] = node_variable['history_message']
|
||||
node.context['question'] = node_variable['question']
|
||||
node.context['run_time'] = time.time() - node.context['start_time']
|
||||
if workflow.is_result(node, NodeResult(node_variable, workflow_variable)):
|
||||
workflow.answer += answer
|
||||
|
||||
|
||||
def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
|
||||
"""
|
||||
写入上下文数据 (流式)
|
||||
@param node_variable: 节点数据
|
||||
@param workflow_variable: 全局数据
|
||||
@param node: 节点
|
||||
@param workflow: 工作流管理器
|
||||
"""
|
||||
response = node_variable.get('result')
|
||||
answer = ''
|
||||
for chunk in response:
|
||||
answer += chunk.content
|
||||
yield chunk.content
|
||||
_write_context(node_variable, workflow_variable, node, workflow, answer)
|
||||
|
||||
|
||||
def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
|
||||
"""
|
||||
写入上下文数据
|
||||
@param node_variable: 节点数据
|
||||
@param workflow_variable: 全局数据
|
||||
@param node: 节点实例对象
|
||||
@param workflow: 工作流管理器
|
||||
"""
|
||||
response = node_variable.get('result')
|
||||
answer = response.content
|
||||
_write_context(node_variable, workflow_variable, node, workflow, answer)
|
||||
|
||||
|
||||
class BaseImageUnderstandNode(IImageUnderstandNode):
|
||||
def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id,
|
||||
image,
|
||||
**kwargs) -> NodeResult:
|
||||
image_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'))
|
||||
history_message = self.get_history_message(history_chat_record, dialogue_number)
|
||||
self.context['history_message'] = history_message
|
||||
question = self.generate_prompt_question(prompt)
|
||||
self.context['question'] = question.content
|
||||
# todo 处理上传图片
|
||||
message_list = self.generate_message_list(image_model, system, prompt, history_message, image)
|
||||
self.context['message_list'] = message_list
|
||||
self.context['image_list'] = image
|
||||
if stream:
|
||||
r = image_model.stream(message_list)
|
||||
return NodeResult({'result': r, 'chat_model': image_model, 'message_list': message_list,
|
||||
'history_message': history_message, 'question': question.content}, {},
|
||||
_write_context=write_context_stream)
|
||||
else:
|
||||
r = image_model.invoke(message_list)
|
||||
return NodeResult({'result': r, 'chat_model': image_model, 'message_list': message_list,
|
||||
'history_message': history_message, 'question': question.content}, {},
|
||||
_write_context=write_context)
|
||||
|
||||
@staticmethod
|
||||
def get_history_message(history_chat_record, dialogue_number):
|
||||
start_index = len(history_chat_record) - dialogue_number
|
||||
history_message = reduce(lambda x, y: [*x, *y], [
|
||||
[history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()]
|
||||
for index in
|
||||
range(start_index if start_index > 0 else 0, len(history_chat_record))], [])
|
||||
return history_message
|
||||
|
||||
def generate_prompt_question(self, prompt):
|
||||
return HumanMessage(self.workflow_manage.generate_prompt(prompt))
|
||||
|
||||
def generate_message_list(self, image_model, system: str, prompt: str, history_message, image):
|
||||
if image is not None and len(image) > 0:
|
||||
file_id = image[0]['file_id']
|
||||
file = QuerySet(File).filter(id=file_id).first()
|
||||
base64_image = base64.b64encode(file.get_byte()).decode("utf-8")
|
||||
messages = [HumanMessage(
|
||||
content=[
|
||||
{'type': 'text', 'text': self.workflow_manage.generate_prompt(prompt)},
|
||||
{'type': 'image_url', 'image_url': {'url': f'data:image/jpeg;base64,{base64_image}'}},
|
||||
])]
|
||||
else:
|
||||
messages = [HumanMessage(self.workflow_manage.generate_prompt(prompt))]
|
||||
|
||||
if system is not None and len(system) > 0:
|
||||
return [
|
||||
SystemMessage(self.workflow_manage.generate_prompt(system)),
|
||||
*history_message,
|
||||
*messages
|
||||
]
|
||||
else:
|
||||
return [
|
||||
*history_message,
|
||||
*messages
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def reset_message_list(message_list: List[BaseMessage], answer_text):
|
||||
result = [{'role': 'user' if isinstance(message, HumanMessage) else 'ai', 'content': message.content} for
|
||||
message
|
||||
in
|
||||
message_list]
|
||||
result.append({'role': 'ai', 'content': answer_text})
|
||||
return result
|
||||
|
||||
def get_details(self, index: int, **kwargs):
|
||||
return {
|
||||
'name': self.node.properties.get('stepName'),
|
||||
"index": index,
|
||||
'run_time': self.context.get('run_time'),
|
||||
'system': self.node_params.get('system'),
|
||||
'history_message': [{'content': message.content, 'role': message.type} for message in
|
||||
(self.context.get('history_message') if self.context.get(
|
||||
'history_message') is not None else [])],
|
||||
'question': self.context.get('question'),
|
||||
'answer': self.context.get('answer'),
|
||||
'type': self.node.type,
|
||||
'message_tokens': self.context.get('message_tokens'),
|
||||
'answer_tokens': self.context.get('answer_tokens'),
|
||||
'status': self.status,
|
||||
'err_message': self.err_message,
|
||||
'image_list': self.context.get('image_list')
|
||||
}
|
||||
|
|
@ -41,7 +41,7 @@ class BaseStartStepNode(IStarNode):
|
|||
"""
|
||||
开始节点 初始化全局变量
|
||||
"""
|
||||
return NodeResult({'question': question},
|
||||
return NodeResult({'question': question, 'image': self.workflow_manage.image_list},
|
||||
workflow_variable)
|
||||
|
||||
def get_details(self, index: int, **kwargs):
|
||||
|
|
@ -61,5 +61,6 @@ class BaseStartStepNode(IStarNode):
|
|||
'type': self.node.type,
|
||||
'status': self.status,
|
||||
'err_message': self.err_message,
|
||||
'image_list': self.context.get('image'),
|
||||
'global_fields': global_fields
|
||||
}
|
||||
|
|
|
|||
|
|
@ -240,10 +240,13 @@ class NodeChunk:
|
|||
|
||||
class WorkflowManage:
|
||||
def __init__(self, flow: Flow, params, work_flow_post_handler: WorkFlowPostHandler,
|
||||
base_to_response: BaseToResponse = SystemToResponse(), form_data=None):
|
||||
base_to_response: BaseToResponse = SystemToResponse(), form_data=None, image_list=None):
|
||||
if form_data is None:
|
||||
form_data = {}
|
||||
if image_list is None:
|
||||
image_list = []
|
||||
self.form_data = form_data
|
||||
self.image_list = image_list
|
||||
self.params = params
|
||||
self.flow = flow
|
||||
self.lock = threading.Lock()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,23 @@
|
|||
# Generated by Django 4.2.15 on 2024-11-07 11:22
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('application', '0018_workflowversion_name'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name='application',
|
||||
name='file_upload_enable',
|
||||
field=models.BooleanField(default=False, verbose_name='文件上传是否启用'),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name='application',
|
||||
name='file_upload_setting',
|
||||
field=models.JSONField(default={}, verbose_name='文件上传相关设置'),
|
||||
),
|
||||
]
|
||||
|
|
@ -66,6 +66,9 @@ class Application(AppModelMixin):
|
|||
stt_model_enable = models.BooleanField(verbose_name="语音识别模型是否启用", default=False)
|
||||
tts_type = models.CharField(verbose_name="语音播放类型", max_length=20, default="BROWSER")
|
||||
clean_time = models.IntegerField(verbose_name="清理时间", default=180)
|
||||
file_upload_enable = models.BooleanField(verbose_name="文件上传是否启用", default=False)
|
||||
file_upload_setting = models.JSONField(verbose_name="文件上传相关设置", default={})
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_default_model_prompt():
|
||||
|
|
|
|||
|
|
@ -823,6 +823,8 @@ class ApplicationSerializer(serializers.Serializer):
|
|||
'stt_model_enable': application.stt_model_enable,
|
||||
'tts_model_enable': application.tts_model_enable,
|
||||
'tts_type': application.tts_type,
|
||||
'file_upload_enable': application.file_upload_enable,
|
||||
'file_upload_setting': application.file_upload_setting,
|
||||
'work_flow': application.work_flow,
|
||||
'show_source': application_access_token.show_source,
|
||||
**application_setting_dict})
|
||||
|
|
@ -876,6 +878,7 @@ class ApplicationSerializer(serializers.Serializer):
|
|||
update_keys = ['name', 'desc', 'model_id', 'multiple_rounds_dialogue', 'prologue', 'status',
|
||||
'dataset_setting', 'model_setting', 'problem_optimization', 'dialogue_number',
|
||||
'stt_model_id', 'tts_model_id', 'tts_model_enable', 'stt_model_enable', 'tts_type',
|
||||
'file_upload_enable', 'file_upload_setting',
|
||||
'api_key_is_active', 'icon', 'work_flow', 'model_params_setting', 'tts_model_params_setting',
|
||||
'problem_optimization_prompt', 'clean_time']
|
||||
for update_key in update_keys:
|
||||
|
|
@ -941,6 +944,10 @@ class ApplicationSerializer(serializers.Serializer):
|
|||
instance['tts_type'] = node_data['tts_type']
|
||||
if 'tts_model_params_setting' in node_data:
|
||||
instance['tts_model_params_setting'] = node_data['tts_model_params_setting']
|
||||
if 'file_upload_enable' in node_data:
|
||||
instance['file_upload_enable'] = node_data['file_upload_enable']
|
||||
if 'file_upload_setting' in node_data:
|
||||
instance['file_upload_setting'] = node_data['file_upload_setting']
|
||||
break
|
||||
|
||||
def speech_to_text(self, file, with_valid=True):
|
||||
|
|
|
|||
|
|
@ -222,6 +222,7 @@ class ChatMessageSerializer(serializers.Serializer):
|
|||
client_id = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端id"))
|
||||
client_type = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端类型"))
|
||||
form_data = serializers.DictField(required=False, error_messages=ErrMessage.char("全局变量"))
|
||||
image_list = serializers.ListField(required=False, error_messages=ErrMessage.list("图片仅1张"))
|
||||
|
||||
def is_valid_application_workflow(self, *, raise_exception=False):
|
||||
self.is_valid_intraday_access_num()
|
||||
|
|
@ -299,6 +300,7 @@ class ChatMessageSerializer(serializers.Serializer):
|
|||
client_id = self.data.get('client_id')
|
||||
client_type = self.data.get('client_type')
|
||||
form_data = self.data.get('form_data')
|
||||
image_list = self.data.get('image_list')
|
||||
user_id = chat_info.application.user_id
|
||||
work_flow_manage = WorkflowManage(Flow.new_instance(chat_info.work_flow_version.work_flow),
|
||||
{'history_chat_record': chat_info.chat_record_list, 'question': message,
|
||||
|
|
@ -308,7 +310,7 @@ class ChatMessageSerializer(serializers.Serializer):
|
|||
'client_id': client_id,
|
||||
'client_type': client_type,
|
||||
'user_id': user_id}, WorkFlowPostHandler(chat_info, client_id, client_type),
|
||||
base_to_response, form_data)
|
||||
base_to_response, form_data, image_list)
|
||||
r = work_flow_manage.run()
|
||||
return r
|
||||
|
||||
|
|
|
|||
|
|
@ -49,6 +49,7 @@ urlpatterns = [
|
|||
path('application/<str:application_id>/chat/<int:current_page>/<int:page_size>', views.ChatView.Page.as_view()),
|
||||
path('application/<str:application_id>/chat/<chat_id>', views.ChatView.Operate.as_view()),
|
||||
path('application/<str:application_id>/chat/<chat_id>/chat_record/', views.ChatView.ChatRecord.as_view()),
|
||||
path('application/<str:application_id>/chat/<chat_id>/upload_file', views.ChatView.UploadFile.as_view()),
|
||||
path('application/<str:application_id>/chat/<chat_id>/chat_record/<int:current_page>/<int:page_size>',
|
||||
views.ChatView.ChatRecord.Page.as_view()),
|
||||
path('application/<str:application_id>/chat/<chat_id>/chat_record/<chat_record_id>',
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ from common.constants.permission_constants import Permission, Group, Operate, \
|
|||
RoleConstants, ViewPermission, CompareConstants
|
||||
from common.response import result
|
||||
from common.util.common import query_params_to_single_dict
|
||||
from dataset.serializers.file_serializers import FileSerializer
|
||||
|
||||
|
||||
class Openai(APIView):
|
||||
|
|
@ -128,6 +129,7 @@ class ChatView(APIView):
|
|||
'client_id': request.auth.client_id,
|
||||
'form_data': (request.data.get(
|
||||
'form_data') if 'form_data' in request.data else {}),
|
||||
'image_list': request.data.get('image_list') if 'image_list' in request.data else [],
|
||||
'client_type': request.auth.client_type}).chat()
|
||||
|
||||
@action(methods=['GET'], detail=False)
|
||||
|
|
@ -391,3 +393,28 @@ class ChatView(APIView):
|
|||
data={'chat_id': chat_id, 'chat_record_id': chat_record_id,
|
||||
'dataset_id': dataset_id, 'document_id': document_id,
|
||||
'paragraph_id': paragraph_id}).delete())
|
||||
|
||||
class UploadFile(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
@action(methods=['POST'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="上传文件",
|
||||
operation_id="上传文件",
|
||||
manual_parameters=ChatRecordApi.get_request_params_api(),
|
||||
tags=["应用/对话日志"]
|
||||
)
|
||||
@has_permissions(
|
||||
ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_KEY,
|
||||
RoleConstants.APPLICATION_ACCESS_TOKEN],
|
||||
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE,
|
||||
dynamic_tag=keywords.get('application_id'))])
|
||||
)
|
||||
def post(self, request: Request, application_id: str, chat_id: str):
|
||||
files = request.FILES.getlist('file')
|
||||
file_ids = []
|
||||
meta = {'application_id': application_id, 'chat_id': chat_id}
|
||||
for file in files:
|
||||
file_url = FileSerializer(data={'file': file, 'meta': meta}).upload()
|
||||
file_ids.append({'name': file.name, 'url': file_url, 'file_id': file_url.split('/')[-1]})
|
||||
return result.success(file_ids)
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
@date:2024/3/14 11:54
|
||||
@desc:
|
||||
"""
|
||||
from .clean_orphaned_file_job import *
|
||||
from .client_access_num_job import *
|
||||
from .clean_chat_job import *
|
||||
|
||||
|
|
@ -13,3 +14,4 @@ from .clean_chat_job import *
|
|||
def run():
|
||||
client_access_num_job.run()
|
||||
clean_chat_job.run()
|
||||
clean_orphaned_file_job.run()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,40 @@
|
|||
# coding=utf-8
|
||||
|
||||
import logging
|
||||
|
||||
from apscheduler.schedulers.background import BackgroundScheduler
|
||||
from django.db.models import Q
|
||||
from django_apscheduler.jobstores import DjangoJobStore
|
||||
|
||||
from application.models import Chat
|
||||
from common.lock.impl.file_lock import FileLock
|
||||
from dataset.models import File
|
||||
|
||||
scheduler = BackgroundScheduler()
|
||||
scheduler.add_jobstore(DjangoJobStore(), "default")
|
||||
lock = FileLock()
|
||||
|
||||
|
||||
def clean_debug_file():
|
||||
logging.getLogger("max_kb").info('开始清理没有关联会话的上传文件')
|
||||
existing_chat_ids = set(Chat.objects.values_list('id', flat=True))
|
||||
# UUID to str
|
||||
existing_chat_ids = [str(chat_id) for chat_id in existing_chat_ids]
|
||||
print(existing_chat_ids)
|
||||
# 查找引用的不存在的 chat_id 并删除相关记录
|
||||
deleted_count, _ = File.objects.filter(~Q(meta__chat_id__in=existing_chat_ids)).delete()
|
||||
|
||||
logging.getLogger("max_kb").info(f'结束清理没有关联会话的上传文件: {deleted_count}')
|
||||
|
||||
|
||||
def run():
|
||||
if lock.try_lock('clean_orphaned_file_job', 30 * 30):
|
||||
try:
|
||||
scheduler.start()
|
||||
clean_orphaned_file = scheduler.get_job(job_id='clean_orphaned_file')
|
||||
if clean_orphaned_file is not None:
|
||||
clean_orphaned_file.remove()
|
||||
scheduler.add_job(clean_debug_file, 'cron', hour='2', minute='0', second='0',
|
||||
id='clean_orphaned_file')
|
||||
finally:
|
||||
lock.un_lock('clean_orphaned_file_job')
|
||||
|
|
@ -0,0 +1,18 @@
|
|||
# Generated by Django 4.2.15 on 2024-11-07 15:32
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('dataset', '0009_alter_document_status_alter_paragraph_status'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name='file',
|
||||
name='meta',
|
||||
field=models.JSONField(default={}, verbose_name='文件关联数据'),
|
||||
),
|
||||
]
|
||||
|
|
@ -141,6 +141,9 @@ class File(AppModelMixin):
|
|||
|
||||
loid = models.IntegerField(verbose_name="loid")
|
||||
|
||||
meta = models.JSONField(verbose_name="文件关联数据", default=dict)
|
||||
|
||||
|
||||
class Meta:
|
||||
db_table = "file"
|
||||
|
||||
|
|
@ -149,7 +152,6 @@ class File(AppModelMixin):
|
|||
):
|
||||
result = select_one("SELECT lo_from_bytea(%s, %s::bytea) as loid", [0, bytea])
|
||||
self.loid = result['loid']
|
||||
self.file_name = 'speech.mp3'
|
||||
super().save()
|
||||
|
||||
def get_byte(self):
|
||||
|
|
|
|||
|
|
@ -56,12 +56,13 @@ mime_types = {"html": "text/html", "htm": "text/html", "shtml": "text/html", "cs
|
|||
|
||||
class FileSerializer(serializers.Serializer):
|
||||
file = UploadedFileField(required=True, error_messages=ErrMessage.image("文件"))
|
||||
meta = serializers.JSONField(required=False)
|
||||
|
||||
def upload(self, with_valid=True):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
file_id = uuid.uuid1()
|
||||
file = File(id=file_id, file_name=self.data.get('file').name)
|
||||
file = File(id=file_id, file_name=self.data.get('file').name, meta=self.data.get('meta'))
|
||||
file.save(self.data.get('file').read())
|
||||
return f'/api/file/{file_id}'
|
||||
|
||||
|
|
|
|||
|
|
@ -1,14 +0,0 @@
|
|||
# coding=utf-8
|
||||
from abc import abstractmethod
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class BaseImage(BaseModel):
|
||||
@abstractmethod
|
||||
def check_auth(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def image_understand(self, image_file, text):
|
||||
pass
|
||||
|
|
@ -1,6 +1,10 @@
|
|||
# coding=utf-8
|
||||
import base64
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
|
|
@ -25,7 +29,7 @@ class OpenAIImageModelCredential(BaseForm, BaseModelCredential):
|
|||
return False
|
||||
try:
|
||||
model = provider.get_model(model_type, model_name, model_credential)
|
||||
model.check_auth()
|
||||
model.stream([HumanMessage(content=[{"type": "text", "text": "你好"}])])
|
||||
except Exception as e:
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
|
|
|
|||
|
|
@ -1,12 +1,9 @@
|
|||
import base64
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
from openai import OpenAI
|
||||
from langchain_openai.chat_models import ChatOpenAI
|
||||
|
||||
from common.config.tokenizer_manage_config import TokenizerManage
|
||||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||
from setting.models_provider.impl.base_image import BaseImage
|
||||
|
||||
|
||||
def custom_get_token_ids(text: str):
|
||||
|
|
@ -14,66 +11,15 @@ def custom_get_token_ids(text: str):
|
|||
return tokenizer.encode(text)
|
||||
|
||||
|
||||
class OpenAIImage(MaxKBBaseModel, BaseImage):
|
||||
api_base: str
|
||||
api_key: str
|
||||
model: str
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.api_key = kwargs.get('api_key')
|
||||
self.api_base = kwargs.get('api_base')
|
||||
class OpenAIImage(MaxKBBaseModel, ChatOpenAI):
|
||||
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
optional_params = {}
|
||||
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
|
||||
optional_params['max_tokens'] = model_kwargs['max_tokens']
|
||||
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
|
||||
optional_params['temperature'] = model_kwargs['temperature']
|
||||
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
|
||||
return OpenAIImage(
|
||||
model=model_name,
|
||||
api_base=model_credential.get('api_base'),
|
||||
api_key=model_credential.get('api_key'),
|
||||
openai_api_base=model_credential.get('api_base'),
|
||||
openai_api_key=model_credential.get('api_key'),
|
||||
stream_options={"include_usage": True},
|
||||
**optional_params,
|
||||
)
|
||||
|
||||
def check_auth(self):
|
||||
client = OpenAI(
|
||||
base_url=self.api_base,
|
||||
api_key=self.api_key
|
||||
)
|
||||
response_list = client.models.with_raw_response.list()
|
||||
# print(response_list)
|
||||
# cwd = os.path.dirname(os.path.abspath(__file__))
|
||||
# with open(f'{cwd}/img_1.png', 'rb') as f:
|
||||
# self.image_understand(f, "一句话概述这个图片")
|
||||
|
||||
def image_understand(self, image_file, text):
|
||||
client = OpenAI(
|
||||
base_url=self.api_base,
|
||||
api_key=self.api_key
|
||||
)
|
||||
base64_image = base64.b64encode(image_file.read()).decode('utf-8')
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": text,
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{base64_image}"
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
)
|
||||
return response.choices[0].message.content
|
||||
|
|
|
|||
|
|
@ -0,0 +1,69 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: llm.py
|
||||
@date:2024/7/11 18:41
|
||||
@desc:
|
||||
"""
|
||||
import base64
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm, TooltipLabel
|
||||
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
|
||||
class QwenModelParams(BaseForm):
|
||||
temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
|
||||
required=True, default_value=1.0,
|
||||
_min=0.1,
|
||||
_max=1.9,
|
||||
_step=0.01,
|
||||
precision=2)
|
||||
|
||||
max_tokens = forms.SliderField(
|
||||
TooltipLabel('输出最大Tokens', '指定模型可生成的最大token个数'),
|
||||
required=True, default_value=800,
|
||||
_min=1,
|
||||
_max=100000,
|
||||
_step=1,
|
||||
precision=0)
|
||||
|
||||
|
||||
class QwenVLModelCredential(BaseForm, BaseModelCredential):
|
||||
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||
raise_exception=False):
|
||||
model_type_list = provider.get_model_type_list()
|
||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||
for key in ['api_key']:
|
||||
if key not in model_credential:
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
||||
else:
|
||||
return False
|
||||
try:
|
||||
model = provider.get_model(model_type, model_name, model_credential)
|
||||
model.stream([HumanMessage(content=[{"type": "text", "text": "你好"}])])
|
||||
except Exception as e:
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model: Dict[str, object]):
|
||||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||
|
||||
api_key = forms.PasswordInputField('API Key', required=True)
|
||||
|
||||
def get_model_params_setting_form(self, model_name):
|
||||
return QwenModelParams()
|
||||
|
|
@ -0,0 +1,22 @@
|
|||
# coding=utf-8
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from langchain_community.chat_models import ChatOpenAI
|
||||
|
||||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||
|
||||
|
||||
class QwenVLChatModel(MaxKBBaseModel, ChatOpenAI):
|
||||
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
|
||||
chat_tong_yi = QwenVLChatModel(
|
||||
model=model_name,
|
||||
openai_api_key=model_credential.get('api_key'),
|
||||
openai_api_base='https://dashscope.aliyuncs.com/compatible-mode/v1',
|
||||
stream_options={"include_usage": True},
|
||||
model_kwargs=optional_params,
|
||||
)
|
||||
return chat_tong_yi
|
||||
|
|
@ -11,21 +11,33 @@ import os
|
|||
from common.util.file_util import get_file_content
|
||||
from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, ModelInfo, IModelProvider, \
|
||||
ModelInfoManage
|
||||
from setting.models_provider.impl.qwen_model_provider.credential.image import QwenVLModelCredential
|
||||
from setting.models_provider.impl.qwen_model_provider.credential.llm import OpenAILLMModelCredential
|
||||
from setting.models_provider.impl.qwen_model_provider.model.image import QwenVLChatModel
|
||||
|
||||
from setting.models_provider.impl.qwen_model_provider.model.llm import QwenChatModel
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
qwen_model_credential = OpenAILLMModelCredential()
|
||||
qwenvl_model_credential = QwenVLModelCredential()
|
||||
|
||||
module_info_list = [
|
||||
ModelInfo('qwen-turbo', '', ModelTypeConst.LLM, qwen_model_credential, QwenChatModel),
|
||||
ModelInfo('qwen-plus', '', ModelTypeConst.LLM, qwen_model_credential, QwenChatModel),
|
||||
ModelInfo('qwen-max', '', ModelTypeConst.LLM, qwen_model_credential, QwenChatModel)
|
||||
]
|
||||
module_info_vl_list = [
|
||||
ModelInfo('qwen-vl-max', '', ModelTypeConst.IMAGE, qwenvl_model_credential, QwenVLChatModel),
|
||||
ModelInfo('qwen-vl-max-0809', '', ModelTypeConst.IMAGE, qwenvl_model_credential, QwenVLChatModel),
|
||||
ModelInfo('qwen-vl-plus-0809', '', ModelTypeConst.IMAGE, qwenvl_model_credential, QwenVLChatModel),
|
||||
]
|
||||
|
||||
model_info_manage = ModelInfoManage.builder().append_model_info_list(module_info_list).append_default_model_info(
|
||||
ModelInfo('qwen-turbo', '', ModelTypeConst.LLM, qwen_model_credential, QwenChatModel)).build()
|
||||
model_info_manage = (ModelInfoManage.builder()
|
||||
.append_model_info_list(module_info_list)
|
||||
.append_default_model_info(
|
||||
ModelInfo('qwen-turbo', '', ModelTypeConst.LLM, qwen_model_credential, QwenChatModel))
|
||||
.append_model_info_list(module_info_vl_list)
|
||||
.build())
|
||||
|
||||
|
||||
class QwenModelProvider(IModelProvider):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,69 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: llm.py
|
||||
@date:2024/7/11 18:41
|
||||
@desc:
|
||||
"""
|
||||
import base64
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm, TooltipLabel
|
||||
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
|
||||
class QwenModelParams(BaseForm):
|
||||
temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
|
||||
required=True, default_value=1.0,
|
||||
_min=0.1,
|
||||
_max=1.9,
|
||||
_step=0.01,
|
||||
precision=2)
|
||||
|
||||
max_tokens = forms.SliderField(
|
||||
TooltipLabel('输出最大Tokens', '指定模型可生成的最大token个数'),
|
||||
required=True, default_value=800,
|
||||
_min=1,
|
||||
_max=100000,
|
||||
_step=1,
|
||||
precision=0)
|
||||
|
||||
|
||||
class TencentVisionModelCredential(BaseForm, BaseModelCredential):
|
||||
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||
raise_exception=False):
|
||||
model_type_list = provider.get_model_type_list()
|
||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||
for key in ['api_key']:
|
||||
if key not in model_credential:
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
||||
else:
|
||||
return False
|
||||
try:
|
||||
model = provider.get_model(model_type, model_name, model_credential)
|
||||
model.stream([HumanMessage(content=[{"type": "text", "text": "你好"}])])
|
||||
except Exception as e:
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model: Dict[str, object]):
|
||||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||
|
||||
api_key = forms.PasswordInputField('API Key', required=True)
|
||||
|
||||
def get_model_params_setting_form(self, model_name):
|
||||
return QwenModelParams()
|
||||
|
|
@ -0,0 +1,25 @@
|
|||
from typing import Dict
|
||||
|
||||
from langchain_openai.chat_models import ChatOpenAI
|
||||
|
||||
from common.config.tokenizer_manage_config import TokenizerManage
|
||||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||
|
||||
|
||||
def custom_get_token_ids(text: str):
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return tokenizer.encode(text)
|
||||
|
||||
|
||||
class TencentVision(MaxKBBaseModel, ChatOpenAI):
|
||||
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
|
||||
return TencentVision(
|
||||
model=model_name,
|
||||
openai_api_base='https://api.hunyuan.cloud.tencent.com/v1',
|
||||
openai_api_key=model_credential.get('api_key'),
|
||||
stream_options={"include_usage": True},
|
||||
**optional_params,
|
||||
)
|
||||
|
|
@ -7,8 +7,10 @@ from setting.models_provider.base_model_provider import (
|
|||
IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, ModelInfoManage
|
||||
)
|
||||
from setting.models_provider.impl.tencent_model_provider.credential.embedding import TencentEmbeddingCredential
|
||||
from setting.models_provider.impl.tencent_model_provider.credential.image import TencentVisionModelCredential
|
||||
from setting.models_provider.impl.tencent_model_provider.credential.llm import TencentLLMModelCredential
|
||||
from setting.models_provider.impl.tencent_model_provider.model.embedding import TencentEmbeddingModel
|
||||
from setting.models_provider.impl.tencent_model_provider.model.image import TencentVision
|
||||
from setting.models_provider.impl.tencent_model_provider.model.llm import TencentModel
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
|
|
@ -78,9 +80,18 @@ def _initialize_model_info():
|
|||
|
||||
model_info_embedding_list = [tencent_embedding_model_info]
|
||||
|
||||
model_info_vision_list = [_create_model_info(
|
||||
'hunyuan-vision',
|
||||
'混元视觉模型',
|
||||
ModelTypeConst.IMAGE,
|
||||
TencentVisionModelCredential,
|
||||
TencentVision)]
|
||||
|
||||
|
||||
model_info_manage = ModelInfoManage.builder() \
|
||||
.append_model_info_list(model_info_list) \
|
||||
.append_model_info_list(model_info_embedding_list) \
|
||||
.append_model_info_list(model_info_vision_list) \
|
||||
.append_default_model_info(model_info_list[0]) \
|
||||
.build()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,11 +1,15 @@
|
|||
# coding=utf-8
|
||||
|
||||
import base64
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
from setting.models_provider.impl.xf_model_provider.model.image import ImageMessage
|
||||
|
||||
|
||||
class XunFeiImageModelCredential(BaseForm, BaseModelCredential):
|
||||
|
|
@ -28,7 +32,10 @@ class XunFeiImageModelCredential(BaseForm, BaseModelCredential):
|
|||
return False
|
||||
try:
|
||||
model = provider.get_model(model_type, model_name, model_credential)
|
||||
model.check_auth()
|
||||
cwd = os.path.dirname(os.path.abspath(__file__))
|
||||
with open(f'{cwd}/img_1.png', 'rb') as f:
|
||||
message_list = [ImageMessage(str(base64.b64encode(f.read()), 'utf-8')), HumanMessage('请概述这张图片')]
|
||||
model.stream(message_list)
|
||||
except Exception as e:
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
|
|
|
|||
|
Before Width: | Height: | Size: 354 KiB After Width: | Height: | Size: 354 KiB |
|
|
@ -1,50 +1,58 @@
|
|||
# coding=utf-8
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import datetime
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import os
|
||||
import ssl
|
||||
from datetime import datetime, UTC
|
||||
from typing import Dict
|
||||
from urllib.parse import urlencode
|
||||
from urllib.parse import urlparse
|
||||
from typing import Dict, Any, List, Optional, Iterator
|
||||
|
||||
import websockets
|
||||
from docutils.utils import SystemMessage
|
||||
from langchain_community.chat_models.sparkllm import ChatSparkLLM, _convert_delta_to_message_chunk
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.messages import BaseMessage, ChatMessage, HumanMessage, AIMessage, AIMessageChunk
|
||||
from langchain_core.outputs import ChatGenerationChunk
|
||||
|
||||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||
from setting.models_provider.impl.base_image import BaseImage
|
||||
|
||||
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||||
ssl_context.check_hostname = False
|
||||
ssl_context.verify_mode = ssl.CERT_NONE
|
||||
|
||||
|
||||
class XFSparkImage(MaxKBBaseModel, BaseImage):
|
||||
class ImageMessage(HumanMessage):
|
||||
content: str
|
||||
|
||||
|
||||
def convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
message_dict: Dict[str, Any]
|
||||
if isinstance(message, ChatMessage):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
elif isinstance(message, ImageMessage):
|
||||
message_dict = {"role": "user", "content": message.content, "content_type": "image"}
|
||||
elif isinstance(message, HumanMessage):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
elif isinstance(message, AIMessage):
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
if "function_call" in message.additional_kwargs:
|
||||
message_dict["function_call"] = message.additional_kwargs["function_call"]
|
||||
# If function call only, content is None not empty string
|
||||
if message_dict["content"] == "":
|
||||
message_dict["content"] = None
|
||||
if "tool_calls" in message.additional_kwargs:
|
||||
message_dict["tool_calls"] = message.additional_kwargs["tool_calls"]
|
||||
# If tool calls only, content is None not empty string
|
||||
if message_dict["content"] == "":
|
||||
message_dict["content"] = None
|
||||
elif isinstance(message, SystemMessage):
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
||||
return message_dict
|
||||
|
||||
|
||||
class XFSparkImage(MaxKBBaseModel, ChatSparkLLM):
|
||||
spark_app_id: str
|
||||
spark_api_key: str
|
||||
spark_api_secret: str
|
||||
spark_api_url: str
|
||||
params: dict
|
||||
|
||||
# 初始化
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.spark_api_url = kwargs.get('spark_api_url')
|
||||
self.spark_app_id = kwargs.get('spark_app_id')
|
||||
self.spark_api_key = kwargs.get('spark_api_key')
|
||||
self.spark_api_secret = kwargs.get('spark_api_secret')
|
||||
self.params = kwargs.get('params')
|
||||
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
optional_params = {'params': {}}
|
||||
for key, value in model_kwargs.items():
|
||||
if key not in ['model_id', 'use_local', 'streaming']:
|
||||
optional_params['params'][key] = value
|
||||
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
|
||||
return XFSparkImage(
|
||||
spark_app_id=model_credential.get('spark_app_id'),
|
||||
spark_api_key=model_credential.get('spark_api_key'),
|
||||
|
|
@ -53,118 +61,36 @@ class XFSparkImage(MaxKBBaseModel, BaseImage):
|
|||
**optional_params
|
||||
)
|
||||
|
||||
def create_url(self):
|
||||
url = self.spark_api_url
|
||||
host = urlparse(url).hostname
|
||||
# 生成RFC1123格式的时间戳
|
||||
gmt_format = '%a, %d %b %Y %H:%M:%S GMT'
|
||||
date = datetime.now(UTC).strftime(gmt_format)
|
||||
|
||||
# 拼接字符串
|
||||
signature_origin = "host: " + host + "\n"
|
||||
signature_origin += "date: " + date + "\n"
|
||||
signature_origin += "GET " + "/v2.1/image " + "HTTP/1.1"
|
||||
# 进行hmac-sha256进行加密
|
||||
signature_sha = hmac.new(self.spark_api_secret.encode('utf-8'), signature_origin.encode('utf-8'),
|
||||
digestmod=hashlib.sha256).digest()
|
||||
signature_sha = base64.b64encode(signature_sha).decode(encoding='utf-8')
|
||||
|
||||
authorization_origin = "api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"" % (
|
||||
self.spark_api_key, "hmac-sha256", "host date request-line", signature_sha)
|
||||
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
|
||||
# 将请求的鉴权参数组合为字典
|
||||
v = {
|
||||
"authorization": authorization,
|
||||
"date": date,
|
||||
"host": host
|
||||
}
|
||||
# 拼接鉴权参数,生成url
|
||||
url = url + '?' + urlencode(v)
|
||||
# print("date: ",date)
|
||||
# print("v: ",v)
|
||||
# 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
|
||||
# print('websocket url :', url)
|
||||
return url
|
||||
|
||||
def check_auth(self):
|
||||
cwd = os.path.dirname(os.path.abspath(__file__))
|
||||
with open(f'{cwd}/img_1.png', 'rb') as f:
|
||||
self.image_understand(f,"一句话概述这个图片")
|
||||
|
||||
def image_understand(self, image_file, question):
|
||||
async def handle():
|
||||
async with websockets.connect(self.create_url(), max_size=1000000000, ssl=ssl_context) as ws:
|
||||
# 发送 full client request
|
||||
await self.send(ws, image_file, question)
|
||||
return await self.handle_message(ws)
|
||||
|
||||
return asyncio.run(handle())
|
||||
|
||||
# 收到websocket消息的处理
|
||||
@staticmethod
|
||||
async def handle_message(ws):
|
||||
# print(message)
|
||||
answer = ''
|
||||
while True:
|
||||
res = await ws.recv()
|
||||
data = json.loads(res)
|
||||
code = data['header']['code']
|
||||
if code != 0:
|
||||
return f'请求错误: {code}, {data}'
|
||||
else:
|
||||
choices = data["payload"]["choices"]
|
||||
status = choices["status"]
|
||||
content = choices["text"][0]["content"]
|
||||
# print(content, end="")
|
||||
answer += content
|
||||
# print(1)
|
||||
if status == 2:
|
||||
break
|
||||
return answer
|
||||
def generate_message(prompt: str, image) -> list[BaseMessage]:
|
||||
if image is None:
|
||||
cwd = os.path.dirname(os.path.abspath(__file__))
|
||||
with open(f'{cwd}/img_1.png', 'rb') as f:
|
||||
base64_image = base64.b64encode(f.read()).decode("utf-8")
|
||||
return [ImageMessage(f'data:image/jpeg;base64,{base64_image}'), HumanMessage(prompt)]
|
||||
return [HumanMessage(prompt)]
|
||||
|
||||
async def send(self, ws, image_file, question):
|
||||
text = [
|
||||
{"role": "user", "content": str(base64.b64encode(image_file.read()), 'utf-8'), "content_type": "image"},
|
||||
{"role": "user", "content": question}
|
||||
]
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
default_chunk_class = AIMessageChunk
|
||||
|
||||
data = {
|
||||
"header": {
|
||||
"app_id": self.spark_app_id
|
||||
},
|
||||
"parameter": {
|
||||
"chat": {
|
||||
"domain": "image",
|
||||
"temperature": 0.5,
|
||||
"top_k": 4,
|
||||
"max_tokens": 2028,
|
||||
"auditing": "default"
|
||||
}
|
||||
},
|
||||
"payload": {
|
||||
"message": {
|
||||
"text": text
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
d = json.dumps(data)
|
||||
await ws.send(d)
|
||||
|
||||
def is_cache_model(self):
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def get_len(text):
|
||||
length = 0
|
||||
for content in text:
|
||||
temp = content["content"]
|
||||
leng = len(temp)
|
||||
length += leng
|
||||
return length
|
||||
|
||||
def check_len(self, text):
|
||||
print("text-content-tokens:", self.get_len(text[1:]))
|
||||
while (self.get_len(text[1:]) > 8000):
|
||||
del text[1]
|
||||
return text
|
||||
self.client.arun(
|
||||
[convert_message_to_dict(m) for m in messages],
|
||||
self.spark_user_id,
|
||||
self.model_kwargs,
|
||||
streaming=True,
|
||||
)
|
||||
for content in self.client.subscribe(timeout=self.request_timeout):
|
||||
if "data" not in content:
|
||||
continue
|
||||
delta = content["data"]
|
||||
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
|
||||
cg_chunk = ChatGenerationChunk(message=chunk)
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(str(chunk.content), chunk=cg_chunk)
|
||||
yield cg_chunk
|
||||
|
|
|
|||
|
|
@ -37,7 +37,6 @@ model_info_list = [
|
|||
ModelInfo('generalv2', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM),
|
||||
ModelInfo('iat', '中英文识别', ModelTypeConst.STT, stt_model_credential, XFSparkSpeechToText),
|
||||
ModelInfo('tts', '', ModelTypeConst.TTS, tts_model_credential, XFSparkTextToSpeech),
|
||||
ModelInfo('image', '', ModelTypeConst.IMAGE, image_model_credential, XFSparkImage),
|
||||
ModelInfo('embedding', '', ModelTypeConst.EMBEDDING, embedding_model_credential, XFEmbedding)
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -286,6 +286,14 @@ const getApplicationTTSModel: (
|
|||
return get(`${prefix}/${application_id}/model`, { model_type: 'TTS' }, loading)
|
||||
}
|
||||
|
||||
const getApplicationImageModel: (
|
||||
application_id: string,
|
||||
loading?: Ref<boolean>
|
||||
) => Promise<Result<Array<any>>> = (application_id, loading) => {
|
||||
return get(`${prefix}/${application_id}/model`, { model_type: 'IMAGE' }, loading)
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* 发布应用
|
||||
* @param 参数
|
||||
|
|
@ -350,6 +358,19 @@ const getModelParamsForm: (
|
|||
return get(`${prefix}/${application_id}/model_params_form/${model_id}`, undefined, loading)
|
||||
}
|
||||
|
||||
/**
|
||||
* 上传文档图片附件
|
||||
*/
|
||||
const uploadFile: (
|
||||
application_id: String,
|
||||
chat_id: String,
|
||||
data: any,
|
||||
loading?: Ref<boolean>
|
||||
) => Promise<Result<any>> = (application_id, chat_id, data, loading) => {
|
||||
return post(`${prefix}/${application_id}/chat/${chat_id}/upload_file`, data, undefined, loading)
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* 语音转文本
|
||||
*/
|
||||
|
|
@ -501,6 +522,7 @@ export default {
|
|||
getApplicationRerankerModel,
|
||||
getApplicationSTTModel,
|
||||
getApplicationTTSModel,
|
||||
getApplicationImageModel,
|
||||
postSpeechToText,
|
||||
postTextToSpeech,
|
||||
getPlatformStatus,
|
||||
|
|
@ -513,5 +535,6 @@ export default {
|
|||
putWorkFlowVersion,
|
||||
playDemoText,
|
||||
getUserList,
|
||||
getApplicationList
|
||||
getApplicationList,
|
||||
uploadFile
|
||||
}
|
||||
|
|
|
|||
|
|
@ -54,6 +54,18 @@
|
|||
<div v-for="(f, i) in item.global_fields" :key="i">
|
||||
{{ f.label }}: {{ f.value }}
|
||||
</div>
|
||||
<div v-if="item.document_list?.length > 0">
|
||||
上传的文档:
|
||||
<div v-for="(f, i) in item.document_list" :key="i">
|
||||
{{ f.name }}
|
||||
</div>
|
||||
</div>
|
||||
<div v-if="item.image_list?.length > 0">
|
||||
上传的图片:
|
||||
<div v-for="(f, i) in item.image_list" :key="i">
|
||||
{{ f.name }}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
|
@ -96,7 +108,8 @@
|
|||
v-if="
|
||||
item.type == WorkflowType.AiChat ||
|
||||
item.type == WorkflowType.Question ||
|
||||
item.type == WorkflowType.Application
|
||||
item.type == WorkflowType.Application ||
|
||||
item.type == WorkflowType.ImageUnderstandNode
|
||||
"
|
||||
>
|
||||
<div
|
||||
|
|
|
|||
|
|
@ -192,6 +192,19 @@
|
|||
/>
|
||||
|
||||
<div class="operate flex align-center">
|
||||
<span v-if="props.data.file_upload_enable" class="flex align-center">
|
||||
<el-upload
|
||||
action="#"
|
||||
:auto-upload="false"
|
||||
:show-file-list="false"
|
||||
:on-change="(file: any, fileList: any) => uploadFile(file, fileList)"
|
||||
>
|
||||
<el-button text>
|
||||
<el-icon><Paperclip /></el-icon>
|
||||
</el-button>
|
||||
</el-upload>
|
||||
<el-divider direction="vertical" />
|
||||
</span>
|
||||
<span v-if="props.data.stt_model_enable" class="flex align-center">
|
||||
<el-button text v-if="mediaRecorderStatus" @click="startRecording">
|
||||
<el-icon>
|
||||
|
|
@ -790,7 +803,7 @@ function chatMessage(chat?: any, problem?: string, re_chat?: boolean) {
|
|||
is_stop: false,
|
||||
record_id: '',
|
||||
vote_status: '-1',
|
||||
status: undefined
|
||||
status: undefined,
|
||||
})
|
||||
chatList.value.push(chat)
|
||||
ChatManagement.addChatRecord(chat, 50, loading)
|
||||
|
|
@ -809,7 +822,8 @@ function chatMessage(chat?: any, problem?: string, re_chat?: boolean) {
|
|||
const obj = {
|
||||
message: chat.problem_text,
|
||||
re_chat: re_chat || false,
|
||||
form_data: { ...form_data.value, ...api_form_data.value }
|
||||
form_data: { ...form_data.value, ...api_form_data.value },
|
||||
image_list: uploadFileList.value,
|
||||
}
|
||||
// 对话
|
||||
applicationApi
|
||||
|
|
@ -832,6 +846,7 @@ function chatMessage(chat?: any, problem?: string, re_chat?: boolean) {
|
|||
nextTick(() => {
|
||||
// 将滚动条滚动到最下面
|
||||
scrollDiv.value.setScrollTop(getMaxHeight())
|
||||
uploadFileList.value = []
|
||||
})
|
||||
const reader = response.body.getReader()
|
||||
// 处理流数据
|
||||
|
|
@ -916,6 +931,42 @@ const handleScroll = () => {
|
|||
}
|
||||
}
|
||||
|
||||
// 保存上传文件列表
|
||||
const uploadFileList = ref<any>([])
|
||||
const uploadFile = async (file: any, fileList: any) => {
|
||||
const { maxFiles, fileLimit } = props.data.file_upload_setting
|
||||
if (fileList.length > maxFiles) {
|
||||
MsgWarning('最多上传' + maxFiles + '个文件')
|
||||
return
|
||||
}
|
||||
if (fileList.filter((f: any) => f.size > fileLimit * 1024 * 1024).length > 0) { // MB
|
||||
MsgWarning('单个文件大小不能超过' + fileLimit + 'MB')
|
||||
fileList.splice(0, fileList.length)
|
||||
return
|
||||
}
|
||||
const formData = new FormData()
|
||||
for (const file of fileList) {
|
||||
formData.append('file', file.raw, file.name)
|
||||
uploadFileList.value.push(file)
|
||||
}
|
||||
|
||||
if (props.chatId === 'new' || !chartOpenId.value) {
|
||||
const res = await applicationApi.getChatOpen(props.data.id as string)
|
||||
chartOpenId.value = res.data
|
||||
}
|
||||
applicationApi.uploadFile(props.data.id as string, chartOpenId.value, formData, loading).then((response) => {
|
||||
fileList.splice(0, fileList.length)
|
||||
uploadFileList.value.forEach((file: any) => {
|
||||
const f = response.data.filter((f: any) => f.name === file.name)
|
||||
if (f.length > 0) {
|
||||
file.url = f[0].url
|
||||
file.file_id = f[0].file_id
|
||||
}
|
||||
})
|
||||
console.log(uploadFileList.value)
|
||||
})
|
||||
}
|
||||
|
||||
// 定义响应式引用
|
||||
const mediaRecorder = ref<any>(null)
|
||||
|
||||
|
|
|
|||
|
|
@ -9,5 +9,7 @@ export enum WorkflowType {
|
|||
FunctionLib = 'function-lib-node',
|
||||
FunctionLibCustom = 'function-node',
|
||||
RrerankerNode = 'reranker-node',
|
||||
Application = 'application-node'
|
||||
Application = 'application-node',
|
||||
DocumentExtractNode = 'document-extract-node',
|
||||
ImageUnderstandNode = 'image-understand-node',
|
||||
}
|
||||
|
|
|
|||
|
|
@ -168,13 +168,49 @@ export const rerankerNode = {
|
|||
}
|
||||
}
|
||||
}
|
||||
export const documentExtractNode = {
|
||||
type: WorkflowType.DocumentExtractNode,
|
||||
text: '提取文档中的内容',
|
||||
label: '文档内容提取',
|
||||
height: 252,
|
||||
properties: {
|
||||
stepName: '文档内容提取',
|
||||
config: {
|
||||
fields: [
|
||||
{
|
||||
label: '文件内容',
|
||||
value: 'content'
|
||||
},
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
export const imageUnderstandNode = {
|
||||
type: WorkflowType.ImageUnderstandNode,
|
||||
text: '识别出图片中的对象、场景等信息回答用户问题',
|
||||
label: '图片理解',
|
||||
height: 252,
|
||||
properties: {
|
||||
stepName: '图片理解',
|
||||
config: {
|
||||
fields: [
|
||||
{
|
||||
label: 'AI 回答内容',
|
||||
value: 'content'
|
||||
},
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
export const menuNodes = [
|
||||
aiChatNode,
|
||||
searchDatasetNode,
|
||||
questionNode,
|
||||
conditionNode,
|
||||
replyNode,
|
||||
rerankerNode
|
||||
rerankerNode,
|
||||
documentExtractNode,
|
||||
imageUnderstandNode
|
||||
]
|
||||
|
||||
/**
|
||||
|
|
@ -261,7 +297,9 @@ export const nodeDict: any = {
|
|||
[WorkflowType.FunctionLib]: functionLibNode,
|
||||
[WorkflowType.FunctionLibCustom]: functionNode,
|
||||
[WorkflowType.RrerankerNode]: rerankerNode,
|
||||
[WorkflowType.Application]: applicationNode
|
||||
[WorkflowType.Application]: applicationNode,
|
||||
[WorkflowType.DocumentExtractNode]: documentExtractNode,
|
||||
[WorkflowType.ImageUnderstandNode]: imageUnderstandNode,
|
||||
}
|
||||
export function isWorkFlow(type: string | undefined) {
|
||||
return type === 'WORK_FLOW'
|
||||
|
|
|
|||
|
|
@ -0,0 +1,6 @@
|
|||
<template>
|
||||
<AppAvatar shape="square" style="background: #7F3BF5">
|
||||
<img src="@/assets/icon_document.svg" style="width: 65%" alt="" />
|
||||
</AppAvatar>
|
||||
</template>
|
||||
<script setup lang="ts"></script>
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
<template>
|
||||
<AppAvatar shape="square" style="background: #14C0FF;">
|
||||
<img src="@/assets/icon_document.svg" style="width: 65%" alt="" />
|
||||
</AppAvatar>
|
||||
</template>
|
||||
<script setup lang="ts"></script>
|
||||
|
|
@ -0,0 +1,101 @@
|
|||
<template>
|
||||
<el-dialog
|
||||
title="文件上传设置"
|
||||
v-model="dialogVisible"
|
||||
:close-on-click-modal="false"
|
||||
:close-on-press-escape="false"
|
||||
:destroy-on-close="true"
|
||||
:before-close="close"
|
||||
append-to-body
|
||||
>
|
||||
<el-form
|
||||
label-position="top"
|
||||
ref="fieldFormRef"
|
||||
:model="form_data"
|
||||
require-asterisk-position="right">
|
||||
<el-form-item label="单次上传最多文件数">
|
||||
<el-slider v-model="form_data.maxFiles" show-input :show-input-controls="false" :min="1" :max="10" />
|
||||
</el-form-item>
|
||||
<el-form-item label="每个文件最大(MB)">
|
||||
<el-slider v-model="form_data.fileLimit" show-input :show-input-controls="false" :min="1" :max="100" />
|
||||
</el-form-item>
|
||||
<el-form-item label="上传的文件类型">
|
||||
<el-card style="width: 100%" class="mb-8">
|
||||
<div class="flex-between">
|
||||
<p>
|
||||
文档(TXT、MD、DOCX、HTML、CSV、XLSX、XLS、PDF)
|
||||
需要与文档内容提取节点配合使用
|
||||
</p>
|
||||
<el-checkbox v-model="form_data.document" />
|
||||
</div>
|
||||
</el-card>
|
||||
<el-card style="width: 100%" class="mb-8">
|
||||
<div class="flex-between">
|
||||
<p>
|
||||
图片(JPG、JPEG、PNG、GIF)
|
||||
所选模型需要支持接收图片
|
||||
</p>
|
||||
<el-checkbox v-model="form_data.image" />
|
||||
</div>
|
||||
</el-card>
|
||||
</el-form-item>
|
||||
</el-form>
|
||||
<template #footer>
|
||||
<span class="dialog-footer">
|
||||
<el-button @click.prevent="close"> 取消 </el-button>
|
||||
<el-button type="primary" @click="submit()" :loading="loading">
|
||||
确定
|
||||
</el-button>
|
||||
</span>
|
||||
</template>
|
||||
</el-dialog>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { nextTick, ref } from 'vue'
|
||||
|
||||
const emit = defineEmits(['refresh'])
|
||||
const props = defineProps<{ nodeModel: any }>()
|
||||
|
||||
const dialogVisible = ref(false)
|
||||
const loading = ref(false)
|
||||
const fieldFormRef = ref()
|
||||
const form_data = ref({
|
||||
maxFiles: 3,
|
||||
fileLimit: 50,
|
||||
document: true,
|
||||
image: false,
|
||||
audio: false,
|
||||
video: false
|
||||
})
|
||||
|
||||
|
||||
function open(data: any) {
|
||||
dialogVisible.value = true
|
||||
nextTick(() => {
|
||||
form_data.value = { ...form_data.value, ...data }
|
||||
})
|
||||
}
|
||||
|
||||
function close() {
|
||||
dialogVisible.value = false
|
||||
}
|
||||
|
||||
async function submit() {
|
||||
const formEl = fieldFormRef.value
|
||||
if (!formEl) return
|
||||
await formEl.validate().then(() => {
|
||||
emit('refresh', form_data.value)
|
||||
props.nodeModel.graphModel.eventCenter.emit('refreshFileUploadConfig')
|
||||
dialogVisible.value = false
|
||||
})
|
||||
}
|
||||
|
||||
defineExpose({
|
||||
open
|
||||
})
|
||||
</script>
|
||||
|
||||
<style scoped lang="scss">
|
||||
|
||||
</style>
|
||||
|
|
@ -45,6 +45,36 @@
|
|||
@submitDialog="submitDialog"
|
||||
/>
|
||||
</el-form-item>
|
||||
<el-form-item >
|
||||
<template #label>
|
||||
<div class="flex-between">
|
||||
<div class="flex align-center">
|
||||
<span class="mr-4">文件上传</span>
|
||||
<el-tooltip
|
||||
effect="dark"
|
||||
content="开启后,问答页面会显示上传文件的按钮。"
|
||||
placement="right"
|
||||
>
|
||||
<AppIcon iconName="app-warning" class="app-warning-icon"></AppIcon>
|
||||
</el-tooltip>
|
||||
</div>
|
||||
<div>
|
||||
<el-button
|
||||
v-if="form_data.file_upload_enable"
|
||||
type="primary"
|
||||
link
|
||||
@click="openFileUploadSettingDialog"
|
||||
class="mr-4"
|
||||
>
|
||||
<el-icon class="mr-4">
|
||||
<Setting />
|
||||
</el-icon>
|
||||
</el-button>
|
||||
<el-switch size="small" v-model="form_data.file_upload_enable"/>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
</el-form-item>
|
||||
<UserInputFieldTable ref="UserInputFieldTableFef" :node-model="nodeModel" />
|
||||
<ApiInputFieldTable ref="ApiInputFieldTableFef" :node-model="nodeModel" />
|
||||
<el-form-item>
|
||||
|
|
@ -139,7 +169,6 @@
|
|||
<el-icon class="mr-4">
|
||||
<Setting />
|
||||
</el-icon>
|
||||
设置
|
||||
</el-button>
|
||||
<el-switch size="small" v-model="form_data.tts_model_enable" @change="ttsModelEnableChange"/>
|
||||
</div>
|
||||
|
|
@ -212,6 +241,7 @@
|
|||
</el-form-item>
|
||||
</el-form>
|
||||
<TTSModeParamSettingDialog ref="TTSModeParamSettingDialogRef" @refresh="refreshTTSForm" />
|
||||
<FileUploadSettingDialog ref="FileUploadSettingDialogRef" :node-model="nodeModel" @refresh="refreshFileUploadForm"/>
|
||||
</NodeContainer>
|
||||
</template>
|
||||
<script setup lang="ts">
|
||||
|
|
@ -229,6 +259,7 @@ import { t } from '@/locales'
|
|||
import TTSModeParamSettingDialog from '@/views/application/component/TTSModeParamSettingDialog.vue'
|
||||
import ApiInputFieldTable from './component/ApiInputFieldTable.vue'
|
||||
import UserInputFieldTable from './component/UserInputFieldTable.vue'
|
||||
import FileUploadSettingDialog from '@/workflow/nodes/base-node/component/FileUploadSettingDialog.vue'
|
||||
|
||||
const { model } = useStore()
|
||||
|
||||
|
|
@ -244,6 +275,7 @@ const providerOptions = ref<Array<Provider>>([])
|
|||
const TTSModeParamSettingDialogRef = ref<InstanceType<typeof TTSModeParamSettingDialog>>()
|
||||
const UserInputFieldTableFef = ref()
|
||||
const ApiInputFieldTableFef = ref()
|
||||
const FileUploadSettingDialogRef = ref<InstanceType<typeof FileUploadSettingDialog>>()
|
||||
|
||||
const form = {
|
||||
name: '',
|
||||
|
|
@ -350,6 +382,14 @@ const refreshTTSForm = (data: any) => {
|
|||
form_data.value.tts_model_params_setting = data
|
||||
}
|
||||
|
||||
const openFileUploadSettingDialog = () => {
|
||||
FileUploadSettingDialogRef.value?.open(form_data.value.file_upload_setting)
|
||||
}
|
||||
|
||||
const refreshFileUploadForm = (data: any) => {
|
||||
form_data.value.file_upload_setting = data
|
||||
}
|
||||
|
||||
onMounted(() => {
|
||||
set(props.nodeModel, 'validate', validate)
|
||||
if (!props.nodeModel.properties.node_data.tts_type) {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,12 @@
|
|||
import DocumentExtractNodeVue from './index.vue'
|
||||
import { AppNode, AppNodeModel } from '@/workflow/common/app-node'
|
||||
class RerankerNode extends AppNode {
|
||||
constructor(props: any) {
|
||||
super(props, DocumentExtractNodeVue)
|
||||
}
|
||||
}
|
||||
export default {
|
||||
type: 'document-extract-node',
|
||||
model: AppNodeModel,
|
||||
view: RerankerNode
|
||||
}
|
||||
|
|
@ -0,0 +1,64 @@
|
|||
<template>
|
||||
<NodeContainer :nodeModel="nodeModel">
|
||||
<h5 class="title-decoration-1 mb-8">节点设置</h5>
|
||||
<el-card shadow="never" class="card-never">
|
||||
<el-form
|
||||
@submit.prevent
|
||||
:model="form_data"
|
||||
label-position="top"
|
||||
require-asterisk-position="right"
|
||||
label-width="auto"
|
||||
ref="DatasetNodeFormRef"
|
||||
>
|
||||
<el-form-item label="选择文件" :rules="{
|
||||
type: 'array',
|
||||
required: true,
|
||||
message: '请选择文件',
|
||||
trigger: 'change'
|
||||
}"
|
||||
>
|
||||
<NodeCascader
|
||||
ref="nodeCascaderRef"
|
||||
:nodeModel="nodeModel"
|
||||
class="w-full"
|
||||
placeholder="请选择文件"
|
||||
v-model="form.file_list"
|
||||
/>
|
||||
</el-form-item>
|
||||
</el-form>
|
||||
</el-card>
|
||||
</NodeContainer>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import NodeContainer from '@/workflow/common/NodeContainer.vue'
|
||||
import { computed } from 'vue'
|
||||
import { set } from 'lodash'
|
||||
import NodeCascader from '@/workflow/common/NodeCascader.vue'
|
||||
|
||||
const props = defineProps<{ nodeModel: any }>()
|
||||
|
||||
const form = {
|
||||
file_list: []
|
||||
}
|
||||
|
||||
|
||||
const form_data = computed({
|
||||
get: () => {
|
||||
if (props.nodeModel.properties.node_data) {
|
||||
return props.nodeModel.properties.node_data
|
||||
} else {
|
||||
set(props.nodeModel.properties, 'node_data', form)
|
||||
}
|
||||
return props.nodeModel.properties.node_data
|
||||
},
|
||||
set: (value) => {
|
||||
set(props.nodeModel.properties, 'node_data', value)
|
||||
}
|
||||
})
|
||||
|
||||
</script>
|
||||
|
||||
<style lang="scss" scoped>
|
||||
|
||||
</style>
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
import ImageUnderstandNodeVue from './index.vue'
|
||||
import { AppNode, AppNodeModel } from '@/workflow/common/app-node'
|
||||
|
||||
class RerankerNode extends AppNode {
|
||||
constructor(props: any) {
|
||||
super(props, ImageUnderstandNodeVue)
|
||||
}
|
||||
}
|
||||
|
||||
export default {
|
||||
type: 'image-understand-node',
|
||||
model: AppNodeModel,
|
||||
view: RerankerNode
|
||||
}
|
||||
|
|
@ -0,0 +1,277 @@
|
|||
<template>
|
||||
<NodeContainer :node-model="nodeModel">
|
||||
<h5 class="title-decoration-1 mb-8">节点设置</h5>
|
||||
<el-card shadow="never" class="card-never">
|
||||
<el-form
|
||||
@submit.prevent
|
||||
:model="form_data"
|
||||
label-position="top"
|
||||
require-asterisk-position="right"
|
||||
class="mb-24"
|
||||
label-width="auto"
|
||||
ref="aiChatNodeFormRef"
|
||||
hide-required-asterisk
|
||||
>
|
||||
<el-form-item
|
||||
label="图片理解模型"
|
||||
prop="model_id"
|
||||
:rules="{
|
||||
required: true,
|
||||
message: '请选择图片理解模型',
|
||||
trigger: 'change'
|
||||
}"
|
||||
>
|
||||
<template #label>
|
||||
<div class="flex-between w-full">
|
||||
<div>
|
||||
<span>图片理解模型<span class="danger">*</span></span>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
<el-select
|
||||
@change="model_change"
|
||||
@wheel="wheel"
|
||||
:teleported="false"
|
||||
v-model="form_data.model_id"
|
||||
placeholder="请选择图片理解模型"
|
||||
class="w-full"
|
||||
popper-class="select-model"
|
||||
:clearable="true"
|
||||
>
|
||||
<el-option-group
|
||||
v-for="(value, label) in modelOptions"
|
||||
:key="value"
|
||||
:label="relatedObject(providerOptions, label, 'provider')?.name"
|
||||
>
|
||||
<el-option
|
||||
v-for="item in value.filter((v: any) => v.status === 'SUCCESS')"
|
||||
:key="item.id"
|
||||
:label="item.name"
|
||||
:value="item.id"
|
||||
class="flex-between"
|
||||
>
|
||||
<div class="flex align-center">
|
||||
<span
|
||||
v-html="relatedObject(providerOptions, label, 'provider')?.icon"
|
||||
class="model-icon mr-8"
|
||||
></span>
|
||||
<span>{{ item.name }}</span>
|
||||
<el-tag v-if="item.permission_type === 'PUBLIC'" type="info" class="info-tag ml-8"
|
||||
>公用
|
||||
</el-tag>
|
||||
</div>
|
||||
<el-icon class="check-icon" v-if="item.id === form_data.model_id">
|
||||
<Check />
|
||||
</el-icon>
|
||||
</el-option>
|
||||
<!-- 不可用 -->
|
||||
<el-option
|
||||
v-for="item in value.filter((v: any) => v.status !== 'SUCCESS')"
|
||||
:key="item.id"
|
||||
:label="item.name"
|
||||
:value="item.id"
|
||||
class="flex-between"
|
||||
disabled
|
||||
>
|
||||
<div class="flex">
|
||||
<span
|
||||
v-html="relatedObject(providerOptions, label, 'provider')?.icon"
|
||||
class="model-icon mr-8"
|
||||
></span>
|
||||
<span>{{ item.name }}</span>
|
||||
<span class="danger">(不可用)</span>
|
||||
</div>
|
||||
<el-icon class="check-icon" v-if="item.id === form_data.model_id">
|
||||
<Check />
|
||||
</el-icon>
|
||||
</el-option>
|
||||
</el-option-group>
|
||||
</el-select>
|
||||
</el-form-item>
|
||||
|
||||
<el-form-item label="角色设定">
|
||||
<MdEditorMagnify
|
||||
title="角色设定"
|
||||
v-model="form_data.system"
|
||||
style="height: 100px"
|
||||
@submitDialog="submitSystemDialog"
|
||||
placeholder="角色设定"
|
||||
/>
|
||||
</el-form-item>
|
||||
<el-form-item
|
||||
label="提示词"
|
||||
prop="prompt"
|
||||
:rules="{
|
||||
required: true,
|
||||
message: '请输入提示词',
|
||||
trigger: 'blur'
|
||||
}"
|
||||
>
|
||||
<template #label>
|
||||
<div class="flex align-center">
|
||||
<div class="mr-4">
|
||||
<span>提示词<span class="danger">*</span></span>
|
||||
</div>
|
||||
<el-tooltip effect="dark" placement="right" popper-class="max-w-200">
|
||||
<template #content
|
||||
>通过调整提示词内容,可以引导大模型聊天方向,该提示词会被固定在上下文的开头,可以使用变量。
|
||||
</template>
|
||||
<AppIcon iconName="app-warning" class="app-warning-icon"></AppIcon>
|
||||
</el-tooltip>
|
||||
</div>
|
||||
</template>
|
||||
<MdEditorMagnify
|
||||
@wheel="wheel"
|
||||
title="提示词"
|
||||
v-model="form_data.prompt"
|
||||
style="height: 150px"
|
||||
@submitDialog="submitDialog"
|
||||
/>
|
||||
</el-form-item>
|
||||
<el-form-item label="历史聊天记录">
|
||||
<el-input-number
|
||||
v-model="form_data.dialogue_number"
|
||||
:min="0"
|
||||
:value-on-clear="0"
|
||||
controls-position="right"
|
||||
class="w-full"
|
||||
/>
|
||||
</el-form-item>
|
||||
<el-form-item label="选择图片" :rules="{
|
||||
type: 'array',
|
||||
required: true,
|
||||
message: '请选择图片',
|
||||
trigger: 'change'
|
||||
}"
|
||||
>
|
||||
<NodeCascader
|
||||
ref="nodeCascaderRef"
|
||||
:nodeModel="nodeModel"
|
||||
class="w-full"
|
||||
placeholder="请选择图片"
|
||||
v-model="form_data.image_list"
|
||||
/>
|
||||
</el-form-item>
|
||||
<el-form-item label="返回内容" @click.prevent>
|
||||
<template #label>
|
||||
<div class="flex align-center">
|
||||
<div class="mr-4">
|
||||
<span>返回内容<span class="danger">*</span></span>
|
||||
</div>
|
||||
<el-tooltip effect="dark" placement="right" popper-class="max-w-200">
|
||||
<template #content>
|
||||
关闭后该节点的内容则不输出给用户。
|
||||
如果你想让用户看到该节点的输出内容,请打开开关。
|
||||
</template>
|
||||
<AppIcon iconName="app-warning" class="app-warning-icon"></AppIcon>
|
||||
</el-tooltip>
|
||||
</div>
|
||||
</template>
|
||||
<el-switch size="small" v-model="form_data.is_result" />
|
||||
</el-form-item>
|
||||
</el-form>
|
||||
</el-card>
|
||||
</NodeContainer>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import NodeContainer from '@/workflow/common/NodeContainer.vue'
|
||||
import { computed, onMounted, ref } from 'vue'
|
||||
import { groupBy, set } from 'lodash'
|
||||
import { relatedObject } from '@/utils/utils'
|
||||
import type { Provider } from '@/api/type/model'
|
||||
import applicationApi from '@/api/application'
|
||||
import { app } from '@/main'
|
||||
import useStore from '@/stores'
|
||||
import NodeCascader from '@/workflow/common/NodeCascader.vue'
|
||||
|
||||
const { model } = useStore()
|
||||
|
||||
const {
|
||||
params: { id }
|
||||
} = app.config.globalProperties.$route as any
|
||||
|
||||
const props = defineProps<{ nodeModel: any }>()
|
||||
const modelOptions = ref<any>(null)
|
||||
const providerOptions = ref<Array<Provider>>([])
|
||||
|
||||
|
||||
const wheel = (e: any) => {
|
||||
if (e.ctrlKey === true) {
|
||||
e.preventDefault()
|
||||
return true
|
||||
} else {
|
||||
e.stopPropagation()
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
const defaultPrompt = `{{开始.question}}`
|
||||
|
||||
const form = {
|
||||
model_id: '',
|
||||
system: '',
|
||||
prompt: defaultPrompt,
|
||||
dialogue_number: 0,
|
||||
is_result: true,
|
||||
temperature: null,
|
||||
max_tokens: null,
|
||||
image_list: ["start-node", "image"]
|
||||
}
|
||||
|
||||
const form_data = computed({
|
||||
get: () => {
|
||||
if (props.nodeModel.properties.node_data) {
|
||||
return props.nodeModel.properties.node_data
|
||||
} else {
|
||||
set(props.nodeModel.properties, 'node_data', form)
|
||||
}
|
||||
return props.nodeModel.properties.node_data
|
||||
},
|
||||
set: (value) => {
|
||||
set(props.nodeModel.properties, 'node_data', value)
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
function getModel() {
|
||||
if (id) {
|
||||
applicationApi.getApplicationImageModel(id).then((res: any) => {
|
||||
modelOptions.value = groupBy(res?.data, 'provider')
|
||||
})
|
||||
} else {
|
||||
model.asyncGetModel().then((res: any) => {
|
||||
modelOptions.value = groupBy(res?.data, 'provider')
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
function getProvider() {
|
||||
model.asyncGetProvider().then((res: any) => {
|
||||
providerOptions.value = res?.data
|
||||
})
|
||||
}
|
||||
|
||||
const model_change = (model_id?: string) => {
|
||||
|
||||
}
|
||||
|
||||
function submitSystemDialog(val: string) {
|
||||
set(props.nodeModel.properties.node_data, 'system', val)
|
||||
}
|
||||
|
||||
function submitDialog(val: string) {
|
||||
set(props.nodeModel.properties.node_data, 'prompt', val)
|
||||
}
|
||||
|
||||
onMounted(() => {
|
||||
getModel()
|
||||
getProvider()
|
||||
})
|
||||
|
||||
</script>
|
||||
|
||||
|
||||
<style scoped lang="scss">
|
||||
|
||||
</style>
|
||||
|
|
@ -61,8 +61,31 @@ const refreshFieldList = () => {
|
|||
}
|
||||
props.nodeModel.graphModel.eventCenter.on('refreshFieldList', refreshFieldList)
|
||||
|
||||
const refreshFileUploadConfig = () => {
|
||||
let fields = cloneDeep(props.nodeModel.properties.config.fields)
|
||||
const form_data = props.nodeModel.graphModel.nodes
|
||||
.filter((v: any) => v.id === 'base-node')
|
||||
.map((v: any) => cloneDeep(v.properties.node_data.file_upload_setting))
|
||||
.filter((v: any) => v)
|
||||
if (form_data.length === 0) {
|
||||
return
|
||||
}
|
||||
fields = fields.filter((item: any) => item.value !== 'image' && item.value !== 'document')
|
||||
let fileUploadFields = []
|
||||
if (form_data[0].document) {
|
||||
fileUploadFields.push({ label: '文档', value: 'document' })
|
||||
}
|
||||
if (form_data[0].image) {
|
||||
fileUploadFields.push({ label: '图片', value: 'image' })
|
||||
}
|
||||
|
||||
set(props.nodeModel.properties.config, 'fields', [...fields, ...fileUploadFields])
|
||||
}
|
||||
props.nodeModel.graphModel.eventCenter.on('refreshFileUploadConfig', refreshFileUploadConfig)
|
||||
|
||||
onMounted(() => {
|
||||
refreshFieldList()
|
||||
refreshFileUploadConfig()
|
||||
})
|
||||
</script>
|
||||
<style lang="scss" scoped></style>
|
||||
|
|
|
|||
Loading…
Reference in New Issue