feat: 高级编排支持文件上传(WIP)

This commit is contained in:
CaptainB 2024-11-11 15:48:56 +08:00 committed by 刘瑞斌
parent 72b91bee9a
commit 8a8305e75b
49 changed files with 1352 additions and 238 deletions

View File

@ -16,9 +16,12 @@ from .direct_reply_node import *
from .function_lib_node import *
from .function_node import *
from .reranker_node import *
from .document_extract_node import *
from .image_understand_step_node import *
node_list = [BaseStartStepNode, BaseChatNode, BaseSearchDatasetNode, BaseQuestionNode, BaseConditionNode, BaseReplyNode,
BaseFunctionNodeNode, BaseFunctionLibNodeNode, BaseRerankerNode, BaseApplicationNode]
BaseFunctionNodeNode, BaseFunctionLibNodeNode, BaseRerankerNode, BaseApplicationNode, BaseDocumentExtractNode,
BaseImageUnderstandNode]
def get_node(node_type):

View File

@ -0,0 +1 @@
from .impl import *

View File

@ -0,0 +1,30 @@
# coding=utf-8
from typing import Type
from rest_framework import serializers
from application.flow.i_step_node import INode, NodeResult
from common.util.field_message import ErrMessage
class DocumentExtractNodeSerializer(serializers.Serializer):
# 需要查询的数据集id列表
file_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True),
error_messages=ErrMessage.list("数据集id列表"))
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
class IDocumentExtractNode(INode):
type = 'document-extract-node'
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
return DocumentExtractNodeSerializer
def _run(self):
return self.execute(**self.flow_params_serializer.data)
def execute(self, file_list, **kwargs) -> NodeResult:
pass

View File

@ -0,0 +1 @@
from .base_document_extract_node import BaseDocumentExtractNode

View File

@ -0,0 +1,11 @@
# coding=utf-8
from application.flow.step_node.document_extract_node.i_document_extract_node import IDocumentExtractNode
class BaseDocumentExtractNode(IDocumentExtractNode):
def execute(self, file_list, **kwargs):
pass
def get_details(self, index: int, **kwargs):
pass

View File

@ -0,0 +1,3 @@
# coding=utf-8
from .impl import *

View File

@ -0,0 +1,39 @@
# coding=utf-8
from typing import Type
from rest_framework import serializers
from application.flow.i_step_node import INode, NodeResult
from common.util.field_message import ErrMessage
class ImageUnderstandNodeSerializer(serializers.Serializer):
model_id = serializers.CharField(required=True, error_messages=ErrMessage.char("模型id"))
system = serializers.CharField(required=False, allow_blank=True, allow_null=True,
error_messages=ErrMessage.char("角色设定"))
prompt = serializers.CharField(required=True, error_messages=ErrMessage.char("提示词"))
# 多轮对话数量
dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer("多轮对话数量"))
is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容'))
image_list = serializers.ListField(required=False, error_messages=ErrMessage.list("图片仅1张"))
class IImageUnderstandNode(INode):
type = 'image-understand-node'
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
return ImageUnderstandNodeSerializer
def _run(self):
res = self.workflow_manage.get_reference_field(self.node_params_serializer.data.get('image_list')[0],
self.node_params_serializer.data.get('image_list')[1:])
return self.execute(image=res, **self.node_params_serializer.data, **self.flow_params_serializer.data)
def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id,
chat_record_id,
image,
**kwargs) -> NodeResult:
pass

View File

@ -0,0 +1,3 @@
# coding=utf-8
from .base_image_understand_node import BaseImageUnderstandNode

View File

@ -0,0 +1,147 @@
# coding=utf-8
import base64
import os
import time
from functools import reduce
from typing import List, Dict
from django.db.models import QuerySet
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
from application.flow.i_step_node import NodeResult, INode
from application.flow.step_node.image_understand_step_node.i_image_understand_node import IImageUnderstandNode
from dataset.models import File
from setting.models_provider.tools import get_model_instance_by_model_user_id
def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str):
chat_model = node_variable.get('chat_model')
message_tokens = node_variable['usage_metadata']['output_tokens'] if 'usage_metadata' in node_variable else 0
answer_tokens = chat_model.get_num_tokens(answer)
node.context['message_tokens'] = message_tokens
node.context['answer_tokens'] = answer_tokens
node.context['answer'] = answer
node.context['history_message'] = node_variable['history_message']
node.context['question'] = node_variable['question']
node.context['run_time'] = time.time() - node.context['start_time']
if workflow.is_result(node, NodeResult(node_variable, workflow_variable)):
workflow.answer += answer
def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
"""
写入上下文数据 (流式)
@param node_variable: 节点数据
@param workflow_variable: 全局数据
@param node: 节点
@param workflow: 工作流管理器
"""
response = node_variable.get('result')
answer = ''
for chunk in response:
answer += chunk.content
yield chunk.content
_write_context(node_variable, workflow_variable, node, workflow, answer)
def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
"""
写入上下文数据
@param node_variable: 节点数据
@param workflow_variable: 全局数据
@param node: 节点实例对象
@param workflow: 工作流管理器
"""
response = node_variable.get('result')
answer = response.content
_write_context(node_variable, workflow_variable, node, workflow, answer)
class BaseImageUnderstandNode(IImageUnderstandNode):
def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id,
image,
**kwargs) -> NodeResult:
image_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'))
history_message = self.get_history_message(history_chat_record, dialogue_number)
self.context['history_message'] = history_message
question = self.generate_prompt_question(prompt)
self.context['question'] = question.content
# todo 处理上传图片
message_list = self.generate_message_list(image_model, system, prompt, history_message, image)
self.context['message_list'] = message_list
self.context['image_list'] = image
if stream:
r = image_model.stream(message_list)
return NodeResult({'result': r, 'chat_model': image_model, 'message_list': message_list,
'history_message': history_message, 'question': question.content}, {},
_write_context=write_context_stream)
else:
r = image_model.invoke(message_list)
return NodeResult({'result': r, 'chat_model': image_model, 'message_list': message_list,
'history_message': history_message, 'question': question.content}, {},
_write_context=write_context)
@staticmethod
def get_history_message(history_chat_record, dialogue_number):
start_index = len(history_chat_record) - dialogue_number
history_message = reduce(lambda x, y: [*x, *y], [
[history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()]
for index in
range(start_index if start_index > 0 else 0, len(history_chat_record))], [])
return history_message
def generate_prompt_question(self, prompt):
return HumanMessage(self.workflow_manage.generate_prompt(prompt))
def generate_message_list(self, image_model, system: str, prompt: str, history_message, image):
if image is not None and len(image) > 0:
file_id = image[0]['file_id']
file = QuerySet(File).filter(id=file_id).first()
base64_image = base64.b64encode(file.get_byte()).decode("utf-8")
messages = [HumanMessage(
content=[
{'type': 'text', 'text': self.workflow_manage.generate_prompt(prompt)},
{'type': 'image_url', 'image_url': {'url': f'data:image/jpeg;base64,{base64_image}'}},
])]
else:
messages = [HumanMessage(self.workflow_manage.generate_prompt(prompt))]
if system is not None and len(system) > 0:
return [
SystemMessage(self.workflow_manage.generate_prompt(system)),
*history_message,
*messages
]
else:
return [
*history_message,
*messages
]
@staticmethod
def reset_message_list(message_list: List[BaseMessage], answer_text):
result = [{'role': 'user' if isinstance(message, HumanMessage) else 'ai', 'content': message.content} for
message
in
message_list]
result.append({'role': 'ai', 'content': answer_text})
return result
def get_details(self, index: int, **kwargs):
return {
'name': self.node.properties.get('stepName'),
"index": index,
'run_time': self.context.get('run_time'),
'system': self.node_params.get('system'),
'history_message': [{'content': message.content, 'role': message.type} for message in
(self.context.get('history_message') if self.context.get(
'history_message') is not None else [])],
'question': self.context.get('question'),
'answer': self.context.get('answer'),
'type': self.node.type,
'message_tokens': self.context.get('message_tokens'),
'answer_tokens': self.context.get('answer_tokens'),
'status': self.status,
'err_message': self.err_message,
'image_list': self.context.get('image_list')
}

View File

@ -41,7 +41,7 @@ class BaseStartStepNode(IStarNode):
"""
开始节点 初始化全局变量
"""
return NodeResult({'question': question},
return NodeResult({'question': question, 'image': self.workflow_manage.image_list},
workflow_variable)
def get_details(self, index: int, **kwargs):
@ -61,5 +61,6 @@ class BaseStartStepNode(IStarNode):
'type': self.node.type,
'status': self.status,
'err_message': self.err_message,
'image_list': self.context.get('image'),
'global_fields': global_fields
}

View File

@ -240,10 +240,13 @@ class NodeChunk:
class WorkflowManage:
def __init__(self, flow: Flow, params, work_flow_post_handler: WorkFlowPostHandler,
base_to_response: BaseToResponse = SystemToResponse(), form_data=None):
base_to_response: BaseToResponse = SystemToResponse(), form_data=None, image_list=None):
if form_data is None:
form_data = {}
if image_list is None:
image_list = []
self.form_data = form_data
self.image_list = image_list
self.params = params
self.flow = flow
self.lock = threading.Lock()

View File

@ -0,0 +1,23 @@
# Generated by Django 4.2.15 on 2024-11-07 11:22
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('application', '0018_workflowversion_name'),
]
operations = [
migrations.AddField(
model_name='application',
name='file_upload_enable',
field=models.BooleanField(default=False, verbose_name='文件上传是否启用'),
),
migrations.AddField(
model_name='application',
name='file_upload_setting',
field=models.JSONField(default={}, verbose_name='文件上传相关设置'),
),
]

View File

@ -66,6 +66,9 @@ class Application(AppModelMixin):
stt_model_enable = models.BooleanField(verbose_name="语音识别模型是否启用", default=False)
tts_type = models.CharField(verbose_name="语音播放类型", max_length=20, default="BROWSER")
clean_time = models.IntegerField(verbose_name="清理时间", default=180)
file_upload_enable = models.BooleanField(verbose_name="文件上传是否启用", default=False)
file_upload_setting = models.JSONField(verbose_name="文件上传相关设置", default={})
@staticmethod
def get_default_model_prompt():

View File

@ -823,6 +823,8 @@ class ApplicationSerializer(serializers.Serializer):
'stt_model_enable': application.stt_model_enable,
'tts_model_enable': application.tts_model_enable,
'tts_type': application.tts_type,
'file_upload_enable': application.file_upload_enable,
'file_upload_setting': application.file_upload_setting,
'work_flow': application.work_flow,
'show_source': application_access_token.show_source,
**application_setting_dict})
@ -876,6 +878,7 @@ class ApplicationSerializer(serializers.Serializer):
update_keys = ['name', 'desc', 'model_id', 'multiple_rounds_dialogue', 'prologue', 'status',
'dataset_setting', 'model_setting', 'problem_optimization', 'dialogue_number',
'stt_model_id', 'tts_model_id', 'tts_model_enable', 'stt_model_enable', 'tts_type',
'file_upload_enable', 'file_upload_setting',
'api_key_is_active', 'icon', 'work_flow', 'model_params_setting', 'tts_model_params_setting',
'problem_optimization_prompt', 'clean_time']
for update_key in update_keys:
@ -941,6 +944,10 @@ class ApplicationSerializer(serializers.Serializer):
instance['tts_type'] = node_data['tts_type']
if 'tts_model_params_setting' in node_data:
instance['tts_model_params_setting'] = node_data['tts_model_params_setting']
if 'file_upload_enable' in node_data:
instance['file_upload_enable'] = node_data['file_upload_enable']
if 'file_upload_setting' in node_data:
instance['file_upload_setting'] = node_data['file_upload_setting']
break
def speech_to_text(self, file, with_valid=True):

View File

@ -222,6 +222,7 @@ class ChatMessageSerializer(serializers.Serializer):
client_id = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端id"))
client_type = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端类型"))
form_data = serializers.DictField(required=False, error_messages=ErrMessage.char("全局变量"))
image_list = serializers.ListField(required=False, error_messages=ErrMessage.list("图片仅1张"))
def is_valid_application_workflow(self, *, raise_exception=False):
self.is_valid_intraday_access_num()
@ -299,6 +300,7 @@ class ChatMessageSerializer(serializers.Serializer):
client_id = self.data.get('client_id')
client_type = self.data.get('client_type')
form_data = self.data.get('form_data')
image_list = self.data.get('image_list')
user_id = chat_info.application.user_id
work_flow_manage = WorkflowManage(Flow.new_instance(chat_info.work_flow_version.work_flow),
{'history_chat_record': chat_info.chat_record_list, 'question': message,
@ -308,7 +310,7 @@ class ChatMessageSerializer(serializers.Serializer):
'client_id': client_id,
'client_type': client_type,
'user_id': user_id}, WorkFlowPostHandler(chat_info, client_id, client_type),
base_to_response, form_data)
base_to_response, form_data, image_list)
r = work_flow_manage.run()
return r

View File

@ -49,6 +49,7 @@ urlpatterns = [
path('application/<str:application_id>/chat/<int:current_page>/<int:page_size>', views.ChatView.Page.as_view()),
path('application/<str:application_id>/chat/<chat_id>', views.ChatView.Operate.as_view()),
path('application/<str:application_id>/chat/<chat_id>/chat_record/', views.ChatView.ChatRecord.as_view()),
path('application/<str:application_id>/chat/<chat_id>/upload_file', views.ChatView.UploadFile.as_view()),
path('application/<str:application_id>/chat/<chat_id>/chat_record/<int:current_page>/<int:page_size>',
views.ChatView.ChatRecord.Page.as_view()),
path('application/<str:application_id>/chat/<chat_id>/chat_record/<chat_record_id>',

View File

@ -22,6 +22,7 @@ from common.constants.permission_constants import Permission, Group, Operate, \
RoleConstants, ViewPermission, CompareConstants
from common.response import result
from common.util.common import query_params_to_single_dict
from dataset.serializers.file_serializers import FileSerializer
class Openai(APIView):
@ -128,6 +129,7 @@ class ChatView(APIView):
'client_id': request.auth.client_id,
'form_data': (request.data.get(
'form_data') if 'form_data' in request.data else {}),
'image_list': request.data.get('image_list') if 'image_list' in request.data else [],
'client_type': request.auth.client_type}).chat()
@action(methods=['GET'], detail=False)
@ -391,3 +393,28 @@ class ChatView(APIView):
data={'chat_id': chat_id, 'chat_record_id': chat_record_id,
'dataset_id': dataset_id, 'document_id': document_id,
'paragraph_id': paragraph_id}).delete())
class UploadFile(APIView):
authentication_classes = [TokenAuth]
@action(methods=['POST'], detail=False)
@swagger_auto_schema(operation_summary="上传文件",
operation_id="上传文件",
manual_parameters=ChatRecordApi.get_request_params_api(),
tags=["应用/对话日志"]
)
@has_permissions(
ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_KEY,
RoleConstants.APPLICATION_ACCESS_TOKEN],
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE,
dynamic_tag=keywords.get('application_id'))])
)
def post(self, request: Request, application_id: str, chat_id: str):
files = request.FILES.getlist('file')
file_ids = []
meta = {'application_id': application_id, 'chat_id': chat_id}
for file in files:
file_url = FileSerializer(data={'file': file, 'meta': meta}).upload()
file_ids.append({'name': file.name, 'url': file_url, 'file_id': file_url.split('/')[-1]})
return result.success(file_ids)

View File

@ -6,6 +6,7 @@
@date2024/3/14 11:54
@desc:
"""
from .clean_orphaned_file_job import *
from .client_access_num_job import *
from .clean_chat_job import *
@ -13,3 +14,4 @@ from .clean_chat_job import *
def run():
client_access_num_job.run()
clean_chat_job.run()
clean_orphaned_file_job.run()

View File

@ -0,0 +1,40 @@
# coding=utf-8
import logging
from apscheduler.schedulers.background import BackgroundScheduler
from django.db.models import Q
from django_apscheduler.jobstores import DjangoJobStore
from application.models import Chat
from common.lock.impl.file_lock import FileLock
from dataset.models import File
scheduler = BackgroundScheduler()
scheduler.add_jobstore(DjangoJobStore(), "default")
lock = FileLock()
def clean_debug_file():
logging.getLogger("max_kb").info('开始清理没有关联会话的上传文件')
existing_chat_ids = set(Chat.objects.values_list('id', flat=True))
# UUID to str
existing_chat_ids = [str(chat_id) for chat_id in existing_chat_ids]
print(existing_chat_ids)
# 查找引用的不存在的 chat_id 并删除相关记录
deleted_count, _ = File.objects.filter(~Q(meta__chat_id__in=existing_chat_ids)).delete()
logging.getLogger("max_kb").info(f'结束清理没有关联会话的上传文件: {deleted_count}')
def run():
if lock.try_lock('clean_orphaned_file_job', 30 * 30):
try:
scheduler.start()
clean_orphaned_file = scheduler.get_job(job_id='clean_orphaned_file')
if clean_orphaned_file is not None:
clean_orphaned_file.remove()
scheduler.add_job(clean_debug_file, 'cron', hour='2', minute='0', second='0',
id='clean_orphaned_file')
finally:
lock.un_lock('clean_orphaned_file_job')

View File

@ -0,0 +1,18 @@
# Generated by Django 4.2.15 on 2024-11-07 15:32
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('dataset', '0009_alter_document_status_alter_paragraph_status'),
]
operations = [
migrations.AddField(
model_name='file',
name='meta',
field=models.JSONField(default={}, verbose_name='文件关联数据'),
),
]

View File

@ -141,6 +141,9 @@ class File(AppModelMixin):
loid = models.IntegerField(verbose_name="loid")
meta = models.JSONField(verbose_name="文件关联数据", default=dict)
class Meta:
db_table = "file"
@ -149,7 +152,6 @@ class File(AppModelMixin):
):
result = select_one("SELECT lo_from_bytea(%s, %s::bytea) as loid", [0, bytea])
self.loid = result['loid']
self.file_name = 'speech.mp3'
super().save()
def get_byte(self):

View File

@ -56,12 +56,13 @@ mime_types = {"html": "text/html", "htm": "text/html", "shtml": "text/html", "cs
class FileSerializer(serializers.Serializer):
file = UploadedFileField(required=True, error_messages=ErrMessage.image("文件"))
meta = serializers.JSONField(required=False)
def upload(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
file_id = uuid.uuid1()
file = File(id=file_id, file_name=self.data.get('file').name)
file = File(id=file_id, file_name=self.data.get('file').name, meta=self.data.get('meta'))
file.save(self.data.get('file').read())
return f'/api/file/{file_id}'

View File

@ -1,14 +0,0 @@
# coding=utf-8
from abc import abstractmethod
from pydantic import BaseModel
class BaseImage(BaseModel):
@abstractmethod
def check_auth(self):
pass
@abstractmethod
def image_understand(self, image_file, text):
pass

View File

@ -1,6 +1,10 @@
# coding=utf-8
import base64
import os
from typing import Dict
from langchain_core.messages import HumanMessage
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
@ -25,7 +29,7 @@ class OpenAIImageModelCredential(BaseForm, BaseModelCredential):
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
model.check_auth()
model.stream([HumanMessage(content=[{"type": "text", "text": "你好"}])])
except Exception as e:
if isinstance(e, AppApiException):
raise e

View File

@ -1,12 +1,9 @@
import base64
import os
from typing import Dict
from openai import OpenAI
from langchain_openai.chat_models import ChatOpenAI
from common.config.tokenizer_manage_config import TokenizerManage
from setting.models_provider.base_model_provider import MaxKBBaseModel
from setting.models_provider.impl.base_image import BaseImage
def custom_get_token_ids(text: str):
@ -14,66 +11,15 @@ def custom_get_token_ids(text: str):
return tokenizer.encode(text)
class OpenAIImage(MaxKBBaseModel, BaseImage):
api_base: str
api_key: str
model: str
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.api_key = kwargs.get('api_key')
self.api_base = kwargs.get('api_base')
class OpenAIImage(MaxKBBaseModel, ChatOpenAI):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {}
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
optional_params['max_tokens'] = model_kwargs['max_tokens']
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
optional_params['temperature'] = model_kwargs['temperature']
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
return OpenAIImage(
model=model_name,
api_base=model_credential.get('api_base'),
api_key=model_credential.get('api_key'),
openai_api_base=model_credential.get('api_base'),
openai_api_key=model_credential.get('api_key'),
stream_options={"include_usage": True},
**optional_params,
)
def check_auth(self):
client = OpenAI(
base_url=self.api_base,
api_key=self.api_key
)
response_list = client.models.with_raw_response.list()
# print(response_list)
# cwd = os.path.dirname(os.path.abspath(__file__))
# with open(f'{cwd}/img_1.png', 'rb') as f:
# self.image_understand(f, "一句话概述这个图片")
def image_understand(self, image_file, text):
client = OpenAI(
base_url=self.api_base,
api_key=self.api_key
)
base64_image = base64.b64encode(image_file.read()).decode('utf-8')
response = client.chat.completions.create(
model=self.model,
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": text,
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}"
},
},
],
}
],
)
return response.choices[0].message.content

View File

@ -0,0 +1,69 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file llm.py
@date2024/7/11 18:41
@desc:
"""
import base64
import os
from typing import Dict
from langchain_core.messages import HumanMessage
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm, TooltipLabel
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
class QwenModelParams(BaseForm):
temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
required=True, default_value=1.0,
_min=0.1,
_max=1.9,
_step=0.01,
precision=2)
max_tokens = forms.SliderField(
TooltipLabel('输出最大Tokens', '指定模型可生成的最大token个数'),
required=True, default_value=800,
_min=1,
_max=100000,
_step=1,
precision=0)
class QwenVLModelCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
for key in ['api_key']:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
model.stream([HumanMessage(content=[{"type": "text", "text": "你好"}])])
except Exception as e:
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
else:
return False
return True
def encryption_dict(self, model: Dict[str, object]):
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
api_key = forms.PasswordInputField('API Key', required=True)
def get_model_params_setting_form(self, model_name):
return QwenModelParams()

View File

@ -0,0 +1,22 @@
# coding=utf-8
from typing import Dict
from langchain_community.chat_models import ChatOpenAI
from setting.models_provider.base_model_provider import MaxKBBaseModel
class QwenVLChatModel(MaxKBBaseModel, ChatOpenAI):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
chat_tong_yi = QwenVLChatModel(
model=model_name,
openai_api_key=model_credential.get('api_key'),
openai_api_base='https://dashscope.aliyuncs.com/compatible-mode/v1',
stream_options={"include_usage": True},
model_kwargs=optional_params,
)
return chat_tong_yi

View File

@ -11,21 +11,33 @@ import os
from common.util.file_util import get_file_content
from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, ModelInfo, IModelProvider, \
ModelInfoManage
from setting.models_provider.impl.qwen_model_provider.credential.image import QwenVLModelCredential
from setting.models_provider.impl.qwen_model_provider.credential.llm import OpenAILLMModelCredential
from setting.models_provider.impl.qwen_model_provider.model.image import QwenVLChatModel
from setting.models_provider.impl.qwen_model_provider.model.llm import QwenChatModel
from smartdoc.conf import PROJECT_DIR
qwen_model_credential = OpenAILLMModelCredential()
qwenvl_model_credential = QwenVLModelCredential()
module_info_list = [
ModelInfo('qwen-turbo', '', ModelTypeConst.LLM, qwen_model_credential, QwenChatModel),
ModelInfo('qwen-plus', '', ModelTypeConst.LLM, qwen_model_credential, QwenChatModel),
ModelInfo('qwen-max', '', ModelTypeConst.LLM, qwen_model_credential, QwenChatModel)
]
module_info_vl_list = [
ModelInfo('qwen-vl-max', '', ModelTypeConst.IMAGE, qwenvl_model_credential, QwenVLChatModel),
ModelInfo('qwen-vl-max-0809', '', ModelTypeConst.IMAGE, qwenvl_model_credential, QwenVLChatModel),
ModelInfo('qwen-vl-plus-0809', '', ModelTypeConst.IMAGE, qwenvl_model_credential, QwenVLChatModel),
]
model_info_manage = ModelInfoManage.builder().append_model_info_list(module_info_list).append_default_model_info(
ModelInfo('qwen-turbo', '', ModelTypeConst.LLM, qwen_model_credential, QwenChatModel)).build()
model_info_manage = (ModelInfoManage.builder()
.append_model_info_list(module_info_list)
.append_default_model_info(
ModelInfo('qwen-turbo', '', ModelTypeConst.LLM, qwen_model_credential, QwenChatModel))
.append_model_info_list(module_info_vl_list)
.build())
class QwenModelProvider(IModelProvider):

View File

@ -0,0 +1,69 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file llm.py
@date2024/7/11 18:41
@desc:
"""
import base64
import os
from typing import Dict
from langchain_core.messages import HumanMessage
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm, TooltipLabel
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
class QwenModelParams(BaseForm):
temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
required=True, default_value=1.0,
_min=0.1,
_max=1.9,
_step=0.01,
precision=2)
max_tokens = forms.SliderField(
TooltipLabel('输出最大Tokens', '指定模型可生成的最大token个数'),
required=True, default_value=800,
_min=1,
_max=100000,
_step=1,
precision=0)
class TencentVisionModelCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
for key in ['api_key']:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
model.stream([HumanMessage(content=[{"type": "text", "text": "你好"}])])
except Exception as e:
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
else:
return False
return True
def encryption_dict(self, model: Dict[str, object]):
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
api_key = forms.PasswordInputField('API Key', required=True)
def get_model_params_setting_form(self, model_name):
return QwenModelParams()

View File

@ -0,0 +1,25 @@
from typing import Dict
from langchain_openai.chat_models import ChatOpenAI
from common.config.tokenizer_manage_config import TokenizerManage
from setting.models_provider.base_model_provider import MaxKBBaseModel
def custom_get_token_ids(text: str):
tokenizer = TokenizerManage.get_tokenizer()
return tokenizer.encode(text)
class TencentVision(MaxKBBaseModel, ChatOpenAI):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
return TencentVision(
model=model_name,
openai_api_base='https://api.hunyuan.cloud.tencent.com/v1',
openai_api_key=model_credential.get('api_key'),
stream_options={"include_usage": True},
**optional_params,
)

View File

@ -7,8 +7,10 @@ from setting.models_provider.base_model_provider import (
IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, ModelInfoManage
)
from setting.models_provider.impl.tencent_model_provider.credential.embedding import TencentEmbeddingCredential
from setting.models_provider.impl.tencent_model_provider.credential.image import TencentVisionModelCredential
from setting.models_provider.impl.tencent_model_provider.credential.llm import TencentLLMModelCredential
from setting.models_provider.impl.tencent_model_provider.model.embedding import TencentEmbeddingModel
from setting.models_provider.impl.tencent_model_provider.model.image import TencentVision
from setting.models_provider.impl.tencent_model_provider.model.llm import TencentModel
from smartdoc.conf import PROJECT_DIR
@ -78,9 +80,18 @@ def _initialize_model_info():
model_info_embedding_list = [tencent_embedding_model_info]
model_info_vision_list = [_create_model_info(
'hunyuan-vision',
'混元视觉模型',
ModelTypeConst.IMAGE,
TencentVisionModelCredential,
TencentVision)]
model_info_manage = ModelInfoManage.builder() \
.append_model_info_list(model_info_list) \
.append_model_info_list(model_info_embedding_list) \
.append_model_info_list(model_info_vision_list) \
.append_default_model_info(model_info_list[0]) \
.build()

View File

@ -1,11 +1,15 @@
# coding=utf-8
import base64
import os
from typing import Dict
from langchain_core.messages import HumanMessage
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
from setting.models_provider.impl.xf_model_provider.model.image import ImageMessage
class XunFeiImageModelCredential(BaseForm, BaseModelCredential):
@ -28,7 +32,10 @@ class XunFeiImageModelCredential(BaseForm, BaseModelCredential):
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
model.check_auth()
cwd = os.path.dirname(os.path.abspath(__file__))
with open(f'{cwd}/img_1.png', 'rb') as f:
message_list = [ImageMessage(str(base64.b64encode(f.read()), 'utf-8')), HumanMessage('请概述这张图片')]
model.stream(message_list)
except Exception as e:
if isinstance(e, AppApiException):
raise e

View File

@ -1,50 +1,58 @@
# coding=utf-8
import asyncio
import base64
import datetime
import hashlib
import hmac
import json
import os
import ssl
from datetime import datetime, UTC
from typing import Dict
from urllib.parse import urlencode
from urllib.parse import urlparse
from typing import Dict, Any, List, Optional, Iterator
import websockets
from docutils.utils import SystemMessage
from langchain_community.chat_models.sparkllm import ChatSparkLLM, _convert_delta_to_message_chunk
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.messages import BaseMessage, ChatMessage, HumanMessage, AIMessage, AIMessageChunk
from langchain_core.outputs import ChatGenerationChunk
from setting.models_provider.base_model_provider import MaxKBBaseModel
from setting.models_provider.impl.base_image import BaseImage
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
class XFSparkImage(MaxKBBaseModel, BaseImage):
class ImageMessage(HumanMessage):
content: str
def convert_message_to_dict(message: BaseMessage) -> dict:
message_dict: Dict[str, Any]
if isinstance(message, ChatMessage):
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, ImageMessage):
message_dict = {"role": "user", "content": message.content, "content_type": "image"}
elif isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
if "function_call" in message.additional_kwargs:
message_dict["function_call"] = message.additional_kwargs["function_call"]
# If function call only, content is None not empty string
if message_dict["content"] == "":
message_dict["content"] = None
if "tool_calls" in message.additional_kwargs:
message_dict["tool_calls"] = message.additional_kwargs["tool_calls"]
# If tool calls only, content is None not empty string
if message_dict["content"] == "":
message_dict["content"] = None
elif isinstance(message, SystemMessage):
message_dict = {"role": "system", "content": message.content}
else:
raise ValueError(f"Got unknown type {message}")
return message_dict
class XFSparkImage(MaxKBBaseModel, ChatSparkLLM):
spark_app_id: str
spark_api_key: str
spark_api_secret: str
spark_api_url: str
params: dict
# 初始化
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.spark_api_url = kwargs.get('spark_api_url')
self.spark_app_id = kwargs.get('spark_app_id')
self.spark_api_key = kwargs.get('spark_api_key')
self.spark_api_secret = kwargs.get('spark_api_secret')
self.params = kwargs.get('params')
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {'params': {}}
for key, value in model_kwargs.items():
if key not in ['model_id', 'use_local', 'streaming']:
optional_params['params'][key] = value
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
return XFSparkImage(
spark_app_id=model_credential.get('spark_app_id'),
spark_api_key=model_credential.get('spark_api_key'),
@ -53,118 +61,36 @@ class XFSparkImage(MaxKBBaseModel, BaseImage):
**optional_params
)
def create_url(self):
url = self.spark_api_url
host = urlparse(url).hostname
# 生成RFC1123格式的时间戳
gmt_format = '%a, %d %b %Y %H:%M:%S GMT'
date = datetime.now(UTC).strftime(gmt_format)
# 拼接字符串
signature_origin = "host: " + host + "\n"
signature_origin += "date: " + date + "\n"
signature_origin += "GET " + "/v2.1/image " + "HTTP/1.1"
# 进行hmac-sha256进行加密
signature_sha = hmac.new(self.spark_api_secret.encode('utf-8'), signature_origin.encode('utf-8'),
digestmod=hashlib.sha256).digest()
signature_sha = base64.b64encode(signature_sha).decode(encoding='utf-8')
authorization_origin = "api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"" % (
self.spark_api_key, "hmac-sha256", "host date request-line", signature_sha)
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
# 将请求的鉴权参数组合为字典
v = {
"authorization": authorization,
"date": date,
"host": host
}
# 拼接鉴权参数生成url
url = url + '?' + urlencode(v)
# print("date: ",date)
# print("v: ",v)
# 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释比对相同参数时生成的url与自己代码生成的url是否一致
# print('websocket url :', url)
return url
def check_auth(self):
cwd = os.path.dirname(os.path.abspath(__file__))
with open(f'{cwd}/img_1.png', 'rb') as f:
self.image_understand(f,"一句话概述这个图片")
def image_understand(self, image_file, question):
async def handle():
async with websockets.connect(self.create_url(), max_size=1000000000, ssl=ssl_context) as ws:
# 发送 full client request
await self.send(ws, image_file, question)
return await self.handle_message(ws)
return asyncio.run(handle())
# 收到websocket消息的处理
@staticmethod
async def handle_message(ws):
# print(message)
answer = ''
while True:
res = await ws.recv()
data = json.loads(res)
code = data['header']['code']
if code != 0:
return f'请求错误: {code}, {data}'
else:
choices = data["payload"]["choices"]
status = choices["status"]
content = choices["text"][0]["content"]
# print(content, end="")
answer += content
# print(1)
if status == 2:
break
return answer
def generate_message(prompt: str, image) -> list[BaseMessage]:
if image is None:
cwd = os.path.dirname(os.path.abspath(__file__))
with open(f'{cwd}/img_1.png', 'rb') as f:
base64_image = base64.b64encode(f.read()).decode("utf-8")
return [ImageMessage(f'data:image/jpeg;base64,{base64_image}'), HumanMessage(prompt)]
return [HumanMessage(prompt)]
async def send(self, ws, image_file, question):
text = [
{"role": "user", "content": str(base64.b64encode(image_file.read()), 'utf-8'), "content_type": "image"},
{"role": "user", "content": question}
]
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
default_chunk_class = AIMessageChunk
data = {
"header": {
"app_id": self.spark_app_id
},
"parameter": {
"chat": {
"domain": "image",
"temperature": 0.5,
"top_k": 4,
"max_tokens": 2028,
"auditing": "default"
}
},
"payload": {
"message": {
"text": text
}
}
}
d = json.dumps(data)
await ws.send(d)
def is_cache_model(self):
return False
@staticmethod
def get_len(text):
length = 0
for content in text:
temp = content["content"]
leng = len(temp)
length += leng
return length
def check_len(self, text):
print("text-content-tokens:", self.get_len(text[1:]))
while (self.get_len(text[1:]) > 8000):
del text[1]
return text
self.client.arun(
[convert_message_to_dict(m) for m in messages],
self.spark_user_id,
self.model_kwargs,
streaming=True,
)
for content in self.client.subscribe(timeout=self.request_timeout):
if "data" not in content:
continue
delta = content["data"]
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
cg_chunk = ChatGenerationChunk(message=chunk)
if run_manager:
run_manager.on_llm_new_token(str(chunk.content), chunk=cg_chunk)
yield cg_chunk

View File

@ -37,7 +37,6 @@ model_info_list = [
ModelInfo('generalv2', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM),
ModelInfo('iat', '中英文识别', ModelTypeConst.STT, stt_model_credential, XFSparkSpeechToText),
ModelInfo('tts', '', ModelTypeConst.TTS, tts_model_credential, XFSparkTextToSpeech),
ModelInfo('image', '', ModelTypeConst.IMAGE, image_model_credential, XFSparkImage),
ModelInfo('embedding', '', ModelTypeConst.EMBEDDING, embedding_model_credential, XFEmbedding)
]

View File

@ -286,6 +286,14 @@ const getApplicationTTSModel: (
return get(`${prefix}/${application_id}/model`, { model_type: 'TTS' }, loading)
}
const getApplicationImageModel: (
application_id: string,
loading?: Ref<boolean>
) => Promise<Result<Array<any>>> = (application_id, loading) => {
return get(`${prefix}/${application_id}/model`, { model_type: 'IMAGE' }, loading)
}
/**
*
* @param
@ -350,6 +358,19 @@ const getModelParamsForm: (
return get(`${prefix}/${application_id}/model_params_form/${model_id}`, undefined, loading)
}
/**
*
*/
const uploadFile: (
application_id: String,
chat_id: String,
data: any,
loading?: Ref<boolean>
) => Promise<Result<any>> = (application_id, chat_id, data, loading) => {
return post(`${prefix}/${application_id}/chat/${chat_id}/upload_file`, data, undefined, loading)
}
/**
*
*/
@ -501,6 +522,7 @@ export default {
getApplicationRerankerModel,
getApplicationSTTModel,
getApplicationTTSModel,
getApplicationImageModel,
postSpeechToText,
postTextToSpeech,
getPlatformStatus,
@ -513,5 +535,6 @@ export default {
putWorkFlowVersion,
playDemoText,
getUserList,
getApplicationList
getApplicationList,
uploadFile
}

View File

@ -54,6 +54,18 @@
<div v-for="(f, i) in item.global_fields" :key="i">
{{ f.label }}: {{ f.value }}
</div>
<div v-if="item.document_list?.length > 0">
上传的文档:
<div v-for="(f, i) in item.document_list" :key="i">
{{ f.name }}
</div>
</div>
<div v-if="item.image_list?.length > 0">
上传的图片:
<div v-for="(f, i) in item.image_list" :key="i">
{{ f.name }}
</div>
</div>
</div>
</div>
</template>
@ -96,7 +108,8 @@
v-if="
item.type == WorkflowType.AiChat ||
item.type == WorkflowType.Question ||
item.type == WorkflowType.Application
item.type == WorkflowType.Application ||
item.type == WorkflowType.ImageUnderstandNode
"
>
<div

View File

@ -192,6 +192,19 @@
/>
<div class="operate flex align-center">
<span v-if="props.data.file_upload_enable" class="flex align-center">
<el-upload
action="#"
:auto-upload="false"
:show-file-list="false"
:on-change="(file: any, fileList: any) => uploadFile(file, fileList)"
>
<el-button text>
<el-icon><Paperclip /></el-icon>
</el-button>
</el-upload>
<el-divider direction="vertical" />
</span>
<span v-if="props.data.stt_model_enable" class="flex align-center">
<el-button text v-if="mediaRecorderStatus" @click="startRecording">
<el-icon>
@ -790,7 +803,7 @@ function chatMessage(chat?: any, problem?: string, re_chat?: boolean) {
is_stop: false,
record_id: '',
vote_status: '-1',
status: undefined
status: undefined,
})
chatList.value.push(chat)
ChatManagement.addChatRecord(chat, 50, loading)
@ -809,7 +822,8 @@ function chatMessage(chat?: any, problem?: string, re_chat?: boolean) {
const obj = {
message: chat.problem_text,
re_chat: re_chat || false,
form_data: { ...form_data.value, ...api_form_data.value }
form_data: { ...form_data.value, ...api_form_data.value },
image_list: uploadFileList.value,
}
//
applicationApi
@ -832,6 +846,7 @@ function chatMessage(chat?: any, problem?: string, re_chat?: boolean) {
nextTick(() => {
//
scrollDiv.value.setScrollTop(getMaxHeight())
uploadFileList.value = []
})
const reader = response.body.getReader()
//
@ -916,6 +931,42 @@ const handleScroll = () => {
}
}
//
const uploadFileList = ref<any>([])
const uploadFile = async (file: any, fileList: any) => {
const { maxFiles, fileLimit } = props.data.file_upload_setting
if (fileList.length > maxFiles) {
MsgWarning('最多上传' + maxFiles + '个文件')
return
}
if (fileList.filter((f: any) => f.size > fileLimit * 1024 * 1024).length > 0) { // MB
MsgWarning('单个文件大小不能超过' + fileLimit + 'MB')
fileList.splice(0, fileList.length)
return
}
const formData = new FormData()
for (const file of fileList) {
formData.append('file', file.raw, file.name)
uploadFileList.value.push(file)
}
if (props.chatId === 'new' || !chartOpenId.value) {
const res = await applicationApi.getChatOpen(props.data.id as string)
chartOpenId.value = res.data
}
applicationApi.uploadFile(props.data.id as string, chartOpenId.value, formData, loading).then((response) => {
fileList.splice(0, fileList.length)
uploadFileList.value.forEach((file: any) => {
const f = response.data.filter((f: any) => f.name === file.name)
if (f.length > 0) {
file.url = f[0].url
file.file_id = f[0].file_id
}
})
console.log(uploadFileList.value)
})
}
//
const mediaRecorder = ref<any>(null)

View File

@ -9,5 +9,7 @@ export enum WorkflowType {
FunctionLib = 'function-lib-node',
FunctionLibCustom = 'function-node',
RrerankerNode = 'reranker-node',
Application = 'application-node'
Application = 'application-node',
DocumentExtractNode = 'document-extract-node',
ImageUnderstandNode = 'image-understand-node',
}

View File

@ -168,13 +168,49 @@ export const rerankerNode = {
}
}
}
export const documentExtractNode = {
type: WorkflowType.DocumentExtractNode,
text: '提取文档中的内容',
label: '文档内容提取',
height: 252,
properties: {
stepName: '文档内容提取',
config: {
fields: [
{
label: '文件内容',
value: 'content'
},
]
}
}
}
export const imageUnderstandNode = {
type: WorkflowType.ImageUnderstandNode,
text: '识别出图片中的对象、场景等信息回答用户问题',
label: '图片理解',
height: 252,
properties: {
stepName: '图片理解',
config: {
fields: [
{
label: 'AI 回答内容',
value: 'content'
},
]
}
}
}
export const menuNodes = [
aiChatNode,
searchDatasetNode,
questionNode,
conditionNode,
replyNode,
rerankerNode
rerankerNode,
documentExtractNode,
imageUnderstandNode
]
/**
@ -261,7 +297,9 @@ export const nodeDict: any = {
[WorkflowType.FunctionLib]: functionLibNode,
[WorkflowType.FunctionLibCustom]: functionNode,
[WorkflowType.RrerankerNode]: rerankerNode,
[WorkflowType.Application]: applicationNode
[WorkflowType.Application]: applicationNode,
[WorkflowType.DocumentExtractNode]: documentExtractNode,
[WorkflowType.ImageUnderstandNode]: imageUnderstandNode,
}
export function isWorkFlow(type: string | undefined) {
return type === 'WORK_FLOW'

View File

@ -0,0 +1,6 @@
<template>
<AppAvatar shape="square" style="background: #7F3BF5">
<img src="@/assets/icon_document.svg" style="width: 65%" alt="" />
</AppAvatar>
</template>
<script setup lang="ts"></script>

View File

@ -0,0 +1,6 @@
<template>
<AppAvatar shape="square" style="background: #14C0FF;">
<img src="@/assets/icon_document.svg" style="width: 65%" alt="" />
</AppAvatar>
</template>
<script setup lang="ts"></script>

View File

@ -0,0 +1,101 @@
<template>
<el-dialog
title="文件上传设置"
v-model="dialogVisible"
:close-on-click-modal="false"
:close-on-press-escape="false"
:destroy-on-close="true"
:before-close="close"
append-to-body
>
<el-form
label-position="top"
ref="fieldFormRef"
:model="form_data"
require-asterisk-position="right">
<el-form-item label="单次上传最多文件数">
<el-slider v-model="form_data.maxFiles" show-input :show-input-controls="false" :min="1" :max="10" />
</el-form-item>
<el-form-item label="每个文件最大MB">
<el-slider v-model="form_data.fileLimit" show-input :show-input-controls="false" :min="1" :max="100" />
</el-form-item>
<el-form-item label="上传的文件类型">
<el-card style="width: 100%" class="mb-8">
<div class="flex-between">
<p>
文档TXTMDDOCXHTMLCSVXLSXXLSPDF
需要与文档内容提取节点配合使用
</p>
<el-checkbox v-model="form_data.document" />
</div>
</el-card>
<el-card style="width: 100%" class="mb-8">
<div class="flex-between">
<p>
图片JPGJPEGPNGGIF
所选模型需要支持接收图片
</p>
<el-checkbox v-model="form_data.image" />
</div>
</el-card>
</el-form-item>
</el-form>
<template #footer>
<span class="dialog-footer">
<el-button @click.prevent="close"> 取消 </el-button>
<el-button type="primary" @click="submit()" :loading="loading">
确定
</el-button>
</span>
</template>
</el-dialog>
</template>
<script setup lang="ts">
import { nextTick, ref } from 'vue'
const emit = defineEmits(['refresh'])
const props = defineProps<{ nodeModel: any }>()
const dialogVisible = ref(false)
const loading = ref(false)
const fieldFormRef = ref()
const form_data = ref({
maxFiles: 3,
fileLimit: 50,
document: true,
image: false,
audio: false,
video: false
})
function open(data: any) {
dialogVisible.value = true
nextTick(() => {
form_data.value = { ...form_data.value, ...data }
})
}
function close() {
dialogVisible.value = false
}
async function submit() {
const formEl = fieldFormRef.value
if (!formEl) return
await formEl.validate().then(() => {
emit('refresh', form_data.value)
props.nodeModel.graphModel.eventCenter.emit('refreshFileUploadConfig')
dialogVisible.value = false
})
}
defineExpose({
open
})
</script>
<style scoped lang="scss">
</style>

View File

@ -45,6 +45,36 @@
@submitDialog="submitDialog"
/>
</el-form-item>
<el-form-item >
<template #label>
<div class="flex-between">
<div class="flex align-center">
<span class="mr-4">文件上传</span>
<el-tooltip
effect="dark"
content="开启后,问答页面会显示上传文件的按钮。"
placement="right"
>
<AppIcon iconName="app-warning" class="app-warning-icon"></AppIcon>
</el-tooltip>
</div>
<div>
<el-button
v-if="form_data.file_upload_enable"
type="primary"
link
@click="openFileUploadSettingDialog"
class="mr-4"
>
<el-icon class="mr-4">
<Setting />
</el-icon>
</el-button>
<el-switch size="small" v-model="form_data.file_upload_enable"/>
</div>
</div>
</template>
</el-form-item>
<UserInputFieldTable ref="UserInputFieldTableFef" :node-model="nodeModel" />
<ApiInputFieldTable ref="ApiInputFieldTableFef" :node-model="nodeModel" />
<el-form-item>
@ -139,7 +169,6 @@
<el-icon class="mr-4">
<Setting />
</el-icon>
设置
</el-button>
<el-switch size="small" v-model="form_data.tts_model_enable" @change="ttsModelEnableChange"/>
</div>
@ -212,6 +241,7 @@
</el-form-item>
</el-form>
<TTSModeParamSettingDialog ref="TTSModeParamSettingDialogRef" @refresh="refreshTTSForm" />
<FileUploadSettingDialog ref="FileUploadSettingDialogRef" :node-model="nodeModel" @refresh="refreshFileUploadForm"/>
</NodeContainer>
</template>
<script setup lang="ts">
@ -229,6 +259,7 @@ import { t } from '@/locales'
import TTSModeParamSettingDialog from '@/views/application/component/TTSModeParamSettingDialog.vue'
import ApiInputFieldTable from './component/ApiInputFieldTable.vue'
import UserInputFieldTable from './component/UserInputFieldTable.vue'
import FileUploadSettingDialog from '@/workflow/nodes/base-node/component/FileUploadSettingDialog.vue'
const { model } = useStore()
@ -244,6 +275,7 @@ const providerOptions = ref<Array<Provider>>([])
const TTSModeParamSettingDialogRef = ref<InstanceType<typeof TTSModeParamSettingDialog>>()
const UserInputFieldTableFef = ref()
const ApiInputFieldTableFef = ref()
const FileUploadSettingDialogRef = ref<InstanceType<typeof FileUploadSettingDialog>>()
const form = {
name: '',
@ -350,6 +382,14 @@ const refreshTTSForm = (data: any) => {
form_data.value.tts_model_params_setting = data
}
const openFileUploadSettingDialog = () => {
FileUploadSettingDialogRef.value?.open(form_data.value.file_upload_setting)
}
const refreshFileUploadForm = (data: any) => {
form_data.value.file_upload_setting = data
}
onMounted(() => {
set(props.nodeModel, 'validate', validate)
if (!props.nodeModel.properties.node_data.tts_type) {

View File

@ -0,0 +1,12 @@
import DocumentExtractNodeVue from './index.vue'
import { AppNode, AppNodeModel } from '@/workflow/common/app-node'
class RerankerNode extends AppNode {
constructor(props: any) {
super(props, DocumentExtractNodeVue)
}
}
export default {
type: 'document-extract-node',
model: AppNodeModel,
view: RerankerNode
}

View File

@ -0,0 +1,64 @@
<template>
<NodeContainer :nodeModel="nodeModel">
<h5 class="title-decoration-1 mb-8">节点设置</h5>
<el-card shadow="never" class="card-never">
<el-form
@submit.prevent
:model="form_data"
label-position="top"
require-asterisk-position="right"
label-width="auto"
ref="DatasetNodeFormRef"
>
<el-form-item label="选择文件" :rules="{
type: 'array',
required: true,
message: '请选择文件',
trigger: 'change'
}"
>
<NodeCascader
ref="nodeCascaderRef"
:nodeModel="nodeModel"
class="w-full"
placeholder="请选择文件"
v-model="form.file_list"
/>
</el-form-item>
</el-form>
</el-card>
</NodeContainer>
</template>
<script setup lang="ts">
import NodeContainer from '@/workflow/common/NodeContainer.vue'
import { computed } from 'vue'
import { set } from 'lodash'
import NodeCascader from '@/workflow/common/NodeCascader.vue'
const props = defineProps<{ nodeModel: any }>()
const form = {
file_list: []
}
const form_data = computed({
get: () => {
if (props.nodeModel.properties.node_data) {
return props.nodeModel.properties.node_data
} else {
set(props.nodeModel.properties, 'node_data', form)
}
return props.nodeModel.properties.node_data
},
set: (value) => {
set(props.nodeModel.properties, 'node_data', value)
}
})
</script>
<style lang="scss" scoped>
</style>

View File

@ -0,0 +1,14 @@
import ImageUnderstandNodeVue from './index.vue'
import { AppNode, AppNodeModel } from '@/workflow/common/app-node'
class RerankerNode extends AppNode {
constructor(props: any) {
super(props, ImageUnderstandNodeVue)
}
}
export default {
type: 'image-understand-node',
model: AppNodeModel,
view: RerankerNode
}

View File

@ -0,0 +1,277 @@
<template>
<NodeContainer :node-model="nodeModel">
<h5 class="title-decoration-1 mb-8">节点设置</h5>
<el-card shadow="never" class="card-never">
<el-form
@submit.prevent
:model="form_data"
label-position="top"
require-asterisk-position="right"
class="mb-24"
label-width="auto"
ref="aiChatNodeFormRef"
hide-required-asterisk
>
<el-form-item
label="图片理解模型"
prop="model_id"
:rules="{
required: true,
message: '请选择图片理解模型',
trigger: 'change'
}"
>
<template #label>
<div class="flex-between w-full">
<div>
<span>图片理解模型<span class="danger">*</span></span>
</div>
</div>
</template>
<el-select
@change="model_change"
@wheel="wheel"
:teleported="false"
v-model="form_data.model_id"
placeholder="请选择图片理解模型"
class="w-full"
popper-class="select-model"
:clearable="true"
>
<el-option-group
v-for="(value, label) in modelOptions"
:key="value"
:label="relatedObject(providerOptions, label, 'provider')?.name"
>
<el-option
v-for="item in value.filter((v: any) => v.status === 'SUCCESS')"
:key="item.id"
:label="item.name"
:value="item.id"
class="flex-between"
>
<div class="flex align-center">
<span
v-html="relatedObject(providerOptions, label, 'provider')?.icon"
class="model-icon mr-8"
></span>
<span>{{ item.name }}</span>
<el-tag v-if="item.permission_type === 'PUBLIC'" type="info" class="info-tag ml-8"
>公用
</el-tag>
</div>
<el-icon class="check-icon" v-if="item.id === form_data.model_id">
<Check />
</el-icon>
</el-option>
<!-- 不可用 -->
<el-option
v-for="item in value.filter((v: any) => v.status !== 'SUCCESS')"
:key="item.id"
:label="item.name"
:value="item.id"
class="flex-between"
disabled
>
<div class="flex">
<span
v-html="relatedObject(providerOptions, label, 'provider')?.icon"
class="model-icon mr-8"
></span>
<span>{{ item.name }}</span>
<span class="danger">不可用</span>
</div>
<el-icon class="check-icon" v-if="item.id === form_data.model_id">
<Check />
</el-icon>
</el-option>
</el-option-group>
</el-select>
</el-form-item>
<el-form-item label="角色设定">
<MdEditorMagnify
title="角色设定"
v-model="form_data.system"
style="height: 100px"
@submitDialog="submitSystemDialog"
placeholder="角色设定"
/>
</el-form-item>
<el-form-item
label="提示词"
prop="prompt"
:rules="{
required: true,
message: '请输入提示词',
trigger: 'blur'
}"
>
<template #label>
<div class="flex align-center">
<div class="mr-4">
<span>提示词<span class="danger">*</span></span>
</div>
<el-tooltip effect="dark" placement="right" popper-class="max-w-200">
<template #content
>通过调整提示词内容可以引导大模型聊天方向该提示词会被固定在上下文的开头可以使用变量
</template>
<AppIcon iconName="app-warning" class="app-warning-icon"></AppIcon>
</el-tooltip>
</div>
</template>
<MdEditorMagnify
@wheel="wheel"
title="提示词"
v-model="form_data.prompt"
style="height: 150px"
@submitDialog="submitDialog"
/>
</el-form-item>
<el-form-item label="历史聊天记录">
<el-input-number
v-model="form_data.dialogue_number"
:min="0"
:value-on-clear="0"
controls-position="right"
class="w-full"
/>
</el-form-item>
<el-form-item label="选择图片" :rules="{
type: 'array',
required: true,
message: '请选择图片',
trigger: 'change'
}"
>
<NodeCascader
ref="nodeCascaderRef"
:nodeModel="nodeModel"
class="w-full"
placeholder="请选择图片"
v-model="form_data.image_list"
/>
</el-form-item>
<el-form-item label="返回内容" @click.prevent>
<template #label>
<div class="flex align-center">
<div class="mr-4">
<span>返回内容<span class="danger">*</span></span>
</div>
<el-tooltip effect="dark" placement="right" popper-class="max-w-200">
<template #content>
关闭后该节点的内容则不输出给用户
如果你想让用户看到该节点的输出内容请打开开关
</template>
<AppIcon iconName="app-warning" class="app-warning-icon"></AppIcon>
</el-tooltip>
</div>
</template>
<el-switch size="small" v-model="form_data.is_result" />
</el-form-item>
</el-form>
</el-card>
</NodeContainer>
</template>
<script setup lang="ts">
import NodeContainer from '@/workflow/common/NodeContainer.vue'
import { computed, onMounted, ref } from 'vue'
import { groupBy, set } from 'lodash'
import { relatedObject } from '@/utils/utils'
import type { Provider } from '@/api/type/model'
import applicationApi from '@/api/application'
import { app } from '@/main'
import useStore from '@/stores'
import NodeCascader from '@/workflow/common/NodeCascader.vue'
const { model } = useStore()
const {
params: { id }
} = app.config.globalProperties.$route as any
const props = defineProps<{ nodeModel: any }>()
const modelOptions = ref<any>(null)
const providerOptions = ref<Array<Provider>>([])
const wheel = (e: any) => {
if (e.ctrlKey === true) {
e.preventDefault()
return true
} else {
e.stopPropagation()
return true
}
}
const defaultPrompt = `{{开始.question}}`
const form = {
model_id: '',
system: '',
prompt: defaultPrompt,
dialogue_number: 0,
is_result: true,
temperature: null,
max_tokens: null,
image_list: ["start-node", "image"]
}
const form_data = computed({
get: () => {
if (props.nodeModel.properties.node_data) {
return props.nodeModel.properties.node_data
} else {
set(props.nodeModel.properties, 'node_data', form)
}
return props.nodeModel.properties.node_data
},
set: (value) => {
set(props.nodeModel.properties, 'node_data', value)
}
})
function getModel() {
if (id) {
applicationApi.getApplicationImageModel(id).then((res: any) => {
modelOptions.value = groupBy(res?.data, 'provider')
})
} else {
model.asyncGetModel().then((res: any) => {
modelOptions.value = groupBy(res?.data, 'provider')
})
}
}
function getProvider() {
model.asyncGetProvider().then((res: any) => {
providerOptions.value = res?.data
})
}
const model_change = (model_id?: string) => {
}
function submitSystemDialog(val: string) {
set(props.nodeModel.properties.node_data, 'system', val)
}
function submitDialog(val: string) {
set(props.nodeModel.properties.node_data, 'prompt', val)
}
onMounted(() => {
getModel()
getProvider()
})
</script>
<style scoped lang="scss">
</style>

View File

@ -61,8 +61,31 @@ const refreshFieldList = () => {
}
props.nodeModel.graphModel.eventCenter.on('refreshFieldList', refreshFieldList)
const refreshFileUploadConfig = () => {
let fields = cloneDeep(props.nodeModel.properties.config.fields)
const form_data = props.nodeModel.graphModel.nodes
.filter((v: any) => v.id === 'base-node')
.map((v: any) => cloneDeep(v.properties.node_data.file_upload_setting))
.filter((v: any) => v)
if (form_data.length === 0) {
return
}
fields = fields.filter((item: any) => item.value !== 'image' && item.value !== 'document')
let fileUploadFields = []
if (form_data[0].document) {
fileUploadFields.push({ label: '文档', value: 'document' })
}
if (form_data[0].image) {
fileUploadFields.push({ label: '图片', value: 'image' })
}
set(props.nodeModel.properties.config, 'fields', [...fields, ...fileUploadFields])
}
props.nodeModel.graphModel.eventCenter.on('refreshFileUploadConfig', refreshFileUploadConfig)
onMounted(() => {
refreshFieldList()
refreshFileUploadConfig()
})
</script>
<style lang="scss" scoped></style>