diff --git a/apps/application/chat_pipeline/I_base_chat_pipeline.py b/apps/application/chat_pipeline/I_base_chat_pipeline.py
new file mode 100644
index 000000000..a35bdc39c
--- /dev/null
+++ b/apps/application/chat_pipeline/I_base_chat_pipeline.py
@@ -0,0 +1,157 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: I_base_chat_pipeline.py
+ @date:2024/1/9 17:25
+ @desc:
+"""
+import time
+from abc import abstractmethod
+from typing import Type
+
+from rest_framework import serializers
+
+from dataset.models import Paragraph
+
+
+class ParagraphPipelineModel:
+
+ def __init__(self, _id: str, document_id: str, dataset_id: str, content: str, title: str, status: str,
+ is_active: bool, comprehensive_score: float, similarity: float, dataset_name: str, document_name: str,
+ hit_handling_method: str, directly_return_similarity: float, meta: dict = None):
+ self.id = _id
+ self.document_id = document_id
+ self.dataset_id = dataset_id
+ self.content = content
+ self.title = title
+ self.status = status,
+ self.is_active = is_active
+ self.comprehensive_score = comprehensive_score
+ self.similarity = similarity
+ self.dataset_name = dataset_name
+ self.document_name = document_name
+ self.hit_handling_method = hit_handling_method
+ self.directly_return_similarity = directly_return_similarity
+ self.meta = meta
+
+ def to_dict(self):
+ return {
+ 'id': self.id,
+ 'document_id': self.document_id,
+ 'dataset_id': self.dataset_id,
+ 'content': self.content,
+ 'title': self.title,
+ 'status': self.status,
+ 'is_active': self.is_active,
+ 'comprehensive_score': self.comprehensive_score,
+ 'similarity': self.similarity,
+ 'dataset_name': self.dataset_name,
+ 'document_name': self.document_name,
+ 'meta': self.meta,
+ }
+
+ class builder:
+ def __init__(self):
+ self.similarity = None
+ self.paragraph = {}
+ self.comprehensive_score = None
+ self.document_name = None
+ self.dataset_name = None
+ self.hit_handling_method = None
+ self.directly_return_similarity = 0.9
+ self.meta = {}
+
+ def add_paragraph(self, paragraph):
+ if isinstance(paragraph, Paragraph):
+ self.paragraph = {'id': paragraph.id,
+ 'document_id': paragraph.document_id,
+ 'dataset_id': paragraph.dataset_id,
+ 'content': paragraph.content,
+ 'title': paragraph.title,
+ 'status': paragraph.status,
+ 'is_active': paragraph.is_active,
+ }
+ else:
+ self.paragraph = paragraph
+ return self
+
+ def add_dataset_name(self, dataset_name):
+ self.dataset_name = dataset_name
+ return self
+
+ def add_document_name(self, document_name):
+ self.document_name = document_name
+ return self
+
+ def add_hit_handling_method(self, hit_handling_method):
+ self.hit_handling_method = hit_handling_method
+ return self
+
+ def add_directly_return_similarity(self, directly_return_similarity):
+ self.directly_return_similarity = directly_return_similarity
+ return self
+
+ def add_comprehensive_score(self, comprehensive_score: float):
+ self.comprehensive_score = comprehensive_score
+ return self
+
+ def add_similarity(self, similarity: float):
+ self.similarity = similarity
+ return self
+
+ def add_meta(self, meta: dict):
+ self.meta = meta
+ return self
+
+ def build(self):
+ return ParagraphPipelineModel(str(self.paragraph.get('id')), str(self.paragraph.get('document_id')),
+ str(self.paragraph.get('dataset_id')),
+ self.paragraph.get('content'), self.paragraph.get('title'),
+ self.paragraph.get('status'),
+ self.paragraph.get('is_active'),
+ self.comprehensive_score, self.similarity, self.dataset_name,
+ self.document_name, self.hit_handling_method, self.directly_return_similarity,
+ self.meta)
+
+
+class IBaseChatPipelineStep:
+ def __init__(self):
+ # 当前步骤上下文,用于存储当前步骤信息
+ self.context = {}
+
+ @abstractmethod
+ def get_step_serializer(self, manage) -> Type[serializers.Serializer]:
+ pass
+
+ def valid_args(self, manage):
+ step_serializer_clazz = self.get_step_serializer(manage)
+ step_serializer = step_serializer_clazz(data=manage.context)
+ step_serializer.is_valid(raise_exception=True)
+ self.context['step_args'] = step_serializer.data
+
+ def run(self, manage):
+ """
+
+ :param manage: 步骤管理器
+ :return: 执行结果
+ """
+ start_time = time.time()
+ self.context['start_time'] = start_time
+ # 校验参数,
+ self.valid_args(manage)
+ self._run(manage)
+ self.context['run_time'] = time.time() - start_time
+
+ def _run(self, manage):
+ pass
+
+ def execute(self, **kwargs):
+ pass
+
+ def get_details(self, manage, **kwargs):
+ """
+ 运行详情
+ :return: 步骤详情
+ """
+ return None
diff --git a/apps/application/chat_pipeline/__init__.py b/apps/application/chat_pipeline/__init__.py
new file mode 100644
index 000000000..719a7e29c
--- /dev/null
+++ b/apps/application/chat_pipeline/__init__.py
@@ -0,0 +1,8 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: __init__.py.py
+ @date:2024/1/9 17:23
+ @desc:
+"""
diff --git a/apps/application/chat_pipeline/pipeline_manage.py b/apps/application/chat_pipeline/pipeline_manage.py
new file mode 100644
index 000000000..7c4acb3a3
--- /dev/null
+++ b/apps/application/chat_pipeline/pipeline_manage.py
@@ -0,0 +1,57 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: pipeline_manage.py
+ @date:2024/1/9 17:40
+ @desc:
+"""
+import time
+from functools import reduce
+from typing import List, Type, Dict
+
+from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep
+from common.handle.base_to_response import BaseToResponse
+from common.handle.impl.response.system_to_response import SystemToResponse
+
+
+class PipelineManage:
+ def __init__(self, step_list: List[Type[IBaseChatPipelineStep]],
+ base_to_response: BaseToResponse = SystemToResponse()):
+ # 步骤执行器
+ self.step_list = [step() for step in step_list]
+ # 上下文
+ self.context = {'message_tokens': 0, 'answer_tokens': 0}
+ self.base_to_response = base_to_response
+
+ def run(self, context: Dict = None):
+ self.context['start_time'] = time.time()
+ if context is not None:
+ for key, value in context.items():
+ self.context[key] = value
+ for step in self.step_list:
+ step.run(self)
+
+ def get_details(self):
+ return reduce(lambda x, y: {**x, **y}, [{item.get('step_type'): item} for item in
+ filter(lambda r: r is not None,
+ [row.get_details(self) for row in self.step_list])], {})
+
+ def get_base_to_response(self):
+ return self.base_to_response
+
+ class builder:
+ def __init__(self):
+ self.step_list: List[Type[IBaseChatPipelineStep]] = []
+ self.base_to_response = SystemToResponse()
+
+ def append_step(self, step: Type[IBaseChatPipelineStep]):
+ self.step_list.append(step)
+ return self
+
+ def add_base_to_response(self, base_to_response: BaseToResponse):
+ self.base_to_response = base_to_response
+ return self
+
+ def build(self):
+ return PipelineManage(step_list=self.step_list, base_to_response=self.base_to_response)
diff --git a/apps/application/chat_pipeline/step/__init__.py b/apps/application/chat_pipeline/step/__init__.py
new file mode 100644
index 000000000..5d9549cdc
--- /dev/null
+++ b/apps/application/chat_pipeline/step/__init__.py
@@ -0,0 +1,8 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: __init__.py.py
+ @date:2024/1/9 18:23
+ @desc:
+"""
diff --git a/apps/application/chat_pipeline/step/chat_step/__init__.py b/apps/application/chat_pipeline/step/chat_step/__init__.py
new file mode 100644
index 000000000..5d9549cdc
--- /dev/null
+++ b/apps/application/chat_pipeline/step/chat_step/__init__.py
@@ -0,0 +1,8 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: __init__.py.py
+ @date:2024/1/9 18:23
+ @desc:
+"""
diff --git a/apps/application/chat_pipeline/step/chat_step/i_chat_step.py b/apps/application/chat_pipeline/step/chat_step/i_chat_step.py
new file mode 100644
index 000000000..2673c6b7b
--- /dev/null
+++ b/apps/application/chat_pipeline/step/chat_step/i_chat_step.py
@@ -0,0 +1,110 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: i_chat_step.py
+ @date:2024/1/9 18:17
+ @desc: 对话
+"""
+from abc import abstractmethod
+from typing import Type, List
+
+from django.utils.translation import gettext_lazy as _
+from langchain.chat_models.base import BaseChatModel
+from langchain.schema import BaseMessage
+from rest_framework import serializers
+
+from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep, ParagraphPipelineModel
+from application.chat_pipeline.pipeline_manage import PipelineManage
+from application.serializers.application_serializers import NoReferencesSetting
+from common.field.common import InstanceField
+from common.util.field_message import ErrMessage
+
+
+class ModelField(serializers.Field):
+ def to_internal_value(self, data):
+ if not isinstance(data, BaseChatModel):
+ self.fail(_('Model type error'), value=data)
+ return data
+
+ def to_representation(self, value):
+ return value
+
+
+class MessageField(serializers.Field):
+ def to_internal_value(self, data):
+ if not isinstance(data, BaseMessage):
+ self.fail(_('Message type error'), value=data)
+ return data
+
+ def to_representation(self, value):
+ return value
+
+
+class PostResponseHandler:
+ @abstractmethod
+ def handler(self, chat_id, chat_record_id, paragraph_list: List[ParagraphPipelineModel], problem_text: str,
+ answer_text,
+ manage, step, padding_problem_text: str = None, client_id=None, **kwargs):
+ pass
+
+
+class IChatStep(IBaseChatPipelineStep):
+ class InstanceSerializer(serializers.Serializer):
+ # 对话列表
+ message_list = serializers.ListField(required=True, child=MessageField(required=True),
+ error_messages=ErrMessage.list(_("Conversation list")))
+ model_id = serializers.UUIDField(required=False, allow_null=True, error_messages=ErrMessage.uuid(_("Model id")))
+ # 段落列表
+ paragraph_list = serializers.ListField(error_messages=ErrMessage.list(_("Paragraph List")))
+ # 对话id
+ chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid(_("Conversation ID")))
+ # 用户问题
+ problem_text = serializers.CharField(required=True, error_messages=ErrMessage.uuid(_("User Questions")))
+ # 后置处理器
+ post_response_handler = InstanceField(model_type=PostResponseHandler,
+ error_messages=ErrMessage.base(_("Post-processor")))
+ # 补全问题
+ padding_problem_text = serializers.CharField(required=False,
+ error_messages=ErrMessage.base(_("Completion Question")))
+ # 是否使用流的形式输出
+ stream = serializers.BooleanField(required=False, error_messages=ErrMessage.base(_("Streaming Output")))
+ client_id = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Client id")))
+ client_type = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Client Type")))
+ # 未查询到引用分段
+ no_references_setting = NoReferencesSetting(required=True,
+ error_messages=ErrMessage.base(_("No reference segment settings")))
+
+ user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid(_("User ID")))
+
+ model_setting = serializers.DictField(required=True, allow_null=True,
+ error_messages=ErrMessage.dict(_("Model settings")))
+
+ model_params_setting = serializers.DictField(required=False, allow_null=True,
+ error_messages=ErrMessage.dict(_("Model parameter settings")))
+
+ def is_valid(self, *, raise_exception=False):
+ super().is_valid(raise_exception=True)
+ message_list: List = self.initial_data.get('message_list')
+ for message in message_list:
+ if not isinstance(message, BaseMessage):
+ raise Exception(_("message type error"))
+
+ def get_step_serializer(self, manage: PipelineManage) -> Type[serializers.Serializer]:
+ return self.InstanceSerializer
+
+ def _run(self, manage: PipelineManage):
+ chat_result = self.execute(**self.context['step_args'], manage=manage)
+ manage.context['chat_result'] = chat_result
+
+ @abstractmethod
+ def execute(self, message_list: List[BaseMessage],
+ chat_id, problem_text,
+ post_response_handler: PostResponseHandler,
+ model_id: str = None,
+ user_id: str = None,
+ paragraph_list=None,
+ manage: PipelineManage = None,
+ padding_problem_text: str = None, stream: bool = True, client_id=None, client_type=None,
+ no_references_setting=None, model_params_setting=None, model_setting=None, **kwargs):
+ pass
diff --git a/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py b/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py
new file mode 100644
index 000000000..b03f06d80
--- /dev/null
+++ b/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py
@@ -0,0 +1,334 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: base_chat_step.py
+ @date:2024/1/9 18:25
+ @desc: 对话step Base实现
+"""
+import logging
+import time
+import traceback
+import uuid
+from typing import List
+
+from django.db.models import QuerySet
+from django.http import StreamingHttpResponse
+from django.utils.translation import gettext as _
+from langchain.chat_models.base import BaseChatModel
+from langchain.schema import BaseMessage
+from langchain.schema.messages import HumanMessage, AIMessage
+from langchain_core.messages import AIMessageChunk
+from rest_framework import status
+
+from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel
+from application.chat_pipeline.pipeline_manage import PipelineManage
+from application.chat_pipeline.step.chat_step.i_chat_step import IChatStep, PostResponseHandler
+from application.flow.tools import Reasoning
+from application.models.api_key_model import ApplicationPublicAccessClient
+from common.constants.authentication_type import AuthenticationType
+from setting.models_provider.tools import get_model_instance_by_model_user_id
+
+
+def add_access_num(client_id=None, client_type=None, application_id=None):
+ if client_type == AuthenticationType.APPLICATION_ACCESS_TOKEN.value and application_id is not None:
+ application_public_access_client = (QuerySet(ApplicationPublicAccessClient).filter(client_id=client_id,
+ application_id=application_id)
+ .first())
+ if application_public_access_client is not None:
+ application_public_access_client.access_num = application_public_access_client.access_num + 1
+ application_public_access_client.intraday_access_num = application_public_access_client.intraday_access_num + 1
+ application_public_access_client.save()
+
+
+def write_context(step, manage, request_token, response_token, all_text):
+ step.context['message_tokens'] = request_token
+ step.context['answer_tokens'] = response_token
+ current_time = time.time()
+ step.context['answer_text'] = all_text
+ step.context['run_time'] = current_time - step.context['start_time']
+ manage.context['run_time'] = current_time - manage.context['start_time']
+ manage.context['message_tokens'] = manage.context['message_tokens'] + request_token
+ manage.context['answer_tokens'] = manage.context['answer_tokens'] + response_token
+
+
+def event_content(response,
+ chat_id,
+ chat_record_id,
+ paragraph_list: List[ParagraphPipelineModel],
+ post_response_handler: PostResponseHandler,
+ manage,
+ step,
+ chat_model,
+ message_list: List[BaseMessage],
+ problem_text: str,
+ padding_problem_text: str = None,
+ client_id=None, client_type=None,
+ is_ai_chat: bool = None,
+ model_setting=None):
+ if model_setting is None:
+ model_setting = {}
+ reasoning_content_enable = model_setting.get('reasoning_content_enable', False)
+ reasoning_content_start = model_setting.get('reasoning_content_start', '')
+ reasoning_content_end = model_setting.get('reasoning_content_end', '')
+ reasoning = Reasoning(reasoning_content_start,
+ reasoning_content_end)
+ all_text = ''
+ reasoning_content = ''
+ try:
+ response_reasoning_content = False
+ for chunk in response:
+ reasoning_chunk = reasoning.get_reasoning_content(chunk)
+ content_chunk = reasoning_chunk.get('content')
+ if 'reasoning_content' in chunk.additional_kwargs:
+ response_reasoning_content = True
+ reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '')
+ else:
+ reasoning_content_chunk = reasoning_chunk.get('reasoning_content')
+ all_text += content_chunk
+ if reasoning_content_chunk is None:
+ reasoning_content_chunk = ''
+ reasoning_content += reasoning_content_chunk
+ yield manage.get_base_to_response().to_stream_chunk_response(chat_id, str(chat_record_id), 'ai-chat-node',
+ [], content_chunk,
+ False,
+ 0, 0, {'node_is_end': False,
+ 'view_type': 'many_view',
+ 'node_type': 'ai-chat-node',
+ 'real_node_id': 'ai-chat-node',
+ 'reasoning_content': reasoning_content_chunk if reasoning_content_enable else ''})
+ reasoning_chunk = reasoning.get_end_reasoning_content()
+ all_text += reasoning_chunk.get('content')
+ reasoning_content_chunk = ""
+ if not response_reasoning_content:
+ reasoning_content_chunk = reasoning_chunk.get(
+ 'reasoning_content')
+ yield manage.get_base_to_response().to_stream_chunk_response(chat_id, str(chat_record_id), 'ai-chat-node',
+ [], reasoning_chunk.get('content'),
+ False,
+ 0, 0, {'node_is_end': False,
+ 'view_type': 'many_view',
+ 'node_type': 'ai-chat-node',
+ 'real_node_id': 'ai-chat-node',
+ 'reasoning_content'
+ : reasoning_content_chunk if reasoning_content_enable else ''})
+ # 获取token
+ if is_ai_chat:
+ try:
+ request_token = chat_model.get_num_tokens_from_messages(message_list)
+ response_token = chat_model.get_num_tokens(all_text)
+ except Exception as e:
+ request_token = 0
+ response_token = 0
+ else:
+ request_token = 0
+ response_token = 0
+ write_context(step, manage, request_token, response_token, all_text)
+ asker = manage.context.get('form_data', {}).get('asker', None)
+ post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text,
+ all_text, manage, step, padding_problem_text, client_id,
+ reasoning_content=reasoning_content if reasoning_content_enable else ''
+ , asker=asker)
+ yield manage.get_base_to_response().to_stream_chunk_response(chat_id, str(chat_record_id), 'ai-chat-node',
+ [], '', True,
+ request_token, response_token,
+ {'node_is_end': True, 'view_type': 'many_view',
+ 'node_type': 'ai-chat-node'})
+ add_access_num(client_id, client_type, manage.context.get('application_id'))
+ except Exception as e:
+ logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}')
+ all_text = 'Exception:' + str(e)
+ write_context(step, manage, 0, 0, all_text)
+ asker = manage.context.get('form_data', {}).get('asker', None)
+ post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text,
+ all_text, manage, step, padding_problem_text, client_id, reasoning_content='',
+ asker=asker)
+ add_access_num(client_id, client_type, manage.context.get('application_id'))
+ yield manage.get_base_to_response().to_stream_chunk_response(chat_id, str(chat_record_id), 'ai-chat-node',
+ [], all_text,
+ False,
+ 0, 0, {'node_is_end': False,
+ 'view_type': 'many_view',
+ 'node_type': 'ai-chat-node',
+ 'real_node_id': 'ai-chat-node',
+ 'reasoning_content': ''})
+
+
+class BaseChatStep(IChatStep):
+ def execute(self, message_list: List[BaseMessage],
+ chat_id,
+ problem_text,
+ post_response_handler: PostResponseHandler,
+ model_id: str = None,
+ user_id: str = None,
+ paragraph_list=None,
+ manage: PipelineManage = None,
+ padding_problem_text: str = None,
+ stream: bool = True,
+ client_id=None, client_type=None,
+ no_references_setting=None,
+ model_params_setting=None,
+ model_setting=None,
+ **kwargs):
+ chat_model = get_model_instance_by_model_user_id(model_id, user_id,
+ **model_params_setting) if model_id is not None else None
+ if stream:
+ return self.execute_stream(message_list, chat_id, problem_text, post_response_handler, chat_model,
+ paragraph_list,
+ manage, padding_problem_text, client_id, client_type, no_references_setting,
+ model_setting)
+ else:
+ return self.execute_block(message_list, chat_id, problem_text, post_response_handler, chat_model,
+ paragraph_list,
+ manage, padding_problem_text, client_id, client_type, no_references_setting,
+ model_setting)
+
+ def get_details(self, manage, **kwargs):
+ return {
+ 'step_type': 'chat_step',
+ 'run_time': self.context['run_time'],
+ 'model_id': str(manage.context['model_id']),
+ 'message_list': self.reset_message_list(self.context['step_args'].get('message_list'),
+ self.context['answer_text']),
+ 'message_tokens': self.context['message_tokens'],
+ 'answer_tokens': self.context['answer_tokens'],
+ 'cost': 0,
+ }
+
+ @staticmethod
+ def reset_message_list(message_list: List[BaseMessage], answer_text):
+ result = [{'role': 'user' if isinstance(message, HumanMessage) else 'ai', 'content': message.content} for
+ message
+ in
+ message_list]
+ result.append({'role': 'ai', 'content': answer_text})
+ return result
+
+ @staticmethod
+ def get_stream_result(message_list: List[BaseMessage],
+ chat_model: BaseChatModel = None,
+ paragraph_list=None,
+ no_references_setting=None,
+ problem_text=None):
+ if paragraph_list is None:
+ paragraph_list = []
+ directly_return_chunk_list = [AIMessageChunk(content=paragraph.content)
+ for paragraph in paragraph_list if (
+ paragraph.hit_handling_method == 'directly_return' and paragraph.similarity >= paragraph.directly_return_similarity)]
+ if directly_return_chunk_list is not None and len(directly_return_chunk_list) > 0:
+ return iter(directly_return_chunk_list), False
+ elif len(paragraph_list) == 0 and no_references_setting.get(
+ 'status') == 'designated_answer':
+ return iter(
+ [AIMessageChunk(content=no_references_setting.get('value').replace('{question}', problem_text))]), False
+ if chat_model is None:
+ return iter([AIMessageChunk(
+ _('Sorry, the AI model is not configured. Please go to the application to set up the AI model first.'))]), False
+ else:
+ return chat_model.stream(message_list), True
+
+ def execute_stream(self, message_list: List[BaseMessage],
+ chat_id,
+ problem_text,
+ post_response_handler: PostResponseHandler,
+ chat_model: BaseChatModel = None,
+ paragraph_list=None,
+ manage: PipelineManage = None,
+ padding_problem_text: str = None,
+ client_id=None, client_type=None,
+ no_references_setting=None,
+ model_setting=None):
+ chat_result, is_ai_chat = self.get_stream_result(message_list, chat_model, paragraph_list,
+ no_references_setting, problem_text)
+ chat_record_id = uuid.uuid1()
+ r = StreamingHttpResponse(
+ streaming_content=event_content(chat_result, chat_id, chat_record_id, paragraph_list,
+ post_response_handler, manage, self, chat_model, message_list, problem_text,
+ padding_problem_text, client_id, client_type, is_ai_chat, model_setting),
+ content_type='text/event-stream;charset=utf-8')
+
+ r['Cache-Control'] = 'no-cache'
+ return r
+
+ @staticmethod
+ def get_block_result(message_list: List[BaseMessage],
+ chat_model: BaseChatModel = None,
+ paragraph_list=None,
+ no_references_setting=None,
+ problem_text=None):
+ if paragraph_list is None:
+ paragraph_list = []
+ directly_return_chunk_list = [AIMessageChunk(content=paragraph.content)
+ for paragraph in paragraph_list if (
+ paragraph.hit_handling_method == 'directly_return' and paragraph.similarity >= paragraph.directly_return_similarity)]
+ if directly_return_chunk_list is not None and len(directly_return_chunk_list) > 0:
+ return directly_return_chunk_list[0], False
+ elif len(paragraph_list) == 0 and no_references_setting.get(
+ 'status') == 'designated_answer':
+ return AIMessage(no_references_setting.get('value').replace('{question}', problem_text)), False
+ if chat_model is None:
+ return AIMessage(
+ _('Sorry, the AI model is not configured. Please go to the application to set up the AI model first.')), False
+ else:
+ return chat_model.invoke(message_list), True
+
+ def execute_block(self, message_list: List[BaseMessage],
+ chat_id,
+ problem_text,
+ post_response_handler: PostResponseHandler,
+ chat_model: BaseChatModel = None,
+ paragraph_list=None,
+ manage: PipelineManage = None,
+ padding_problem_text: str = None,
+ client_id=None, client_type=None, no_references_setting=None,
+ model_setting=None):
+ reasoning_content_enable = model_setting.get('reasoning_content_enable', False)
+ reasoning_content_start = model_setting.get('reasoning_content_start', '')
+ reasoning_content_end = model_setting.get('reasoning_content_end', '')
+ reasoning = Reasoning(reasoning_content_start,
+ reasoning_content_end)
+ chat_record_id = uuid.uuid1()
+ # 调用模型
+ try:
+ chat_result, is_ai_chat = self.get_block_result(message_list, chat_model, paragraph_list,
+ no_references_setting, problem_text)
+ if is_ai_chat:
+ request_token = chat_model.get_num_tokens_from_messages(message_list)
+ response_token = chat_model.get_num_tokens(chat_result.content)
+ else:
+ request_token = 0
+ response_token = 0
+ write_context(self, manage, request_token, response_token, chat_result.content)
+ reasoning_result = reasoning.get_reasoning_content(chat_result)
+ reasoning_result_end = reasoning.get_end_reasoning_content()
+ content = reasoning_result.get('content') + reasoning_result_end.get('content')
+ if 'reasoning_content' in chat_result.response_metadata:
+ reasoning_content = chat_result.response_metadata.get('reasoning_content', '')
+ else:
+ reasoning_content = reasoning_result.get('reasoning_content') + reasoning_result_end.get(
+ 'reasoning_content')
+ asker = manage.context.get('form_data', {}).get('asker', None)
+ post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text,
+ content, manage, self, padding_problem_text, client_id,
+ reasoning_content=reasoning_content if reasoning_content_enable else '',
+ asker=asker)
+ add_access_num(client_id, client_type, manage.context.get('application_id'))
+ return manage.get_base_to_response().to_block_response(str(chat_id), str(chat_record_id),
+ content, True,
+ request_token, response_token,
+ {
+ 'reasoning_content': reasoning_content if reasoning_content_enable else '',
+ 'answer_list': [{
+ 'content': content,
+ 'reasoning_content': reasoning_content if reasoning_content_enable else ''
+ }]})
+ except Exception as e:
+ all_text = 'Exception:' + str(e)
+ write_context(self, manage, 0, 0, all_text)
+ asker = manage.context.get('form_data', {}).get('asker', None)
+ post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text,
+ all_text, manage, self, padding_problem_text, client_id, reasoning_content='',
+ asker=asker)
+ add_access_num(client_id, client_type, manage.context.get('application_id'))
+ return manage.get_base_to_response().to_block_response(str(chat_id), str(chat_record_id), all_text, True, 0,
+ 0, _status=status.HTTP_500_INTERNAL_SERVER_ERROR)
diff --git a/apps/application/chat_pipeline/step/generate_human_message_step/__init__.py b/apps/application/chat_pipeline/step/generate_human_message_step/__init__.py
new file mode 100644
index 000000000..5d9549cdc
--- /dev/null
+++ b/apps/application/chat_pipeline/step/generate_human_message_step/__init__.py
@@ -0,0 +1,8 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: __init__.py.py
+ @date:2024/1/9 18:23
+ @desc:
+"""
diff --git a/apps/application/chat_pipeline/step/generate_human_message_step/i_generate_human_message_step.py b/apps/application/chat_pipeline/step/generate_human_message_step/i_generate_human_message_step.py
new file mode 100644
index 000000000..9e23f2d6c
--- /dev/null
+++ b/apps/application/chat_pipeline/step/generate_human_message_step/i_generate_human_message_step.py
@@ -0,0 +1,81 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: i_generate_human_message_step.py
+ @date:2024/1/9 18:15
+ @desc: 生成对话模板
+"""
+from abc import abstractmethod
+from typing import Type, List
+
+from django.utils.translation import gettext_lazy as _
+from langchain.schema import BaseMessage
+from rest_framework import serializers
+
+from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep, ParagraphPipelineModel
+from application.chat_pipeline.pipeline_manage import PipelineManage
+from application.models import ChatRecord
+from application.serializers.application_serializers import NoReferencesSetting
+from common.field.common import InstanceField
+from common.util.field_message import ErrMessage
+
+
+class IGenerateHumanMessageStep(IBaseChatPipelineStep):
+ class InstanceSerializer(serializers.Serializer):
+ # 问题
+ problem_text = serializers.CharField(required=True, error_messages=ErrMessage.char(_("question")))
+ # 段落列表
+ paragraph_list = serializers.ListField(child=InstanceField(model_type=ParagraphPipelineModel, required=True),
+ error_messages=ErrMessage.list(_("Paragraph List")))
+ # 历史对答
+ history_chat_record = serializers.ListField(child=InstanceField(model_type=ChatRecord, required=True),
+ error_messages=ErrMessage.list(_("History Questions")))
+ # 多轮对话数量
+ dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer(_("Number of multi-round conversations")))
+ # 最大携带知识库段落长度
+ max_paragraph_char_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer(
+ _("Maximum length of the knowledge base paragraph")))
+ # 模板
+ prompt = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Prompt word")))
+ system = serializers.CharField(required=False, allow_null=True, allow_blank=True,
+ error_messages=ErrMessage.char(_("System prompt words (role)")))
+ # 补齐问题
+ padding_problem_text = serializers.CharField(required=False, error_messages=ErrMessage.char(_("Completion problem")))
+ # 未查询到引用分段
+ no_references_setting = NoReferencesSetting(required=True, error_messages=ErrMessage.base(_("No reference segment settings")))
+
+ def get_step_serializer(self, manage: PipelineManage) -> Type[serializers.Serializer]:
+ return self.InstanceSerializer
+
+ def _run(self, manage: PipelineManage):
+ message_list = self.execute(**self.context['step_args'])
+ manage.context['message_list'] = message_list
+
+ @abstractmethod
+ def execute(self,
+ problem_text: str,
+ paragraph_list: List[ParagraphPipelineModel],
+ history_chat_record: List[ChatRecord],
+ dialogue_number: int,
+ max_paragraph_char_number: int,
+ prompt: str,
+ padding_problem_text: str = None,
+ no_references_setting=None,
+ system=None,
+ **kwargs) -> List[BaseMessage]:
+ """
+
+ :param problem_text: 原始问题文本
+ :param paragraph_list: 段落列表
+ :param history_chat_record: 历史对话记录
+ :param dialogue_number: 多轮对话数量
+ :param max_paragraph_char_number: 最大段落长度
+ :param prompt: 模板
+ :param padding_problem_text 用户修改文本
+ :param kwargs: 其他参数
+ :param no_references_setting: 无引用分段设置
+ :param system 系统提示称
+ :return:
+ """
+ pass
diff --git a/apps/application/chat_pipeline/step/generate_human_message_step/impl/base_generate_human_message_step.py b/apps/application/chat_pipeline/step/generate_human_message_step/impl/base_generate_human_message_step.py
new file mode 100644
index 000000000..68cfbbcb9
--- /dev/null
+++ b/apps/application/chat_pipeline/step/generate_human_message_step/impl/base_generate_human_message_step.py
@@ -0,0 +1,73 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: base_generate_human_message_step.py.py
+ @date:2024/1/10 17:50
+ @desc:
+"""
+from typing import List, Dict
+
+from langchain.schema import BaseMessage, HumanMessage
+from langchain_core.messages import SystemMessage
+
+from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel
+from application.chat_pipeline.step.generate_human_message_step.i_generate_human_message_step import \
+ IGenerateHumanMessageStep
+from application.models import ChatRecord
+from common.util.split_model import flat_map
+
+
+class BaseGenerateHumanMessageStep(IGenerateHumanMessageStep):
+
+ def execute(self, problem_text: str,
+ paragraph_list: List[ParagraphPipelineModel],
+ history_chat_record: List[ChatRecord],
+ dialogue_number: int,
+ max_paragraph_char_number: int,
+ prompt: str,
+ padding_problem_text: str = None,
+ no_references_setting=None,
+ system=None,
+ **kwargs) -> List[BaseMessage]:
+ prompt = prompt if (paragraph_list is not None and len(paragraph_list) > 0) else no_references_setting.get(
+ 'value')
+ exec_problem_text = padding_problem_text if padding_problem_text is not None else problem_text
+ start_index = len(history_chat_record) - dialogue_number
+ history_message = [[history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()]
+ for index in
+ range(start_index if start_index > 0 else 0, len(history_chat_record))]
+ if system is not None and len(system) > 0:
+ return [SystemMessage(system), *flat_map(history_message),
+ self.to_human_message(prompt, exec_problem_text, max_paragraph_char_number, paragraph_list,
+ no_references_setting)]
+
+ return [*flat_map(history_message),
+ self.to_human_message(prompt, exec_problem_text, max_paragraph_char_number, paragraph_list,
+ no_references_setting)]
+
+ @staticmethod
+ def to_human_message(prompt: str,
+ problem: str,
+ max_paragraph_char_number: int,
+ paragraph_list: List[ParagraphPipelineModel],
+ no_references_setting: Dict):
+ if paragraph_list is None or len(paragraph_list) == 0:
+ if no_references_setting.get('status') == 'ai_questioning':
+ return HumanMessage(
+ content=no_references_setting.get('value').replace('{question}', problem))
+ else:
+ return HumanMessage(content=prompt.replace('{data}', "").replace('{question}', problem))
+ temp_data = ""
+ data_list = []
+ for p in paragraph_list:
+ content = f"{p.title}:{p.content}"
+ temp_data += content
+ if len(temp_data) > max_paragraph_char_number:
+ row_data = content[0:max_paragraph_char_number - len(temp_data)]
+ data_list.append(f"{row_data}")
+ break
+ else:
+ data_list.append(f"{content}")
+ data = "\n".join(data_list)
+ return HumanMessage(content=prompt.replace('{data}', data).replace('{question}', problem))
diff --git a/apps/application/chat_pipeline/step/reset_problem_step/__init__.py b/apps/application/chat_pipeline/step/reset_problem_step/__init__.py
new file mode 100644
index 000000000..5d9549cdc
--- /dev/null
+++ b/apps/application/chat_pipeline/step/reset_problem_step/__init__.py
@@ -0,0 +1,8 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: __init__.py.py
+ @date:2024/1/9 18:23
+ @desc:
+"""
diff --git a/apps/application/chat_pipeline/step/reset_problem_step/i_reset_problem_step.py b/apps/application/chat_pipeline/step/reset_problem_step/i_reset_problem_step.py
new file mode 100644
index 000000000..f48f5c804
--- /dev/null
+++ b/apps/application/chat_pipeline/step/reset_problem_step/i_reset_problem_step.py
@@ -0,0 +1,57 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: i_reset_problem_step.py
+ @date:2024/1/9 18:12
+ @desc: 重写处理问题
+"""
+from abc import abstractmethod
+from typing import Type, List
+
+from django.utils.translation import gettext_lazy as _
+from rest_framework import serializers
+
+from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep
+from application.chat_pipeline.pipeline_manage import PipelineManage
+from application.models import ChatRecord
+from common.field.common import InstanceField
+from common.util.field_message import ErrMessage
+
+
+class IResetProblemStep(IBaseChatPipelineStep):
+ class InstanceSerializer(serializers.Serializer):
+ # 问题文本
+ problem_text = serializers.CharField(required=True, error_messages=ErrMessage.float(_("question")))
+ # 历史对答
+ history_chat_record = serializers.ListField(child=InstanceField(model_type=ChatRecord, required=True),
+ error_messages=ErrMessage.list(_("History Questions")))
+ # 大语言模型
+ model_id = serializers.UUIDField(required=False, allow_null=True, error_messages=ErrMessage.uuid(_("Model id")))
+ user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid(_("User ID")))
+ problem_optimization_prompt = serializers.CharField(required=False, max_length=102400,
+ error_messages=ErrMessage.char(
+ _("Question completion prompt")))
+
+ def get_step_serializer(self, manage: PipelineManage) -> Type[serializers.Serializer]:
+ return self.InstanceSerializer
+
+ def _run(self, manage: PipelineManage):
+ padding_problem = self.execute(**self.context.get('step_args'))
+ # 用户输入问题
+ source_problem_text = self.context.get('step_args').get('problem_text')
+ self.context['problem_text'] = source_problem_text
+ self.context['padding_problem_text'] = padding_problem
+ manage.context['problem_text'] = source_problem_text
+ manage.context['padding_problem_text'] = padding_problem
+ # 累加tokens
+ manage.context['message_tokens'] = manage.context.get('message_tokens', 0) + self.context.get('message_tokens',
+ 0)
+ manage.context['answer_tokens'] = manage.context.get('answer_tokens', 0) + self.context.get('answer_tokens', 0)
+
+ @abstractmethod
+ def execute(self, problem_text: str, history_chat_record: List[ChatRecord] = None, model_id: str = None,
+ problem_optimization_prompt=None,
+ user_id=None,
+ **kwargs):
+ pass
diff --git a/apps/application/chat_pipeline/step/reset_problem_step/impl/base_reset_problem_step.py b/apps/application/chat_pipeline/step/reset_problem_step/impl/base_reset_problem_step.py
new file mode 100644
index 000000000..ec01daa34
--- /dev/null
+++ b/apps/application/chat_pipeline/step/reset_problem_step/impl/base_reset_problem_step.py
@@ -0,0 +1,68 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: base_reset_problem_step.py
+ @date:2024/1/10 14:35
+ @desc:
+"""
+from typing import List
+
+from django.utils.translation import gettext as _
+from langchain.schema import HumanMessage
+
+from application.chat_pipeline.step.reset_problem_step.i_reset_problem_step import IResetProblemStep
+from application.models import ChatRecord
+from common.util.split_model import flat_map
+from setting.models_provider.tools import get_model_instance_by_model_user_id
+
+prompt = _(
+ "() contains the user's question. Answer the guessed user's question based on the context ({question}) Requirement: Output a complete question and put it in the tag")
+
+
+class BaseResetProblemStep(IResetProblemStep):
+ def execute(self, problem_text: str, history_chat_record: List[ChatRecord] = None, model_id: str = None,
+ problem_optimization_prompt=None,
+ user_id=None,
+ **kwargs) -> str:
+ chat_model = get_model_instance_by_model_user_id(model_id, user_id) if model_id is not None else None
+ if chat_model is None:
+ return problem_text
+ start_index = len(history_chat_record) - 3
+ history_message = [[history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()]
+ for index in
+ range(start_index if start_index > 0 else 0, len(history_chat_record))]
+ reset_prompt = problem_optimization_prompt if problem_optimization_prompt else prompt
+ message_list = [*flat_map(history_message),
+ HumanMessage(content=reset_prompt.replace('{question}', problem_text))]
+ response = chat_model.invoke(message_list)
+ padding_problem = problem_text
+ if response.content.__contains__("") and response.content.__contains__(''):
+ padding_problem_data = response.content[
+ response.content.index('') + 6:response.content.index('')]
+ if padding_problem_data is not None and len(padding_problem_data.strip()) > 0:
+ padding_problem = padding_problem_data
+ elif len(response.content) > 0:
+ padding_problem = response.content
+
+ try:
+ request_token = chat_model.get_num_tokens_from_messages(message_list)
+ response_token = chat_model.get_num_tokens(padding_problem)
+ except Exception as e:
+ request_token = 0
+ response_token = 0
+ self.context['message_tokens'] = request_token
+ self.context['answer_tokens'] = response_token
+ return padding_problem
+
+ def get_details(self, manage, **kwargs):
+ return {
+ 'step_type': 'problem_padding',
+ 'run_time': self.context['run_time'],
+ 'model_id': str(manage.context['model_id']) if 'model_id' in manage.context else None,
+ 'message_tokens': self.context.get('message_tokens', 0),
+ 'answer_tokens': self.context.get('answer_tokens', 0),
+ 'cost': 0,
+ 'padding_problem_text': self.context.get('padding_problem_text'),
+ 'problem_text': self.context.get("step_args").get('problem_text'),
+ }
diff --git a/apps/application/chat_pipeline/step/search_dataset_step/__init__.py b/apps/application/chat_pipeline/step/search_dataset_step/__init__.py
new file mode 100644
index 000000000..023c4bc38
--- /dev/null
+++ b/apps/application/chat_pipeline/step/search_dataset_step/__init__.py
@@ -0,0 +1,8 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: __init__.py.py
+ @date:2024/1/9 18:24
+ @desc:
+"""
diff --git a/apps/application/chat_pipeline/step/search_dataset_step/i_search_dataset_step.py b/apps/application/chat_pipeline/step/search_dataset_step/i_search_dataset_step.py
new file mode 100644
index 000000000..7b222cbc2
--- /dev/null
+++ b/apps/application/chat_pipeline/step/search_dataset_step/i_search_dataset_step.py
@@ -0,0 +1,77 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: i_search_dataset_step.py
+ @date:2024/1/9 18:10
+ @desc: 检索知识库
+"""
+import re
+from abc import abstractmethod
+from typing import List, Type
+
+from django.core import validators
+from django.utils.translation import gettext_lazy as _
+from rest_framework import serializers
+
+from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep, ParagraphPipelineModel
+from application.chat_pipeline.pipeline_manage import PipelineManage
+from common.util.field_message import ErrMessage
+
+
+class ISearchDatasetStep(IBaseChatPipelineStep):
+ class InstanceSerializer(serializers.Serializer):
+ # 原始问题文本
+ problem_text = serializers.CharField(required=True, error_messages=ErrMessage.char(_("question")))
+ # 系统补全问题文本
+ padding_problem_text = serializers.CharField(required=False,
+ error_messages=ErrMessage.char(_("System completes question text")))
+ # 需要查询的数据集id列表
+ dataset_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True),
+ error_messages=ErrMessage.list(_("Dataset id list")))
+ # 需要排除的文档id
+ exclude_document_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True),
+ error_messages=ErrMessage.list(_("List of document ids to exclude")))
+ # 需要排除向量id
+ exclude_paragraph_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True),
+ error_messages=ErrMessage.list(_("List of exclusion vector ids")))
+ # 需要查询的条数
+ top_n = serializers.IntegerField(required=True,
+ error_messages=ErrMessage.integer(_("Reference segment number")))
+ # 相似度 0-1之间
+ similarity = serializers.FloatField(required=True, max_value=1, min_value=0,
+ error_messages=ErrMessage.float(_("Similarity")))
+ search_mode = serializers.CharField(required=True, validators=[
+ validators.RegexValidator(regex=re.compile("^embedding|keywords|blend$"),
+ message=_("The type only supports embedding|keywords|blend"), code=500)
+ ], error_messages=ErrMessage.char(_("Retrieval Mode")))
+ user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid(_("User ID")))
+
+ def get_step_serializer(self, manage: PipelineManage) -> Type[InstanceSerializer]:
+ return self.InstanceSerializer
+
+ def _run(self, manage: PipelineManage):
+ paragraph_list = self.execute(**self.context['step_args'])
+ manage.context['paragraph_list'] = paragraph_list
+ self.context['paragraph_list'] = paragraph_list
+
+ @abstractmethod
+ def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str],
+ exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None,
+ search_mode: str = None,
+ user_id=None,
+ **kwargs) -> List[ParagraphPipelineModel]:
+ """
+ 关于 用户和补全问题 说明: 补全问题如果有就使用补全问题去查询 反之就用用户原始问题查询
+ :param similarity: 相关性
+ :param top_n: 查询多少条
+ :param problem_text: 用户问题
+ :param dataset_id_list: 需要查询的数据集id列表
+ :param exclude_document_id_list: 需要排除的文档id
+ :param exclude_paragraph_id_list: 需要排除段落id
+ :param padding_problem_text 补全问题
+ :param search_mode 检索模式
+ :param user_id 用户id
+ :return: 段落列表
+ """
+ pass
diff --git a/apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py b/apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py
new file mode 100644
index 000000000..6591f6d24
--- /dev/null
+++ b/apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py
@@ -0,0 +1,138 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: base_search_dataset_step.py
+ @date:2024/1/10 10:33
+ @desc:
+"""
+import os
+from typing import List, Dict
+
+from django.db.models import QuerySet
+from django.utils.translation import gettext_lazy as _
+from rest_framework.utils.formatting import lazy_format
+
+from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel
+from application.chat_pipeline.step.search_dataset_step.i_search_dataset_step import ISearchDatasetStep
+from common.config.embedding_config import VectorStore, ModelManage
+from common.db.search import native_search
+from common.util.file_util import get_file_content
+from dataset.models import Paragraph, DataSet
+from embedding.models import SearchMode
+from setting.models import Model
+from setting.models_provider import get_model
+from smartdoc.conf import PROJECT_DIR
+
+
+def get_model_by_id(_id, user_id):
+ model = QuerySet(Model).filter(id=_id).first()
+ if model is None:
+ raise Exception(_("Model does not exist"))
+ if model.permission_type == 'PRIVATE' and str(model.user_id) != str(user_id):
+ message = lazy_format(_('No permission to use this model {model_name}'), model_name=model.name)
+ raise Exception(message)
+ return model
+
+
+def get_embedding_id(dataset_id_list):
+ dataset_list = QuerySet(DataSet).filter(id__in=dataset_id_list)
+ if len(set([dataset.embedding_mode_id for dataset in dataset_list])) > 1:
+ raise Exception(_("The vector model of the associated knowledge base is inconsistent and the segmentation cannot be recalled."))
+ if len(dataset_list) == 0:
+ raise Exception(_("The knowledge base setting is wrong, please reset the knowledge base"))
+ return dataset_list[0].embedding_mode_id
+
+
+class BaseSearchDatasetStep(ISearchDatasetStep):
+
+ def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str],
+ exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None,
+ search_mode: str = None,
+ user_id=None,
+ **kwargs) -> List[ParagraphPipelineModel]:
+ if len(dataset_id_list) == 0:
+ return []
+ exec_problem_text = padding_problem_text if padding_problem_text is not None else problem_text
+ model_id = get_embedding_id(dataset_id_list)
+ model = get_model_by_id(model_id, user_id)
+ self.context['model_name'] = model.name
+ embedding_model = ModelManage.get_model(model_id, lambda _id: get_model(model))
+ embedding_value = embedding_model.embed_query(exec_problem_text)
+ vector = VectorStore.get_embedding_vector()
+ embedding_list = vector.query(exec_problem_text, embedding_value, dataset_id_list, exclude_document_id_list,
+ exclude_paragraph_id_list, True, top_n, similarity, SearchMode(search_mode))
+ if embedding_list is None:
+ return []
+ paragraph_list = self.list_paragraph(embedding_list, vector)
+ result = [self.reset_paragraph(paragraph, embedding_list) for paragraph in paragraph_list]
+ return result
+
+ @staticmethod
+ def reset_paragraph(paragraph: Dict, embedding_list: List) -> ParagraphPipelineModel:
+ filter_embedding_list = [embedding for embedding in embedding_list if
+ str(embedding.get('paragraph_id')) == str(paragraph.get('id'))]
+ if filter_embedding_list is not None and len(filter_embedding_list) > 0:
+ find_embedding = filter_embedding_list[-1]
+ return (ParagraphPipelineModel.builder()
+ .add_paragraph(paragraph)
+ .add_similarity(find_embedding.get('similarity'))
+ .add_comprehensive_score(find_embedding.get('comprehensive_score'))
+ .add_dataset_name(paragraph.get('dataset_name'))
+ .add_document_name(paragraph.get('document_name'))
+ .add_hit_handling_method(paragraph.get('hit_handling_method'))
+ .add_directly_return_similarity(paragraph.get('directly_return_similarity'))
+ .add_meta(paragraph.get('meta'))
+ .build())
+
+ @staticmethod
+ def get_similarity(paragraph, embedding_list: List):
+ filter_embedding_list = [embedding for embedding in embedding_list if
+ str(embedding.get('paragraph_id')) == str(paragraph.get('id'))]
+ if filter_embedding_list is not None and len(filter_embedding_list) > 0:
+ find_embedding = filter_embedding_list[-1]
+ return find_embedding.get('comprehensive_score')
+ return 0
+
+ @staticmethod
+ def list_paragraph(embedding_list: List, vector):
+ paragraph_id_list = [row.get('paragraph_id') for row in embedding_list]
+ if paragraph_id_list is None or len(paragraph_id_list) == 0:
+ return []
+ paragraph_list = native_search(QuerySet(Paragraph).filter(id__in=paragraph_id_list),
+ get_file_content(
+ os.path.join(PROJECT_DIR, "apps", "application", 'sql',
+ 'list_dataset_paragraph_by_paragraph_id.sql')),
+ with_table_name=True)
+ # 如果向量库中存在脏数据 直接删除
+ if len(paragraph_list) != len(paragraph_id_list):
+ exist_paragraph_list = [row.get('id') for row in paragraph_list]
+ for paragraph_id in paragraph_id_list:
+ if not exist_paragraph_list.__contains__(paragraph_id):
+ vector.delete_by_paragraph_id(paragraph_id)
+ # 如果存在直接返回的则取直接返回段落
+ hit_handling_method_paragraph = [paragraph for paragraph in paragraph_list if
+ (paragraph.get(
+ 'hit_handling_method') == 'directly_return' and BaseSearchDatasetStep.get_similarity(
+ paragraph, embedding_list) >= paragraph.get(
+ 'directly_return_similarity'))]
+ if len(hit_handling_method_paragraph) > 0:
+ # 找到评分最高的
+ return [sorted(hit_handling_method_paragraph,
+ key=lambda p: BaseSearchDatasetStep.get_similarity(p, embedding_list))[-1]]
+ return paragraph_list
+
+ def get_details(self, manage, **kwargs):
+ step_args = self.context['step_args']
+
+ return {
+ 'step_type': 'search_step',
+ 'paragraph_list': [row.to_dict() for row in self.context['paragraph_list']],
+ 'run_time': self.context['run_time'],
+ 'problem_text': step_args.get(
+ 'padding_problem_text') if 'padding_problem_text' in step_args else step_args.get('problem_text'),
+ 'model_name': self.context.get('model_name'),
+ 'message_tokens': 0,
+ 'answer_tokens': 0,
+ 'cost': 0
+ }
diff --git a/apps/application/flow/__init__.py b/apps/application/flow/__init__.py
new file mode 100644
index 000000000..328e8f8ec
--- /dev/null
+++ b/apps/application/flow/__init__.py
@@ -0,0 +1,8 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: __init__.py.py
+ @date:2024/6/7 14:43
+ @desc:
+"""
diff --git a/apps/application/flow/common.py b/apps/application/flow/common.py
new file mode 100644
index 000000000..f5d4cb9b0
--- /dev/null
+++ b/apps/application/flow/common.py
@@ -0,0 +1,44 @@
+# coding=utf-8
+"""
+ @project: MaxKB
+ @Author:虎
+ @file: common.py
+ @date:2024/12/11 17:57
+ @desc:
+"""
+
+
+class Answer:
+ def __init__(self, content, view_type, runtime_node_id, chat_record_id, child_node, real_node_id,
+ reasoning_content):
+ self.view_type = view_type
+ self.content = content
+ self.reasoning_content = reasoning_content
+ self.runtime_node_id = runtime_node_id
+ self.chat_record_id = chat_record_id
+ self.child_node = child_node
+ self.real_node_id = real_node_id
+
+ def to_dict(self):
+ return {'view_type': self.view_type, 'content': self.content, 'runtime_node_id': self.runtime_node_id,
+ 'chat_record_id': self.chat_record_id,
+ 'child_node': self.child_node,
+ 'reasoning_content': self.reasoning_content,
+ 'real_node_id': self.real_node_id}
+
+
+class NodeChunk:
+ def __init__(self):
+ self.status = 0
+ self.chunk_list = []
+
+ def add_chunk(self, chunk):
+ self.chunk_list.append(chunk)
+
+ def end(self, chunk=None):
+ if chunk is not None:
+ self.add_chunk(chunk)
+ self.status = 200
+
+ def is_end(self):
+ return self.status == 200
diff --git a/apps/application/flow/default_workflow.json b/apps/application/flow/default_workflow.json
new file mode 100644
index 000000000..48ac23c4d
--- /dev/null
+++ b/apps/application/flow/default_workflow.json
@@ -0,0 +1,451 @@
+{
+ "nodes": [
+ {
+ "id": "base-node",
+ "type": "base-node",
+ "x": 360,
+ "y": 2810,
+ "properties": {
+ "config": {
+
+ },
+ "height": 825.6,
+ "stepName": "基本信息",
+ "node_data": {
+ "desc": "",
+ "name": "maxkbapplication",
+ "prologue": "您好,我是 MaxKB 小助手,您可以向我提出 MaxKB 使用问题。\n- MaxKB 主要功能有什么?\n- MaxKB 支持哪些大语言模型?\n- MaxKB 支持哪些文档类型?"
+ },
+ "input_field_list": [
+
+ ]
+ }
+ },
+ {
+ "id": "start-node",
+ "type": "start-node",
+ "x": 430,
+ "y": 3660,
+ "properties": {
+ "config": {
+ "fields": [
+ {
+ "label": "用户问题",
+ "value": "question"
+ }
+ ],
+ "globalFields": [
+ {
+ "label": "当前时间",
+ "value": "time"
+ }
+ ]
+ },
+ "fields": [
+ {
+ "label": "用户问题",
+ "value": "question"
+ }
+ ],
+ "height": 276,
+ "stepName": "开始",
+ "globalFields": [
+ {
+ "label": "当前时间",
+ "value": "time"
+ }
+ ]
+ }
+ },
+ {
+ "id": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
+ "type": "search-dataset-node",
+ "x": 840,
+ "y": 3210,
+ "properties": {
+ "config": {
+ "fields": [
+ {
+ "label": "检索结果的分段列表",
+ "value": "paragraph_list"
+ },
+ {
+ "label": "满足直接回答的分段列表",
+ "value": "is_hit_handling_method_list"
+ },
+ {
+ "label": "检索结果",
+ "value": "data"
+ },
+ {
+ "label": "满足直接回答的分段内容",
+ "value": "directly_return"
+ }
+ ]
+ },
+ "height": 794,
+ "stepName": "知识库检索",
+ "node_data": {
+ "dataset_id_list": [
+
+ ],
+ "dataset_setting": {
+ "top_n": 3,
+ "similarity": 0.6,
+ "search_mode": "embedding",
+ "max_paragraph_char_number": 5000
+ },
+ "question_reference_address": [
+ "start-node",
+ "question"
+ ],
+ "source_dataset_id_list": [
+
+ ]
+ }
+ }
+ },
+ {
+ "id": "fc60863a-dec2-4854-9e5a-7a44b7187a2b",
+ "type": "condition-node",
+ "x": 1490,
+ "y": 3210,
+ "properties": {
+ "width": 600,
+ "config": {
+ "fields": [
+ {
+ "label": "分支名称",
+ "value": "branch_name"
+ }
+ ]
+ },
+ "height": 543.675,
+ "stepName": "判断器",
+ "node_data": {
+ "branch": [
+ {
+ "id": "1009",
+ "type": "IF",
+ "condition": "and",
+ "conditions": [
+ {
+ "field": [
+ "b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
+ "is_hit_handling_method_list"
+ ],
+ "value": "1",
+ "compare": "len_ge"
+ }
+ ]
+ },
+ {
+ "id": "4908",
+ "type": "ELSE IF 1",
+ "condition": "and",
+ "conditions": [
+ {
+ "field": [
+ "b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
+ "paragraph_list"
+ ],
+ "value": "1",
+ "compare": "len_ge"
+ }
+ ]
+ },
+ {
+ "id": "161",
+ "type": "ELSE",
+ "condition": "and",
+ "conditions": [
+
+ ]
+ }
+ ]
+ },
+ "branch_condition_list": [
+ {
+ "index": 0,
+ "height": 121.225,
+ "id": "1009"
+ },
+ {
+ "index": 1,
+ "height": 121.225,
+ "id": "4908"
+ },
+ {
+ "index": 2,
+ "height": 44,
+ "id": "161"
+ }
+ ]
+ }
+ },
+ {
+ "id": "4ffe1086-25df-4c85-b168-979b5bbf0a26",
+ "type": "reply-node",
+ "x": 2170,
+ "y": 2480,
+ "properties": {
+ "config": {
+ "fields": [
+ {
+ "label": "内容",
+ "value": "answer"
+ }
+ ]
+ },
+ "height": 378,
+ "stepName": "指定回复",
+ "node_data": {
+ "fields": [
+ "b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
+ "directly_return"
+ ],
+ "content": "",
+ "reply_type": "referencing",
+ "is_result": true
+ }
+ }
+ },
+ {
+ "id": "f1f1ee18-5a02-46f6-b4e6-226253cdffbb",
+ "type": "ai-chat-node",
+ "x": 2160,
+ "y": 3200,
+ "properties": {
+ "config": {
+ "fields": [
+ {
+ "label": "AI 回答内容",
+ "value": "answer"
+ }
+ ]
+ },
+ "height": 763,
+ "stepName": "AI 对话",
+ "node_data": {
+ "prompt": "已知信息:\n{{知识库检索.data}}\n问题:\n{{开始.question}}",
+ "system": "",
+ "model_id": "",
+ "dialogue_number": 0,
+ "is_result": true
+ }
+ }
+ },
+ {
+ "id": "309d0eef-c597-46b5-8d51-b9a28aaef4c7",
+ "type": "ai-chat-node",
+ "x": 2160,
+ "y": 3970,
+ "properties": {
+ "config": {
+ "fields": [
+ {
+ "label": "AI 回答内容",
+ "value": "answer"
+ }
+ ]
+ },
+ "height": 763,
+ "stepName": "AI 对话1",
+ "node_data": {
+ "prompt": "{{开始.question}}",
+ "system": "",
+ "model_id": "",
+ "dialogue_number": 0,
+ "is_result": true
+ }
+ }
+ }
+ ],
+ "edges": [
+ {
+ "id": "7d0f166f-c472-41b2-b9a2-c294f4c83d73",
+ "type": "app-edge",
+ "sourceNodeId": "start-node",
+ "targetNodeId": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
+ "startPoint": {
+ "x": 590,
+ "y": 3660
+ },
+ "endPoint": {
+ "x": 680,
+ "y": 3210
+ },
+ "properties": {
+
+ },
+ "pointsList": [
+ {
+ "x": 590,
+ "y": 3660
+ },
+ {
+ "x": 700,
+ "y": 3660
+ },
+ {
+ "x": 570,
+ "y": 3210
+ },
+ {
+ "x": 680,
+ "y": 3210
+ }
+ ],
+ "sourceAnchorId": "start-node_right",
+ "targetAnchorId": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5_left"
+ },
+ {
+ "id": "35cb86dd-f328-429e-a973-12fd7218b696",
+ "type": "app-edge",
+ "sourceNodeId": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
+ "targetNodeId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b",
+ "startPoint": {
+ "x": 1000,
+ "y": 3210
+ },
+ "endPoint": {
+ "x": 1200,
+ "y": 3210
+ },
+ "properties": {
+
+ },
+ "pointsList": [
+ {
+ "x": 1000,
+ "y": 3210
+ },
+ {
+ "x": 1110,
+ "y": 3210
+ },
+ {
+ "x": 1090,
+ "y": 3210
+ },
+ {
+ "x": 1200,
+ "y": 3210
+ }
+ ],
+ "sourceAnchorId": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5_right",
+ "targetAnchorId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b_left"
+ },
+ {
+ "id": "e8f6cfe6-7e48-41cd-abd3-abfb5304d0d8",
+ "type": "app-edge",
+ "sourceNodeId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b",
+ "targetNodeId": "4ffe1086-25df-4c85-b168-979b5bbf0a26",
+ "startPoint": {
+ "x": 1780,
+ "y": 3073.775
+ },
+ "endPoint": {
+ "x": 2010,
+ "y": 2480
+ },
+ "properties": {
+
+ },
+ "pointsList": [
+ {
+ "x": 1780,
+ "y": 3073.775
+ },
+ {
+ "x": 1890,
+ "y": 3073.775
+ },
+ {
+ "x": 1900,
+ "y": 2480
+ },
+ {
+ "x": 2010,
+ "y": 2480
+ }
+ ],
+ "sourceAnchorId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b_1009_right",
+ "targetAnchorId": "4ffe1086-25df-4c85-b168-979b5bbf0a26_left"
+ },
+ {
+ "id": "994ff325-6f7a-4ebc-b61b-10e15519d6d2",
+ "type": "app-edge",
+ "sourceNodeId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b",
+ "targetNodeId": "f1f1ee18-5a02-46f6-b4e6-226253cdffbb",
+ "startPoint": {
+ "x": 1780,
+ "y": 3203
+ },
+ "endPoint": {
+ "x": 2000,
+ "y": 3200
+ },
+ "properties": {
+
+ },
+ "pointsList": [
+ {
+ "x": 1780,
+ "y": 3203
+ },
+ {
+ "x": 1890,
+ "y": 3203
+ },
+ {
+ "x": 1890,
+ "y": 3200
+ },
+ {
+ "x": 2000,
+ "y": 3200
+ }
+ ],
+ "sourceAnchorId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b_4908_right",
+ "targetAnchorId": "f1f1ee18-5a02-46f6-b4e6-226253cdffbb_left"
+ },
+ {
+ "id": "19270caf-bb9f-4ba7-9bf8-200aa70fecd5",
+ "type": "app-edge",
+ "sourceNodeId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b",
+ "targetNodeId": "309d0eef-c597-46b5-8d51-b9a28aaef4c7",
+ "startPoint": {
+ "x": 1780,
+ "y": 3293.6124999999997
+ },
+ "endPoint": {
+ "x": 2000,
+ "y": 3970
+ },
+ "properties": {
+
+ },
+ "pointsList": [
+ {
+ "x": 1780,
+ "y": 3293.6124999999997
+ },
+ {
+ "x": 1890,
+ "y": 3293.6124999999997
+ },
+ {
+ "x": 1890,
+ "y": 3970
+ },
+ {
+ "x": 2000,
+ "y": 3970
+ }
+ ],
+ "sourceAnchorId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b_161_right",
+ "targetAnchorId": "309d0eef-c597-46b5-8d51-b9a28aaef4c7_left"
+ }
+ ]
+}
\ No newline at end of file
diff --git a/apps/application/flow/default_workflow_en.json b/apps/application/flow/default_workflow_en.json
new file mode 100644
index 000000000..7c0194be6
--- /dev/null
+++ b/apps/application/flow/default_workflow_en.json
@@ -0,0 +1,451 @@
+{
+ "nodes": [
+ {
+ "id": "base-node",
+ "type": "base-node",
+ "x": 360,
+ "y": 2810,
+ "properties": {
+ "config": {
+
+ },
+ "height": 825.6,
+ "stepName": "Base",
+ "node_data": {
+ "desc": "",
+ "name": "maxkbapplication",
+ "prologue": "Hello, I am the MaxKB assistant. You can ask me about MaxKB usage issues.\n-What are the main functions of MaxKB?\n-What major language models does MaxKB support?\n-What document types does MaxKB support?"
+ },
+ "input_field_list": [
+
+ ]
+ }
+ },
+ {
+ "id": "start-node",
+ "type": "start-node",
+ "x": 430,
+ "y": 3660,
+ "properties": {
+ "config": {
+ "fields": [
+ {
+ "label": "用户问题",
+ "value": "question"
+ }
+ ],
+ "globalFields": [
+ {
+ "label": "当前时间",
+ "value": "time"
+ }
+ ]
+ },
+ "fields": [
+ {
+ "label": "用户问题",
+ "value": "question"
+ }
+ ],
+ "height": 276,
+ "stepName": "Start",
+ "globalFields": [
+ {
+ "label": "当前时间",
+ "value": "time"
+ }
+ ]
+ }
+ },
+ {
+ "id": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
+ "type": "search-dataset-node",
+ "x": 840,
+ "y": 3210,
+ "properties": {
+ "config": {
+ "fields": [
+ {
+ "label": "检索结果的分段列表",
+ "value": "paragraph_list"
+ },
+ {
+ "label": "满足直接回答的分段列表",
+ "value": "is_hit_handling_method_list"
+ },
+ {
+ "label": "检索结果",
+ "value": "data"
+ },
+ {
+ "label": "满足直接回答的分段内容",
+ "value": "directly_return"
+ }
+ ]
+ },
+ "height": 794,
+ "stepName": "Knowledge Search",
+ "node_data": {
+ "dataset_id_list": [
+
+ ],
+ "dataset_setting": {
+ "top_n": 3,
+ "similarity": 0.6,
+ "search_mode": "embedding",
+ "max_paragraph_char_number": 5000
+ },
+ "question_reference_address": [
+ "start-node",
+ "question"
+ ],
+ "source_dataset_id_list": [
+
+ ]
+ }
+ }
+ },
+ {
+ "id": "fc60863a-dec2-4854-9e5a-7a44b7187a2b",
+ "type": "condition-node",
+ "x": 1490,
+ "y": 3210,
+ "properties": {
+ "width": 600,
+ "config": {
+ "fields": [
+ {
+ "label": "分支名称",
+ "value": "branch_name"
+ }
+ ]
+ },
+ "height": 543.675,
+ "stepName": "Conditional Branch",
+ "node_data": {
+ "branch": [
+ {
+ "id": "1009",
+ "type": "IF",
+ "condition": "and",
+ "conditions": [
+ {
+ "field": [
+ "b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
+ "is_hit_handling_method_list"
+ ],
+ "value": "1",
+ "compare": "len_ge"
+ }
+ ]
+ },
+ {
+ "id": "4908",
+ "type": "ELSE IF 1",
+ "condition": "and",
+ "conditions": [
+ {
+ "field": [
+ "b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
+ "paragraph_list"
+ ],
+ "value": "1",
+ "compare": "len_ge"
+ }
+ ]
+ },
+ {
+ "id": "161",
+ "type": "ELSE",
+ "condition": "and",
+ "conditions": [
+
+ ]
+ }
+ ]
+ },
+ "branch_condition_list": [
+ {
+ "index": 0,
+ "height": 121.225,
+ "id": "1009"
+ },
+ {
+ "index": 1,
+ "height": 121.225,
+ "id": "4908"
+ },
+ {
+ "index": 2,
+ "height": 44,
+ "id": "161"
+ }
+ ]
+ }
+ },
+ {
+ "id": "4ffe1086-25df-4c85-b168-979b5bbf0a26",
+ "type": "reply-node",
+ "x": 2170,
+ "y": 2480,
+ "properties": {
+ "config": {
+ "fields": [
+ {
+ "label": "内容",
+ "value": "answer"
+ }
+ ]
+ },
+ "height": 378,
+ "stepName": "Specified Reply",
+ "node_data": {
+ "fields": [
+ "b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
+ "directly_return"
+ ],
+ "content": "",
+ "reply_type": "referencing",
+ "is_result": true
+ }
+ }
+ },
+ {
+ "id": "f1f1ee18-5a02-46f6-b4e6-226253cdffbb",
+ "type": "ai-chat-node",
+ "x": 2160,
+ "y": 3200,
+ "properties": {
+ "config": {
+ "fields": [
+ {
+ "label": "AI 回答内容",
+ "value": "answer"
+ }
+ ]
+ },
+ "height": 763,
+ "stepName": "AI Chat",
+ "node_data": {
+ "prompt": "Known information:\n{{Knowledge Search.data}}\nQuestion:\n{{Start.question}}",
+ "system": "",
+ "model_id": "",
+ "dialogue_number": 0,
+ "is_result": true
+ }
+ }
+ },
+ {
+ "id": "309d0eef-c597-46b5-8d51-b9a28aaef4c7",
+ "type": "ai-chat-node",
+ "x": 2160,
+ "y": 3970,
+ "properties": {
+ "config": {
+ "fields": [
+ {
+ "label": "AI 回答内容",
+ "value": "answer"
+ }
+ ]
+ },
+ "height": 763,
+ "stepName": "AI Chat1",
+ "node_data": {
+ "prompt": "{{Start.question}}",
+ "system": "",
+ "model_id": "",
+ "dialogue_number": 0,
+ "is_result": true
+ }
+ }
+ }
+ ],
+ "edges": [
+ {
+ "id": "7d0f166f-c472-41b2-b9a2-c294f4c83d73",
+ "type": "app-edge",
+ "sourceNodeId": "start-node",
+ "targetNodeId": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
+ "startPoint": {
+ "x": 590,
+ "y": 3660
+ },
+ "endPoint": {
+ "x": 680,
+ "y": 3210
+ },
+ "properties": {
+
+ },
+ "pointsList": [
+ {
+ "x": 590,
+ "y": 3660
+ },
+ {
+ "x": 700,
+ "y": 3660
+ },
+ {
+ "x": 570,
+ "y": 3210
+ },
+ {
+ "x": 680,
+ "y": 3210
+ }
+ ],
+ "sourceAnchorId": "start-node_right",
+ "targetAnchorId": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5_left"
+ },
+ {
+ "id": "35cb86dd-f328-429e-a973-12fd7218b696",
+ "type": "app-edge",
+ "sourceNodeId": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
+ "targetNodeId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b",
+ "startPoint": {
+ "x": 1000,
+ "y": 3210
+ },
+ "endPoint": {
+ "x": 1200,
+ "y": 3210
+ },
+ "properties": {
+
+ },
+ "pointsList": [
+ {
+ "x": 1000,
+ "y": 3210
+ },
+ {
+ "x": 1110,
+ "y": 3210
+ },
+ {
+ "x": 1090,
+ "y": 3210
+ },
+ {
+ "x": 1200,
+ "y": 3210
+ }
+ ],
+ "sourceAnchorId": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5_right",
+ "targetAnchorId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b_left"
+ },
+ {
+ "id": "e8f6cfe6-7e48-41cd-abd3-abfb5304d0d8",
+ "type": "app-edge",
+ "sourceNodeId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b",
+ "targetNodeId": "4ffe1086-25df-4c85-b168-979b5bbf0a26",
+ "startPoint": {
+ "x": 1780,
+ "y": 3073.775
+ },
+ "endPoint": {
+ "x": 2010,
+ "y": 2480
+ },
+ "properties": {
+
+ },
+ "pointsList": [
+ {
+ "x": 1780,
+ "y": 3073.775
+ },
+ {
+ "x": 1890,
+ "y": 3073.775
+ },
+ {
+ "x": 1900,
+ "y": 2480
+ },
+ {
+ "x": 2010,
+ "y": 2480
+ }
+ ],
+ "sourceAnchorId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b_1009_right",
+ "targetAnchorId": "4ffe1086-25df-4c85-b168-979b5bbf0a26_left"
+ },
+ {
+ "id": "994ff325-6f7a-4ebc-b61b-10e15519d6d2",
+ "type": "app-edge",
+ "sourceNodeId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b",
+ "targetNodeId": "f1f1ee18-5a02-46f6-b4e6-226253cdffbb",
+ "startPoint": {
+ "x": 1780,
+ "y": 3203
+ },
+ "endPoint": {
+ "x": 2000,
+ "y": 3200
+ },
+ "properties": {
+
+ },
+ "pointsList": [
+ {
+ "x": 1780,
+ "y": 3203
+ },
+ {
+ "x": 1890,
+ "y": 3203
+ },
+ {
+ "x": 1890,
+ "y": 3200
+ },
+ {
+ "x": 2000,
+ "y": 3200
+ }
+ ],
+ "sourceAnchorId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b_4908_right",
+ "targetAnchorId": "f1f1ee18-5a02-46f6-b4e6-226253cdffbb_left"
+ },
+ {
+ "id": "19270caf-bb9f-4ba7-9bf8-200aa70fecd5",
+ "type": "app-edge",
+ "sourceNodeId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b",
+ "targetNodeId": "309d0eef-c597-46b5-8d51-b9a28aaef4c7",
+ "startPoint": {
+ "x": 1780,
+ "y": 3293.6124999999997
+ },
+ "endPoint": {
+ "x": 2000,
+ "y": 3970
+ },
+ "properties": {
+
+ },
+ "pointsList": [
+ {
+ "x": 1780,
+ "y": 3293.6124999999997
+ },
+ {
+ "x": 1890,
+ "y": 3293.6124999999997
+ },
+ {
+ "x": 1890,
+ "y": 3970
+ },
+ {
+ "x": 2000,
+ "y": 3970
+ }
+ ],
+ "sourceAnchorId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b_161_right",
+ "targetAnchorId": "309d0eef-c597-46b5-8d51-b9a28aaef4c7_left"
+ }
+ ]
+}
\ No newline at end of file
diff --git a/apps/application/flow/default_workflow_zh.json b/apps/application/flow/default_workflow_zh.json
new file mode 100644
index 000000000..48ac23c4d
--- /dev/null
+++ b/apps/application/flow/default_workflow_zh.json
@@ -0,0 +1,451 @@
+{
+ "nodes": [
+ {
+ "id": "base-node",
+ "type": "base-node",
+ "x": 360,
+ "y": 2810,
+ "properties": {
+ "config": {
+
+ },
+ "height": 825.6,
+ "stepName": "基本信息",
+ "node_data": {
+ "desc": "",
+ "name": "maxkbapplication",
+ "prologue": "您好,我是 MaxKB 小助手,您可以向我提出 MaxKB 使用问题。\n- MaxKB 主要功能有什么?\n- MaxKB 支持哪些大语言模型?\n- MaxKB 支持哪些文档类型?"
+ },
+ "input_field_list": [
+
+ ]
+ }
+ },
+ {
+ "id": "start-node",
+ "type": "start-node",
+ "x": 430,
+ "y": 3660,
+ "properties": {
+ "config": {
+ "fields": [
+ {
+ "label": "用户问题",
+ "value": "question"
+ }
+ ],
+ "globalFields": [
+ {
+ "label": "当前时间",
+ "value": "time"
+ }
+ ]
+ },
+ "fields": [
+ {
+ "label": "用户问题",
+ "value": "question"
+ }
+ ],
+ "height": 276,
+ "stepName": "开始",
+ "globalFields": [
+ {
+ "label": "当前时间",
+ "value": "time"
+ }
+ ]
+ }
+ },
+ {
+ "id": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
+ "type": "search-dataset-node",
+ "x": 840,
+ "y": 3210,
+ "properties": {
+ "config": {
+ "fields": [
+ {
+ "label": "检索结果的分段列表",
+ "value": "paragraph_list"
+ },
+ {
+ "label": "满足直接回答的分段列表",
+ "value": "is_hit_handling_method_list"
+ },
+ {
+ "label": "检索结果",
+ "value": "data"
+ },
+ {
+ "label": "满足直接回答的分段内容",
+ "value": "directly_return"
+ }
+ ]
+ },
+ "height": 794,
+ "stepName": "知识库检索",
+ "node_data": {
+ "dataset_id_list": [
+
+ ],
+ "dataset_setting": {
+ "top_n": 3,
+ "similarity": 0.6,
+ "search_mode": "embedding",
+ "max_paragraph_char_number": 5000
+ },
+ "question_reference_address": [
+ "start-node",
+ "question"
+ ],
+ "source_dataset_id_list": [
+
+ ]
+ }
+ }
+ },
+ {
+ "id": "fc60863a-dec2-4854-9e5a-7a44b7187a2b",
+ "type": "condition-node",
+ "x": 1490,
+ "y": 3210,
+ "properties": {
+ "width": 600,
+ "config": {
+ "fields": [
+ {
+ "label": "分支名称",
+ "value": "branch_name"
+ }
+ ]
+ },
+ "height": 543.675,
+ "stepName": "判断器",
+ "node_data": {
+ "branch": [
+ {
+ "id": "1009",
+ "type": "IF",
+ "condition": "and",
+ "conditions": [
+ {
+ "field": [
+ "b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
+ "is_hit_handling_method_list"
+ ],
+ "value": "1",
+ "compare": "len_ge"
+ }
+ ]
+ },
+ {
+ "id": "4908",
+ "type": "ELSE IF 1",
+ "condition": "and",
+ "conditions": [
+ {
+ "field": [
+ "b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
+ "paragraph_list"
+ ],
+ "value": "1",
+ "compare": "len_ge"
+ }
+ ]
+ },
+ {
+ "id": "161",
+ "type": "ELSE",
+ "condition": "and",
+ "conditions": [
+
+ ]
+ }
+ ]
+ },
+ "branch_condition_list": [
+ {
+ "index": 0,
+ "height": 121.225,
+ "id": "1009"
+ },
+ {
+ "index": 1,
+ "height": 121.225,
+ "id": "4908"
+ },
+ {
+ "index": 2,
+ "height": 44,
+ "id": "161"
+ }
+ ]
+ }
+ },
+ {
+ "id": "4ffe1086-25df-4c85-b168-979b5bbf0a26",
+ "type": "reply-node",
+ "x": 2170,
+ "y": 2480,
+ "properties": {
+ "config": {
+ "fields": [
+ {
+ "label": "内容",
+ "value": "answer"
+ }
+ ]
+ },
+ "height": 378,
+ "stepName": "指定回复",
+ "node_data": {
+ "fields": [
+ "b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
+ "directly_return"
+ ],
+ "content": "",
+ "reply_type": "referencing",
+ "is_result": true
+ }
+ }
+ },
+ {
+ "id": "f1f1ee18-5a02-46f6-b4e6-226253cdffbb",
+ "type": "ai-chat-node",
+ "x": 2160,
+ "y": 3200,
+ "properties": {
+ "config": {
+ "fields": [
+ {
+ "label": "AI 回答内容",
+ "value": "answer"
+ }
+ ]
+ },
+ "height": 763,
+ "stepName": "AI 对话",
+ "node_data": {
+ "prompt": "已知信息:\n{{知识库检索.data}}\n问题:\n{{开始.question}}",
+ "system": "",
+ "model_id": "",
+ "dialogue_number": 0,
+ "is_result": true
+ }
+ }
+ },
+ {
+ "id": "309d0eef-c597-46b5-8d51-b9a28aaef4c7",
+ "type": "ai-chat-node",
+ "x": 2160,
+ "y": 3970,
+ "properties": {
+ "config": {
+ "fields": [
+ {
+ "label": "AI 回答内容",
+ "value": "answer"
+ }
+ ]
+ },
+ "height": 763,
+ "stepName": "AI 对话1",
+ "node_data": {
+ "prompt": "{{开始.question}}",
+ "system": "",
+ "model_id": "",
+ "dialogue_number": 0,
+ "is_result": true
+ }
+ }
+ }
+ ],
+ "edges": [
+ {
+ "id": "7d0f166f-c472-41b2-b9a2-c294f4c83d73",
+ "type": "app-edge",
+ "sourceNodeId": "start-node",
+ "targetNodeId": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
+ "startPoint": {
+ "x": 590,
+ "y": 3660
+ },
+ "endPoint": {
+ "x": 680,
+ "y": 3210
+ },
+ "properties": {
+
+ },
+ "pointsList": [
+ {
+ "x": 590,
+ "y": 3660
+ },
+ {
+ "x": 700,
+ "y": 3660
+ },
+ {
+ "x": 570,
+ "y": 3210
+ },
+ {
+ "x": 680,
+ "y": 3210
+ }
+ ],
+ "sourceAnchorId": "start-node_right",
+ "targetAnchorId": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5_left"
+ },
+ {
+ "id": "35cb86dd-f328-429e-a973-12fd7218b696",
+ "type": "app-edge",
+ "sourceNodeId": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
+ "targetNodeId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b",
+ "startPoint": {
+ "x": 1000,
+ "y": 3210
+ },
+ "endPoint": {
+ "x": 1200,
+ "y": 3210
+ },
+ "properties": {
+
+ },
+ "pointsList": [
+ {
+ "x": 1000,
+ "y": 3210
+ },
+ {
+ "x": 1110,
+ "y": 3210
+ },
+ {
+ "x": 1090,
+ "y": 3210
+ },
+ {
+ "x": 1200,
+ "y": 3210
+ }
+ ],
+ "sourceAnchorId": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5_right",
+ "targetAnchorId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b_left"
+ },
+ {
+ "id": "e8f6cfe6-7e48-41cd-abd3-abfb5304d0d8",
+ "type": "app-edge",
+ "sourceNodeId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b",
+ "targetNodeId": "4ffe1086-25df-4c85-b168-979b5bbf0a26",
+ "startPoint": {
+ "x": 1780,
+ "y": 3073.775
+ },
+ "endPoint": {
+ "x": 2010,
+ "y": 2480
+ },
+ "properties": {
+
+ },
+ "pointsList": [
+ {
+ "x": 1780,
+ "y": 3073.775
+ },
+ {
+ "x": 1890,
+ "y": 3073.775
+ },
+ {
+ "x": 1900,
+ "y": 2480
+ },
+ {
+ "x": 2010,
+ "y": 2480
+ }
+ ],
+ "sourceAnchorId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b_1009_right",
+ "targetAnchorId": "4ffe1086-25df-4c85-b168-979b5bbf0a26_left"
+ },
+ {
+ "id": "994ff325-6f7a-4ebc-b61b-10e15519d6d2",
+ "type": "app-edge",
+ "sourceNodeId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b",
+ "targetNodeId": "f1f1ee18-5a02-46f6-b4e6-226253cdffbb",
+ "startPoint": {
+ "x": 1780,
+ "y": 3203
+ },
+ "endPoint": {
+ "x": 2000,
+ "y": 3200
+ },
+ "properties": {
+
+ },
+ "pointsList": [
+ {
+ "x": 1780,
+ "y": 3203
+ },
+ {
+ "x": 1890,
+ "y": 3203
+ },
+ {
+ "x": 1890,
+ "y": 3200
+ },
+ {
+ "x": 2000,
+ "y": 3200
+ }
+ ],
+ "sourceAnchorId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b_4908_right",
+ "targetAnchorId": "f1f1ee18-5a02-46f6-b4e6-226253cdffbb_left"
+ },
+ {
+ "id": "19270caf-bb9f-4ba7-9bf8-200aa70fecd5",
+ "type": "app-edge",
+ "sourceNodeId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b",
+ "targetNodeId": "309d0eef-c597-46b5-8d51-b9a28aaef4c7",
+ "startPoint": {
+ "x": 1780,
+ "y": 3293.6124999999997
+ },
+ "endPoint": {
+ "x": 2000,
+ "y": 3970
+ },
+ "properties": {
+
+ },
+ "pointsList": [
+ {
+ "x": 1780,
+ "y": 3293.6124999999997
+ },
+ {
+ "x": 1890,
+ "y": 3293.6124999999997
+ },
+ {
+ "x": 1890,
+ "y": 3970
+ },
+ {
+ "x": 2000,
+ "y": 3970
+ }
+ ],
+ "sourceAnchorId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b_161_right",
+ "targetAnchorId": "309d0eef-c597-46b5-8d51-b9a28aaef4c7_left"
+ }
+ ]
+}
\ No newline at end of file
diff --git a/apps/application/flow/default_workflow_zh_Hant.json b/apps/application/flow/default_workflow_zh_Hant.json
new file mode 100644
index 000000000..b06301533
--- /dev/null
+++ b/apps/application/flow/default_workflow_zh_Hant.json
@@ -0,0 +1,451 @@
+{
+ "nodes": [
+ {
+ "id": "base-node",
+ "type": "base-node",
+ "x": 360,
+ "y": 2810,
+ "properties": {
+ "config": {
+
+ },
+ "height": 825.6,
+ "stepName": "基本資訊",
+ "node_data": {
+ "desc": "",
+ "name": "maxkbapplication",
+ "prologue": "您好,我是MaxKB小助手,您可以向我提出MaxKB使用問題。\n- MaxKB主要功能有什麼?\n- MaxKB支持哪些大語言模型?\n- MaxKB支持哪些文檔類型?"
+ },
+ "input_field_list": [
+
+ ]
+ }
+ },
+ {
+ "id": "start-node",
+ "type": "start-node",
+ "x": 430,
+ "y": 3660,
+ "properties": {
+ "config": {
+ "fields": [
+ {
+ "label": "用户问题",
+ "value": "question"
+ }
+ ],
+ "globalFields": [
+ {
+ "label": "当前时间",
+ "value": "time"
+ }
+ ]
+ },
+ "fields": [
+ {
+ "label": "用户问题",
+ "value": "question"
+ }
+ ],
+ "height": 276,
+ "stepName": "開始",
+ "globalFields": [
+ {
+ "label": "当前时间",
+ "value": "time"
+ }
+ ]
+ }
+ },
+ {
+ "id": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
+ "type": "search-dataset-node",
+ "x": 840,
+ "y": 3210,
+ "properties": {
+ "config": {
+ "fields": [
+ {
+ "label": "检索结果的分段列表",
+ "value": "paragraph_list"
+ },
+ {
+ "label": "满足直接回答的分段列表",
+ "value": "is_hit_handling_method_list"
+ },
+ {
+ "label": "检索结果",
+ "value": "data"
+ },
+ {
+ "label": "满足直接回答的分段内容",
+ "value": "directly_return"
+ }
+ ]
+ },
+ "height": 794,
+ "stepName": "知識庫檢索",
+ "node_data": {
+ "dataset_id_list": [
+
+ ],
+ "dataset_setting": {
+ "top_n": 3,
+ "similarity": 0.6,
+ "search_mode": "embedding",
+ "max_paragraph_char_number": 5000
+ },
+ "question_reference_address": [
+ "start-node",
+ "question"
+ ],
+ "source_dataset_id_list": [
+
+ ]
+ }
+ }
+ },
+ {
+ "id": "fc60863a-dec2-4854-9e5a-7a44b7187a2b",
+ "type": "condition-node",
+ "x": 1490,
+ "y": 3210,
+ "properties": {
+ "width": 600,
+ "config": {
+ "fields": [
+ {
+ "label": "分支名称",
+ "value": "branch_name"
+ }
+ ]
+ },
+ "height": 543.675,
+ "stepName": "判斷器",
+ "node_data": {
+ "branch": [
+ {
+ "id": "1009",
+ "type": "IF",
+ "condition": "and",
+ "conditions": [
+ {
+ "field": [
+ "b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
+ "is_hit_handling_method_list"
+ ],
+ "value": "1",
+ "compare": "len_ge"
+ }
+ ]
+ },
+ {
+ "id": "4908",
+ "type": "ELSE IF 1",
+ "condition": "and",
+ "conditions": [
+ {
+ "field": [
+ "b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
+ "paragraph_list"
+ ],
+ "value": "1",
+ "compare": "len_ge"
+ }
+ ]
+ },
+ {
+ "id": "161",
+ "type": "ELSE",
+ "condition": "and",
+ "conditions": [
+
+ ]
+ }
+ ]
+ },
+ "branch_condition_list": [
+ {
+ "index": 0,
+ "height": 121.225,
+ "id": "1009"
+ },
+ {
+ "index": 1,
+ "height": 121.225,
+ "id": "4908"
+ },
+ {
+ "index": 2,
+ "height": 44,
+ "id": "161"
+ }
+ ]
+ }
+ },
+ {
+ "id": "4ffe1086-25df-4c85-b168-979b5bbf0a26",
+ "type": "reply-node",
+ "x": 2170,
+ "y": 2480,
+ "properties": {
+ "config": {
+ "fields": [
+ {
+ "label": "内容",
+ "value": "answer"
+ }
+ ]
+ },
+ "height": 378,
+ "stepName": "指定回覆",
+ "node_data": {
+ "fields": [
+ "b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
+ "directly_return"
+ ],
+ "content": "",
+ "reply_type": "referencing",
+ "is_result": true
+ }
+ }
+ },
+ {
+ "id": "f1f1ee18-5a02-46f6-b4e6-226253cdffbb",
+ "type": "ai-chat-node",
+ "x": 2160,
+ "y": 3200,
+ "properties": {
+ "config": {
+ "fields": [
+ {
+ "label": "AI 回答内容",
+ "value": "answer"
+ }
+ ]
+ },
+ "height": 763,
+ "stepName": "AI 對話",
+ "node_data": {
+ "prompt": "已知資訊:\n{{知識庫檢索.data}}\n問題:\n{{開始.question}}",
+ "system": "",
+ "model_id": "",
+ "dialogue_number": 0,
+ "is_result": true
+ }
+ }
+ },
+ {
+ "id": "309d0eef-c597-46b5-8d51-b9a28aaef4c7",
+ "type": "ai-chat-node",
+ "x": 2160,
+ "y": 3970,
+ "properties": {
+ "config": {
+ "fields": [
+ {
+ "label": "AI 回答内容",
+ "value": "answer"
+ }
+ ]
+ },
+ "height": 763,
+ "stepName": "AI 對話1",
+ "node_data": {
+ "prompt": "{{開始.question}}",
+ "system": "",
+ "model_id": "",
+ "dialogue_number": 0,
+ "is_result": true
+ }
+ }
+ }
+ ],
+ "edges": [
+ {
+ "id": "7d0f166f-c472-41b2-b9a2-c294f4c83d73",
+ "type": "app-edge",
+ "sourceNodeId": "start-node",
+ "targetNodeId": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
+ "startPoint": {
+ "x": 590,
+ "y": 3660
+ },
+ "endPoint": {
+ "x": 680,
+ "y": 3210
+ },
+ "properties": {
+
+ },
+ "pointsList": [
+ {
+ "x": 590,
+ "y": 3660
+ },
+ {
+ "x": 700,
+ "y": 3660
+ },
+ {
+ "x": 570,
+ "y": 3210
+ },
+ {
+ "x": 680,
+ "y": 3210
+ }
+ ],
+ "sourceAnchorId": "start-node_right",
+ "targetAnchorId": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5_left"
+ },
+ {
+ "id": "35cb86dd-f328-429e-a973-12fd7218b696",
+ "type": "app-edge",
+ "sourceNodeId": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
+ "targetNodeId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b",
+ "startPoint": {
+ "x": 1000,
+ "y": 3210
+ },
+ "endPoint": {
+ "x": 1200,
+ "y": 3210
+ },
+ "properties": {
+
+ },
+ "pointsList": [
+ {
+ "x": 1000,
+ "y": 3210
+ },
+ {
+ "x": 1110,
+ "y": 3210
+ },
+ {
+ "x": 1090,
+ "y": 3210
+ },
+ {
+ "x": 1200,
+ "y": 3210
+ }
+ ],
+ "sourceAnchorId": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5_right",
+ "targetAnchorId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b_left"
+ },
+ {
+ "id": "e8f6cfe6-7e48-41cd-abd3-abfb5304d0d8",
+ "type": "app-edge",
+ "sourceNodeId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b",
+ "targetNodeId": "4ffe1086-25df-4c85-b168-979b5bbf0a26",
+ "startPoint": {
+ "x": 1780,
+ "y": 3073.775
+ },
+ "endPoint": {
+ "x": 2010,
+ "y": 2480
+ },
+ "properties": {
+
+ },
+ "pointsList": [
+ {
+ "x": 1780,
+ "y": 3073.775
+ },
+ {
+ "x": 1890,
+ "y": 3073.775
+ },
+ {
+ "x": 1900,
+ "y": 2480
+ },
+ {
+ "x": 2010,
+ "y": 2480
+ }
+ ],
+ "sourceAnchorId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b_1009_right",
+ "targetAnchorId": "4ffe1086-25df-4c85-b168-979b5bbf0a26_left"
+ },
+ {
+ "id": "994ff325-6f7a-4ebc-b61b-10e15519d6d2",
+ "type": "app-edge",
+ "sourceNodeId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b",
+ "targetNodeId": "f1f1ee18-5a02-46f6-b4e6-226253cdffbb",
+ "startPoint": {
+ "x": 1780,
+ "y": 3203
+ },
+ "endPoint": {
+ "x": 2000,
+ "y": 3200
+ },
+ "properties": {
+
+ },
+ "pointsList": [
+ {
+ "x": 1780,
+ "y": 3203
+ },
+ {
+ "x": 1890,
+ "y": 3203
+ },
+ {
+ "x": 1890,
+ "y": 3200
+ },
+ {
+ "x": 2000,
+ "y": 3200
+ }
+ ],
+ "sourceAnchorId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b_4908_right",
+ "targetAnchorId": "f1f1ee18-5a02-46f6-b4e6-226253cdffbb_left"
+ },
+ {
+ "id": "19270caf-bb9f-4ba7-9bf8-200aa70fecd5",
+ "type": "app-edge",
+ "sourceNodeId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b",
+ "targetNodeId": "309d0eef-c597-46b5-8d51-b9a28aaef4c7",
+ "startPoint": {
+ "x": 1780,
+ "y": 3293.6124999999997
+ },
+ "endPoint": {
+ "x": 2000,
+ "y": 3970
+ },
+ "properties": {
+
+ },
+ "pointsList": [
+ {
+ "x": 1780,
+ "y": 3293.6124999999997
+ },
+ {
+ "x": 1890,
+ "y": 3293.6124999999997
+ },
+ {
+ "x": 1890,
+ "y": 3970
+ },
+ {
+ "x": 2000,
+ "y": 3970
+ }
+ ],
+ "sourceAnchorId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b_161_right",
+ "targetAnchorId": "309d0eef-c597-46b5-8d51-b9a28aaef4c7_left"
+ }
+ ]
+}
\ No newline at end of file
diff --git a/apps/application/flow/i_step_node.py b/apps/application/flow/i_step_node.py
new file mode 100644
index 000000000..fcead7a40
--- /dev/null
+++ b/apps/application/flow/i_step_node.py
@@ -0,0 +1,256 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: i_step_node.py
+ @date:2024/6/3 14:57
+ @desc:
+"""
+import time
+import uuid
+from abc import abstractmethod
+from hashlib import sha1
+from typing import Type, Dict, List
+
+from django.core import cache
+from django.db.models import QuerySet
+from rest_framework import serializers
+from rest_framework.exceptions import ValidationError, ErrorDetail
+
+from application.flow.common import Answer, NodeChunk
+from application.models import ChatRecord
+from application.models.api_key_model import ApplicationPublicAccessClient
+from common.constants.authentication_type import AuthenticationType
+from common.field.common import InstanceField
+from common.util.field_message import ErrMessage
+
+chat_cache = cache.caches['chat_cache']
+
+
+def write_context(step_variable: Dict, global_variable: Dict, node, workflow):
+ if step_variable is not None:
+ for key in step_variable:
+ node.context[key] = step_variable[key]
+ if workflow.is_result(node, NodeResult(step_variable, global_variable)) and 'answer' in step_variable:
+ answer = step_variable['answer']
+ yield answer
+ node.answer_text = answer
+ if global_variable is not None:
+ for key in global_variable:
+ workflow.context[key] = global_variable[key]
+ node.context['run_time'] = time.time() - node.context['start_time']
+
+
+def is_interrupt(node, step_variable: Dict, global_variable: Dict):
+ return node.type == 'form-node' and not node.context.get('is_submit', False)
+
+
+class WorkFlowPostHandler:
+ def __init__(self, chat_info, client_id, client_type):
+ self.chat_info = chat_info
+ self.client_id = client_id
+ self.client_type = client_type
+
+ def handler(self, chat_id,
+ chat_record_id,
+ answer,
+ workflow):
+ question = workflow.params['question']
+ details = workflow.get_runtime_details()
+ message_tokens = sum([row.get('message_tokens') for row in details.values() if
+ 'message_tokens' in row and row.get('message_tokens') is not None])
+ answer_tokens = sum([row.get('answer_tokens') for row in details.values() if
+ 'answer_tokens' in row and row.get('answer_tokens') is not None])
+ answer_text_list = workflow.get_answer_text_list()
+ answer_text = '\n\n'.join(
+ '\n\n'.join([a.get('content') for a in answer]) for answer in
+ answer_text_list)
+ if workflow.chat_record is not None:
+ chat_record = workflow.chat_record
+ chat_record.answer_text = answer_text
+ chat_record.details = details
+ chat_record.message_tokens = message_tokens
+ chat_record.answer_tokens = answer_tokens
+ chat_record.answer_text_list = answer_text_list
+ chat_record.run_time = time.time() - workflow.context['start_time']
+ else:
+ chat_record = ChatRecord(id=chat_record_id,
+ chat_id=chat_id,
+ problem_text=question,
+ answer_text=answer_text,
+ details=details,
+ message_tokens=message_tokens,
+ answer_tokens=answer_tokens,
+ answer_text_list=answer_text_list,
+ run_time=time.time() - workflow.context['start_time'],
+ index=0)
+ asker = workflow.context.get('asker', None)
+ self.chat_info.append_chat_record(chat_record, self.client_id, asker)
+ # 重新设置缓存
+ chat_cache.set(chat_id,
+ self.chat_info, timeout=60 * 30)
+ if self.client_type == AuthenticationType.APPLICATION_ACCESS_TOKEN.value:
+ application_public_access_client = (QuerySet(ApplicationPublicAccessClient)
+ .filter(client_id=self.client_id,
+ application_id=self.chat_info.application.id).first())
+ if application_public_access_client is not None:
+ application_public_access_client.access_num = application_public_access_client.access_num + 1
+ application_public_access_client.intraday_access_num = application_public_access_client.intraday_access_num + 1
+ application_public_access_client.save()
+
+
+class NodeResult:
+ def __init__(self, node_variable: Dict, workflow_variable: Dict,
+ _write_context=write_context, _is_interrupt=is_interrupt):
+ self._write_context = _write_context
+ self.node_variable = node_variable
+ self.workflow_variable = workflow_variable
+ self._is_interrupt = _is_interrupt
+
+ def write_context(self, node, workflow):
+ return self._write_context(self.node_variable, self.workflow_variable, node, workflow)
+
+ def is_assertion_result(self):
+ return 'branch_id' in self.node_variable
+
+ def is_interrupt_exec(self, current_node):
+ """
+ 是否中断执行
+ @param current_node:
+ @return:
+ """
+ return self._is_interrupt(current_node, self.node_variable, self.workflow_variable)
+
+
+class ReferenceAddressSerializer(serializers.Serializer):
+ node_id = serializers.CharField(required=True, error_messages=ErrMessage.char("节点id"))
+ fields = serializers.ListField(
+ child=serializers.CharField(required=True, error_messages=ErrMessage.char("节点字段")), required=True,
+ error_messages=ErrMessage.list("节点字段数组"))
+
+
+class FlowParamsSerializer(serializers.Serializer):
+ # 历史对答
+ history_chat_record = serializers.ListField(child=InstanceField(model_type=ChatRecord, required=True),
+ error_messages=ErrMessage.list("历史对答"))
+
+ question = serializers.CharField(required=True, error_messages=ErrMessage.list("用户问题"))
+
+ chat_id = serializers.CharField(required=True, error_messages=ErrMessage.list("对话id"))
+
+ chat_record_id = serializers.CharField(required=True, error_messages=ErrMessage.char("对话记录id"))
+
+ stream = serializers.BooleanField(required=True, error_messages=ErrMessage.boolean("流式输出"))
+
+ client_id = serializers.CharField(required=False, error_messages=ErrMessage.char("客户端id"))
+
+ client_type = serializers.CharField(required=False, error_messages=ErrMessage.char("客户端类型"))
+
+ user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))
+ re_chat = serializers.BooleanField(required=True, error_messages=ErrMessage.boolean("换个答案"))
+
+
+class INode:
+ view_type = 'many_view'
+
+ @abstractmethod
+ def save_context(self, details, workflow_manage):
+ pass
+
+ def get_answer_list(self) -> List[Answer] | None:
+ if self.answer_text is None:
+ return None
+ reasoning_content_enable = self.context.get('model_setting', {}).get('reasoning_content_enable', False)
+ return [
+ Answer(self.answer_text, self.view_type, self.runtime_node_id, self.workflow_params['chat_record_id'], {},
+ self.runtime_node_id, self.context.get('reasoning_content', '') if reasoning_content_enable else '')]
+
+ def __init__(self, node, workflow_params, workflow_manage, up_node_id_list=None,
+ get_node_params=lambda node: node.properties.get('node_data')):
+ # 当前步骤上下文,用于存储当前步骤信息
+ self.status = 200
+ self.err_message = ''
+ self.node = node
+ self.node_params = get_node_params(node)
+ self.workflow_params = workflow_params
+ self.workflow_manage = workflow_manage
+ self.node_params_serializer = None
+ self.flow_params_serializer = None
+ self.context = {}
+ self.answer_text = None
+ self.id = node.id
+ if up_node_id_list is None:
+ up_node_id_list = []
+ self.up_node_id_list = up_node_id_list
+ self.node_chunk = NodeChunk()
+ self.runtime_node_id = sha1(uuid.NAMESPACE_DNS.bytes + bytes(str(uuid.uuid5(uuid.NAMESPACE_DNS,
+ "".join([*sorted(up_node_id_list),
+ node.id]))),
+ "utf-8")).hexdigest()
+
+ def valid_args(self, node_params, flow_params):
+ flow_params_serializer_class = self.get_flow_params_serializer_class()
+ node_params_serializer_class = self.get_node_params_serializer_class()
+ if flow_params_serializer_class is not None and flow_params is not None:
+ self.flow_params_serializer = flow_params_serializer_class(data=flow_params)
+ self.flow_params_serializer.is_valid(raise_exception=True)
+ if node_params_serializer_class is not None:
+ self.node_params_serializer = node_params_serializer_class(data=node_params)
+ self.node_params_serializer.is_valid(raise_exception=True)
+ if self.node.properties.get('status', 200) != 200:
+ raise ValidationError(ErrorDetail(f'节点{self.node.properties.get("stepName")} 不可用'))
+
+ def get_reference_field(self, fields: List[str]):
+ return self.get_field(self.context, fields)
+
+ @staticmethod
+ def get_field(obj, fields: List[str]):
+ for field in fields:
+ value = obj.get(field)
+ if value is None:
+ return None
+ else:
+ obj = value
+ return obj
+
+ @abstractmethod
+ def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
+ pass
+
+ def get_flow_params_serializer_class(self) -> Type[serializers.Serializer]:
+ return FlowParamsSerializer
+
+ def get_write_error_context(self, e):
+ self.status = 500
+ self.answer_text = str(e)
+ self.err_message = str(e)
+ self.context['run_time'] = time.time() - self.context['start_time']
+
+ def write_error_context(answer, status=200):
+ pass
+
+ return write_error_context
+
+ def run(self) -> NodeResult:
+ """
+ :return: 执行结果
+ """
+ start_time = time.time()
+ self.context['start_time'] = start_time
+ result = self._run()
+ self.context['run_time'] = time.time() - start_time
+ return result
+
+ def _run(self):
+ result = self.execute()
+ return result
+
+ def execute(self, **kwargs) -> NodeResult:
+ pass
+
+ def get_details(self, index: int, **kwargs):
+ """
+ 运行详情
+ :return: 步骤详情
+ """
+ return {}
diff --git a/apps/application/flow/step_node/__init__.py b/apps/application/flow/step_node/__init__.py
new file mode 100644
index 000000000..0ce1d5fed
--- /dev/null
+++ b/apps/application/flow/step_node/__init__.py
@@ -0,0 +1,42 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: __init__.py.py
+ @date:2024/6/7 14:43
+ @desc:
+"""
+from .ai_chat_step_node import *
+from .application_node import BaseApplicationNode
+from .condition_node import *
+from .direct_reply_node import *
+from .form_node import *
+from .function_lib_node import *
+from .function_node import *
+from .question_node import *
+from .reranker_node import *
+
+from .document_extract_node import *
+from .image_understand_step_node import *
+from .image_generate_step_node import *
+
+from .search_dataset_node import *
+from .speech_to_text_step_node import BaseSpeechToTextNode
+from .start_node import *
+from .text_to_speech_step_node.impl.base_text_to_speech_node import BaseTextToSpeechNode
+from .variable_assign_node import BaseVariableAssignNode
+from .mcp_node import BaseMcpNode
+
+node_list = [BaseStartStepNode, BaseChatNode, BaseSearchDatasetNode, BaseQuestionNode,
+ BaseConditionNode, BaseReplyNode,
+ BaseFunctionNodeNode, BaseFunctionLibNodeNode, BaseRerankerNode, BaseApplicationNode,
+ BaseDocumentExtractNode,
+ BaseImageUnderstandNode, BaseFormNode, BaseSpeechToTextNode, BaseTextToSpeechNode,
+ BaseImageGenerateNode, BaseVariableAssignNode, BaseMcpNode]
+
+
+def get_node(node_type):
+ find_list = [node for node in node_list if node.type == node_type]
+ if len(find_list) > 0:
+ return find_list[0]
+ return None
diff --git a/apps/application/flow/step_node/ai_chat_step_node/__init__.py b/apps/application/flow/step_node/ai_chat_step_node/__init__.py
new file mode 100644
index 000000000..1929ae2af
--- /dev/null
+++ b/apps/application/flow/step_node/ai_chat_step_node/__init__.py
@@ -0,0 +1,9 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: __init__.py
+ @date:2024/6/11 15:29
+ @desc:
+"""
+from .impl import *
diff --git a/apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py b/apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py
new file mode 100644
index 000000000..a83d2ef57
--- /dev/null
+++ b/apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py
@@ -0,0 +1,58 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: i_chat_node.py
+ @date:2024/6/4 13:58
+ @desc:
+"""
+from typing import Type
+
+from django.utils.translation import gettext_lazy as _
+from rest_framework import serializers
+
+from application.flow.i_step_node import INode, NodeResult
+from common.util.field_message import ErrMessage
+
+
+class ChatNodeSerializer(serializers.Serializer):
+ model_id = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Model id")))
+ system = serializers.CharField(required=False, allow_blank=True, allow_null=True,
+ error_messages=ErrMessage.char(_("Role Setting")))
+ prompt = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Prompt word")))
+ # 多轮对话数量
+ dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer(
+ _("Number of multi-round conversations")))
+
+ is_result = serializers.BooleanField(required=False,
+ error_messages=ErrMessage.boolean(_('Whether to return content')))
+
+ model_params_setting = serializers.DictField(required=False,
+ error_messages=ErrMessage.dict(_("Model parameter settings")))
+ model_setting = serializers.DictField(required=False,
+ error_messages=ErrMessage.dict('Model settings'))
+ dialogue_type = serializers.CharField(required=False, allow_blank=True, allow_null=True,
+ error_messages=ErrMessage.char(_("Context Type")))
+ mcp_enable = serializers.BooleanField(required=False,
+ error_messages=ErrMessage.boolean(_("Whether to enable MCP")))
+ mcp_servers = serializers.JSONField(required=False, error_messages=ErrMessage.list(_("MCP Server")))
+
+
+class IChatNode(INode):
+ type = 'ai-chat-node'
+
+ def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
+ return ChatNodeSerializer
+
+ def _run(self):
+ return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data)
+
+ def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id,
+ chat_record_id,
+ model_params_setting=None,
+ dialogue_type=None,
+ model_setting=None,
+ mcp_enable=False,
+ mcp_servers=None,
+ **kwargs) -> NodeResult:
+ pass
diff --git a/apps/application/flow/step_node/ai_chat_step_node/impl/__init__.py b/apps/application/flow/step_node/ai_chat_step_node/impl/__init__.py
new file mode 100644
index 000000000..79051a999
--- /dev/null
+++ b/apps/application/flow/step_node/ai_chat_step_node/impl/__init__.py
@@ -0,0 +1,9 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: __init__.py
+ @date:2024/6/11 15:34
+ @desc:
+"""
+from .base_chat_node import BaseChatNode
diff --git a/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py b/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py
new file mode 100644
index 000000000..8d576d416
--- /dev/null
+++ b/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py
@@ -0,0 +1,288 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: base_question_node.py
+ @date:2024/6/4 14:30
+ @desc:
+"""
+import asyncio
+import json
+import re
+import time
+from functools import reduce
+from types import AsyncGeneratorType
+from typing import List, Dict
+
+from django.db.models import QuerySet
+from langchain.schema import HumanMessage, SystemMessage
+from langchain_core.messages import BaseMessage, AIMessage, AIMessageChunk, ToolMessage
+from langchain_mcp_adapters.client import MultiServerMCPClient
+from langgraph.prebuilt import create_react_agent
+
+from application.flow.i_step_node import NodeResult, INode
+from application.flow.step_node.ai_chat_step_node.i_chat_node import IChatNode
+from application.flow.tools import Reasoning
+from setting.models import Model
+from setting.models_provider import get_model_credential
+from setting.models_provider.tools import get_model_instance_by_model_user_id
+
+tool_message_template = """
+
+
+ Called MCP Tool: %s
+
+
+```json
+%s
+```
+
+
+"""
+
+
+def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str,
+ reasoning_content: str):
+ chat_model = node_variable.get('chat_model')
+ message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list'))
+ answer_tokens = chat_model.get_num_tokens(answer)
+ node.context['message_tokens'] = message_tokens
+ node.context['answer_tokens'] = answer_tokens
+ node.context['answer'] = answer
+ node.context['history_message'] = node_variable['history_message']
+ node.context['question'] = node_variable['question']
+ node.context['run_time'] = time.time() - node.context['start_time']
+ node.context['reasoning_content'] = reasoning_content
+ if workflow.is_result(node, NodeResult(node_variable, workflow_variable)):
+ node.answer_text = answer
+
+
+def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
+ """
+ 写入上下文数据 (流式)
+ @param node_variable: 节点数据
+ @param workflow_variable: 全局数据
+ @param node: 节点
+ @param workflow: 工作流管理器
+ """
+ response = node_variable.get('result')
+ answer = ''
+ reasoning_content = ''
+ model_setting = node.context.get('model_setting',
+ {'reasoning_content_enable': False, 'reasoning_content_end': '',
+ 'reasoning_content_start': ''})
+ reasoning = Reasoning(model_setting.get('reasoning_content_start', ''),
+ model_setting.get('reasoning_content_end', ''))
+ response_reasoning_content = False
+
+ for chunk in response:
+ reasoning_chunk = reasoning.get_reasoning_content(chunk)
+ content_chunk = reasoning_chunk.get('content')
+ if 'reasoning_content' in chunk.additional_kwargs:
+ response_reasoning_content = True
+ reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '')
+ else:
+ reasoning_content_chunk = reasoning_chunk.get('reasoning_content')
+ answer += content_chunk
+ if reasoning_content_chunk is None:
+ reasoning_content_chunk = ''
+ reasoning_content += reasoning_content_chunk
+ yield {'content': content_chunk,
+ 'reasoning_content': reasoning_content_chunk if model_setting.get('reasoning_content_enable',
+ False) else ''}
+
+ reasoning_chunk = reasoning.get_end_reasoning_content()
+ answer += reasoning_chunk.get('content')
+ reasoning_content_chunk = ""
+ if not response_reasoning_content:
+ reasoning_content_chunk = reasoning_chunk.get(
+ 'reasoning_content')
+ yield {'content': reasoning_chunk.get('content'),
+ 'reasoning_content': reasoning_content_chunk if model_setting.get('reasoning_content_enable',
+ False) else ''}
+ _write_context(node_variable, workflow_variable, node, workflow, answer, reasoning_content)
+
+
+async def _yield_mcp_response(chat_model, message_list, mcp_servers):
+ async with MultiServerMCPClient(json.loads(mcp_servers)) as client:
+ agent = create_react_agent(chat_model, client.get_tools())
+ response = agent.astream({"messages": message_list}, stream_mode='messages')
+ async for chunk in response:
+ if isinstance(chunk[0], ToolMessage):
+ content = tool_message_template % (chunk[0].name, chunk[0].content)
+ chunk[0].content = content
+ yield chunk[0]
+ if isinstance(chunk[0], AIMessageChunk):
+ yield chunk[0]
+
+
+def mcp_response_generator(chat_model, message_list, mcp_servers):
+ loop = asyncio.new_event_loop()
+ try:
+ async_gen = _yield_mcp_response(chat_model, message_list, mcp_servers)
+ while True:
+ try:
+ chunk = loop.run_until_complete(anext_async(async_gen))
+ yield chunk
+ except StopAsyncIteration:
+ break
+ except Exception as e:
+ print(f'exception: {e}')
+ finally:
+ loop.close()
+
+
+async def anext_async(agen):
+ return await agen.__anext__()
+
+
+def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
+ """
+ 写入上下文数据
+ @param node_variable: 节点数据
+ @param workflow_variable: 全局数据
+ @param node: 节点实例对象
+ @param workflow: 工作流管理器
+ """
+ response = node_variable.get('result')
+ model_setting = node.context.get('model_setting',
+ {'reasoning_content_enable': False, 'reasoning_content_end': '',
+ 'reasoning_content_start': ''})
+ reasoning = Reasoning(model_setting.get('reasoning_content_start'), model_setting.get('reasoning_content_end'))
+ reasoning_result = reasoning.get_reasoning_content(response)
+ reasoning_result_end = reasoning.get_end_reasoning_content()
+ content = reasoning_result.get('content') + reasoning_result_end.get('content')
+ if 'reasoning_content' in response.response_metadata:
+ reasoning_content = response.response_metadata.get('reasoning_content', '')
+ else:
+ reasoning_content = reasoning_result.get('reasoning_content') + reasoning_result_end.get('reasoning_content')
+ _write_context(node_variable, workflow_variable, node, workflow, content, reasoning_content)
+
+
+def get_default_model_params_setting(model_id):
+ model = QuerySet(Model).filter(id=model_id).first()
+ credential = get_model_credential(model.provider, model.model_type, model.model_name)
+ model_params_setting = credential.get_model_params_setting_form(
+ model.model_name).get_default_form_data()
+ return model_params_setting
+
+
+def get_node_message(chat_record, runtime_node_id):
+ node_details = chat_record.get_node_details_runtime_node_id(runtime_node_id)
+ if node_details is None:
+ return []
+ return [HumanMessage(node_details.get('question')), AIMessage(node_details.get('answer'))]
+
+
+def get_workflow_message(chat_record):
+ return [chat_record.get_human_message(), chat_record.get_ai_message()]
+
+
+def get_message(chat_record, dialogue_type, runtime_node_id):
+ return get_node_message(chat_record, runtime_node_id) if dialogue_type == 'NODE' else get_workflow_message(
+ chat_record)
+
+
+class BaseChatNode(IChatNode):
+ def save_context(self, details, workflow_manage):
+ self.context['answer'] = details.get('answer')
+ self.context['question'] = details.get('question')
+ self.context['reasoning_content'] = details.get('reasoning_content')
+ if self.node_params.get('is_result', False):
+ self.answer_text = details.get('answer')
+
+ def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id,
+ model_params_setting=None,
+ dialogue_type=None,
+ model_setting=None,
+ mcp_enable=False,
+ mcp_servers=None,
+ **kwargs) -> NodeResult:
+ if dialogue_type is None:
+ dialogue_type = 'WORKFLOW'
+
+ if model_params_setting is None:
+ model_params_setting = get_default_model_params_setting(model_id)
+ if model_setting is None:
+ model_setting = {'reasoning_content_enable': False, 'reasoning_content_end': '',
+ 'reasoning_content_start': ''}
+ self.context['model_setting'] = model_setting
+ chat_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'),
+ **model_params_setting)
+ history_message = self.get_history_message(history_chat_record, dialogue_number, dialogue_type,
+ self.runtime_node_id)
+ self.context['history_message'] = history_message
+ question = self.generate_prompt_question(prompt)
+ self.context['question'] = question.content
+ system = self.workflow_manage.generate_prompt(system)
+ self.context['system'] = system
+ message_list = self.generate_message_list(system, prompt, history_message)
+ self.context['message_list'] = message_list
+
+ if mcp_enable and mcp_servers is not None:
+ r = mcp_response_generator(chat_model, message_list, mcp_servers)
+ return NodeResult(
+ {'result': r, 'chat_model': chat_model, 'message_list': message_list,
+ 'history_message': history_message, 'question': question.content}, {},
+ _write_context=write_context_stream)
+
+ if stream:
+ r = chat_model.stream(message_list)
+ return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list,
+ 'history_message': history_message, 'question': question.content}, {},
+ _write_context=write_context_stream)
+ else:
+ r = chat_model.invoke(message_list)
+ return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list,
+ 'history_message': history_message, 'question': question.content}, {},
+ _write_context=write_context)
+
+ @staticmethod
+ def get_history_message(history_chat_record, dialogue_number, dialogue_type, runtime_node_id):
+ start_index = len(history_chat_record) - dialogue_number
+ history_message = reduce(lambda x, y: [*x, *y], [
+ get_message(history_chat_record[index], dialogue_type, runtime_node_id)
+ for index in
+ range(start_index if start_index > 0 else 0, len(history_chat_record))], [])
+ for message in history_message:
+ if isinstance(message.content, str):
+ message.content = re.sub('[\d\D]*?<\/form_rander>', '', message.content)
+ return history_message
+
+ def generate_prompt_question(self, prompt):
+ return HumanMessage(self.workflow_manage.generate_prompt(prompt))
+
+ def generate_message_list(self, system: str, prompt: str, history_message):
+ if system is not None and len(system) > 0:
+ return [SystemMessage(self.workflow_manage.generate_prompt(system)), *history_message,
+ HumanMessage(self.workflow_manage.generate_prompt(prompt))]
+ else:
+ return [*history_message, HumanMessage(self.workflow_manage.generate_prompt(prompt))]
+
+ @staticmethod
+ def reset_message_list(message_list: List[BaseMessage], answer_text):
+ result = [{'role': 'user' if isinstance(message, HumanMessage) else 'ai', 'content': message.content} for
+ message
+ in
+ message_list]
+ result.append({'role': 'ai', 'content': answer_text})
+ return result
+
+ def get_details(self, index: int, **kwargs):
+ return {
+ 'name': self.node.properties.get('stepName'),
+ "index": index,
+ 'run_time': self.context.get('run_time'),
+ 'system': self.context.get('system'),
+ 'history_message': [{'content': message.content, 'role': message.type} for message in
+ (self.context.get('history_message') if self.context.get(
+ 'history_message') is not None else [])],
+ 'question': self.context.get('question'),
+ 'answer': self.context.get('answer'),
+ 'reasoning_content': self.context.get('reasoning_content'),
+ 'type': self.node.type,
+ 'message_tokens': self.context.get('message_tokens'),
+ 'answer_tokens': self.context.get('answer_tokens'),
+ 'status': self.status,
+ 'err_message': self.err_message
+ }
diff --git a/apps/application/flow/step_node/application_node/__init__.py b/apps/application/flow/step_node/application_node/__init__.py
new file mode 100644
index 000000000..d1ea91ca7
--- /dev/null
+++ b/apps/application/flow/step_node/application_node/__init__.py
@@ -0,0 +1,2 @@
+# coding=utf-8
+from .impl import *
diff --git a/apps/application/flow/step_node/application_node/i_application_node.py b/apps/application/flow/step_node/application_node/i_application_node.py
new file mode 100644
index 000000000..6394fa49c
--- /dev/null
+++ b/apps/application/flow/step_node/application_node/i_application_node.py
@@ -0,0 +1,86 @@
+# coding=utf-8
+from typing import Type
+
+from rest_framework import serializers
+
+from application.flow.i_step_node import INode, NodeResult
+from common.util.field_message import ErrMessage
+
+from django.utils.translation import gettext_lazy as _
+
+
+class ApplicationNodeSerializer(serializers.Serializer):
+ application_id = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Application ID")))
+ question_reference_address = serializers.ListField(required=True,
+ error_messages=ErrMessage.list(_("User Questions")))
+ api_input_field_list = serializers.ListField(required=False, error_messages=ErrMessage.list(_("API Input Fields")))
+ user_input_field_list = serializers.ListField(required=False,
+ error_messages=ErrMessage.uuid(_("User Input Fields")))
+ image_list = serializers.ListField(required=False, error_messages=ErrMessage.list(_("picture")))
+ document_list = serializers.ListField(required=False, error_messages=ErrMessage.list(_("document")))
+ audio_list = serializers.ListField(required=False, error_messages=ErrMessage.list(_("Audio")))
+ child_node = serializers.DictField(required=False, allow_null=True,
+ error_messages=ErrMessage.dict(_("Child Nodes")))
+ node_data = serializers.DictField(required=False, allow_null=True, error_messages=ErrMessage.dict(_("Form Data")))
+
+
+class IApplicationNode(INode):
+ type = 'application-node'
+
+ def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
+ return ApplicationNodeSerializer
+
+ def _run(self):
+ question = self.workflow_manage.get_reference_field(
+ self.node_params_serializer.data.get('question_reference_address')[0],
+ self.node_params_serializer.data.get('question_reference_address')[1:])
+ kwargs = {}
+ for api_input_field in self.node_params_serializer.data.get('api_input_field_list', []):
+ value = api_input_field.get('value', [''])[0] if api_input_field.get('value') else ''
+ kwargs[api_input_field['variable']] = self.workflow_manage.get_reference_field(value,
+ api_input_field['value'][
+ 1:]) if value != '' else ''
+
+ for user_input_field in self.node_params_serializer.data.get('user_input_field_list', []):
+ value = user_input_field.get('value', [''])[0] if user_input_field.get('value') else ''
+ kwargs[user_input_field['field']] = self.workflow_manage.get_reference_field(value,
+ user_input_field['value'][
+ 1:]) if value != '' else ''
+ # 判断是否包含这个属性
+ app_document_list = self.node_params_serializer.data.get('document_list', [])
+ if app_document_list and len(app_document_list) > 0:
+ app_document_list = self.workflow_manage.get_reference_field(
+ app_document_list[0],
+ app_document_list[1:])
+ for document in app_document_list:
+ if 'file_id' not in document:
+ raise ValueError(
+ _("Parameter value error: The uploaded document lacks file_id, and the document upload fails"))
+ app_image_list = self.node_params_serializer.data.get('image_list', [])
+ if app_image_list and len(app_image_list) > 0:
+ app_image_list = self.workflow_manage.get_reference_field(
+ app_image_list[0],
+ app_image_list[1:])
+ for image in app_image_list:
+ if 'file_id' not in image:
+ raise ValueError(
+ _("Parameter value error: The uploaded image lacks file_id, and the image upload fails"))
+
+ app_audio_list = self.node_params_serializer.data.get('audio_list', [])
+ if app_audio_list and len(app_audio_list) > 0:
+ app_audio_list = self.workflow_manage.get_reference_field(
+ app_audio_list[0],
+ app_audio_list[1:])
+ for audio in app_audio_list:
+ if 'file_id' not in audio:
+ raise ValueError(
+ _("Parameter value error: The uploaded audio lacks file_id, and the audio upload fails."))
+ return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data,
+ app_document_list=app_document_list, app_image_list=app_image_list,
+ app_audio_list=app_audio_list,
+ message=str(question), **kwargs)
+
+ def execute(self, application_id, message, chat_id, chat_record_id, stream, re_chat, client_id, client_type,
+ app_document_list=None, app_image_list=None, app_audio_list=None, child_node=None, node_data=None,
+ **kwargs) -> NodeResult:
+ pass
diff --git a/apps/application/flow/step_node/application_node/impl/__init__.py b/apps/application/flow/step_node/application_node/impl/__init__.py
new file mode 100644
index 000000000..e31a8d885
--- /dev/null
+++ b/apps/application/flow/step_node/application_node/impl/__init__.py
@@ -0,0 +1,2 @@
+# coding=utf-8
+from .base_application_node import BaseApplicationNode
diff --git a/apps/application/flow/step_node/application_node/impl/base_application_node.py b/apps/application/flow/step_node/application_node/impl/base_application_node.py
new file mode 100644
index 000000000..95445f456
--- /dev/null
+++ b/apps/application/flow/step_node/application_node/impl/base_application_node.py
@@ -0,0 +1,267 @@
+# coding=utf-8
+import json
+import re
+import time
+import uuid
+from typing import Dict, List
+
+from application.flow.common import Answer
+from application.flow.i_step_node import NodeResult, INode
+from application.flow.step_node.application_node.i_application_node import IApplicationNode
+from application.models import Chat
+
+
+def string_to_uuid(input_str):
+ return str(uuid.uuid5(uuid.NAMESPACE_DNS, input_str))
+
+
+def _is_interrupt_exec(node, node_variable: Dict, workflow_variable: Dict):
+ return node_variable.get('is_interrupt_exec', False)
+
+
+def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str,
+ reasoning_content: str):
+ result = node_variable.get('result')
+ node.context['application_node_dict'] = node_variable.get('application_node_dict')
+ node.context['node_dict'] = node_variable.get('node_dict', {})
+ node.context['is_interrupt_exec'] = node_variable.get('is_interrupt_exec')
+ node.context['message_tokens'] = result.get('usage', {}).get('prompt_tokens', 0)
+ node.context['answer_tokens'] = result.get('usage', {}).get('completion_tokens', 0)
+ node.context['answer'] = answer
+ node.context['result'] = answer
+ node.context['reasoning_content'] = reasoning_content
+ node.context['question'] = node_variable['question']
+ node.context['run_time'] = time.time() - node.context['start_time']
+ if workflow.is_result(node, NodeResult(node_variable, workflow_variable)):
+ node.answer_text = answer
+
+
+def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
+ """
+ 写入上下文数据 (流式)
+ @param node_variable: 节点数据
+ @param workflow_variable: 全局数据
+ @param node: 节点
+ @param workflow: 工作流管理器
+ """
+ response = node_variable.get('result')
+ answer = ''
+ reasoning_content = ''
+ usage = {}
+ node_child_node = {}
+ application_node_dict = node.context.get('application_node_dict', {})
+ is_interrupt_exec = False
+ for chunk in response:
+ # 先把流转成字符串
+ response_content = chunk.decode('utf-8')[6:]
+ response_content = json.loads(response_content)
+ content = response_content.get('content', '')
+ runtime_node_id = response_content.get('runtime_node_id', '')
+ chat_record_id = response_content.get('chat_record_id', '')
+ child_node = response_content.get('child_node')
+ view_type = response_content.get('view_type')
+ node_type = response_content.get('node_type')
+ real_node_id = response_content.get('real_node_id')
+ node_is_end = response_content.get('node_is_end', False)
+ _reasoning_content = response_content.get('reasoning_content', '')
+ if node_type == 'form-node':
+ is_interrupt_exec = True
+ answer += content
+ reasoning_content += _reasoning_content
+ node_child_node = {'runtime_node_id': runtime_node_id, 'chat_record_id': chat_record_id,
+ 'child_node': child_node}
+
+ if real_node_id is not None:
+ application_node = application_node_dict.get(real_node_id, None)
+ if application_node is None:
+
+ application_node_dict[real_node_id] = {'content': content,
+ 'runtime_node_id': runtime_node_id,
+ 'chat_record_id': chat_record_id,
+ 'child_node': child_node,
+ 'index': len(application_node_dict),
+ 'view_type': view_type,
+ 'reasoning_content': _reasoning_content}
+ else:
+ application_node['content'] += content
+ application_node['reasoning_content'] += _reasoning_content
+
+ yield {'content': content,
+ 'node_type': node_type,
+ 'runtime_node_id': runtime_node_id, 'chat_record_id': chat_record_id,
+ 'reasoning_content': _reasoning_content,
+ 'child_node': child_node,
+ 'real_node_id': real_node_id,
+ 'node_is_end': node_is_end,
+ 'view_type': view_type}
+ usage = response_content.get('usage', {})
+ node_variable['result'] = {'usage': usage}
+ node_variable['is_interrupt_exec'] = is_interrupt_exec
+ node_variable['child_node'] = node_child_node
+ node_variable['application_node_dict'] = application_node_dict
+ _write_context(node_variable, workflow_variable, node, workflow, answer, reasoning_content)
+
+
+def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
+ """
+ 写入上下文数据
+ @param node_variable: 节点数据
+ @param workflow_variable: 全局数据
+ @param node: 节点实例对象
+ @param workflow: 工作流管理器
+ """
+ response = node_variable.get('result', {}).get('data', {})
+ node_variable['result'] = {'usage': {'completion_tokens': response.get('completion_tokens'),
+ 'prompt_tokens': response.get('prompt_tokens')}}
+ answer = response.get('content', '') or "抱歉,没有查找到相关内容,请重新描述您的问题或提供更多信息。"
+ reasoning_content = response.get('reasoning_content', '')
+ answer_list = response.get('answer_list', [])
+ node_variable['application_node_dict'] = {answer.get('real_node_id'): {**answer, 'index': index} for answer, index
+ in
+ zip(answer_list, range(len(answer_list)))}
+ _write_context(node_variable, workflow_variable, node, workflow, answer, reasoning_content)
+
+
+def reset_application_node_dict(application_node_dict, runtime_node_id, node_data):
+ try:
+ if application_node_dict is None:
+ return
+ for key in application_node_dict:
+ application_node = application_node_dict[key]
+ if application_node.get('runtime_node_id') == runtime_node_id:
+ content: str = application_node.get('content')
+ match = re.search('.*?', content)
+ if match:
+ form_setting_str = match.group().replace('', '').replace('', '')
+ form_setting = json.loads(form_setting_str)
+ form_setting['is_submit'] = True
+ form_setting['form_data'] = node_data
+ value = f'{json.dumps(form_setting)}'
+ res = re.sub('.*?',
+ '${value}', content)
+ application_node['content'] = res.replace('${value}', value)
+ except Exception as e:
+ pass
+
+
+class BaseApplicationNode(IApplicationNode):
+ def get_answer_list(self) -> List[Answer] | None:
+ if self.answer_text is None:
+ return None
+ application_node_dict = self.context.get('application_node_dict')
+ if application_node_dict is None or len(application_node_dict) == 0:
+ return [
+ Answer(self.answer_text, self.view_type, self.runtime_node_id, self.workflow_params['chat_record_id'],
+ self.context.get('child_node'), self.runtime_node_id, '')]
+ else:
+ return [Answer(n.get('content'), n.get('view_type'), self.runtime_node_id,
+ self.workflow_params['chat_record_id'], {'runtime_node_id': n.get('runtime_node_id'),
+ 'chat_record_id': n.get('chat_record_id')
+ , 'child_node': n.get('child_node')}, n.get('real_node_id'),
+ n.get('reasoning_content', ''))
+ for n in
+ sorted(application_node_dict.values(), key=lambda item: item.get('index'))]
+
+ def save_context(self, details, workflow_manage):
+ self.context['answer'] = details.get('answer')
+ self.context['result'] = details.get('answer')
+ self.context['question'] = details.get('question')
+ self.context['type'] = details.get('type')
+ self.context['reasoning_content'] = details.get('reasoning_content')
+ if self.node_params.get('is_result', False):
+ self.answer_text = details.get('answer')
+
+ def execute(self, application_id, message, chat_id, chat_record_id, stream, re_chat, client_id, client_type,
+ app_document_list=None, app_image_list=None, app_audio_list=None, child_node=None, node_data=None,
+ **kwargs) -> NodeResult:
+ from application.serializers.chat_message_serializers import ChatMessageSerializer
+ # 生成嵌入应用的chat_id
+ current_chat_id = string_to_uuid(chat_id + application_id)
+ Chat.objects.get_or_create(id=current_chat_id, defaults={
+ 'application_id': application_id,
+ 'abstract': message[0:1024],
+ 'client_id': client_id,
+ })
+ if app_document_list is None:
+ app_document_list = []
+ if app_image_list is None:
+ app_image_list = []
+ if app_audio_list is None:
+ app_audio_list = []
+ runtime_node_id = None
+ record_id = None
+ child_node_value = None
+ if child_node is not None:
+ runtime_node_id = child_node.get('runtime_node_id')
+ record_id = child_node.get('chat_record_id')
+ child_node_value = child_node.get('child_node')
+ application_node_dict = self.context.get('application_node_dict')
+ reset_application_node_dict(application_node_dict, runtime_node_id, node_data)
+
+ response = ChatMessageSerializer(
+ data={'chat_id': current_chat_id, 'message': message,
+ 're_chat': re_chat,
+ 'stream': stream,
+ 'application_id': application_id,
+ 'client_id': client_id,
+ 'client_type': client_type,
+ 'document_list': app_document_list,
+ 'image_list': app_image_list,
+ 'audio_list': app_audio_list,
+ 'runtime_node_id': runtime_node_id,
+ 'chat_record_id': record_id,
+ 'child_node': child_node_value,
+ 'node_data': node_data,
+ 'form_data': kwargs}).chat()
+ if response.status_code == 200:
+ if stream:
+ content_generator = response.streaming_content
+ return NodeResult({'result': content_generator, 'question': message}, {},
+ _write_context=write_context_stream, _is_interrupt=_is_interrupt_exec)
+ else:
+ data = json.loads(response.content)
+ return NodeResult({'result': data, 'question': message}, {},
+ _write_context=write_context, _is_interrupt=_is_interrupt_exec)
+
+ def get_details(self, index: int, **kwargs):
+ global_fields = []
+ for api_input_field in self.node_params_serializer.data.get('api_input_field_list', []):
+ value = api_input_field.get('value', [''])[0] if api_input_field.get('value') else ''
+ global_fields.append({
+ 'label': api_input_field['variable'],
+ 'key': api_input_field['variable'],
+ 'value': self.workflow_manage.get_reference_field(
+ value,
+ api_input_field['value'][1:]
+ ) if value != '' else ''
+ })
+
+ for user_input_field in self.node_params_serializer.data.get('user_input_field_list', []):
+ value = user_input_field.get('value', [''])[0] if user_input_field.get('value') else ''
+ global_fields.append({
+ 'label': user_input_field['label'],
+ 'key': user_input_field['field'],
+ 'value': self.workflow_manage.get_reference_field(
+ value,
+ user_input_field['value'][1:]
+ ) if value != '' else ''
+ })
+ return {
+ 'name': self.node.properties.get('stepName'),
+ "index": index,
+ "info": self.node.properties.get('node_data'),
+ 'run_time': self.context.get('run_time'),
+ 'question': self.context.get('question'),
+ 'answer': self.context.get('answer'),
+ 'reasoning_content': self.context.get('reasoning_content'),
+ 'type': self.node.type,
+ 'message_tokens': self.context.get('message_tokens'),
+ 'answer_tokens': self.context.get('answer_tokens'),
+ 'status': self.status,
+ 'err_message': self.err_message,
+ 'global_fields': global_fields,
+ 'document_list': self.workflow_manage.document_list,
+ 'image_list': self.workflow_manage.image_list,
+ 'audio_list': self.workflow_manage.audio_list,
+ 'application_node_dict': self.context.get('application_node_dict')
+ }
diff --git a/apps/application/flow/step_node/condition_node/__init__.py b/apps/application/flow/step_node/condition_node/__init__.py
new file mode 100644
index 000000000..57638504c
--- /dev/null
+++ b/apps/application/flow/step_node/condition_node/__init__.py
@@ -0,0 +1,9 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: __init__.py.py
+ @date:2024/6/7 14:43
+ @desc:
+"""
+from .impl import *
diff --git a/apps/application/flow/step_node/condition_node/compare/__init__.py b/apps/application/flow/step_node/condition_node/compare/__init__.py
new file mode 100644
index 000000000..c015f6fea
--- /dev/null
+++ b/apps/application/flow/step_node/condition_node/compare/__init__.py
@@ -0,0 +1,30 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: __init__.py.py
+ @date:2024/6/7 14:43
+ @desc:
+"""
+
+from .contain_compare import *
+from .equal_compare import *
+from .ge_compare import *
+from .gt_compare import *
+from .is_not_null_compare import *
+from .is_not_true import IsNotTrueCompare
+from .is_null_compare import *
+from .is_true import IsTrueCompare
+from .le_compare import *
+from .len_equal_compare import *
+from .len_ge_compare import *
+from .len_gt_compare import *
+from .len_le_compare import *
+from .len_lt_compare import *
+from .lt_compare import *
+from .not_contain_compare import *
+
+compare_handle_list = [GECompare(), GTCompare(), ContainCompare(), EqualCompare(), LTCompare(), LECompare(),
+ LenLECompare(), LenGECompare(), LenEqualCompare(), LenGTCompare(), LenLTCompare(),
+ IsNullCompare(),
+ IsNotNullCompare(), NotContainCompare(), IsTrueCompare(), IsNotTrueCompare()]
diff --git a/apps/application/flow/step_node/condition_node/compare/compare.py b/apps/application/flow/step_node/condition_node/compare/compare.py
new file mode 100644
index 000000000..6cbb4af07
--- /dev/null
+++ b/apps/application/flow/step_node/condition_node/compare/compare.py
@@ -0,0 +1,20 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: compare.py
+ @date:2024/6/7 14:37
+ @desc:
+"""
+from abc import abstractmethod
+from typing import List
+
+
+class Compare:
+ @abstractmethod
+ def support(self, node_id, fields: List[str], source_value, compare, target_value):
+ pass
+
+ @abstractmethod
+ def compare(self, source_value, compare, target_value):
+ pass
diff --git a/apps/application/flow/step_node/condition_node/compare/contain_compare.py b/apps/application/flow/step_node/condition_node/compare/contain_compare.py
new file mode 100644
index 000000000..6073131a5
--- /dev/null
+++ b/apps/application/flow/step_node/condition_node/compare/contain_compare.py
@@ -0,0 +1,23 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: contain_compare.py
+ @date:2024/6/11 10:02
+ @desc:
+"""
+from typing import List
+
+from application.flow.step_node.condition_node.compare.compare import Compare
+
+
+class ContainCompare(Compare):
+
+ def support(self, node_id, fields: List[str], source_value, compare, target_value):
+ if compare == 'contain':
+ return True
+
+ def compare(self, source_value, compare, target_value):
+ if isinstance(source_value, str):
+ return str(target_value) in source_value
+ return any([str(item) == str(target_value) for item in source_value])
diff --git a/apps/application/flow/step_node/condition_node/compare/equal_compare.py b/apps/application/flow/step_node/condition_node/compare/equal_compare.py
new file mode 100644
index 000000000..0061a82f6
--- /dev/null
+++ b/apps/application/flow/step_node/condition_node/compare/equal_compare.py
@@ -0,0 +1,21 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: equal_compare.py
+ @date:2024/6/7 14:44
+ @desc:
+"""
+from typing import List
+
+from application.flow.step_node.condition_node.compare.compare import Compare
+
+
+class EqualCompare(Compare):
+
+ def support(self, node_id, fields: List[str], source_value, compare, target_value):
+ if compare == 'eq':
+ return True
+
+ def compare(self, source_value, compare, target_value):
+ return str(source_value) == str(target_value)
diff --git a/apps/application/flow/step_node/condition_node/compare/ge_compare.py b/apps/application/flow/step_node/condition_node/compare/ge_compare.py
new file mode 100644
index 000000000..d4e22cbd6
--- /dev/null
+++ b/apps/application/flow/step_node/condition_node/compare/ge_compare.py
@@ -0,0 +1,24 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: lt_compare.py
+ @date:2024/6/11 9:52
+ @desc: 大于比较器
+"""
+from typing import List
+
+from application.flow.step_node.condition_node.compare.compare import Compare
+
+
+class GECompare(Compare):
+
+ def support(self, node_id, fields: List[str], source_value, compare, target_value):
+ if compare == 'ge':
+ return True
+
+ def compare(self, source_value, compare, target_value):
+ try:
+ return float(source_value) >= float(target_value)
+ except Exception as e:
+ return False
diff --git a/apps/application/flow/step_node/condition_node/compare/gt_compare.py b/apps/application/flow/step_node/condition_node/compare/gt_compare.py
new file mode 100644
index 000000000..80942abb2
--- /dev/null
+++ b/apps/application/flow/step_node/condition_node/compare/gt_compare.py
@@ -0,0 +1,24 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: lt_compare.py
+ @date:2024/6/11 9:52
+ @desc: 大于比较器
+"""
+from typing import List
+
+from application.flow.step_node.condition_node.compare.compare import Compare
+
+
+class GTCompare(Compare):
+
+ def support(self, node_id, fields: List[str], source_value, compare, target_value):
+ if compare == 'gt':
+ return True
+
+ def compare(self, source_value, compare, target_value):
+ try:
+ return float(source_value) > float(target_value)
+ except Exception as e:
+ return False
diff --git a/apps/application/flow/step_node/condition_node/compare/is_not_null_compare.py b/apps/application/flow/step_node/condition_node/compare/is_not_null_compare.py
new file mode 100644
index 000000000..5dec26713
--- /dev/null
+++ b/apps/application/flow/step_node/condition_node/compare/is_not_null_compare.py
@@ -0,0 +1,21 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: is_not_null_compare.py
+ @date:2024/6/28 10:45
+ @desc:
+"""
+from typing import List
+
+from application.flow.step_node.condition_node.compare import Compare
+
+
+class IsNotNullCompare(Compare):
+
+ def support(self, node_id, fields: List[str], source_value, compare, target_value):
+ if compare == 'is_not_null':
+ return True
+
+ def compare(self, source_value, compare, target_value):
+ return source_value is not None and len(source_value) > 0
diff --git a/apps/application/flow/step_node/condition_node/compare/is_not_true.py b/apps/application/flow/step_node/condition_node/compare/is_not_true.py
new file mode 100644
index 000000000..f8a29f5a1
--- /dev/null
+++ b/apps/application/flow/step_node/condition_node/compare/is_not_true.py
@@ -0,0 +1,24 @@
+# coding=utf-8
+"""
+ @project: MaxKB
+ @Author:虎
+ @file: is_not_true.py
+ @date:2025/4/7 13:44
+ @desc:
+"""
+from typing import List
+
+from application.flow.step_node.condition_node.compare import Compare
+
+
+class IsNotTrueCompare(Compare):
+
+ def support(self, node_id, fields: List[str], source_value, compare, target_value):
+ if compare == 'is_not_true':
+ return True
+
+ def compare(self, source_value, compare, target_value):
+ try:
+ return source_value is False
+ except Exception as e:
+ return False
diff --git a/apps/application/flow/step_node/condition_node/compare/is_null_compare.py b/apps/application/flow/step_node/condition_node/compare/is_null_compare.py
new file mode 100644
index 000000000..c463f3fda
--- /dev/null
+++ b/apps/application/flow/step_node/condition_node/compare/is_null_compare.py
@@ -0,0 +1,21 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: is_null_compare.py
+ @date:2024/6/28 10:45
+ @desc:
+"""
+from typing import List
+
+from application.flow.step_node.condition_node.compare import Compare
+
+
+class IsNullCompare(Compare):
+
+ def support(self, node_id, fields: List[str], source_value, compare, target_value):
+ if compare == 'is_null':
+ return True
+
+ def compare(self, source_value, compare, target_value):
+ return source_value is None or len(source_value) == 0
diff --git a/apps/application/flow/step_node/condition_node/compare/is_true.py b/apps/application/flow/step_node/condition_node/compare/is_true.py
new file mode 100644
index 000000000..166e0993a
--- /dev/null
+++ b/apps/application/flow/step_node/condition_node/compare/is_true.py
@@ -0,0 +1,24 @@
+# coding=utf-8
+"""
+ @project: MaxKB
+ @Author:虎
+ @file: IsTrue.py
+ @date:2025/4/7 13:38
+ @desc:
+"""
+from typing import List
+
+from application.flow.step_node.condition_node.compare import Compare
+
+
+class IsTrueCompare(Compare):
+
+ def support(self, node_id, fields: List[str], source_value, compare, target_value):
+ if compare == 'is_true':
+ return True
+
+ def compare(self, source_value, compare, target_value):
+ try:
+ return source_value is True
+ except Exception as e:
+ return False
diff --git a/apps/application/flow/step_node/condition_node/compare/le_compare.py b/apps/application/flow/step_node/condition_node/compare/le_compare.py
new file mode 100644
index 000000000..77a0bca0f
--- /dev/null
+++ b/apps/application/flow/step_node/condition_node/compare/le_compare.py
@@ -0,0 +1,24 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: lt_compare.py
+ @date:2024/6/11 9:52
+ @desc: 小于比较器
+"""
+from typing import List
+
+from application.flow.step_node.condition_node.compare.compare import Compare
+
+
+class LECompare(Compare):
+
+ def support(self, node_id, fields: List[str], source_value, compare, target_value):
+ if compare == 'le':
+ return True
+
+ def compare(self, source_value, compare, target_value):
+ try:
+ return float(source_value) <= float(target_value)
+ except Exception as e:
+ return False
diff --git a/apps/application/flow/step_node/condition_node/compare/len_equal_compare.py b/apps/application/flow/step_node/condition_node/compare/len_equal_compare.py
new file mode 100644
index 000000000..f2b0764c5
--- /dev/null
+++ b/apps/application/flow/step_node/condition_node/compare/len_equal_compare.py
@@ -0,0 +1,24 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: equal_compare.py
+ @date:2024/6/7 14:44
+ @desc:
+"""
+from typing import List
+
+from application.flow.step_node.condition_node.compare.compare import Compare
+
+
+class LenEqualCompare(Compare):
+
+ def support(self, node_id, fields: List[str], source_value, compare, target_value):
+ if compare == 'len_eq':
+ return True
+
+ def compare(self, source_value, compare, target_value):
+ try:
+ return len(source_value) == int(target_value)
+ except Exception as e:
+ return False
diff --git a/apps/application/flow/step_node/condition_node/compare/len_ge_compare.py b/apps/application/flow/step_node/condition_node/compare/len_ge_compare.py
new file mode 100644
index 000000000..87f11eb2c
--- /dev/null
+++ b/apps/application/flow/step_node/condition_node/compare/len_ge_compare.py
@@ -0,0 +1,24 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: lt_compare.py
+ @date:2024/6/11 9:52
+ @desc: 大于比较器
+"""
+from typing import List
+
+from application.flow.step_node.condition_node.compare.compare import Compare
+
+
+class LenGECompare(Compare):
+
+ def support(self, node_id, fields: List[str], source_value, compare, target_value):
+ if compare == 'len_ge':
+ return True
+
+ def compare(self, source_value, compare, target_value):
+ try:
+ return len(source_value) >= int(target_value)
+ except Exception as e:
+ return False
diff --git a/apps/application/flow/step_node/condition_node/compare/len_gt_compare.py b/apps/application/flow/step_node/condition_node/compare/len_gt_compare.py
new file mode 100644
index 000000000..0532d353d
--- /dev/null
+++ b/apps/application/flow/step_node/condition_node/compare/len_gt_compare.py
@@ -0,0 +1,24 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: lt_compare.py
+ @date:2024/6/11 9:52
+ @desc: 大于比较器
+"""
+from typing import List
+
+from application.flow.step_node.condition_node.compare.compare import Compare
+
+
+class LenGTCompare(Compare):
+
+ def support(self, node_id, fields: List[str], source_value, compare, target_value):
+ if compare == 'len_gt':
+ return True
+
+ def compare(self, source_value, compare, target_value):
+ try:
+ return len(source_value) > int(target_value)
+ except Exception as e:
+ return False
diff --git a/apps/application/flow/step_node/condition_node/compare/len_le_compare.py b/apps/application/flow/step_node/condition_node/compare/len_le_compare.py
new file mode 100644
index 000000000..d315a754a
--- /dev/null
+++ b/apps/application/flow/step_node/condition_node/compare/len_le_compare.py
@@ -0,0 +1,24 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: lt_compare.py
+ @date:2024/6/11 9:52
+ @desc: 小于比较器
+"""
+from typing import List
+
+from application.flow.step_node.condition_node.compare.compare import Compare
+
+
+class LenLECompare(Compare):
+
+ def support(self, node_id, fields: List[str], source_value, compare, target_value):
+ if compare == 'len_le':
+ return True
+
+ def compare(self, source_value, compare, target_value):
+ try:
+ return len(source_value) <= int(target_value)
+ except Exception as e:
+ return False
diff --git a/apps/application/flow/step_node/condition_node/compare/len_lt_compare.py b/apps/application/flow/step_node/condition_node/compare/len_lt_compare.py
new file mode 100644
index 000000000..c89638cd7
--- /dev/null
+++ b/apps/application/flow/step_node/condition_node/compare/len_lt_compare.py
@@ -0,0 +1,24 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: lt_compare.py
+ @date:2024/6/11 9:52
+ @desc: 小于比较器
+"""
+from typing import List
+
+from application.flow.step_node.condition_node.compare.compare import Compare
+
+
+class LenLTCompare(Compare):
+
+ def support(self, node_id, fields: List[str], source_value, compare, target_value):
+ if compare == 'len_lt':
+ return True
+
+ def compare(self, source_value, compare, target_value):
+ try:
+ return len(source_value) < int(target_value)
+ except Exception as e:
+ return False
diff --git a/apps/application/flow/step_node/condition_node/compare/lt_compare.py b/apps/application/flow/step_node/condition_node/compare/lt_compare.py
new file mode 100644
index 000000000..d2d5be748
--- /dev/null
+++ b/apps/application/flow/step_node/condition_node/compare/lt_compare.py
@@ -0,0 +1,24 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: lt_compare.py
+ @date:2024/6/11 9:52
+ @desc: 小于比较器
+"""
+from typing import List
+
+from application.flow.step_node.condition_node.compare.compare import Compare
+
+
+class LTCompare(Compare):
+
+ def support(self, node_id, fields: List[str], source_value, compare, target_value):
+ if compare == 'lt':
+ return True
+
+ def compare(self, source_value, compare, target_value):
+ try:
+ return float(source_value) < float(target_value)
+ except Exception as e:
+ return False
diff --git a/apps/application/flow/step_node/condition_node/compare/not_contain_compare.py b/apps/application/flow/step_node/condition_node/compare/not_contain_compare.py
new file mode 100644
index 000000000..f95b237dd
--- /dev/null
+++ b/apps/application/flow/step_node/condition_node/compare/not_contain_compare.py
@@ -0,0 +1,23 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: contain_compare.py
+ @date:2024/6/11 10:02
+ @desc:
+"""
+from typing import List
+
+from application.flow.step_node.condition_node.compare.compare import Compare
+
+
+class NotContainCompare(Compare):
+
+ def support(self, node_id, fields: List[str], source_value, compare, target_value):
+ if compare == 'not_contain':
+ return True
+
+ def compare(self, source_value, compare, target_value):
+ if isinstance(source_value, str):
+ return str(target_value) not in source_value
+ return not any([str(item) == str(target_value) for item in source_value])
diff --git a/apps/application/flow/step_node/condition_node/i_condition_node.py b/apps/application/flow/step_node/condition_node/i_condition_node.py
new file mode 100644
index 000000000..a0e9814ff
--- /dev/null
+++ b/apps/application/flow/step_node/condition_node/i_condition_node.py
@@ -0,0 +1,39 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: i_condition_node.py
+ @date:2024/6/7 9:54
+ @desc:
+"""
+from typing import Type
+
+from django.utils.translation import gettext_lazy as _
+from rest_framework import serializers
+
+from application.flow.i_step_node import INode
+from common.util.field_message import ErrMessage
+
+
+class ConditionSerializer(serializers.Serializer):
+ compare = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Comparator")))
+ value = serializers.CharField(required=True, error_messages=ErrMessage.char(_("value")))
+ field = serializers.ListField(required=True, error_messages=ErrMessage.char(_("Fields")))
+
+
+class ConditionBranchSerializer(serializers.Serializer):
+ id = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Branch id")))
+ type = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Branch Type")))
+ condition = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Condition or|and")))
+ conditions = ConditionSerializer(many=True)
+
+
+class ConditionNodeParamsSerializer(serializers.Serializer):
+ branch = ConditionBranchSerializer(many=True)
+
+
+class IConditionNode(INode):
+ def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
+ return ConditionNodeParamsSerializer
+
+ type = 'condition-node'
diff --git a/apps/application/flow/step_node/condition_node/impl/__init__.py b/apps/application/flow/step_node/condition_node/impl/__init__.py
new file mode 100644
index 000000000..c21cd3ebb
--- /dev/null
+++ b/apps/application/flow/step_node/condition_node/impl/__init__.py
@@ -0,0 +1,9 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: __init__.py
+ @date:2024/6/11 15:35
+ @desc:
+"""
+from .base_condition_node import BaseConditionNode
diff --git a/apps/application/flow/step_node/condition_node/impl/base_condition_node.py b/apps/application/flow/step_node/condition_node/impl/base_condition_node.py
new file mode 100644
index 000000000..109029be2
--- /dev/null
+++ b/apps/application/flow/step_node/condition_node/impl/base_condition_node.py
@@ -0,0 +1,62 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: base_condition_node.py
+ @date:2024/6/7 11:29
+ @desc:
+"""
+from typing import List
+
+from application.flow.i_step_node import NodeResult
+from application.flow.step_node.condition_node.compare import compare_handle_list
+from application.flow.step_node.condition_node.i_condition_node import IConditionNode
+
+
+class BaseConditionNode(IConditionNode):
+ def save_context(self, details, workflow_manage):
+ self.context['branch_id'] = details.get('branch_id')
+ self.context['branch_name'] = details.get('branch_name')
+
+ def execute(self, **kwargs) -> NodeResult:
+ branch_list = self.node_params_serializer.data['branch']
+ branch = self._execute(branch_list)
+ r = NodeResult({'branch_id': branch.get('id'), 'branch_name': branch.get('type')}, {})
+ return r
+
+ def _execute(self, branch_list: List):
+ for branch in branch_list:
+ if self.branch_assertion(branch):
+ return branch
+
+ def branch_assertion(self, branch):
+ condition_list = [self.assertion(row.get('field'), row.get('compare'), row.get('value')) for row in
+ branch.get('conditions')]
+ condition = branch.get('condition')
+ return all(condition_list) if condition == 'and' else any(condition_list)
+
+ def assertion(self, field_list: List[str], compare: str, value):
+ try:
+ value = self.workflow_manage.generate_prompt(value)
+ except Exception as e:
+ pass
+ field_value = None
+ try:
+ field_value = self.workflow_manage.get_reference_field(field_list[0], field_list[1:])
+ except Exception as e:
+ pass
+ for compare_handler in compare_handle_list:
+ if compare_handler.support(field_list[0], field_list[1:], field_value, compare, value):
+ return compare_handler.compare(field_value, compare, value)
+
+ def get_details(self, index: int, **kwargs):
+ return {
+ 'name': self.node.properties.get('stepName'),
+ "index": index,
+ 'run_time': self.context.get('run_time'),
+ 'branch_id': self.context.get('branch_id'),
+ 'branch_name': self.context.get('branch_name'),
+ 'type': self.node.type,
+ 'status': self.status,
+ 'err_message': self.err_message
+ }
diff --git a/apps/application/flow/step_node/direct_reply_node/__init__.py b/apps/application/flow/step_node/direct_reply_node/__init__.py
new file mode 100644
index 000000000..cf360f956
--- /dev/null
+++ b/apps/application/flow/step_node/direct_reply_node/__init__.py
@@ -0,0 +1,9 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: __init__.py
+ @date:2024/6/11 17:50
+ @desc:
+"""
+from .impl import *
\ No newline at end of file
diff --git a/apps/application/flow/step_node/direct_reply_node/i_reply_node.py b/apps/application/flow/step_node/direct_reply_node/i_reply_node.py
new file mode 100644
index 000000000..d60541b18
--- /dev/null
+++ b/apps/application/flow/step_node/direct_reply_node/i_reply_node.py
@@ -0,0 +1,48 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: i_reply_node.py
+ @date:2024/6/11 16:25
+ @desc:
+"""
+from typing import Type
+
+from rest_framework import serializers
+
+from application.flow.i_step_node import INode, NodeResult
+from common.exception.app_exception import AppApiException
+from common.util.field_message import ErrMessage
+from django.utils.translation import gettext_lazy as _
+
+
+class ReplyNodeParamsSerializer(serializers.Serializer):
+ reply_type = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Response Type")))
+ fields = serializers.ListField(required=False, error_messages=ErrMessage.list(_("Reference Field")))
+ content = serializers.CharField(required=False, allow_blank=True, allow_null=True,
+ error_messages=ErrMessage.char(_("Direct answer content")))
+ is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean(_('Whether to return content')))
+
+ def is_valid(self, *, raise_exception=False):
+ super().is_valid(raise_exception=True)
+ if self.data.get('reply_type') == 'referencing':
+ if 'fields' not in self.data:
+ raise AppApiException(500, _("Reference field cannot be empty"))
+ if len(self.data.get('fields')) < 2:
+ raise AppApiException(500, _("Reference field error"))
+ else:
+ if 'content' not in self.data or self.data.get('content') is None:
+ raise AppApiException(500, _("Content cannot be empty"))
+
+
+class IReplyNode(INode):
+ type = 'reply-node'
+
+ def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
+ return ReplyNodeParamsSerializer
+
+ def _run(self):
+ return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data)
+
+ def execute(self, reply_type, stream, fields=None, content=None, **kwargs) -> NodeResult:
+ pass
diff --git a/apps/application/flow/step_node/direct_reply_node/impl/__init__.py b/apps/application/flow/step_node/direct_reply_node/impl/__init__.py
new file mode 100644
index 000000000..3307e9089
--- /dev/null
+++ b/apps/application/flow/step_node/direct_reply_node/impl/__init__.py
@@ -0,0 +1,9 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: __init__.py
+ @date:2024/6/11 17:49
+ @desc:
+"""
+from .base_reply_node import *
\ No newline at end of file
diff --git a/apps/application/flow/step_node/direct_reply_node/impl/base_reply_node.py b/apps/application/flow/step_node/direct_reply_node/impl/base_reply_node.py
new file mode 100644
index 000000000..1d3115e4c
--- /dev/null
+++ b/apps/application/flow/step_node/direct_reply_node/impl/base_reply_node.py
@@ -0,0 +1,45 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: base_reply_node.py
+ @date:2024/6/11 17:25
+ @desc:
+"""
+from typing import List
+
+from application.flow.i_step_node import NodeResult
+from application.flow.step_node.direct_reply_node.i_reply_node import IReplyNode
+
+
+class BaseReplyNode(IReplyNode):
+ def save_context(self, details, workflow_manage):
+ self.context['answer'] = details.get('answer')
+ if self.node_params.get('is_result', False):
+ self.answer_text = details.get('answer')
+
+ def execute(self, reply_type, stream, fields=None, content=None, **kwargs) -> NodeResult:
+ if reply_type == 'referencing':
+ result = self.get_reference_content(fields)
+ else:
+ result = self.generate_reply_content(content)
+ return NodeResult({'answer': result}, {})
+
+ def generate_reply_content(self, prompt):
+ return self.workflow_manage.generate_prompt(prompt)
+
+ def get_reference_content(self, fields: List[str]):
+ return str(self.workflow_manage.get_reference_field(
+ fields[0],
+ fields[1:]))
+
+ def get_details(self, index: int, **kwargs):
+ return {
+ 'name': self.node.properties.get('stepName'),
+ "index": index,
+ 'run_time': self.context.get('run_time'),
+ 'type': self.node.type,
+ 'answer': self.context.get('answer'),
+ 'status': self.status,
+ 'err_message': self.err_message
+ }
diff --git a/apps/application/flow/step_node/document_extract_node/__init__.py b/apps/application/flow/step_node/document_extract_node/__init__.py
new file mode 100644
index 000000000..ce8f10f3e
--- /dev/null
+++ b/apps/application/flow/step_node/document_extract_node/__init__.py
@@ -0,0 +1 @@
+from .impl import *
\ No newline at end of file
diff --git a/apps/application/flow/step_node/document_extract_node/i_document_extract_node.py b/apps/application/flow/step_node/document_extract_node/i_document_extract_node.py
new file mode 100644
index 000000000..93d2b5b98
--- /dev/null
+++ b/apps/application/flow/step_node/document_extract_node/i_document_extract_node.py
@@ -0,0 +1,28 @@
+# coding=utf-8
+
+from typing import Type
+
+from django.utils.translation import gettext_lazy as _
+from rest_framework import serializers
+
+from application.flow.i_step_node import INode, NodeResult
+from common.util.field_message import ErrMessage
+
+
+class DocumentExtractNodeSerializer(serializers.Serializer):
+ document_list = serializers.ListField(required=False, error_messages=ErrMessage.list(_("document")))
+
+
+class IDocumentExtractNode(INode):
+ type = 'document-extract-node'
+
+ def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
+ return DocumentExtractNodeSerializer
+
+ def _run(self):
+ res = self.workflow_manage.get_reference_field(self.node_params_serializer.data.get('document_list')[0],
+ self.node_params_serializer.data.get('document_list')[1:])
+ return self.execute(document=res, **self.flow_params_serializer.data)
+
+ def execute(self, document, chat_id, **kwargs) -> NodeResult:
+ pass
diff --git a/apps/application/flow/step_node/document_extract_node/impl/__init__.py b/apps/application/flow/step_node/document_extract_node/impl/__init__.py
new file mode 100644
index 000000000..cf9d55ecd
--- /dev/null
+++ b/apps/application/flow/step_node/document_extract_node/impl/__init__.py
@@ -0,0 +1 @@
+from .base_document_extract_node import BaseDocumentExtractNode
diff --git a/apps/application/flow/step_node/document_extract_node/impl/base_document_extract_node.py b/apps/application/flow/step_node/document_extract_node/impl/base_document_extract_node.py
new file mode 100644
index 000000000..6ddcb6e2f
--- /dev/null
+++ b/apps/application/flow/step_node/document_extract_node/impl/base_document_extract_node.py
@@ -0,0 +1,94 @@
+# coding=utf-8
+import io
+import mimetypes
+
+from django.core.files.uploadedfile import InMemoryUploadedFile
+from django.db.models import QuerySet
+
+from application.flow.i_step_node import NodeResult
+from application.flow.step_node.document_extract_node.i_document_extract_node import IDocumentExtractNode
+from dataset.models import File
+from dataset.serializers.document_serializers import split_handles, parse_table_handle_list, FileBufferHandle
+from dataset.serializers.file_serializers import FileSerializer
+
+
+def bytes_to_uploaded_file(file_bytes, file_name="file.txt"):
+ content_type, _ = mimetypes.guess_type(file_name)
+ if content_type is None:
+ # 如果未能识别,设置为默认的二进制文件类型
+ content_type = "application/octet-stream"
+ # 创建一个内存中的字节流对象
+ file_stream = io.BytesIO(file_bytes)
+
+ # 获取文件大小
+ file_size = len(file_bytes)
+
+ # 创建 InMemoryUploadedFile 对象
+ uploaded_file = InMemoryUploadedFile(
+ file=file_stream,
+ field_name=None,
+ name=file_name,
+ content_type=content_type,
+ size=file_size,
+ charset=None,
+ )
+ return uploaded_file
+
+
+splitter = '\n`-----------------------------------`\n'
+
+class BaseDocumentExtractNode(IDocumentExtractNode):
+ def save_context(self, details, workflow_manage):
+ self.context['content'] = details.get('content')
+
+
+ def execute(self, document, chat_id, **kwargs):
+ get_buffer = FileBufferHandle().get_buffer
+
+ self.context['document_list'] = document
+ content = []
+ if document is None or not isinstance(document, list):
+ return NodeResult({'content': ''}, {})
+
+ application = self.workflow_manage.work_flow_post_handler.chat_info.application
+
+ # doc文件中的图片保存
+ def save_image(image_list):
+ for image in image_list:
+ meta = {
+ 'debug': False if application.id else True,
+ 'chat_id': chat_id,
+ 'application_id': str(application.id) if application.id else None,
+ 'file_id': str(image.id)
+ }
+ file = bytes_to_uploaded_file(image.image, image.image_name)
+ FileSerializer(data={'file': file, 'meta': meta}).upload()
+
+ for doc in document:
+ file = QuerySet(File).filter(id=doc['file_id']).first()
+ buffer = io.BytesIO(file.get_byte().tobytes())
+ buffer.name = doc['name'] # this is the important line
+
+ for split_handle in (parse_table_handle_list + split_handles):
+ if split_handle.support(buffer, get_buffer):
+ # 回到文件头
+ buffer.seek(0)
+ file_content = split_handle.get_content(buffer, save_image)
+ content.append('### ' + doc['name'] + '\n' + file_content)
+ break
+
+ return NodeResult({'content': splitter.join(content)}, {})
+
+ def get_details(self, index: int, **kwargs):
+ content = self.context.get('content', '').split(splitter)
+ # 不保存content全部内容,因为content内容可能会很大
+ return {
+ 'name': self.node.properties.get('stepName'),
+ "index": index,
+ 'run_time': self.context.get('run_time'),
+ 'type': self.node.type,
+ 'content': [file_content[:500] for file_content in content],
+ 'status': self.status,
+ 'err_message': self.err_message,
+ 'document_list': self.context.get('document_list')
+ }
diff --git a/apps/application/flow/step_node/form_node/__init__.py b/apps/application/flow/step_node/form_node/__init__.py
new file mode 100644
index 000000000..ce04b64ae
--- /dev/null
+++ b/apps/application/flow/step_node/form_node/__init__.py
@@ -0,0 +1,9 @@
+# coding=utf-8
+"""
+ @project: MaxKB
+ @Author:虎
+ @file: __init__.py.py
+ @date:2024/11/4 14:48
+ @desc:
+"""
+from .impl import *
\ No newline at end of file
diff --git a/apps/application/flow/step_node/form_node/i_form_node.py b/apps/application/flow/step_node/form_node/i_form_node.py
new file mode 100644
index 000000000..7e8249429
--- /dev/null
+++ b/apps/application/flow/step_node/form_node/i_form_node.py
@@ -0,0 +1,35 @@
+# coding=utf-8
+"""
+ @project: MaxKB
+ @Author:虎
+ @file: i_form_node.py
+ @date:2024/11/4 14:48
+ @desc:
+"""
+from typing import Type
+
+from rest_framework import serializers
+
+from application.flow.i_step_node import INode, NodeResult
+from common.util.field_message import ErrMessage
+from django.utils.translation import gettext_lazy as _
+
+
+class FormNodeParamsSerializer(serializers.Serializer):
+ form_field_list = serializers.ListField(required=True, error_messages=ErrMessage.list(_("Form Configuration")))
+ form_content_format = serializers.CharField(required=True, error_messages=ErrMessage.char(_('Form output content')))
+ form_data = serializers.DictField(required=False, allow_null=True, error_messages=ErrMessage.dict(_("Form Data")))
+
+
+class IFormNode(INode):
+ type = 'form-node'
+ view_type = 'single_view'
+
+ def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
+ return FormNodeParamsSerializer
+
+ def _run(self):
+ return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data)
+
+ def execute(self, form_field_list, form_content_format, form_data, **kwargs) -> NodeResult:
+ pass
diff --git a/apps/application/flow/step_node/form_node/impl/__init__.py b/apps/application/flow/step_node/form_node/impl/__init__.py
new file mode 100644
index 000000000..4cea85e1d
--- /dev/null
+++ b/apps/application/flow/step_node/form_node/impl/__init__.py
@@ -0,0 +1,9 @@
+# coding=utf-8
+"""
+ @project: MaxKB
+ @Author:虎
+ @file: __init__.py.py
+ @date:2024/11/4 14:49
+ @desc:
+"""
+from .base_form_node import BaseFormNode
diff --git a/apps/application/flow/step_node/form_node/impl/base_form_node.py b/apps/application/flow/step_node/form_node/impl/base_form_node.py
new file mode 100644
index 000000000..dcf35dd3c
--- /dev/null
+++ b/apps/application/flow/step_node/form_node/impl/base_form_node.py
@@ -0,0 +1,107 @@
+# coding=utf-8
+"""
+ @project: MaxKB
+ @Author:虎
+ @file: base_form_node.py
+ @date:2024/11/4 14:52
+ @desc:
+"""
+import json
+import time
+from typing import Dict, List
+
+from langchain_core.prompts import PromptTemplate
+
+from application.flow.common import Answer
+from application.flow.i_step_node import NodeResult
+from application.flow.step_node.form_node.i_form_node import IFormNode
+
+
+def write_context(step_variable: Dict, global_variable: Dict, node, workflow):
+ if step_variable is not None:
+ for key in step_variable:
+ node.context[key] = step_variable[key]
+ if workflow.is_result(node, NodeResult(step_variable, global_variable)) and 'result' in step_variable:
+ result = step_variable['result']
+ yield result
+ node.answer_text = result
+ node.context['run_time'] = time.time() - node.context['start_time']
+
+
+class BaseFormNode(IFormNode):
+ def save_context(self, details, workflow_manage):
+ form_data = details.get('form_data', None)
+ self.context['result'] = details.get('result')
+ self.context['form_content_format'] = details.get('form_content_format')
+ self.context['form_field_list'] = details.get('form_field_list')
+ self.context['run_time'] = details.get('run_time')
+ self.context['start_time'] = details.get('start_time')
+ self.context['form_data'] = form_data
+ self.context['is_submit'] = details.get('is_submit')
+ if self.node_params.get('is_result', False):
+ self.answer_text = details.get('result')
+ if form_data is not None:
+ for key in form_data:
+ self.context[key] = form_data[key]
+
+ def execute(self, form_field_list, form_content_format, form_data, **kwargs) -> NodeResult:
+ if form_data is not None:
+ self.context['is_submit'] = True
+ self.context['form_data'] = form_data
+ for key in form_data:
+ self.context[key] = form_data.get(key)
+ else:
+ self.context['is_submit'] = False
+ form_setting = {"form_field_list": form_field_list, "runtime_node_id": self.runtime_node_id,
+ "chat_record_id": self.flow_params_serializer.data.get("chat_record_id"),
+ "is_submit": self.context.get("is_submit", False)}
+ form = f'{json.dumps(form_setting, ensure_ascii=False)}'
+ context = self.workflow_manage.get_workflow_content()
+ form_content_format = self.workflow_manage.reset_prompt(form_content_format)
+ prompt_template = PromptTemplate.from_template(form_content_format, template_format='jinja2')
+ value = prompt_template.format(form=form, context=context)
+ return NodeResult(
+ {'result': value, 'form_field_list': form_field_list, 'form_content_format': form_content_format}, {},
+ _write_context=write_context)
+
+ def get_answer_list(self) -> List[Answer] | None:
+ form_content_format = self.context.get('form_content_format')
+ form_field_list = self.context.get('form_field_list')
+ form_setting = {"form_field_list": form_field_list, "runtime_node_id": self.runtime_node_id,
+ "chat_record_id": self.flow_params_serializer.data.get("chat_record_id"),
+ 'form_data': self.context.get('form_data', {}),
+ "is_submit": self.context.get("is_submit", False)}
+ form = f'{json.dumps(form_setting, ensure_ascii=False)}'
+ context = self.workflow_manage.get_workflow_content()
+ form_content_format = self.workflow_manage.reset_prompt(form_content_format)
+ prompt_template = PromptTemplate.from_template(form_content_format, template_format='jinja2')
+ value = prompt_template.format(form=form, context=context)
+ return [Answer(value, self.view_type, self.runtime_node_id, self.workflow_params['chat_record_id'], None,
+ self.runtime_node_id, '')]
+
+ def get_details(self, index: int, **kwargs):
+ form_content_format = self.context.get('form_content_format')
+ form_field_list = self.context.get('form_field_list')
+ form_setting = {"form_field_list": form_field_list, "runtime_node_id": self.runtime_node_id,
+ "chat_record_id": self.flow_params_serializer.data.get("chat_record_id"),
+ 'form_data': self.context.get('form_data', {}),
+ "is_submit": self.context.get("is_submit", False)}
+ form = f'{json.dumps(form_setting, ensure_ascii=False)}'
+ context = self.workflow_manage.get_workflow_content()
+ form_content_format = self.workflow_manage.reset_prompt(form_content_format)
+ prompt_template = PromptTemplate.from_template(form_content_format, template_format='jinja2')
+ value = prompt_template.format(form=form, context=context)
+ return {
+ 'name': self.node.properties.get('stepName'),
+ "index": index,
+ "result": value,
+ "form_content_format": self.context.get('form_content_format'),
+ "form_field_list": self.context.get('form_field_list'),
+ 'form_data': self.context.get('form_data'),
+ 'start_time': self.context.get('start_time'),
+ 'is_submit': self.context.get('is_submit'),
+ 'run_time': self.context.get('run_time'),
+ 'type': self.node.type,
+ 'status': self.status,
+ 'err_message': self.err_message
+ }
diff --git a/apps/application/flow/step_node/function_lib_node/__init__.py b/apps/application/flow/step_node/function_lib_node/__init__.py
new file mode 100644
index 000000000..7422965c3
--- /dev/null
+++ b/apps/application/flow/step_node/function_lib_node/__init__.py
@@ -0,0 +1,9 @@
+# coding=utf-8
+"""
+ @project: MaxKB
+ @Author:虎
+ @file: __init__.py
+ @date:2024/8/8 17:45
+ @desc:
+"""
+from .impl import *
\ No newline at end of file
diff --git a/apps/application/flow/step_node/function_lib_node/i_function_lib_node.py b/apps/application/flow/step_node/function_lib_node/i_function_lib_node.py
new file mode 100644
index 000000000..c84782ff6
--- /dev/null
+++ b/apps/application/flow/step_node/function_lib_node/i_function_lib_node.py
@@ -0,0 +1,48 @@
+# coding=utf-8
+"""
+ @project: MaxKB
+ @Author:虎
+ @file: i_function_lib_node.py
+ @date:2024/8/8 16:21
+ @desc:
+"""
+from typing import Type
+
+from django.db.models import QuerySet
+from rest_framework import serializers
+
+from application.flow.i_step_node import INode, NodeResult
+from common.field.common import ObjectField
+from common.util.field_message import ErrMessage
+from function_lib.models.function import FunctionLib
+from django.utils.translation import gettext_lazy as _
+
+
+class InputField(serializers.Serializer):
+ name = serializers.CharField(required=True, error_messages=ErrMessage.char(_('Variable Name')))
+ value = ObjectField(required=True, error_messages=ErrMessage.char(_("Variable Value")), model_type_list=[str, list])
+
+
+class FunctionLibNodeParamsSerializer(serializers.Serializer):
+ function_lib_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid(_('Library ID')))
+ input_field_list = InputField(required=True, many=True)
+ is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean(_('Whether to return content')))
+
+ def is_valid(self, *, raise_exception=False):
+ super().is_valid(raise_exception=True)
+ f_lib = QuerySet(FunctionLib).filter(id=self.data.get('function_lib_id')).first()
+ if f_lib is None:
+ raise Exception(_('The function has been deleted'))
+
+
+class IFunctionLibNode(INode):
+ type = 'function-lib-node'
+
+ def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
+ return FunctionLibNodeParamsSerializer
+
+ def _run(self):
+ return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data)
+
+ def execute(self, function_lib_id, input_field_list, **kwargs) -> NodeResult:
+ pass
diff --git a/apps/application/flow/step_node/function_lib_node/impl/__init__.py b/apps/application/flow/step_node/function_lib_node/impl/__init__.py
new file mode 100644
index 000000000..96681474f
--- /dev/null
+++ b/apps/application/flow/step_node/function_lib_node/impl/__init__.py
@@ -0,0 +1,9 @@
+# coding=utf-8
+"""
+ @project: MaxKB
+ @Author:虎
+ @file: __init__.py
+ @date:2024/8/8 17:48
+ @desc:
+"""
+from .base_function_lib_node import BaseFunctionLibNodeNode
diff --git a/apps/application/flow/step_node/function_lib_node/impl/base_function_lib_node.py b/apps/application/flow/step_node/function_lib_node/impl/base_function_lib_node.py
new file mode 100644
index 000000000..341bb91da
--- /dev/null
+++ b/apps/application/flow/step_node/function_lib_node/impl/base_function_lib_node.py
@@ -0,0 +1,150 @@
+# coding=utf-8
+"""
+ @project: MaxKB
+ @Author:虎
+ @file: base_function_lib_node.py
+ @date:2024/8/8 17:49
+ @desc:
+"""
+import json
+import time
+from typing import Dict
+
+from django.db.models import QuerySet
+from django.utils.translation import gettext as _
+
+from application.flow.i_step_node import NodeResult
+from application.flow.step_node.function_lib_node.i_function_lib_node import IFunctionLibNode
+from common.exception.app_exception import AppApiException
+from common.util.function_code import FunctionExecutor
+from common.util.rsa_util import rsa_long_decrypt
+from function_lib.models.function import FunctionLib
+from smartdoc.const import CONFIG
+
+function_executor = FunctionExecutor(CONFIG.get('SANDBOX'))
+
+
+def write_context(step_variable: Dict, global_variable: Dict, node, workflow):
+ if step_variable is not None:
+ for key in step_variable:
+ node.context[key] = step_variable[key]
+ if workflow.is_result(node, NodeResult(step_variable, global_variable)) and 'result' in step_variable:
+ result = str(step_variable['result']) + '\n'
+ yield result
+ node.answer_text = result
+ node.context['run_time'] = time.time() - node.context['start_time']
+
+
+def get_field_value(debug_field_list, name, is_required):
+ result = [field for field in debug_field_list if field.get('name') == name]
+ if len(result) > 0:
+ return result[-1]['value']
+ if is_required:
+ raise AppApiException(500, _('Field: {name} No value set').format(name=name))
+ return None
+
+
+def valid_reference_value(_type, value, name):
+ if _type == 'int':
+ instance_type = int | float
+ elif _type == 'float':
+ instance_type = float | int
+ elif _type == 'dict':
+ instance_type = dict
+ elif _type == 'array':
+ instance_type = list
+ elif _type == 'string':
+ instance_type = str
+ else:
+ raise Exception(_('Field: {name} Type: {_type} Value: {value} Unsupported types').format(name=name,
+ _type=_type))
+ if not isinstance(value, instance_type):
+ raise Exception(
+ _('Field: {name} Type: {_type} Value: {value} Type error').format(name=name, _type=_type,
+ value=value))
+
+
+def convert_value(name: str, value, _type, is_required, source, node):
+ if not is_required and (value is None or (isinstance(value, str) and len(value) == 0)):
+ return None
+ if not is_required and source == 'reference' and (value is None or len(value) == 0):
+ return None
+ if source == 'reference':
+ value = node.workflow_manage.get_reference_field(
+ value[0],
+ value[1:])
+ valid_reference_value(_type, value, name)
+ if _type == 'int':
+ return int(value)
+ if _type == 'float':
+ return float(value)
+ return value
+ try:
+ if _type == 'int':
+ return int(value)
+ if _type == 'float':
+ return float(value)
+ if _type == 'dict':
+ v = json.loads(value)
+ if isinstance(v, dict):
+ return v
+ raise Exception(_('type error'))
+ if _type == 'array':
+ v = json.loads(value)
+ if isinstance(v, list):
+ return v
+ raise Exception(_('type error'))
+ return value
+ except Exception as e:
+ raise Exception(
+ _('Field: {name} Type: {_type} Value: {value} Type error').format(name=name, _type=_type,
+ value=value))
+
+
+def valid_function(function_lib, user_id):
+ if function_lib is None:
+ raise Exception(_('Function does not exist'))
+ if function_lib.permission_type == 'PRIVATE' and str(function_lib.user_id) != str(user_id):
+ raise Exception(_('No permission to use this function {name}').format(name=function_lib.name))
+ if not function_lib.is_active:
+ raise Exception(_('Function {name} is unavailable').format(name=function_lib.name))
+
+
+class BaseFunctionLibNodeNode(IFunctionLibNode):
+ def save_context(self, details, workflow_manage):
+ self.context['result'] = details.get('result')
+ if self.node_params.get('is_result'):
+ self.answer_text = str(details.get('result'))
+
+ def execute(self, function_lib_id, input_field_list, **kwargs) -> NodeResult:
+ function_lib = QuerySet(FunctionLib).filter(id=function_lib_id).first()
+ valid_function(function_lib, self.flow_params_serializer.data.get('user_id'))
+ params = {field.get('name'): convert_value(field.get('name'), field.get('value'), field.get('type'),
+ field.get('is_required'),
+ field.get('source'), self)
+ for field in
+ [{'value': get_field_value(input_field_list, field.get('name'), field.get('is_required'),
+ ), **field}
+ for field in
+ function_lib.input_field_list]}
+
+ self.context['params'] = params
+ # 合并初始化参数
+ if function_lib.init_params is not None:
+ all_params = json.loads(rsa_long_decrypt(function_lib.init_params)) | params
+ else:
+ all_params = params
+ result = function_executor.exec_code(function_lib.code, all_params)
+ return NodeResult({'result': result}, {}, _write_context=write_context)
+
+ def get_details(self, index: int, **kwargs):
+ return {
+ 'name': self.node.properties.get('stepName'),
+ "index": index,
+ "result": self.context.get('result'),
+ "params": self.context.get('params'),
+ 'run_time': self.context.get('run_time'),
+ 'type': self.node.type,
+ 'status': self.status,
+ 'err_message': self.err_message
+ }
diff --git a/apps/application/flow/step_node/function_node/__init__.py b/apps/application/flow/step_node/function_node/__init__.py
new file mode 100644
index 000000000..ebfbe8d8b
--- /dev/null
+++ b/apps/application/flow/step_node/function_node/__init__.py
@@ -0,0 +1,9 @@
+# coding=utf-8
+"""
+ @project: MaxKB
+ @Author:虎
+ @file: __init__.py.py
+ @date:2024/8/13 10:43
+ @desc:
+"""
+from .impl import *
\ No newline at end of file
diff --git a/apps/application/flow/step_node/function_node/i_function_node.py b/apps/application/flow/step_node/function_node/i_function_node.py
new file mode 100644
index 000000000..bbaae6c73
--- /dev/null
+++ b/apps/application/flow/step_node/function_node/i_function_node.py
@@ -0,0 +1,63 @@
+# coding=utf-8
+"""
+ @project: MaxKB
+ @Author:虎
+ @file: i_function_lib_node.py
+ @date:2024/8/8 16:21
+ @desc:
+"""
+import re
+from typing import Type
+
+from django.core import validators
+from rest_framework import serializers
+
+from application.flow.i_step_node import INode, NodeResult
+from common.exception.app_exception import AppApiException
+from common.field.common import ObjectField
+from common.util.field_message import ErrMessage
+from django.utils.translation import gettext_lazy as _
+from rest_framework.utils.formatting import lazy_format
+
+
+class InputField(serializers.Serializer):
+ name = serializers.CharField(required=True, error_messages=ErrMessage.char(_('Variable Name')))
+ is_required = serializers.BooleanField(required=True, error_messages=ErrMessage.boolean(_("Is this field required")))
+ type = serializers.CharField(required=True, error_messages=ErrMessage.char(_("type")), validators=[
+ validators.RegexValidator(regex=re.compile("^string|int|dict|array|float$"),
+ message=_("The field only supports string|int|dict|array|float"), code=500)
+ ])
+ source = serializers.CharField(required=True, error_messages=ErrMessage.char(_("source")), validators=[
+ validators.RegexValidator(regex=re.compile("^custom|reference$"),
+ message=_("The field only supports custom|reference"), code=500)
+ ])
+ value = ObjectField(required=True, error_messages=ErrMessage.char(_("Variable Value")), model_type_list=[str, list])
+
+ def is_valid(self, *, raise_exception=False):
+ super().is_valid(raise_exception=True)
+ is_required = self.data.get('is_required')
+ if is_required and self.data.get('value') is None:
+ message = lazy_format(_('{field}, this field is required.'), field=self.data.get("name"))
+ raise AppApiException(500, message)
+
+
+class FunctionNodeParamsSerializer(serializers.Serializer):
+ input_field_list = InputField(required=True, many=True)
+ code = serializers.CharField(required=True, error_messages=ErrMessage.char(_("function")))
+ is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean(_('Whether to return content')))
+
+ def is_valid(self, *, raise_exception=False):
+ super().is_valid(raise_exception=True)
+
+
+class IFunctionNode(INode):
+ type = 'function-node'
+
+ def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
+ return FunctionNodeParamsSerializer
+
+ def _run(self):
+ return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data)
+
+ def execute(self, input_field_list, code, **kwargs) -> NodeResult:
+ pass
diff --git a/apps/application/flow/step_node/function_node/impl/__init__.py b/apps/application/flow/step_node/function_node/impl/__init__.py
new file mode 100644
index 000000000..1a096368f
--- /dev/null
+++ b/apps/application/flow/step_node/function_node/impl/__init__.py
@@ -0,0 +1,9 @@
+# coding=utf-8
+"""
+ @project: MaxKB
+ @Author:虎
+ @file: __init__.py.py
+ @date:2024/8/13 11:19
+ @desc:
+"""
+from .base_function_node import BaseFunctionNodeNode
diff --git a/apps/application/flow/step_node/function_node/impl/base_function_node.py b/apps/application/flow/step_node/function_node/impl/base_function_node.py
new file mode 100644
index 000000000..d659227f1
--- /dev/null
+++ b/apps/application/flow/step_node/function_node/impl/base_function_node.py
@@ -0,0 +1,108 @@
+# coding=utf-8
+"""
+ @project: MaxKB
+ @Author:虎
+ @file: base_function_lib_node.py
+ @date:2024/8/8 17:49
+ @desc:
+"""
+import json
+import time
+
+from typing import Dict
+
+from application.flow.i_step_node import NodeResult
+from application.flow.step_node.function_node.i_function_node import IFunctionNode
+from common.exception.app_exception import AppApiException
+from common.util.function_code import FunctionExecutor
+from smartdoc.const import CONFIG
+
+function_executor = FunctionExecutor(CONFIG.get('SANDBOX'))
+
+
+def write_context(step_variable: Dict, global_variable: Dict, node, workflow):
+ if step_variable is not None:
+ for key in step_variable:
+ node.context[key] = step_variable[key]
+ if workflow.is_result(node, NodeResult(step_variable, global_variable)) and 'result' in step_variable:
+ result = str(step_variable['result']) + '\n'
+ yield result
+ node.answer_text = result
+ node.context['run_time'] = time.time() - node.context['start_time']
+
+
+def valid_reference_value(_type, value, name):
+ if _type == 'int':
+ instance_type = int | float
+ elif _type == 'float':
+ instance_type = float | int
+ elif _type == 'dict':
+ instance_type = dict
+ elif _type == 'array':
+ instance_type = list
+ elif _type == 'string':
+ instance_type = str
+ else:
+ raise Exception(500, f'字段:{name}类型:{_type} 不支持的类型')
+ if not isinstance(value, instance_type):
+ raise Exception(f'字段:{name}类型:{_type}值:{value}类型错误')
+
+
+def convert_value(name: str, value, _type, is_required, source, node):
+ if not is_required and (value is None or (isinstance(value, str) and len(value) == 0)):
+ return None
+ if source == 'reference':
+ value = node.workflow_manage.get_reference_field(
+ value[0],
+ value[1:])
+ valid_reference_value(_type, value, name)
+ if _type == 'int':
+ return int(value)
+ if _type == 'float':
+ return float(value)
+ return value
+ try:
+ if _type == 'int':
+ return int(value)
+ if _type == 'float':
+ return float(value)
+ if _type == 'dict':
+ v = json.loads(value)
+ if isinstance(v, dict):
+ return v
+ raise Exception("类型错误")
+ if _type == 'array':
+ v = json.loads(value)
+ if isinstance(v, list):
+ return v
+ raise Exception("类型错误")
+ return value
+ except Exception as e:
+ raise Exception(f'字段:{name}类型:{_type}值:{value}类型错误')
+
+
+class BaseFunctionNodeNode(IFunctionNode):
+ def save_context(self, details, workflow_manage):
+ self.context['result'] = details.get('result')
+ if self.node_params.get('is_result', False):
+ self.answer_text = str(details.get('result'))
+
+ def execute(self, input_field_list, code, **kwargs) -> NodeResult:
+ params = {field.get('name'): convert_value(field.get('name'), field.get('value'), field.get('type'),
+ field.get('is_required'), field.get('source'), self)
+ for field in input_field_list}
+ result = function_executor.exec_code(code, params)
+ self.context['params'] = params
+ return NodeResult({'result': result}, {}, _write_context=write_context)
+
+ def get_details(self, index: int, **kwargs):
+ return {
+ 'name': self.node.properties.get('stepName'),
+ "index": index,
+ "result": self.context.get('result'),
+ "params": self.context.get('params'),
+ 'run_time': self.context.get('run_time'),
+ 'type': self.node.type,
+ 'status': self.status,
+ 'err_message': self.err_message
+ }
diff --git a/apps/application/flow/step_node/image_generate_step_node/__init__.py b/apps/application/flow/step_node/image_generate_step_node/__init__.py
new file mode 100644
index 000000000..f3feecc9c
--- /dev/null
+++ b/apps/application/flow/step_node/image_generate_step_node/__init__.py
@@ -0,0 +1,3 @@
+# coding=utf-8
+
+from .impl import *
diff --git a/apps/application/flow/step_node/image_generate_step_node/i_image_generate_node.py b/apps/application/flow/step_node/image_generate_step_node/i_image_generate_node.py
new file mode 100644
index 000000000..56a214cf9
--- /dev/null
+++ b/apps/application/flow/step_node/image_generate_step_node/i_image_generate_node.py
@@ -0,0 +1,45 @@
+# coding=utf-8
+
+from typing import Type
+
+from rest_framework import serializers
+
+from application.flow.i_step_node import INode, NodeResult
+from common.util.field_message import ErrMessage
+from django.utils.translation import gettext_lazy as _
+
+
+class ImageGenerateNodeSerializer(serializers.Serializer):
+ model_id = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Model id")))
+
+ prompt = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Prompt word (positive)")))
+
+ negative_prompt = serializers.CharField(required=False, error_messages=ErrMessage.char(_("Prompt word (negative)")),
+ allow_null=True, allow_blank=True, )
+ # 多轮对话数量
+ dialogue_number = serializers.IntegerField(required=False, default=0,
+ error_messages=ErrMessage.integer(_("Number of multi-round conversations")))
+
+ dialogue_type = serializers.CharField(required=False, default='NODE',
+ error_messages=ErrMessage.char(_("Conversation storage type")))
+
+ is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean(_('Whether to return content')))
+
+ model_params_setting = serializers.JSONField(required=False, default=dict,
+ error_messages=ErrMessage.json(_("Model parameter settings")))
+
+
+class IImageGenerateNode(INode):
+ type = 'image-generate-node'
+
+ def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
+ return ImageGenerateNodeSerializer
+
+ def _run(self):
+ return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data)
+
+ def execute(self, model_id, prompt, negative_prompt, dialogue_number, dialogue_type, history_chat_record, chat_id,
+ model_params_setting,
+ chat_record_id,
+ **kwargs) -> NodeResult:
+ pass
diff --git a/apps/application/flow/step_node/image_generate_step_node/impl/__init__.py b/apps/application/flow/step_node/image_generate_step_node/impl/__init__.py
new file mode 100644
index 000000000..14a21a915
--- /dev/null
+++ b/apps/application/flow/step_node/image_generate_step_node/impl/__init__.py
@@ -0,0 +1,3 @@
+# coding=utf-8
+
+from .base_image_generate_node import BaseImageGenerateNode
diff --git a/apps/application/flow/step_node/image_generate_step_node/impl/base_image_generate_node.py b/apps/application/flow/step_node/image_generate_step_node/impl/base_image_generate_node.py
new file mode 100644
index 000000000..16423eafd
--- /dev/null
+++ b/apps/application/flow/step_node/image_generate_step_node/impl/base_image_generate_node.py
@@ -0,0 +1,122 @@
+# coding=utf-8
+from functools import reduce
+from typing import List
+
+import requests
+from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
+
+from application.flow.i_step_node import NodeResult
+from application.flow.step_node.image_generate_step_node.i_image_generate_node import IImageGenerateNode
+from common.util.common import bytes_to_uploaded_file
+from dataset.serializers.file_serializers import FileSerializer
+from setting.models_provider.tools import get_model_instance_by_model_user_id
+
+
+class BaseImageGenerateNode(IImageGenerateNode):
+ def save_context(self, details, workflow_manage):
+ self.context['answer'] = details.get('answer')
+ self.context['question'] = details.get('question')
+ if self.node_params.get('is_result', False):
+ self.answer_text = details.get('answer')
+
+ def execute(self, model_id, prompt, negative_prompt, dialogue_number, dialogue_type, history_chat_record, chat_id,
+ model_params_setting,
+ chat_record_id,
+ **kwargs) -> NodeResult:
+ print(model_params_setting)
+ application = self.workflow_manage.work_flow_post_handler.chat_info.application
+ tti_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'),
+ **model_params_setting)
+ history_message = self.get_history_message(history_chat_record, dialogue_number)
+ self.context['history_message'] = history_message
+ question = self.generate_prompt_question(prompt)
+ self.context['question'] = question
+ message_list = self.generate_message_list(question, history_message)
+ self.context['message_list'] = message_list
+ self.context['dialogue_type'] = dialogue_type
+ print(message_list)
+ image_urls = tti_model.generate_image(question, negative_prompt)
+ # 保存图片
+ file_urls = []
+ for image_url in image_urls:
+ file_name = 'generated_image.png'
+ file = bytes_to_uploaded_file(requests.get(image_url).content, file_name)
+ meta = {
+ 'debug': False if application.id else True,
+ 'chat_id': chat_id,
+ 'application_id': str(application.id) if application.id else None,
+ }
+ file_url = FileSerializer(data={'file': file, 'meta': meta}).upload()
+ file_urls.append(file_url)
+ self.context['image_list'] = [{'file_id': path.split('/')[-1], 'url': path} for path in file_urls]
+ answer = ' '.join([f"" for path in file_urls])
+ return NodeResult({'answer': answer, 'chat_model': tti_model, 'message_list': message_list,
+ 'image': [{'file_id': path.split('/')[-1], 'url': path} for path in file_urls],
+ 'history_message': history_message, 'question': question}, {})
+
+ def generate_history_ai_message(self, chat_record):
+ for val in chat_record.details.values():
+ if self.node.id == val['node_id'] and 'image_list' in val:
+ if val['dialogue_type'] == 'WORKFLOW':
+ return chat_record.get_ai_message()
+ image_list = val['image_list']
+ return AIMessage(content=[
+ *[{'type': 'image_url', 'image_url': {'url': f'{file_url}'}} for file_url in image_list]
+ ])
+ return chat_record.get_ai_message()
+
+ def get_history_message(self, history_chat_record, dialogue_number):
+ start_index = len(history_chat_record) - dialogue_number
+ history_message = reduce(lambda x, y: [*x, *y], [
+ [self.generate_history_human_message(history_chat_record[index]),
+ self.generate_history_ai_message(history_chat_record[index])]
+ for index in
+ range(start_index if start_index > 0 else 0, len(history_chat_record))], [])
+ return history_message
+
+ def generate_history_human_message(self, chat_record):
+
+ for data in chat_record.details.values():
+ if self.node.id == data['node_id'] and 'image_list' in data:
+ image_list = data['image_list']
+ if len(image_list) == 0 or data['dialogue_type'] == 'WORKFLOW':
+ return HumanMessage(content=chat_record.problem_text)
+ return HumanMessage(content=data['question'])
+ return HumanMessage(content=chat_record.problem_text)
+
+ def generate_prompt_question(self, prompt):
+ return self.workflow_manage.generate_prompt(prompt)
+
+ def generate_message_list(self, question: str, history_message):
+ return [
+ *history_message,
+ question
+ ]
+
+ @staticmethod
+ def reset_message_list(message_list: List[BaseMessage], answer_text):
+ result = [{'role': 'user' if isinstance(message, HumanMessage) else 'ai', 'content': message.content} for
+ message
+ in
+ message_list]
+ result.append({'role': 'ai', 'content': answer_text})
+ return result
+
+ def get_details(self, index: int, **kwargs):
+ return {
+ 'name': self.node.properties.get('stepName'),
+ "index": index,
+ 'run_time': self.context.get('run_time'),
+ 'history_message': [{'content': message.content, 'role': message.type} for message in
+ (self.context.get('history_message') if self.context.get(
+ 'history_message') is not None else [])],
+ 'question': self.context.get('question'),
+ 'answer': self.context.get('answer'),
+ 'type': self.node.type,
+ 'message_tokens': self.context.get('message_tokens'),
+ 'answer_tokens': self.context.get('answer_tokens'),
+ 'status': self.status,
+ 'err_message': self.err_message,
+ 'image_list': self.context.get('image_list'),
+ 'dialogue_type': self.context.get('dialogue_type')
+ }
diff --git a/apps/application/flow/step_node/image_understand_step_node/__init__.py b/apps/application/flow/step_node/image_understand_step_node/__init__.py
new file mode 100644
index 000000000..f3feecc9c
--- /dev/null
+++ b/apps/application/flow/step_node/image_understand_step_node/__init__.py
@@ -0,0 +1,3 @@
+# coding=utf-8
+
+from .impl import *
diff --git a/apps/application/flow/step_node/image_understand_step_node/i_image_understand_node.py b/apps/application/flow/step_node/image_understand_step_node/i_image_understand_node.py
new file mode 100644
index 000000000..5ef4c1017
--- /dev/null
+++ b/apps/application/flow/step_node/image_understand_step_node/i_image_understand_node.py
@@ -0,0 +1,46 @@
+# coding=utf-8
+
+from typing import Type
+
+from rest_framework import serializers
+
+from application.flow.i_step_node import INode, NodeResult
+from common.util.field_message import ErrMessage
+from django.utils.translation import gettext_lazy as _
+
+
+class ImageUnderstandNodeSerializer(serializers.Serializer):
+ model_id = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Model id")))
+ system = serializers.CharField(required=False, allow_blank=True, allow_null=True,
+ error_messages=ErrMessage.char(_("Role Setting")))
+ prompt = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Prompt word")))
+ # 多轮对话数量
+ dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer(_("Number of multi-round conversations")))
+
+ dialogue_type = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Conversation storage type")))
+
+ is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean(_('Whether to return content')))
+
+ image_list = serializers.ListField(required=False, error_messages=ErrMessage.list(_("picture")))
+
+ model_params_setting = serializers.JSONField(required=False, default=dict,
+ error_messages=ErrMessage.json(_("Model parameter settings")))
+
+
+class IImageUnderstandNode(INode):
+ type = 'image-understand-node'
+
+ def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
+ return ImageUnderstandNodeSerializer
+
+ def _run(self):
+ res = self.workflow_manage.get_reference_field(self.node_params_serializer.data.get('image_list')[0],
+ self.node_params_serializer.data.get('image_list')[1:])
+ return self.execute(image=res, **self.node_params_serializer.data, **self.flow_params_serializer.data)
+
+ def execute(self, model_id, system, prompt, dialogue_number, dialogue_type, history_chat_record, stream, chat_id,
+ model_params_setting,
+ chat_record_id,
+ image,
+ **kwargs) -> NodeResult:
+ pass
diff --git a/apps/application/flow/step_node/image_understand_step_node/impl/__init__.py b/apps/application/flow/step_node/image_understand_step_node/impl/__init__.py
new file mode 100644
index 000000000..ba2512839
--- /dev/null
+++ b/apps/application/flow/step_node/image_understand_step_node/impl/__init__.py
@@ -0,0 +1,3 @@
+# coding=utf-8
+
+from .base_image_understand_node import BaseImageUnderstandNode
diff --git a/apps/application/flow/step_node/image_understand_step_node/impl/base_image_understand_node.py b/apps/application/flow/step_node/image_understand_step_node/impl/base_image_understand_node.py
new file mode 100644
index 000000000..44765bc4f
--- /dev/null
+++ b/apps/application/flow/step_node/image_understand_step_node/impl/base_image_understand_node.py
@@ -0,0 +1,224 @@
+# coding=utf-8
+import base64
+import os
+import time
+from functools import reduce
+from typing import List, Dict
+
+from django.db.models import QuerySet
+from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage, AIMessage
+
+from application.flow.i_step_node import NodeResult, INode
+from application.flow.step_node.image_understand_step_node.i_image_understand_node import IImageUnderstandNode
+from dataset.models import File
+from setting.models_provider.tools import get_model_instance_by_model_user_id
+from imghdr import what
+
+
+def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str):
+ chat_model = node_variable.get('chat_model')
+ message_tokens = node_variable['usage_metadata']['output_tokens'] if 'usage_metadata' in node_variable else 0
+ answer_tokens = chat_model.get_num_tokens(answer)
+ node.context['message_tokens'] = message_tokens
+ node.context['answer_tokens'] = answer_tokens
+ node.context['answer'] = answer
+ node.context['history_message'] = node_variable['history_message']
+ node.context['question'] = node_variable['question']
+ node.context['run_time'] = time.time() - node.context['start_time']
+ if workflow.is_result(node, NodeResult(node_variable, workflow_variable)):
+ node.answer_text = answer
+
+
+def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
+ """
+ 写入上下文数据 (流式)
+ @param node_variable: 节点数据
+ @param workflow_variable: 全局数据
+ @param node: 节点
+ @param workflow: 工作流管理器
+ """
+ response = node_variable.get('result')
+ answer = ''
+ for chunk in response:
+ answer += chunk.content
+ yield chunk.content
+ _write_context(node_variable, workflow_variable, node, workflow, answer)
+
+
+def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
+ """
+ 写入上下文数据
+ @param node_variable: 节点数据
+ @param workflow_variable: 全局数据
+ @param node: 节点实例对象
+ @param workflow: 工作流管理器
+ """
+ response = node_variable.get('result')
+ answer = response.content
+ _write_context(node_variable, workflow_variable, node, workflow, answer)
+
+
+def file_id_to_base64(file_id: str):
+ file = QuerySet(File).filter(id=file_id).first()
+ file_bytes = file.get_byte()
+ base64_image = base64.b64encode(file_bytes).decode("utf-8")
+ return [base64_image, what(None, file_bytes.tobytes())]
+
+
+class BaseImageUnderstandNode(IImageUnderstandNode):
+ def save_context(self, details, workflow_manage):
+ self.context['answer'] = details.get('answer')
+ self.context['question'] = details.get('question')
+ if self.node_params.get('is_result', False):
+ self.answer_text = details.get('answer')
+
+ def execute(self, model_id, system, prompt, dialogue_number, dialogue_type, history_chat_record, stream, chat_id,
+ model_params_setting,
+ chat_record_id,
+ image,
+ **kwargs) -> NodeResult:
+ # 处理不正确的参数
+ if image is None or not isinstance(image, list):
+ image = []
+ print(model_params_setting)
+ image_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'), **model_params_setting)
+ # 执行详情中的历史消息不需要图片内容
+ history_message = self.get_history_message_for_details(history_chat_record, dialogue_number)
+ self.context['history_message'] = history_message
+ question = self.generate_prompt_question(prompt)
+ self.context['question'] = question.content
+ # 生成消息列表, 真实的history_message
+ message_list = self.generate_message_list(image_model, system, prompt,
+ self.get_history_message(history_chat_record, dialogue_number), image)
+ self.context['message_list'] = message_list
+ self.context['image_list'] = image
+ self.context['dialogue_type'] = dialogue_type
+ if stream:
+ r = image_model.stream(message_list)
+ return NodeResult({'result': r, 'chat_model': image_model, 'message_list': message_list,
+ 'history_message': history_message, 'question': question.content}, {},
+ _write_context=write_context_stream)
+ else:
+ r = image_model.invoke(message_list)
+ return NodeResult({'result': r, 'chat_model': image_model, 'message_list': message_list,
+ 'history_message': history_message, 'question': question.content}, {},
+ _write_context=write_context)
+
+ def get_history_message_for_details(self, history_chat_record, dialogue_number):
+ start_index = len(history_chat_record) - dialogue_number
+ history_message = reduce(lambda x, y: [*x, *y], [
+ [self.generate_history_human_message_for_details(history_chat_record[index]),
+ self.generate_history_ai_message(history_chat_record[index])]
+ for index in
+ range(start_index if start_index > 0 else 0, len(history_chat_record))], [])
+ return history_message
+
+ def generate_history_ai_message(self, chat_record):
+ for val in chat_record.details.values():
+ if self.node.id == val['node_id'] and 'image_list' in val:
+ if val['dialogue_type'] == 'WORKFLOW':
+ return chat_record.get_ai_message()
+ return AIMessage(content=val['answer'])
+ return chat_record.get_ai_message()
+
+ def generate_history_human_message_for_details(self, chat_record):
+ for data in chat_record.details.values():
+ if self.node.id == data['node_id'] and 'image_list' in data:
+ image_list = data['image_list']
+ if len(image_list) == 0 or data['dialogue_type'] == 'WORKFLOW':
+ return HumanMessage(content=chat_record.problem_text)
+ file_id_list = [image.get('file_id') for image in image_list]
+ return HumanMessage(content=[
+ {'type': 'text', 'text': data['question']},
+ *[{'type': 'image_url', 'image_url': {'url': f'/api/file/{file_id}'}} for file_id in file_id_list]
+
+ ])
+ return HumanMessage(content=chat_record.problem_text)
+
+ def get_history_message(self, history_chat_record, dialogue_number):
+ start_index = len(history_chat_record) - dialogue_number
+ history_message = reduce(lambda x, y: [*x, *y], [
+ [self.generate_history_human_message(history_chat_record[index]),
+ self.generate_history_ai_message(history_chat_record[index])]
+ for index in
+ range(start_index if start_index > 0 else 0, len(history_chat_record))], [])
+ return history_message
+
+ def generate_history_human_message(self, chat_record):
+
+ for data in chat_record.details.values():
+ if self.node.id == data['node_id'] and 'image_list' in data:
+ image_list = data['image_list']
+ if len(image_list) == 0 or data['dialogue_type'] == 'WORKFLOW':
+ return HumanMessage(content=chat_record.problem_text)
+ image_base64_list = [file_id_to_base64(image.get('file_id')) for image in image_list]
+ return HumanMessage(
+ content=[
+ {'type': 'text', 'text': data['question']},
+ *[{'type': 'image_url', 'image_url': {'url': f'data:image/{base64_image[1]};base64,{base64_image[0]}'}} for
+ base64_image in image_base64_list]
+ ])
+ return HumanMessage(content=chat_record.problem_text)
+
+ def generate_prompt_question(self, prompt):
+ return HumanMessage(self.workflow_manage.generate_prompt(prompt))
+
+ def generate_message_list(self, image_model, system: str, prompt: str, history_message, image):
+ if image is not None and len(image) > 0:
+ # 处理多张图片
+ images = []
+ for img in image:
+ file_id = img['file_id']
+ file = QuerySet(File).filter(id=file_id).first()
+ image_bytes = file.get_byte()
+ base64_image = base64.b64encode(image_bytes).decode("utf-8")
+ image_format = what(None, image_bytes.tobytes())
+ images.append({'type': 'image_url', 'image_url': {'url': f'data:image/{image_format};base64,{base64_image}'}})
+ messages = [HumanMessage(
+ content=[
+ {'type': 'text', 'text': self.workflow_manage.generate_prompt(prompt)},
+ *images
+ ])]
+ else:
+ messages = [HumanMessage(self.workflow_manage.generate_prompt(prompt))]
+
+ if system is not None and len(system) > 0:
+ return [
+ SystemMessage(self.workflow_manage.generate_prompt(system)),
+ *history_message,
+ *messages
+ ]
+ else:
+ return [
+ *history_message,
+ *messages
+ ]
+
+ @staticmethod
+ def reset_message_list(message_list: List[BaseMessage], answer_text):
+ result = [{'role': 'user' if isinstance(message, HumanMessage) else 'ai', 'content': message.content} for
+ message
+ in
+ message_list]
+ result.append({'role': 'ai', 'content': answer_text})
+ return result
+
+ def get_details(self, index: int, **kwargs):
+ return {
+ 'name': self.node.properties.get('stepName'),
+ "index": index,
+ 'run_time': self.context.get('run_time'),
+ 'system': self.node_params.get('system'),
+ 'history_message': [{'content': message.content, 'role': message.type} for message in
+ (self.context.get('history_message') if self.context.get(
+ 'history_message') is not None else [])],
+ 'question': self.context.get('question'),
+ 'answer': self.context.get('answer'),
+ 'type': self.node.type,
+ 'message_tokens': self.context.get('message_tokens'),
+ 'answer_tokens': self.context.get('answer_tokens'),
+ 'status': self.status,
+ 'err_message': self.err_message,
+ 'image_list': self.context.get('image_list'),
+ 'dialogue_type': self.context.get('dialogue_type')
+ }
diff --git a/apps/application/flow/step_node/mcp_node/__init__.py b/apps/application/flow/step_node/mcp_node/__init__.py
new file mode 100644
index 000000000..f3feecc9c
--- /dev/null
+++ b/apps/application/flow/step_node/mcp_node/__init__.py
@@ -0,0 +1,3 @@
+# coding=utf-8
+
+from .impl import *
diff --git a/apps/application/flow/step_node/mcp_node/i_mcp_node.py b/apps/application/flow/step_node/mcp_node/i_mcp_node.py
new file mode 100644
index 000000000..94cb4da77
--- /dev/null
+++ b/apps/application/flow/step_node/mcp_node/i_mcp_node.py
@@ -0,0 +1,35 @@
+# coding=utf-8
+
+from typing import Type
+
+from rest_framework import serializers
+
+from application.flow.i_step_node import INode, NodeResult
+from common.util.field_message import ErrMessage
+from django.utils.translation import gettext_lazy as _
+
+
+class McpNodeSerializer(serializers.Serializer):
+ mcp_servers = serializers.JSONField(required=True,
+ error_messages=ErrMessage.char(_("Mcp servers")))
+
+ mcp_server = serializers.CharField(required=True,
+ error_messages=ErrMessage.char(_("Mcp server")))
+
+ mcp_tool = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Mcp tool")))
+
+ tool_params = serializers.DictField(required=True,
+ error_messages=ErrMessage.char(_("Tool parameters")))
+
+
+class IMcpNode(INode):
+ type = 'mcp-node'
+
+ def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
+ return McpNodeSerializer
+
+ def _run(self):
+ return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data)
+
+ def execute(self, mcp_servers, mcp_server, mcp_tool, tool_params, **kwargs) -> NodeResult:
+ pass
diff --git a/apps/application/flow/step_node/mcp_node/impl/__init__.py b/apps/application/flow/step_node/mcp_node/impl/__init__.py
new file mode 100644
index 000000000..8c9a5ee19
--- /dev/null
+++ b/apps/application/flow/step_node/mcp_node/impl/__init__.py
@@ -0,0 +1,3 @@
+# coding=utf-8
+
+from .base_mcp_node import BaseMcpNode
diff --git a/apps/application/flow/step_node/mcp_node/impl/base_mcp_node.py b/apps/application/flow/step_node/mcp_node/impl/base_mcp_node.py
new file mode 100644
index 000000000..e49ef7019
--- /dev/null
+++ b/apps/application/flow/step_node/mcp_node/impl/base_mcp_node.py
@@ -0,0 +1,61 @@
+# coding=utf-8
+import asyncio
+import json
+from typing import List
+
+from langchain_mcp_adapters.client import MultiServerMCPClient
+
+from application.flow.i_step_node import NodeResult
+from application.flow.step_node.mcp_node.i_mcp_node import IMcpNode
+
+
+class BaseMcpNode(IMcpNode):
+ def save_context(self, details, workflow_manage):
+ self.context['result'] = details.get('result')
+ self.context['tool_params'] = details.get('tool_params')
+ self.context['mcp_tool'] = details.get('mcp_tool')
+ if self.node_params.get('is_result', False):
+ self.answer_text = details.get('result')
+
+ def execute(self, mcp_servers, mcp_server, mcp_tool, tool_params, **kwargs) -> NodeResult:
+ servers = json.loads(mcp_servers)
+ params = json.loads(json.dumps(tool_params))
+ params = self.handle_variables(params)
+
+ async def call_tool(s, session, t, a):
+ async with MultiServerMCPClient(s) as client:
+ s = await client.sessions[session].call_tool(t, a)
+ return s
+
+ res = asyncio.run(call_tool(servers, mcp_server, mcp_tool, params))
+ return NodeResult(
+ {'result': [content.text for content in res.content], 'tool_params': params, 'mcp_tool': mcp_tool}, {})
+
+ def handle_variables(self, tool_params):
+ # 处理参数中的变量
+ for k, v in tool_params.items():
+ if type(v) == str:
+ tool_params[k] = self.workflow_manage.generate_prompt(tool_params[k])
+ if type(v) == dict:
+ self.handle_variables(v)
+ if (type(v) == list) and (type(v[0]) == str):
+ tool_params[k] = self.get_reference_content(v)
+ return tool_params
+
+ def get_reference_content(self, fields: List[str]):
+ return str(self.workflow_manage.get_reference_field(
+ fields[0],
+ fields[1:]))
+
+ def get_details(self, index: int, **kwargs):
+ return {
+ 'name': self.node.properties.get('stepName'),
+ "index": index,
+ 'run_time': self.context.get('run_time'),
+ 'status': self.status,
+ 'err_message': self.err_message,
+ 'type': self.node.type,
+ 'mcp_tool': self.context.get('mcp_tool'),
+ 'tool_params': self.context.get('tool_params'),
+ 'result': self.context.get('result'),
+ }
diff --git a/apps/application/flow/step_node/question_node/__init__.py b/apps/application/flow/step_node/question_node/__init__.py
new file mode 100644
index 000000000..98a1afcd9
--- /dev/null
+++ b/apps/application/flow/step_node/question_node/__init__.py
@@ -0,0 +1,9 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: __init__.py
+ @date:2024/6/11 15:30
+ @desc:
+"""
+from .impl import *
diff --git a/apps/application/flow/step_node/question_node/i_question_node.py b/apps/application/flow/step_node/question_node/i_question_node.py
new file mode 100644
index 000000000..57898bf22
--- /dev/null
+++ b/apps/application/flow/step_node/question_node/i_question_node.py
@@ -0,0 +1,42 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: i_chat_node.py
+ @date:2024/6/4 13:58
+ @desc:
+"""
+from typing import Type
+
+from rest_framework import serializers
+
+from application.flow.i_step_node import INode, NodeResult
+from common.util.field_message import ErrMessage
+from django.utils.translation import gettext_lazy as _
+
+
+class QuestionNodeSerializer(serializers.Serializer):
+ model_id = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Model id")))
+ system = serializers.CharField(required=False, allow_blank=True, allow_null=True,
+ error_messages=ErrMessage.char(_("Role Setting")))
+ prompt = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Prompt word")))
+ # 多轮对话数量
+ dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer(_("Number of multi-round conversations")))
+
+ is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean(_('Whether to return content')))
+ model_params_setting = serializers.DictField(required=False, error_messages=ErrMessage.integer(_("Model parameter settings")))
+
+
+class IQuestionNode(INode):
+ type = 'question-node'
+
+ def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
+ return QuestionNodeSerializer
+
+ def _run(self):
+ return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data)
+
+ def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id,
+ model_params_setting=None,
+ **kwargs) -> NodeResult:
+ pass
diff --git a/apps/application/flow/step_node/question_node/impl/__init__.py b/apps/application/flow/step_node/question_node/impl/__init__.py
new file mode 100644
index 000000000..d85aa8724
--- /dev/null
+++ b/apps/application/flow/step_node/question_node/impl/__init__.py
@@ -0,0 +1,9 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: __init__.py
+ @date:2024/6/11 15:35
+ @desc:
+"""
+from .base_question_node import BaseQuestionNode
diff --git a/apps/application/flow/step_node/question_node/impl/base_question_node.py b/apps/application/flow/step_node/question_node/impl/base_question_node.py
new file mode 100644
index 000000000..e1fd5b860
--- /dev/null
+++ b/apps/application/flow/step_node/question_node/impl/base_question_node.py
@@ -0,0 +1,159 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: base_question_node.py
+ @date:2024/6/4 14:30
+ @desc:
+"""
+import re
+import time
+from functools import reduce
+from typing import List, Dict
+
+from django.db.models import QuerySet
+from langchain.schema import HumanMessage, SystemMessage
+from langchain_core.messages import BaseMessage
+
+from application.flow.i_step_node import NodeResult, INode
+from application.flow.step_node.question_node.i_question_node import IQuestionNode
+from setting.models import Model
+from setting.models_provider import get_model_credential
+from setting.models_provider.tools import get_model_instance_by_model_user_id
+
+
+def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str):
+ chat_model = node_variable.get('chat_model')
+ message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list'))
+ answer_tokens = chat_model.get_num_tokens(answer)
+ node.context['message_tokens'] = message_tokens
+ node.context['answer_tokens'] = answer_tokens
+ node.context['answer'] = answer
+ node.context['history_message'] = node_variable['history_message']
+ node.context['question'] = node_variable['question']
+ node.context['run_time'] = time.time() - node.context['start_time']
+ if workflow.is_result(node, NodeResult(node_variable, workflow_variable)):
+ node.answer_text = answer
+
+
+def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
+ """
+ 写入上下文数据 (流式)
+ @param node_variable: 节点数据
+ @param workflow_variable: 全局数据
+ @param node: 节点
+ @param workflow: 工作流管理器
+ """
+ response = node_variable.get('result')
+ answer = ''
+ for chunk in response:
+ answer += chunk.content
+ yield chunk.content
+ _write_context(node_variable, workflow_variable, node, workflow, answer)
+
+
+def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
+ """
+ 写入上下文数据
+ @param node_variable: 节点数据
+ @param workflow_variable: 全局数据
+ @param node: 节点实例对象
+ @param workflow: 工作流管理器
+ """
+ response = node_variable.get('result')
+ answer = response.content
+ _write_context(node_variable, workflow_variable, node, workflow, answer)
+
+
+def get_default_model_params_setting(model_id):
+ model = QuerySet(Model).filter(id=model_id).first()
+ credential = get_model_credential(model.provider, model.model_type, model.model_name)
+ model_params_setting = credential.get_model_params_setting_form(
+ model.model_name).get_default_form_data()
+ return model_params_setting
+
+
+class BaseQuestionNode(IQuestionNode):
+ def save_context(self, details, workflow_manage):
+ self.context['run_time'] = details.get('run_time')
+ self.context['question'] = details.get('question')
+ self.context['answer'] = details.get('answer')
+ self.context['message_tokens'] = details.get('message_tokens')
+ self.context['answer_tokens'] = details.get('answer_tokens')
+ if self.node_params.get('is_result', False):
+ self.answer_text = details.get('answer')
+
+ def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id,
+ model_params_setting=None,
+ **kwargs) -> NodeResult:
+ if model_params_setting is None:
+ model_params_setting = get_default_model_params_setting(model_id)
+ chat_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'),
+ **model_params_setting)
+ history_message = self.get_history_message(history_chat_record, dialogue_number)
+ self.context['history_message'] = history_message
+ question = self.generate_prompt_question(prompt)
+ self.context['question'] = question.content
+ system = self.workflow_manage.generate_prompt(system)
+ self.context['system'] = system
+ message_list = self.generate_message_list(system, prompt, history_message)
+ self.context['message_list'] = message_list
+ if stream:
+ r = chat_model.stream(message_list)
+ return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list,
+ 'history_message': history_message, 'question': question.content}, {},
+ _write_context=write_context_stream)
+ else:
+ r = chat_model.invoke(message_list)
+ return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list,
+ 'history_message': history_message, 'question': question.content}, {},
+ _write_context=write_context)
+
+ @staticmethod
+ def get_history_message(history_chat_record, dialogue_number):
+ start_index = len(history_chat_record) - dialogue_number
+ history_message = reduce(lambda x, y: [*x, *y], [
+ [history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()]
+ for index in
+ range(start_index if start_index > 0 else 0, len(history_chat_record))], [])
+ for message in history_message:
+ if isinstance(message.content, str):
+ message.content = re.sub('[\d\D]*?<\/form_rander>', '', message.content)
+ return history_message
+
+ def generate_prompt_question(self, prompt):
+ return HumanMessage(self.workflow_manage.generate_prompt(prompt))
+
+ def generate_message_list(self, system: str, prompt: str, history_message):
+ if system is None or len(system) == 0:
+ return [SystemMessage(self.workflow_manage.generate_prompt(system)), *history_message,
+ HumanMessage(self.workflow_manage.generate_prompt(prompt))]
+ else:
+ return [*history_message, HumanMessage(self.workflow_manage.generate_prompt(prompt))]
+
+ @staticmethod
+ def reset_message_list(message_list: List[BaseMessage], answer_text):
+ result = [{'role': 'user' if isinstance(message, HumanMessage) else 'ai', 'content': message.content} for
+ message
+ in
+ message_list]
+ result.append({'role': 'ai', 'content': answer_text})
+ return result
+
+ def get_details(self, index: int, **kwargs):
+ return {
+ 'name': self.node.properties.get('stepName'),
+ "index": index,
+ 'run_time': self.context.get('run_time'),
+ 'system': self.context.get('system'),
+ 'history_message': [{'content': message.content, 'role': message.type} for message in
+ (self.context.get('history_message') if self.context.get(
+ 'history_message') is not None else [])],
+ 'question': self.context.get('question'),
+ 'answer': self.context.get('answer'),
+ 'type': self.node.type,
+ 'message_tokens': self.context.get('message_tokens'),
+ 'answer_tokens': self.context.get('answer_tokens'),
+ 'status': self.status,
+ 'err_message': self.err_message
+ }
diff --git a/apps/application/flow/step_node/reranker_node/__init__.py b/apps/application/flow/step_node/reranker_node/__init__.py
new file mode 100644
index 000000000..881d0f8a3
--- /dev/null
+++ b/apps/application/flow/step_node/reranker_node/__init__.py
@@ -0,0 +1,9 @@
+# coding=utf-8
+"""
+ @project: MaxKB
+ @Author:虎
+ @file: __init__.py
+ @date:2024/9/4 11:37
+ @desc:
+"""
+from .impl import *
diff --git a/apps/application/flow/step_node/reranker_node/i_reranker_node.py b/apps/application/flow/step_node/reranker_node/i_reranker_node.py
new file mode 100644
index 000000000..3b95e4dd6
--- /dev/null
+++ b/apps/application/flow/step_node/reranker_node/i_reranker_node.py
@@ -0,0 +1,60 @@
+# coding=utf-8
+"""
+ @project: MaxKB
+ @Author:虎
+ @file: i_reranker_node.py
+ @date:2024/9/4 10:40
+ @desc:
+"""
+from typing import Type
+
+from rest_framework import serializers
+
+from application.flow.i_step_node import INode, NodeResult
+from common.util.field_message import ErrMessage
+from django.utils.translation import gettext_lazy as _
+
+
+class RerankerSettingSerializer(serializers.Serializer):
+ # 需要查询的条数
+ top_n = serializers.IntegerField(required=True,
+ error_messages=ErrMessage.integer(_("Reference segment number")))
+ # 相似度 0-1之间
+ similarity = serializers.FloatField(required=True, max_value=2, min_value=0,
+ error_messages=ErrMessage.float(_("Reference segment number")))
+ max_paragraph_char_number = serializers.IntegerField(required=True,
+ error_messages=ErrMessage.float(_("Maximum number of words in a quoted segment")))
+
+
+class RerankerStepNodeSerializer(serializers.Serializer):
+ reranker_setting = RerankerSettingSerializer(required=True)
+
+ question_reference_address = serializers.ListField(required=True)
+ reranker_model_id = serializers.UUIDField(required=True)
+ reranker_reference_list = serializers.ListField(required=True, child=serializers.ListField(required=True))
+
+ def is_valid(self, *, raise_exception=False):
+ super().is_valid(raise_exception=True)
+
+
+class IRerankerNode(INode):
+ type = 'reranker-node'
+
+ def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
+ return RerankerStepNodeSerializer
+
+ def _run(self):
+ question = self.workflow_manage.get_reference_field(
+ self.node_params_serializer.data.get('question_reference_address')[0],
+ self.node_params_serializer.data.get('question_reference_address')[1:])
+ reranker_list = [self.workflow_manage.get_reference_field(
+ reference[0],
+ reference[1:]) for reference in
+ self.node_params_serializer.data.get('reranker_reference_list')]
+ return self.execute(**self.node_params_serializer.data, question=str(question),
+
+ reranker_list=reranker_list)
+
+ def execute(self, question, reranker_setting, reranker_list, reranker_model_id,
+ **kwargs) -> NodeResult:
+ pass
diff --git a/apps/application/flow/step_node/reranker_node/impl/__init__.py b/apps/application/flow/step_node/reranker_node/impl/__init__.py
new file mode 100644
index 000000000..ef5ca8058
--- /dev/null
+++ b/apps/application/flow/step_node/reranker_node/impl/__init__.py
@@ -0,0 +1,9 @@
+# coding=utf-8
+"""
+ @project: MaxKB
+ @Author:虎
+ @file: __init__.py
+ @date:2024/9/4 11:39
+ @desc:
+"""
+from .base_reranker_node import *
diff --git a/apps/application/flow/step_node/reranker_node/impl/base_reranker_node.py b/apps/application/flow/step_node/reranker_node/impl/base_reranker_node.py
new file mode 100644
index 000000000..ee92b88a5
--- /dev/null
+++ b/apps/application/flow/step_node/reranker_node/impl/base_reranker_node.py
@@ -0,0 +1,106 @@
+# coding=utf-8
+"""
+ @project: MaxKB
+ @Author:虎
+ @file: base_reranker_node.py
+ @date:2024/9/4 11:41
+ @desc:
+"""
+from typing import List
+
+from langchain_core.documents import Document
+
+from application.flow.i_step_node import NodeResult
+from application.flow.step_node.reranker_node.i_reranker_node import IRerankerNode
+from setting.models_provider.tools import get_model_instance_by_model_user_id
+
+
+def merge_reranker_list(reranker_list, result=None):
+ if result is None:
+ result = []
+ for document in reranker_list:
+ if isinstance(document, list):
+ merge_reranker_list(document, result)
+ elif isinstance(document, dict):
+ content = document.get('title', '') + document.get('content', '')
+ title = document.get("title")
+ dataset_name = document.get("dataset_name")
+ document_name = document.get('document_name')
+ result.append(
+ Document(page_content=str(document) if len(content) == 0 else content,
+ metadata={'title': title, 'dataset_name': dataset_name, 'document_name': document_name}))
+ else:
+ result.append(Document(page_content=str(document), metadata={}))
+ return result
+
+
+def filter_result(document_list: List[Document], max_paragraph_char_number, top_n, similarity):
+ use_len = 0
+ result = []
+ for index in range(len(document_list)):
+ document = document_list[index]
+ if use_len >= max_paragraph_char_number or index >= top_n or document.metadata.get(
+ 'relevance_score') < similarity:
+ break
+ content = document.page_content[0:max_paragraph_char_number - use_len]
+ use_len = use_len + len(content)
+ result.append({'page_content': content, 'metadata': document.metadata})
+ return result
+
+
+def reset_result_list(result_list: List[Document], document_list: List[Document]):
+ r = []
+ document_list = document_list.copy()
+ for result in result_list:
+ filter_result_list = [document for document in document_list if document.page_content == result.page_content]
+ if len(filter_result_list) > 0:
+ item = filter_result_list[0]
+ document_list.remove(item)
+ r.append(Document(page_content=item.page_content,
+ metadata={**item.metadata, 'relevance_score': result.metadata.get('relevance_score')}))
+ else:
+ r.append(result)
+ return r
+
+
+class BaseRerankerNode(IRerankerNode):
+ def save_context(self, details, workflow_manage):
+ self.context['document_list'] = details.get('document_list', [])
+ self.context['question'] = details.get('question')
+ self.context['run_time'] = details.get('run_time')
+ self.context['result_list'] = details.get('result_list')
+ self.context['result'] = details.get('result')
+
+ def execute(self, question, reranker_setting, reranker_list, reranker_model_id,
+ **kwargs) -> NodeResult:
+ documents = merge_reranker_list(reranker_list)
+ top_n = reranker_setting.get('top_n', 3)
+ self.context['document_list'] = [{'page_content': document.page_content, 'metadata': document.metadata} for
+ document in documents]
+ self.context['question'] = question
+ reranker_model = get_model_instance_by_model_user_id(reranker_model_id,
+ self.flow_params_serializer.data.get('user_id'),
+ top_n=top_n)
+ result = reranker_model.compress_documents(
+ documents,
+ question)
+ similarity = reranker_setting.get('similarity', 0.6)
+ max_paragraph_char_number = reranker_setting.get('max_paragraph_char_number', 5000)
+ result = reset_result_list(result, documents)
+ r = filter_result(result, max_paragraph_char_number, top_n, similarity)
+ return NodeResult({'result_list': r, 'result': ''.join([item.get('page_content') for item in r])}, {})
+
+ def get_details(self, index: int, **kwargs):
+ return {
+ 'name': self.node.properties.get('stepName'),
+ "index": index,
+ 'document_list': self.context.get('document_list'),
+ "question": self.context.get('question'),
+ 'run_time': self.context.get('run_time'),
+ 'type': self.node.type,
+ 'reranker_setting': self.node_params_serializer.data.get('reranker_setting'),
+ 'result_list': self.context.get('result_list'),
+ 'result': self.context.get('result'),
+ 'status': self.status,
+ 'err_message': self.err_message
+ }
diff --git a/apps/application/flow/step_node/search_dataset_node/__init__.py b/apps/application/flow/step_node/search_dataset_node/__init__.py
new file mode 100644
index 000000000..98a1afcd9
--- /dev/null
+++ b/apps/application/flow/step_node/search_dataset_node/__init__.py
@@ -0,0 +1,9 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: __init__.py
+ @date:2024/6/11 15:30
+ @desc:
+"""
+from .impl import *
diff --git a/apps/application/flow/step_node/search_dataset_node/i_search_dataset_node.py b/apps/application/flow/step_node/search_dataset_node/i_search_dataset_node.py
new file mode 100644
index 000000000..8f15c7a32
--- /dev/null
+++ b/apps/application/flow/step_node/search_dataset_node/i_search_dataset_node.py
@@ -0,0 +1,79 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: i_search_dataset_node.py
+ @date:2024/6/3 17:52
+ @desc:
+"""
+import re
+from typing import Type
+
+from django.core import validators
+from rest_framework import serializers
+
+from application.flow.i_step_node import INode, NodeResult
+from common.util.common import flat_map
+from common.util.field_message import ErrMessage
+from django.utils.translation import gettext_lazy as _
+
+
+class DatasetSettingSerializer(serializers.Serializer):
+ # 需要查询的条数
+ top_n = serializers.IntegerField(required=True,
+ error_messages=ErrMessage.integer(_("Reference segment number")))
+ # 相似度 0-1之间
+ similarity = serializers.FloatField(required=True, max_value=2, min_value=0,
+ error_messages=ErrMessage.float(_('similarity')))
+ search_mode = serializers.CharField(required=True, validators=[
+ validators.RegexValidator(regex=re.compile("^embedding|keywords|blend$"),
+ message=_("The type only supports embedding|keywords|blend"), code=500)
+ ], error_messages=ErrMessage.char(_("Retrieval Mode")))
+ max_paragraph_char_number = serializers.IntegerField(required=True,
+ error_messages=ErrMessage.float(_("Maximum number of words in a quoted segment")))
+
+
+class SearchDatasetStepNodeSerializer(serializers.Serializer):
+ # 需要查询的数据集id列表
+ dataset_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True),
+ error_messages=ErrMessage.list(_("Dataset id list")))
+ dataset_setting = DatasetSettingSerializer(required=True)
+
+ question_reference_address = serializers.ListField(required=True)
+
+ def is_valid(self, *, raise_exception=False):
+ super().is_valid(raise_exception=True)
+
+
+def get_paragraph_list(chat_record, node_id):
+ return flat_map([chat_record.details[key].get('paragraph_list', []) for key in chat_record.details if
+ (chat_record.details[
+ key].get('type', '') == 'search-dataset-node') and chat_record.details[key].get(
+ 'paragraph_list', []) is not None and key == node_id])
+
+
+class ISearchDatasetStepNode(INode):
+ type = 'search-dataset-node'
+
+ def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
+ return SearchDatasetStepNodeSerializer
+
+ def _run(self):
+ question = self.workflow_manage.get_reference_field(
+ self.node_params_serializer.data.get('question_reference_address')[0],
+ self.node_params_serializer.data.get('question_reference_address')[1:])
+ exclude_paragraph_id_list = []
+ if self.flow_params_serializer.data.get('re_chat', False):
+ history_chat_record = self.flow_params_serializer.data.get('history_chat_record', [])
+ paragraph_id_list = [p.get('id') for p in flat_map(
+ [get_paragraph_list(chat_record, self.runtime_node_id) for chat_record in history_chat_record if
+ chat_record.problem_text == question])]
+ exclude_paragraph_id_list = list(set(paragraph_id_list))
+
+ return self.execute(**self.node_params_serializer.data, question=str(question),
+ exclude_paragraph_id_list=exclude_paragraph_id_list)
+
+ def execute(self, dataset_id_list, dataset_setting, question,
+ exclude_paragraph_id_list=None,
+ **kwargs) -> NodeResult:
+ pass
diff --git a/apps/application/flow/step_node/search_dataset_node/impl/__init__.py b/apps/application/flow/step_node/search_dataset_node/impl/__init__.py
new file mode 100644
index 000000000..a9cff0d09
--- /dev/null
+++ b/apps/application/flow/step_node/search_dataset_node/impl/__init__.py
@@ -0,0 +1,9 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: __init__.py
+ @date:2024/6/11 15:35
+ @desc:
+"""
+from .base_search_dataset_node import BaseSearchDatasetNode
diff --git a/apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py b/apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py
new file mode 100644
index 000000000..5107d4ce2
--- /dev/null
+++ b/apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py
@@ -0,0 +1,146 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: base_search_dataset_node.py
+ @date:2024/6/4 11:56
+ @desc:
+"""
+import os
+from typing import List, Dict
+
+from django.db.models import QuerySet
+from django.db import connection
+from application.flow.i_step_node import NodeResult
+from application.flow.step_node.search_dataset_node.i_search_dataset_node import ISearchDatasetStepNode
+from common.config.embedding_config import VectorStore
+from common.db.search import native_search
+from common.util.file_util import get_file_content
+from dataset.models import Document, Paragraph, DataSet
+from embedding.models import SearchMode
+from setting.models_provider.tools import get_model_instance_by_model_user_id
+from smartdoc.conf import PROJECT_DIR
+
+
+def get_embedding_id(dataset_id_list):
+ dataset_list = QuerySet(DataSet).filter(id__in=dataset_id_list)
+ if len(set([dataset.embedding_mode_id for dataset in dataset_list])) > 1:
+ raise Exception("关联知识库的向量模型不一致,无法召回分段。")
+ if len(dataset_list) == 0:
+ raise Exception("知识库设置错误,请重新设置知识库")
+ return dataset_list[0].embedding_mode_id
+
+
+def get_none_result(question):
+ return NodeResult(
+ {'paragraph_list': [], 'is_hit_handling_method': [], 'question': question, 'data': '',
+ 'directly_return': ''}, {})
+
+
+def reset_title(title):
+ if title is None or len(title.strip()) == 0:
+ return ""
+ else:
+ return f"#### {title}\n"
+
+
+class BaseSearchDatasetNode(ISearchDatasetStepNode):
+ def save_context(self, details, workflow_manage):
+ result = details.get('paragraph_list', [])
+ dataset_setting = self.node_params_serializer.data.get('dataset_setting')
+ directly_return = '\n'.join(
+ [f"{paragraph.get('title', '')}:{paragraph.get('content')}" for paragraph in result if
+ paragraph.get('is_hit_handling_method')])
+ self.context['paragraph_list'] = result
+ self.context['question'] = details.get('question')
+ self.context['run_time'] = details.get('run_time')
+ self.context['is_hit_handling_method_list'] = [row for row in result if row.get('is_hit_handling_method')]
+ self.context['data'] = '\n'.join(
+ [f"{paragraph.get('title', '')}:{paragraph.get('content')}" for paragraph in
+ result])[0:dataset_setting.get('max_paragraph_char_number', 5000)]
+ self.context['directly_return'] = directly_return
+
+ def execute(self, dataset_id_list, dataset_setting, question,
+ exclude_paragraph_id_list=None,
+ **kwargs) -> NodeResult:
+ self.context['question'] = question
+ if len(dataset_id_list) == 0:
+ return get_none_result(question)
+ model_id = get_embedding_id(dataset_id_list)
+ embedding_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'))
+ embedding_value = embedding_model.embed_query(question)
+ vector = VectorStore.get_embedding_vector()
+ exclude_document_id_list = [str(document.id) for document in
+ QuerySet(Document).filter(
+ dataset_id__in=dataset_id_list,
+ is_active=False)]
+ embedding_list = vector.query(question, embedding_value, dataset_id_list, exclude_document_id_list,
+ exclude_paragraph_id_list, True, dataset_setting.get('top_n'),
+ dataset_setting.get('similarity'), SearchMode(dataset_setting.get('search_mode')))
+ # 手动关闭数据库连接
+ connection.close()
+ if embedding_list is None:
+ return get_none_result(question)
+ paragraph_list = self.list_paragraph(embedding_list, vector)
+ result = [self.reset_paragraph(paragraph, embedding_list) for paragraph in paragraph_list]
+ result = sorted(result, key=lambda p: p.get('similarity'), reverse=True)
+ return NodeResult({'paragraph_list': result,
+ 'is_hit_handling_method_list': [row for row in result if row.get('is_hit_handling_method')],
+ 'data': '\n'.join(
+ [f"{reset_title(paragraph.get('title', ''))}{paragraph.get('content')}" for paragraph in
+ result])[0:dataset_setting.get('max_paragraph_char_number', 5000)],
+ 'directly_return': '\n'.join(
+ [paragraph.get('content') for paragraph in
+ result if
+ paragraph.get('is_hit_handling_method')]),
+ 'question': question},
+
+ {})
+
+ @staticmethod
+ def reset_paragraph(paragraph: Dict, embedding_list: List):
+ filter_embedding_list = [embedding for embedding in embedding_list if
+ str(embedding.get('paragraph_id')) == str(paragraph.get('id'))]
+ if filter_embedding_list is not None and len(filter_embedding_list) > 0:
+ find_embedding = filter_embedding_list[-1]
+ return {
+ **paragraph,
+ 'similarity': find_embedding.get('similarity'),
+ 'is_hit_handling_method': find_embedding.get('similarity') > paragraph.get(
+ 'directly_return_similarity') and paragraph.get('hit_handling_method') == 'directly_return',
+ 'update_time': paragraph.get('update_time').strftime("%Y-%m-%d %H:%M:%S"),
+ 'create_time': paragraph.get('create_time').strftime("%Y-%m-%d %H:%M:%S"),
+ 'id': str(paragraph.get('id')),
+ 'dataset_id': str(paragraph.get('dataset_id')),
+ 'document_id': str(paragraph.get('document_id'))
+ }
+
+ @staticmethod
+ def list_paragraph(embedding_list: List, vector):
+ paragraph_id_list = [row.get('paragraph_id') for row in embedding_list]
+ if paragraph_id_list is None or len(paragraph_id_list) == 0:
+ return []
+ paragraph_list = native_search(QuerySet(Paragraph).filter(id__in=paragraph_id_list),
+ get_file_content(
+ os.path.join(PROJECT_DIR, "apps", "application", 'sql',
+ 'list_dataset_paragraph_by_paragraph_id.sql')),
+ with_table_name=True)
+ # 如果向量库中存在脏数据 直接删除
+ if len(paragraph_list) != len(paragraph_id_list):
+ exist_paragraph_list = [row.get('id') for row in paragraph_list]
+ for paragraph_id in paragraph_id_list:
+ if not exist_paragraph_list.__contains__(paragraph_id):
+ vector.delete_by_paragraph_id(paragraph_id)
+ return paragraph_list
+
+ def get_details(self, index: int, **kwargs):
+ return {
+ 'name': self.node.properties.get('stepName'),
+ 'question': self.context.get('question'),
+ "index": index,
+ 'run_time': self.context.get('run_time'),
+ 'paragraph_list': self.context.get('paragraph_list'),
+ 'type': self.node.type,
+ 'status': self.status,
+ 'err_message': self.err_message
+ }
diff --git a/apps/application/flow/step_node/speech_to_text_step_node/__init__.py b/apps/application/flow/step_node/speech_to_text_step_node/__init__.py
new file mode 100644
index 000000000..f3feecc9c
--- /dev/null
+++ b/apps/application/flow/step_node/speech_to_text_step_node/__init__.py
@@ -0,0 +1,3 @@
+# coding=utf-8
+
+from .impl import *
diff --git a/apps/application/flow/step_node/speech_to_text_step_node/i_speech_to_text_node.py b/apps/application/flow/step_node/speech_to_text_step_node/i_speech_to_text_node.py
new file mode 100644
index 000000000..154762dca
--- /dev/null
+++ b/apps/application/flow/step_node/speech_to_text_step_node/i_speech_to_text_node.py
@@ -0,0 +1,38 @@
+# coding=utf-8
+
+from typing import Type
+
+from rest_framework import serializers
+
+from application.flow.i_step_node import INode, NodeResult
+from common.util.field_message import ErrMessage
+from django.utils.translation import gettext_lazy as _
+
+
+class SpeechToTextNodeSerializer(serializers.Serializer):
+ stt_model_id = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Model id")))
+
+ is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean(_('Whether to return content')))
+
+ audio_list = serializers.ListField(required=True, error_messages=ErrMessage.list(_("The audio file cannot be empty")))
+
+
+class ISpeechToTextNode(INode):
+ type = 'speech-to-text-node'
+
+ def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
+ return SpeechToTextNodeSerializer
+
+ def _run(self):
+ res = self.workflow_manage.get_reference_field(self.node_params_serializer.data.get('audio_list')[0],
+ self.node_params_serializer.data.get('audio_list')[1:])
+ for audio in res:
+ if 'file_id' not in audio:
+ raise ValueError(_("Parameter value error: The uploaded audio lacks file_id, and the audio upload fails"))
+
+ return self.execute(audio=res, **self.node_params_serializer.data, **self.flow_params_serializer.data)
+
+ def execute(self, stt_model_id, chat_id,
+ audio,
+ **kwargs) -> NodeResult:
+ pass
diff --git a/apps/application/flow/step_node/speech_to_text_step_node/impl/__init__.py b/apps/application/flow/step_node/speech_to_text_step_node/impl/__init__.py
new file mode 100644
index 000000000..9d2da6158
--- /dev/null
+++ b/apps/application/flow/step_node/speech_to_text_step_node/impl/__init__.py
@@ -0,0 +1,3 @@
+# coding=utf-8
+
+from .base_speech_to_text_node import BaseSpeechToTextNode
diff --git a/apps/application/flow/step_node/speech_to_text_step_node/impl/base_speech_to_text_node.py b/apps/application/flow/step_node/speech_to_text_step_node/impl/base_speech_to_text_node.py
new file mode 100644
index 000000000..13b954e46
--- /dev/null
+++ b/apps/application/flow/step_node/speech_to_text_step_node/impl/base_speech_to_text_node.py
@@ -0,0 +1,72 @@
+# coding=utf-8
+import os
+import tempfile
+import time
+import io
+from typing import List, Dict
+
+from django.db.models import QuerySet
+from pydub import AudioSegment
+from concurrent.futures import ThreadPoolExecutor
+from application.flow.i_step_node import NodeResult, INode
+from application.flow.step_node.speech_to_text_step_node.i_speech_to_text_node import ISpeechToTextNode
+from common.util.common import split_and_transcribe, any_to_mp3
+from dataset.models import File
+from setting.models_provider.tools import get_model_instance_by_model_user_id
+
+class BaseSpeechToTextNode(ISpeechToTextNode):
+
+ def save_context(self, details, workflow_manage):
+ self.context['answer'] = details.get('answer')
+ if self.node_params.get('is_result', False):
+ self.answer_text = details.get('answer')
+
+ def execute(self, stt_model_id, chat_id, audio, **kwargs) -> NodeResult:
+ stt_model = get_model_instance_by_model_user_id(stt_model_id, self.flow_params_serializer.data.get('user_id'))
+ audio_list = audio
+ self.context['audio_list'] = audio
+
+ def process_audio_item(audio_item, model):
+ file = QuerySet(File).filter(id=audio_item['file_id']).first()
+ # 根据file_name 吧文件转成mp3格式
+ file_format = file.file_name.split('.')[-1]
+ with tempfile.NamedTemporaryFile(delete=False, suffix=f'.{file_format}') as temp_file:
+ temp_file.write(file.get_byte().tobytes())
+ temp_file_path = temp_file.name
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.mp3') as temp_amr_file:
+ temp_mp3_path = temp_amr_file.name
+ any_to_mp3(temp_file_path, temp_mp3_path)
+ try:
+ transcription = split_and_transcribe(temp_mp3_path, model)
+ return {file.file_name: transcription}
+ finally:
+ os.remove(temp_file_path)
+ os.remove(temp_mp3_path)
+
+ def process_audio_items(audio_list, model):
+ with ThreadPoolExecutor(max_workers=5) as executor:
+ results = list(executor.map(lambda item: process_audio_item(item, model), audio_list))
+ return results
+
+ result = process_audio_items(audio_list, stt_model)
+ content = []
+ result_content = []
+ for item in result:
+ for key, value in item.items():
+ content.append(f'### {key}\n{value}')
+ result_content.append(value)
+ return NodeResult({'answer': '\n'.join(result_content), 'result': '\n'.join(result_content),
+ 'content': content}, {})
+
+ def get_details(self, index: int, **kwargs):
+ return {
+ 'name': self.node.properties.get('stepName'),
+ "index": index,
+ 'run_time': self.context.get('run_time'),
+ 'answer': self.context.get('answer'),
+ 'content': self.context.get('content'),
+ 'type': self.node.type,
+ 'status': self.status,
+ 'err_message': self.err_message,
+ 'audio_list': self.context.get('audio_list'),
+ }
diff --git a/apps/application/flow/step_node/start_node/__init__.py b/apps/application/flow/step_node/start_node/__init__.py
new file mode 100644
index 000000000..98a1afcd9
--- /dev/null
+++ b/apps/application/flow/step_node/start_node/__init__.py
@@ -0,0 +1,9 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: __init__.py
+ @date:2024/6/11 15:30
+ @desc:
+"""
+from .impl import *
diff --git a/apps/application/flow/step_node/start_node/i_start_node.py b/apps/application/flow/step_node/start_node/i_start_node.py
new file mode 100644
index 000000000..41d73f218
--- /dev/null
+++ b/apps/application/flow/step_node/start_node/i_start_node.py
@@ -0,0 +1,20 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: i_start_node.py
+ @date:2024/6/3 16:54
+ @desc:
+"""
+
+from application.flow.i_step_node import INode, NodeResult
+
+
+class IStarNode(INode):
+ type = 'start-node'
+
+ def _run(self):
+ return self.execute(**self.flow_params_serializer.data)
+
+ def execute(self, question, **kwargs) -> NodeResult:
+ pass
diff --git a/apps/application/flow/step_node/start_node/impl/__init__.py b/apps/application/flow/step_node/start_node/impl/__init__.py
new file mode 100644
index 000000000..b68a92d02
--- /dev/null
+++ b/apps/application/flow/step_node/start_node/impl/__init__.py
@@ -0,0 +1,9 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: __init__.py
+ @date:2024/6/11 15:36
+ @desc:
+"""
+from .base_start_node import BaseStartStepNode
diff --git a/apps/application/flow/step_node/start_node/impl/base_start_node.py b/apps/application/flow/step_node/start_node/impl/base_start_node.py
new file mode 100644
index 000000000..24b968471
--- /dev/null
+++ b/apps/application/flow/step_node/start_node/impl/base_start_node.py
@@ -0,0 +1,92 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: base_start_node.py
+ @date:2024/6/3 17:17
+ @desc:
+"""
+import time
+from datetime import datetime
+from typing import List, Type
+
+from rest_framework import serializers
+
+from application.flow.i_step_node import NodeResult
+from application.flow.step_node.start_node.i_start_node import IStarNode
+
+
+def get_default_global_variable(input_field_list: List):
+ return {item.get('variable'): item.get('default_value') for item in input_field_list if
+ item.get('default_value', None) is not None}
+
+
+def get_global_variable(node):
+ history_chat_record = node.flow_params_serializer.data.get('history_chat_record', [])
+ history_context = [{'question': chat_record.problem_text, 'answer': chat_record.answer_text} for chat_record in
+ history_chat_record]
+ chat_id = node.flow_params_serializer.data.get('chat_id')
+ return {'time': datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 'start_time': time.time(),
+ 'history_context': history_context, 'chat_id': str(chat_id), **node.workflow_manage.form_data}
+
+
+class BaseStartStepNode(IStarNode):
+ def save_context(self, details, workflow_manage):
+ base_node = self.workflow_manage.get_base_node()
+ default_global_variable = get_default_global_variable(base_node.properties.get('input_field_list', []))
+ workflow_variable = {**default_global_variable, **get_global_variable(self)}
+ self.context['question'] = details.get('question')
+ self.context['run_time'] = details.get('run_time')
+ self.context['document'] = details.get('document_list')
+ self.context['image'] = details.get('image_list')
+ self.context['audio'] = details.get('audio_list')
+ self.context['other'] = details.get('other_list')
+ self.status = details.get('status')
+ self.err_message = details.get('err_message')
+ for key, value in workflow_variable.items():
+ workflow_manage.context[key] = value
+ for item in details.get('global_fields', []):
+ workflow_manage.context[item.get('key')] = item.get('value')
+
+ def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
+ pass
+
+ def execute(self, question, **kwargs) -> NodeResult:
+ base_node = self.workflow_manage.get_base_node()
+ default_global_variable = get_default_global_variable(base_node.properties.get('input_field_list', []))
+ workflow_variable = {**default_global_variable, **get_global_variable(self)}
+ """
+ 开始节点 初始化全局变量
+ """
+ node_variable = {
+ 'question': question,
+ 'image': self.workflow_manage.image_list,
+ 'document': self.workflow_manage.document_list,
+ 'audio': self.workflow_manage.audio_list,
+ 'other': self.workflow_manage.other_list,
+ }
+ return NodeResult(node_variable, workflow_variable)
+
+ def get_details(self, index: int, **kwargs):
+ global_fields = []
+ for field in self.node.properties.get('config')['globalFields']:
+ key = field['value']
+ global_fields.append({
+ 'label': field['label'],
+ 'key': key,
+ 'value': self.workflow_manage.context[key] if key in self.workflow_manage.context else ''
+ })
+ return {
+ 'name': self.node.properties.get('stepName'),
+ "index": index,
+ "question": self.context.get('question'),
+ 'run_time': self.context.get('run_time'),
+ 'type': self.node.type,
+ 'status': self.status,
+ 'err_message': self.err_message,
+ 'image_list': self.context.get('image'),
+ 'document_list': self.context.get('document'),
+ 'audio_list': self.context.get('audio'),
+ 'other_list': self.context.get('other'),
+ 'global_fields': global_fields
+ }
diff --git a/apps/application/flow/step_node/text_to_speech_step_node/__init__.py b/apps/application/flow/step_node/text_to_speech_step_node/__init__.py
new file mode 100644
index 000000000..f3feecc9c
--- /dev/null
+++ b/apps/application/flow/step_node/text_to_speech_step_node/__init__.py
@@ -0,0 +1,3 @@
+# coding=utf-8
+
+from .impl import *
diff --git a/apps/application/flow/step_node/text_to_speech_step_node/i_text_to_speech_node.py b/apps/application/flow/step_node/text_to_speech_step_node/i_text_to_speech_node.py
new file mode 100644
index 000000000..68b53ea92
--- /dev/null
+++ b/apps/application/flow/step_node/text_to_speech_step_node/i_text_to_speech_node.py
@@ -0,0 +1,36 @@
+# coding=utf-8
+
+from typing import Type
+
+from rest_framework import serializers
+
+from application.flow.i_step_node import INode, NodeResult
+from common.util.field_message import ErrMessage
+from django.utils.translation import gettext_lazy as _
+
+
+class TextToSpeechNodeSerializer(serializers.Serializer):
+ tts_model_id = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Model id")))
+
+ is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean(_('Whether to return content')))
+
+ content_list = serializers.ListField(required=True, error_messages=ErrMessage.list(_("Text content")))
+ model_params_setting = serializers.DictField(required=False,
+ error_messages=ErrMessage.integer(_("Model parameter settings")))
+
+
+class ITextToSpeechNode(INode):
+ type = 'text-to-speech-node'
+
+ def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
+ return TextToSpeechNodeSerializer
+
+ def _run(self):
+ content = self.workflow_manage.get_reference_field(self.node_params_serializer.data.get('content_list')[0],
+ self.node_params_serializer.data.get('content_list')[1:])
+ return self.execute(content=content, **self.node_params_serializer.data, **self.flow_params_serializer.data)
+
+ def execute(self, tts_model_id, chat_id,
+ content, model_params_setting=None,
+ **kwargs) -> NodeResult:
+ pass
diff --git a/apps/application/flow/step_node/text_to_speech_step_node/impl/__init__.py b/apps/application/flow/step_node/text_to_speech_step_node/impl/__init__.py
new file mode 100644
index 000000000..385b9718f
--- /dev/null
+++ b/apps/application/flow/step_node/text_to_speech_step_node/impl/__init__.py
@@ -0,0 +1,3 @@
+# coding=utf-8
+
+from .base_text_to_speech_node import BaseTextToSpeechNode
diff --git a/apps/application/flow/step_node/text_to_speech_step_node/impl/base_text_to_speech_node.py b/apps/application/flow/step_node/text_to_speech_step_node/impl/base_text_to_speech_node.py
new file mode 100644
index 000000000..970447295
--- /dev/null
+++ b/apps/application/flow/step_node/text_to_speech_step_node/impl/base_text_to_speech_node.py
@@ -0,0 +1,76 @@
+# coding=utf-8
+import io
+import mimetypes
+
+from django.core.files.uploadedfile import InMemoryUploadedFile
+
+from application.flow.i_step_node import NodeResult, INode
+from application.flow.step_node.image_understand_step_node.i_image_understand_node import IImageUnderstandNode
+from application.flow.step_node.text_to_speech_step_node.i_text_to_speech_node import ITextToSpeechNode
+from dataset.models import File
+from dataset.serializers.file_serializers import FileSerializer
+from setting.models_provider.tools import get_model_instance_by_model_user_id
+
+
+def bytes_to_uploaded_file(file_bytes, file_name="generated_audio.mp3"):
+ content_type, _ = mimetypes.guess_type(file_name)
+ if content_type is None:
+ # 如果未能识别,设置为默认的二进制文件类型
+ content_type = "application/octet-stream"
+ # 创建一个内存中的字节流对象
+ file_stream = io.BytesIO(file_bytes)
+
+ # 获取文件大小
+ file_size = len(file_bytes)
+
+ uploaded_file = InMemoryUploadedFile(
+ file=file_stream,
+ field_name=None,
+ name=file_name,
+ content_type=content_type,
+ size=file_size,
+ charset=None,
+ )
+ return uploaded_file
+
+
+class BaseTextToSpeechNode(ITextToSpeechNode):
+ def save_context(self, details, workflow_manage):
+ self.context['answer'] = details.get('answer')
+ if self.node_params.get('is_result', False):
+ self.answer_text = details.get('answer')
+
+ def execute(self, tts_model_id, chat_id,
+ content, model_params_setting=None,
+ **kwargs) -> NodeResult:
+ self.context['content'] = content
+ model = get_model_instance_by_model_user_id(tts_model_id, self.flow_params_serializer.data.get('user_id'),
+ **model_params_setting)
+ audio_byte = model.text_to_speech(content)
+ # 需要把这个音频文件存储到数据库中
+ file_name = 'generated_audio.mp3'
+ file = bytes_to_uploaded_file(audio_byte, file_name)
+ application = self.workflow_manage.work_flow_post_handler.chat_info.application
+ meta = {
+ 'debug': False if application.id else True,
+ 'chat_id': chat_id,
+ 'application_id': str(application.id) if application.id else None,
+ }
+ file_url = FileSerializer(data={'file': file, 'meta': meta}).upload()
+ # 拼接一个audio标签的src属性
+ audio_label = f''
+ file_id = file_url.split('/')[-1]
+ audio_list = [{'file_id': file_id, 'file_name': file_name, 'url': file_url}]
+ return NodeResult({'answer': audio_label, 'result': audio_list}, {})
+
+ def get_details(self, index: int, **kwargs):
+ return {
+ 'name': self.node.properties.get('stepName'),
+ "index": index,
+ 'run_time': self.context.get('run_time'),
+ 'type': self.node.type,
+ 'status': self.status,
+ 'content': self.context.get('content'),
+ 'err_message': self.err_message,
+ 'answer': self.context.get('answer'),
+ }
diff --git a/apps/application/flow/step_node/variable_assign_node/__init__.py b/apps/application/flow/step_node/variable_assign_node/__init__.py
new file mode 100644
index 000000000..2d231e606
--- /dev/null
+++ b/apps/application/flow/step_node/variable_assign_node/__init__.py
@@ -0,0 +1,3 @@
+# coding=utf-8
+
+from .impl import *
\ No newline at end of file
diff --git a/apps/application/flow/step_node/variable_assign_node/i_variable_assign_node.py b/apps/application/flow/step_node/variable_assign_node/i_variable_assign_node.py
new file mode 100644
index 000000000..e4594183f
--- /dev/null
+++ b/apps/application/flow/step_node/variable_assign_node/i_variable_assign_node.py
@@ -0,0 +1,27 @@
+# coding=utf-8
+
+from typing import Type
+
+from django.utils.translation import gettext_lazy as _
+from rest_framework import serializers
+
+from application.flow.i_step_node import INode, NodeResult
+from common.util.field_message import ErrMessage
+
+
+class VariableAssignNodeParamsSerializer(serializers.Serializer):
+ variable_list = serializers.ListField(required=True,
+ error_messages=ErrMessage.list(_("Reference Field")))
+
+
+class IVariableAssignNode(INode):
+ type = 'variable-assign-node'
+
+ def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
+ return VariableAssignNodeParamsSerializer
+
+ def _run(self):
+ return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data)
+
+ def execute(self, variable_list, stream, **kwargs) -> NodeResult:
+ pass
diff --git a/apps/application/flow/step_node/variable_assign_node/impl/__init__.py b/apps/application/flow/step_node/variable_assign_node/impl/__init__.py
new file mode 100644
index 000000000..7585cdd8f
--- /dev/null
+++ b/apps/application/flow/step_node/variable_assign_node/impl/__init__.py
@@ -0,0 +1,9 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: __init__.py
+ @date:2024/6/11 17:49
+ @desc:
+"""
+from .base_variable_assign_node import *
\ No newline at end of file
diff --git a/apps/application/flow/step_node/variable_assign_node/impl/base_variable_assign_node.py b/apps/application/flow/step_node/variable_assign_node/impl/base_variable_assign_node.py
new file mode 100644
index 000000000..ce2906e62
--- /dev/null
+++ b/apps/application/flow/step_node/variable_assign_node/impl/base_variable_assign_node.py
@@ -0,0 +1,65 @@
+# coding=utf-8
+import json
+from typing import List
+
+from application.flow.i_step_node import NodeResult
+from application.flow.step_node.variable_assign_node.i_variable_assign_node import IVariableAssignNode
+
+
+class BaseVariableAssignNode(IVariableAssignNode):
+ def save_context(self, details, workflow_manage):
+ self.context['variable_list'] = details.get('variable_list')
+ self.context['result_list'] = details.get('result_list')
+
+ def execute(self, variable_list, stream, **kwargs) -> NodeResult:
+ #
+ result_list = []
+ for variable in variable_list:
+ if 'fields' not in variable:
+ continue
+ if 'global' == variable['fields'][0]:
+ result = {
+ 'name': variable['name'],
+ 'input_value': self.get_reference_content(variable['fields']),
+ }
+ if variable['source'] == 'custom':
+ if variable['type'] == 'json':
+ if isinstance(variable['value'], dict) or isinstance(variable['value'], list):
+ val = variable['value']
+ else:
+ val = json.loads(variable['value'])
+ self.workflow_manage.context[variable['fields'][1]] = val
+ result['output_value'] = variable['value'] = val
+ elif variable['type'] == 'string':
+ # 变量解析 例如:{{global.xxx}}
+ val = self.workflow_manage.generate_prompt(variable['value'])
+ self.workflow_manage.context[variable['fields'][1]] = val
+ result['output_value'] = val
+ else:
+ val = variable['value']
+ self.workflow_manage.context[variable['fields'][1]] = val
+ result['output_value'] = val
+ else:
+ reference = self.get_reference_content(variable['reference'])
+ self.workflow_manage.context[variable['fields'][1]] = reference
+ result['output_value'] = reference
+ result_list.append(result)
+
+ return NodeResult({'variable_list': variable_list, 'result_list': result_list}, {})
+
+ def get_reference_content(self, fields: List[str]):
+ return str(self.workflow_manage.get_reference_field(
+ fields[0],
+ fields[1:]))
+
+ def get_details(self, index: int, **kwargs):
+ return {
+ 'name': self.node.properties.get('stepName'),
+ "index": index,
+ 'run_time': self.context.get('run_time'),
+ 'type': self.node.type,
+ 'variable_list': self.context.get('variable_list'),
+ 'result_list': self.context.get('result_list'),
+ 'status': self.status,
+ 'err_message': self.err_message
+ }
diff --git a/apps/application/flow/tools.py b/apps/application/flow/tools.py
new file mode 100644
index 000000000..dfbf69b35
--- /dev/null
+++ b/apps/application/flow/tools.py
@@ -0,0 +1,191 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: utils.py
+ @date:2024/6/6 15:15
+ @desc:
+"""
+import json
+from typing import Iterator
+
+from django.http import StreamingHttpResponse
+from langchain_core.messages import BaseMessageChunk, BaseMessage
+
+from application.flow.i_step_node import WorkFlowPostHandler
+from common.response import result
+
+
+class Reasoning:
+ def __init__(self, reasoning_content_start, reasoning_content_end):
+ self.content = ""
+ self.reasoning_content = ""
+ self.all_content = ""
+ self.reasoning_content_start_tag = reasoning_content_start
+ self.reasoning_content_end_tag = reasoning_content_end
+ self.reasoning_content_start_tag_len = len(
+ reasoning_content_start) if reasoning_content_start is not None else 0
+ self.reasoning_content_end_tag_len = len(reasoning_content_end) if reasoning_content_end is not None else 0
+ self.reasoning_content_end_tag_prefix = reasoning_content_end[
+ 0] if self.reasoning_content_end_tag_len > 0 else ''
+ self.reasoning_content_is_start = False
+ self.reasoning_content_is_end = False
+ self.reasoning_content_chunk = ""
+
+ def get_end_reasoning_content(self):
+ if not self.reasoning_content_is_start and not self.reasoning_content_is_end:
+ r = {'content': self.all_content, 'reasoning_content': ''}
+ self.reasoning_content_chunk = ""
+ return r
+ if self.reasoning_content_is_start and not self.reasoning_content_is_end:
+ r = {'content': '', 'reasoning_content': self.reasoning_content_chunk}
+ self.reasoning_content_chunk = ""
+ return r
+ return {'content': '', 'reasoning_content': ''}
+
+ def get_reasoning_content(self, chunk):
+ # 如果没有开始思考过程标签那么就全是结果
+ if self.reasoning_content_start_tag is None or len(self.reasoning_content_start_tag) == 0:
+ self.content += chunk.content
+ return {'content': chunk.content, 'reasoning_content': ''}
+ # 如果没有结束思考过程标签那么就全部是思考过程
+ if self.reasoning_content_end_tag is None or len(self.reasoning_content_end_tag) == 0:
+ return {'content': '', 'reasoning_content': chunk.content}
+ self.all_content += chunk.content
+ if not self.reasoning_content_is_start and len(self.all_content) >= self.reasoning_content_start_tag_len:
+ if self.all_content.startswith(self.reasoning_content_start_tag):
+ self.reasoning_content_is_start = True
+ self.reasoning_content_chunk = self.all_content[self.reasoning_content_start_tag_len:]
+ else:
+ if not self.reasoning_content_is_end:
+ self.reasoning_content_is_end = True
+ self.content += self.all_content
+ return {'content': self.all_content, 'reasoning_content': ''}
+ else:
+ if self.reasoning_content_is_start:
+ self.reasoning_content_chunk += chunk.content
+ reasoning_content_end_tag_prefix_index = self.reasoning_content_chunk.find(
+ self.reasoning_content_end_tag_prefix)
+ if self.reasoning_content_is_end:
+ self.content += chunk.content
+ return {'content': chunk.content, 'reasoning_content': ''}
+ # 是否包含结束
+ if reasoning_content_end_tag_prefix_index > -1:
+ if len(self.reasoning_content_chunk) - reasoning_content_end_tag_prefix_index >= self.reasoning_content_end_tag_len:
+ reasoning_content_end_tag_index = self.reasoning_content_chunk.find(self.reasoning_content_end_tag)
+ if reasoning_content_end_tag_index > -1:
+ reasoning_content_chunk = self.reasoning_content_chunk[0:reasoning_content_end_tag_index]
+ content_chunk = self.reasoning_content_chunk[
+ reasoning_content_end_tag_index + self.reasoning_content_end_tag_len:]
+ self.reasoning_content += reasoning_content_chunk
+ self.content += content_chunk
+ self.reasoning_content_chunk = ""
+ self.reasoning_content_is_end = True
+ return {'content': content_chunk, 'reasoning_content': reasoning_content_chunk}
+ else:
+ reasoning_content_chunk = self.reasoning_content_chunk[0:reasoning_content_end_tag_prefix_index + 1]
+ self.reasoning_content_chunk = self.reasoning_content_chunk.replace(reasoning_content_chunk, '')
+ self.reasoning_content += reasoning_content_chunk
+ return {'content': '', 'reasoning_content': reasoning_content_chunk}
+ else:
+ return {'content': '', 'reasoning_content': ''}
+
+ else:
+ if self.reasoning_content_is_end:
+ self.content += chunk.content
+ return {'content': chunk.content, 'reasoning_content': ''}
+ else:
+ # aaa
+ result = {'content': '', 'reasoning_content': self.reasoning_content_chunk}
+ self.reasoning_content += self.reasoning_content_chunk
+ self.reasoning_content_chunk = ""
+ return result
+
+
+def event_content(chat_id, chat_record_id, response, workflow,
+ write_context,
+ post_handler: WorkFlowPostHandler):
+ """
+ 用于处理流式输出
+ @param chat_id: 会话id
+ @param chat_record_id: 对话记录id
+ @param response: 响应数据
+ @param workflow: 工作流管理器
+ @param write_context 写入节点上下文
+ @param post_handler: 后置处理器
+ """
+ answer = ''
+ try:
+ for chunk in response:
+ answer += chunk.content
+ yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
+ 'content': chunk.content, 'is_end': False}, ensure_ascii=False) + "\n\n"
+ write_context(answer, 200)
+ post_handler.handler(chat_id, chat_record_id, answer, workflow)
+ yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
+ 'content': '', 'is_end': True}, ensure_ascii=False) + "\n\n"
+ except Exception as e:
+ answer = str(e)
+ write_context(answer, 500)
+ post_handler.handler(chat_id, chat_record_id, answer, workflow)
+ yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
+ 'content': answer, 'is_end': True}, ensure_ascii=False) + "\n\n"
+
+
+def to_stream_response(chat_id, chat_record_id, response: Iterator[BaseMessageChunk], workflow, write_context,
+ post_handler):
+ """
+ 将结果转换为服务流输出
+ @param chat_id: 会话id
+ @param chat_record_id: 对话记录id
+ @param response: 响应数据
+ @param workflow: 工作流管理器
+ @param write_context 写入节点上下文
+ @param post_handler: 后置处理器
+ @return: 响应
+ """
+ r = StreamingHttpResponse(
+ streaming_content=event_content(chat_id, chat_record_id, response, workflow, write_context, post_handler),
+ content_type='text/event-stream;charset=utf-8',
+ charset='utf-8')
+
+ r['Cache-Control'] = 'no-cache'
+ return r
+
+
+def to_response(chat_id, chat_record_id, response: BaseMessage, workflow, write_context,
+ post_handler: WorkFlowPostHandler):
+ """
+ 将结果转换为服务输出
+
+ @param chat_id: 会话id
+ @param chat_record_id: 对话记录id
+ @param response: 响应数据
+ @param workflow: 工作流管理器
+ @param write_context 写入节点上下文
+ @param post_handler: 后置处理器
+ @return: 响应
+ """
+ answer = response.content
+ write_context(answer)
+ post_handler.handler(chat_id, chat_record_id, answer, workflow)
+ return result.success({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
+ 'content': answer, 'is_end': True})
+
+
+def to_response_simple(chat_id, chat_record_id, response: BaseMessage, workflow,
+ post_handler: WorkFlowPostHandler):
+ answer = response.content
+ post_handler.handler(chat_id, chat_record_id, answer, workflow)
+ return result.success({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
+ 'content': answer, 'is_end': True})
+
+
+def to_stream_response_simple(stream_event):
+ r = StreamingHttpResponse(
+ streaming_content=stream_event,
+ content_type='text/event-stream;charset=utf-8',
+ charset='utf-8')
+
+ r['Cache-Control'] = 'no-cache'
+ return r
diff --git a/apps/application/flow/workflow_manage.py b/apps/application/flow/workflow_manage.py
new file mode 100644
index 000000000..0f7bc9c75
--- /dev/null
+++ b/apps/application/flow/workflow_manage.py
@@ -0,0 +1,827 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: workflow_manage.py
+ @date:2024/1/9 17:40
+ @desc:
+"""
+import concurrent
+import json
+import threading
+import traceback
+from concurrent.futures import ThreadPoolExecutor
+from functools import reduce
+from typing import List, Dict
+
+from django.db import close_old_connections
+from django.db.models import QuerySet
+from django.utils import translation
+from django.utils.translation import get_language
+from django.utils.translation import gettext as _
+from langchain_core.prompts import PromptTemplate
+from rest_framework import status
+from rest_framework.exceptions import ErrorDetail, ValidationError
+
+from application.flow import tools
+from application.flow.i_step_node import INode, WorkFlowPostHandler, NodeResult
+from application.flow.step_node import get_node
+from common.exception.app_exception import AppApiException
+from common.handle.base_to_response import BaseToResponse
+from common.handle.impl.response.system_to_response import SystemToResponse
+from function_lib.models.function import FunctionLib
+from setting.models import Model
+from setting.models_provider import get_model_credential
+
+executor = ThreadPoolExecutor(max_workers=200)
+
+
+class Edge:
+ def __init__(self, _id: str, _type: str, sourceNodeId: str, targetNodeId: str, **keywords):
+ self.id = _id
+ self.type = _type
+ self.sourceNodeId = sourceNodeId
+ self.targetNodeId = targetNodeId
+ for keyword in keywords:
+ self.__setattr__(keyword, keywords.get(keyword))
+
+
+class Node:
+ def __init__(self, _id: str, _type: str, x: int, y: int, properties: dict, **kwargs):
+ self.id = _id
+ self.type = _type
+ self.x = x
+ self.y = y
+ self.properties = properties
+ for keyword in kwargs:
+ self.__setattr__(keyword, kwargs.get(keyword))
+
+
+end_nodes = ['ai-chat-node', 'reply-node', 'function-node', 'function-lib-node', 'application-node',
+ 'image-understand-node', 'speech-to-text-node', 'text-to-speech-node', 'image-generate-node']
+
+
+class Flow:
+ def __init__(self, nodes: List[Node], edges: List[Edge]):
+ self.nodes = nodes
+ self.edges = edges
+
+ @staticmethod
+ def new_instance(flow_obj: Dict):
+ nodes = flow_obj.get('nodes')
+ edges = flow_obj.get('edges')
+ nodes = [Node(node.get('id'), node.get('type'), **node)
+ for node in nodes]
+ edges = [Edge(edge.get('id'), edge.get('type'), **edge) for edge in edges]
+ return Flow(nodes, edges)
+
+ def get_start_node(self):
+ start_node_list = [node for node in self.nodes if node.id == 'start-node']
+ return start_node_list[0]
+
+ def get_search_node(self):
+ return [node for node in self.nodes if node.type == 'search-dataset-node']
+
+ def is_valid(self):
+ """
+ 校验工作流数据
+ """
+ self.is_valid_model_params()
+ self.is_valid_start_node()
+ self.is_valid_base_node()
+ self.is_valid_work_flow()
+
+ @staticmethod
+ def is_valid_node_params(node: Node):
+ get_node(node.type)(node, None, None)
+
+ def is_valid_node(self, node: Node):
+ self.is_valid_node_params(node)
+ if node.type == 'condition-node':
+ branch_list = node.properties.get('node_data').get('branch')
+ for branch in branch_list:
+ source_anchor_id = f"{node.id}_{branch.get('id')}_right"
+ edge_list = [edge for edge in self.edges if edge.sourceAnchorId == source_anchor_id]
+ if len(edge_list) == 0:
+ raise AppApiException(500,
+ _('The branch {branch} of the {node} node needs to be connected').format(
+ node=node.properties.get("stepName"), branch=branch.get("type")))
+
+ else:
+ edge_list = [edge for edge in self.edges if edge.sourceNodeId == node.id]
+ if len(edge_list) == 0 and not end_nodes.__contains__(node.type):
+ raise AppApiException(500, _("{node} Nodes cannot be considered as end nodes").format(
+ node=node.properties.get("stepName")))
+
+ def get_next_nodes(self, node: Node):
+ edge_list = [edge for edge in self.edges if edge.sourceNodeId == node.id]
+ node_list = reduce(lambda x, y: [*x, *y],
+ [[node for node in self.nodes if node.id == edge.targetNodeId] for edge in edge_list],
+ [])
+ if len(node_list) == 0 and not end_nodes.__contains__(node.type):
+ raise AppApiException(500,
+ _("The next node that does not exist"))
+ return node_list
+
+ def is_valid_work_flow(self, up_node=None):
+ if up_node is None:
+ up_node = self.get_start_node()
+ self.is_valid_node(up_node)
+ next_nodes = self.get_next_nodes(up_node)
+ for next_node in next_nodes:
+ self.is_valid_work_flow(next_node)
+
+ def is_valid_start_node(self):
+ start_node_list = [node for node in self.nodes if node.id == 'start-node']
+ if len(start_node_list) == 0:
+ raise AppApiException(500, _('The starting node is required'))
+ if len(start_node_list) > 1:
+ raise AppApiException(500, _('There can only be one starting node'))
+
+ def is_valid_model_params(self):
+ node_list = [node for node in self.nodes if (node.type == 'ai-chat-node' or node.type == 'question-node')]
+ for node in node_list:
+ model = QuerySet(Model).filter(id=node.properties.get('node_data', {}).get('model_id')).first()
+ if model is None:
+ raise ValidationError(ErrorDetail(
+ _('The node {node} model does not exist').format(node=node.properties.get("stepName"))))
+ credential = get_model_credential(model.provider, model.model_type, model.model_name)
+ model_params_setting = node.properties.get('node_data', {}).get('model_params_setting')
+ model_params_setting_form = credential.get_model_params_setting_form(
+ model.model_name)
+ if model_params_setting is None:
+ model_params_setting = model_params_setting_form.get_default_form_data()
+ node.properties.get('node_data', {})['model_params_setting'] = model_params_setting
+ if node.properties.get('status', 200) != 200:
+ raise ValidationError(
+ ErrorDetail(_("Node {node} is unavailable").format(node.properties.get("stepName"))))
+ node_list = [node for node in self.nodes if (node.type == 'function-lib-node')]
+ for node in node_list:
+ function_lib_id = node.properties.get('node_data', {}).get('function_lib_id')
+ if function_lib_id is None:
+ raise ValidationError(ErrorDetail(
+ _('The library ID of node {node} cannot be empty').format(node=node.properties.get("stepName"))))
+ f_lib = QuerySet(FunctionLib).filter(id=function_lib_id).first()
+ if f_lib is None:
+ raise ValidationError(ErrorDetail(_("The function library for node {node} is not available").format(
+ node=node.properties.get("stepName"))))
+
+ def is_valid_base_node(self):
+ base_node_list = [node for node in self.nodes if node.id == 'base-node']
+ if len(base_node_list) == 0:
+ raise AppApiException(500, _('Basic information node is required'))
+ if len(base_node_list) > 1:
+ raise AppApiException(500, _('There can only be one basic information node'))
+
+
+class NodeResultFuture:
+ def __init__(self, r, e, status=200):
+ self.r = r
+ self.e = e
+ self.status = status
+
+ def result(self):
+ if self.status == 200:
+ return self.r
+ else:
+ raise self.e
+
+
+def await_result(result, timeout=1):
+ try:
+ result.result(timeout)
+ return False
+ except Exception as e:
+ return True
+
+
+class NodeChunkManage:
+
+ def __init__(self, work_flow):
+ self.node_chunk_list = []
+ self.current_node_chunk = None
+ self.work_flow = work_flow
+
+ def add_node_chunk(self, node_chunk):
+ self.node_chunk_list.append(node_chunk)
+
+ def contains(self, node_chunk):
+ return self.node_chunk_list.__contains__(node_chunk)
+
+ def pop(self):
+ if self.current_node_chunk is None:
+ try:
+ current_node_chunk = self.node_chunk_list.pop(0)
+ self.current_node_chunk = current_node_chunk
+ except IndexError as e:
+ pass
+ if self.current_node_chunk is not None:
+ try:
+ chunk = self.current_node_chunk.chunk_list.pop(0)
+ return chunk
+ except IndexError as e:
+ if self.current_node_chunk.is_end():
+ self.current_node_chunk = None
+ if self.work_flow.answer_is_not_empty():
+ chunk = self.work_flow.base_to_response.to_stream_chunk_response(
+ self.work_flow.params['chat_id'],
+ self.work_flow.params['chat_record_id'],
+ '\n\n', False, 0, 0)
+ self.work_flow.append_answer('\n\n')
+ return chunk
+ return self.pop()
+ return None
+
+
+class WorkflowManage:
+ def __init__(self, flow: Flow, params, work_flow_post_handler: WorkFlowPostHandler,
+ base_to_response: BaseToResponse = SystemToResponse(), form_data=None, image_list=None,
+ document_list=None,
+ audio_list=None,
+ other_list=None,
+ start_node_id=None,
+ start_node_data=None, chat_record=None, child_node=None):
+ if form_data is None:
+ form_data = {}
+ if image_list is None:
+ image_list = []
+ if document_list is None:
+ document_list = []
+ if audio_list is None:
+ audio_list = []
+ if other_list is None:
+ other_list = []
+ self.start_node_id = start_node_id
+ self.start_node = None
+ self.form_data = form_data
+ self.image_list = image_list
+ self.document_list = document_list
+ self.audio_list = audio_list
+ self.other_list = other_list
+ self.params = params
+ self.flow = flow
+ self.context = {}
+ self.node_chunk_manage = NodeChunkManage(self)
+ self.work_flow_post_handler = work_flow_post_handler
+ self.current_node = None
+ self.current_result = None
+ self.answer = ""
+ self.answer_list = ['']
+ self.status = 200
+ self.base_to_response = base_to_response
+ self.chat_record = chat_record
+ self.child_node = child_node
+ self.future_list = []
+ self.lock = threading.Lock()
+ self.field_list = []
+ self.global_field_list = []
+ self.init_fields()
+ if start_node_id is not None:
+ self.load_node(chat_record, start_node_id, start_node_data)
+ else:
+ self.node_context = []
+
+ def init_fields(self):
+ field_list = []
+ global_field_list = []
+ for node in self.flow.nodes:
+ properties = node.properties
+ node_name = properties.get('stepName')
+ node_id = node.id
+ node_config = properties.get('config')
+ if node_config is not None:
+ fields = node_config.get('fields')
+ if fields is not None:
+ for field in fields:
+ field_list.append({**field, 'node_id': node_id, 'node_name': node_name})
+ global_fields = node_config.get('globalFields')
+ if global_fields is not None:
+ for global_field in global_fields:
+ global_field_list.append({**global_field, 'node_id': node_id, 'node_name': node_name})
+ field_list.sort(key=lambda f: len(f.get('node_name')), reverse=True)
+ global_field_list.sort(key=lambda f: len(f.get('node_name')), reverse=True)
+ self.field_list = field_list
+ self.global_field_list = global_field_list
+
+ def append_answer(self, content):
+ self.answer += content
+ self.answer_list[-1] += content
+
+ def answer_is_not_empty(self):
+ return len(self.answer_list[-1]) > 0
+
+ def load_node(self, chat_record, start_node_id, start_node_data):
+ self.node_context = []
+ self.answer = chat_record.answer_text
+ self.answer_list = chat_record.answer_text_list
+ self.answer_list.append('')
+ for node_details in sorted(chat_record.details.values(), key=lambda d: d.get('index')):
+ node_id = node_details.get('node_id')
+ if node_details.get('runtime_node_id') == start_node_id:
+ def get_node_params(n):
+ is_result = False
+ if n.type == 'application-node':
+ is_result = True
+ return {**n.properties.get('node_data'), 'form_data': start_node_data, 'node_data': start_node_data,
+ 'child_node': self.child_node, 'is_result': is_result}
+
+ self.start_node = self.get_node_cls_by_id(node_id, node_details.get('up_node_id_list'),
+ get_node_params=get_node_params)
+ self.start_node.valid_args(
+ {**self.start_node.node_params, 'form_data': start_node_data}, self.start_node.workflow_params)
+ if self.start_node.type == 'application-node':
+ application_node_dict = node_details.get('application_node_dict', {})
+ self.start_node.context['application_node_dict'] = application_node_dict
+ self.node_context.append(self.start_node)
+ continue
+
+ node_id = node_details.get('node_id')
+ node = self.get_node_cls_by_id(node_id, node_details.get('up_node_id_list'))
+ node.valid_args(node.node_params, node.workflow_params)
+ node.save_context(node_details, self)
+ node.node_chunk.end()
+ self.node_context.append(node)
+
+ def run(self):
+ close_old_connections()
+ language = get_language()
+ if self.params.get('stream'):
+ return self.run_stream(self.start_node, None, language)
+ return self.run_block(language)
+
+ def run_block(self, language='zh'):
+ """
+ 非流式响应
+ @return: 结果
+ """
+ self.run_chain_async(None, None, language)
+ while self.is_run():
+ pass
+ details = self.get_runtime_details()
+ message_tokens = sum([row.get('message_tokens') for row in details.values() if
+ 'message_tokens' in row and row.get('message_tokens') is not None])
+ answer_tokens = sum([row.get('answer_tokens') for row in details.values() if
+ 'answer_tokens' in row and row.get('answer_tokens') is not None])
+ answer_text_list = self.get_answer_text_list()
+ answer_text = '\n\n'.join(
+ '\n\n'.join([a.get('content') for a in answer]) for answer in
+ answer_text_list)
+ answer_list = reduce(lambda pre, _n: [*pre, *_n], answer_text_list, [])
+ self.work_flow_post_handler.handler(self.params['chat_id'], self.params['chat_record_id'],
+ answer_text,
+ self)
+ return self.base_to_response.to_block_response(self.params['chat_id'],
+ self.params['chat_record_id'], answer_text, True
+ , message_tokens, answer_tokens,
+ _status=status.HTTP_200_OK if self.status == 200 else status.HTTP_500_INTERNAL_SERVER_ERROR,
+ other_params={'answer_list': answer_list})
+
+ def run_stream(self, current_node, node_result_future, language='zh'):
+ """
+ 流式响应
+ @return:
+ """
+ self.run_chain_async(current_node, node_result_future, language)
+ return tools.to_stream_response_simple(self.await_result())
+
+ def is_run(self, timeout=0.5):
+ future_list_len = len(self.future_list)
+ try:
+ r = concurrent.futures.wait(self.future_list, timeout)
+ if len(r.not_done) > 0:
+ return True
+ else:
+ if future_list_len == len(self.future_list):
+ return False
+ else:
+ return True
+ except Exception as e:
+ return True
+
+ def await_result(self):
+ try:
+ while self.is_run():
+ while True:
+ chunk = self.node_chunk_manage.pop()
+ if chunk is not None:
+ yield chunk
+ else:
+ break
+ while True:
+ chunk = self.node_chunk_manage.pop()
+ if chunk is None:
+ break
+ yield chunk
+ finally:
+ while self.is_run():
+ pass
+ details = self.get_runtime_details()
+ message_tokens = sum([row.get('message_tokens') for row in details.values() if
+ 'message_tokens' in row and row.get('message_tokens') is not None])
+ answer_tokens = sum([row.get('answer_tokens') for row in details.values() if
+ 'answer_tokens' in row and row.get('answer_tokens') is not None])
+ self.work_flow_post_handler.handler(self.params['chat_id'], self.params['chat_record_id'],
+ self.answer,
+ self)
+ yield self.base_to_response.to_stream_chunk_response(self.params['chat_id'],
+ self.params['chat_record_id'],
+ '',
+ [],
+ '', True, message_tokens, answer_tokens, {})
+
+ def run_chain_async(self, current_node, node_result_future, language='zh'):
+ future = executor.submit(self.run_chain_manage, current_node, node_result_future, language)
+ self.future_list.append(future)
+
+ def run_chain_manage(self, current_node, node_result_future, language='zh'):
+ translation.activate(language)
+ if current_node is None:
+ start_node = self.get_start_node()
+ current_node = get_node(start_node.type)(start_node, self.params, self)
+ self.node_chunk_manage.add_node_chunk(current_node.node_chunk)
+ # 添加节点
+ self.append_node(current_node)
+ result = self.run_chain(current_node, node_result_future)
+ if result is None:
+ return
+ node_list = self.get_next_node_list(current_node, result)
+ if len(node_list) == 1:
+ self.run_chain_manage(node_list[0], None, language)
+ elif len(node_list) > 1:
+ sorted_node_run_list = sorted(node_list, key=lambda n: n.node.y)
+ # 获取到可执行的子节点
+ result_list = [{'node': node, 'future': executor.submit(self.run_chain_manage, node, None, language)} for
+ node in
+ sorted_node_run_list]
+ for r in result_list:
+ self.future_list.append(r.get('future'))
+
+ def run_chain(self, current_node, node_result_future=None):
+ if node_result_future is None:
+ node_result_future = self.run_node_future(current_node)
+ try:
+ is_stream = self.params.get('stream', True)
+ result = self.hand_event_node_result(current_node,
+ node_result_future) if is_stream else self.hand_node_result(
+ current_node, node_result_future)
+ return result
+ except Exception as e:
+ traceback.print_exc()
+ return None
+
+ def hand_node_result(self, current_node, node_result_future):
+ try:
+ current_result = node_result_future.result()
+ result = current_result.write_context(current_node, self)
+ if result is not None:
+ # 阻塞获取结果
+ list(result)
+ return current_result
+ except Exception as e:
+ traceback.print_exc()
+ self.status = 500
+ current_node.get_write_error_context(e)
+ self.answer += str(e)
+ finally:
+ current_node.node_chunk.end()
+
+ def append_node(self, current_node):
+ for index in range(len(self.node_context)):
+ n = self.node_context[index]
+ if current_node.id == n.node.id and current_node.runtime_node_id == n.runtime_node_id:
+ self.node_context[index] = current_node
+ return
+ self.node_context.append(current_node)
+
+ def hand_event_node_result(self, current_node, node_result_future):
+ runtime_node_id = current_node.runtime_node_id
+ real_node_id = current_node.runtime_node_id
+ child_node = {}
+ view_type = current_node.view_type
+ try:
+ current_result = node_result_future.result()
+ result = current_result.write_context(current_node, self)
+ if result is not None:
+ if self.is_result(current_node, current_result):
+ for r in result:
+ reasoning_content = ''
+ content = r
+ child_node = {}
+ node_is_end = False
+ view_type = current_node.view_type
+ if isinstance(r, dict):
+ content = r.get('content')
+ child_node = {'runtime_node_id': r.get('runtime_node_id'),
+ 'chat_record_id': r.get('chat_record_id')
+ , 'child_node': r.get('child_node')}
+ if r.__contains__('real_node_id'):
+ real_node_id = r.get('real_node_id')
+ if r.__contains__('node_is_end'):
+ node_is_end = r.get('node_is_end')
+ view_type = r.get('view_type')
+ reasoning_content = r.get('reasoning_content')
+ chunk = self.base_to_response.to_stream_chunk_response(self.params['chat_id'],
+ self.params['chat_record_id'],
+ current_node.id,
+ current_node.up_node_id_list,
+ content, False, 0, 0,
+ {'node_type': current_node.type,
+ 'runtime_node_id': runtime_node_id,
+ 'view_type': view_type,
+ 'child_node': child_node,
+ 'node_is_end': node_is_end,
+ 'real_node_id': real_node_id,
+ 'reasoning_content': reasoning_content})
+ current_node.node_chunk.add_chunk(chunk)
+ chunk = (self.base_to_response
+ .to_stream_chunk_response(self.params['chat_id'],
+ self.params['chat_record_id'],
+ current_node.id,
+ current_node.up_node_id_list,
+ '', False, 0, 0, {'node_is_end': True,
+ 'runtime_node_id': runtime_node_id,
+ 'node_type': current_node.type,
+ 'view_type': view_type,
+ 'child_node': child_node,
+ 'real_node_id': real_node_id,
+ 'reasoning_content': ''}))
+ current_node.node_chunk.add_chunk(chunk)
+ else:
+ list(result)
+ return current_result
+ except Exception as e:
+ # 添加节点
+ traceback.print_exc()
+ chunk = self.base_to_response.to_stream_chunk_response(self.params['chat_id'],
+ self.params['chat_record_id'],
+ current_node.id,
+ current_node.up_node_id_list,
+ 'Exception:' + str(e), False, 0, 0,
+ {'node_is_end': True,
+ 'runtime_node_id': current_node.runtime_node_id,
+ 'node_type': current_node.type,
+ 'view_type': current_node.view_type,
+ 'child_node': {},
+ 'real_node_id': real_node_id})
+ current_node.node_chunk.add_chunk(chunk)
+ current_node.get_write_error_context(e)
+ self.status = 500
+ return None
+ finally:
+ current_node.node_chunk.end()
+
+ def run_node_async(self, node):
+ future = executor.submit(self.run_node, node)
+ return future
+
+ def run_node_future(self, node):
+ try:
+ node.valid_args(node.node_params, node.workflow_params)
+ result = self.run_node(node)
+ return NodeResultFuture(result, None, 200)
+ except Exception as e:
+ return NodeResultFuture(None, e, 500)
+
+ def run_node(self, node):
+ result = node.run()
+ return result
+
+ def is_result(self, current_node, current_node_result):
+ return current_node.node_params.get('is_result', not self._has_next_node(
+ current_node, current_node_result)) if current_node.node_params is not None else False
+
+ def get_chunk_content(self, chunk, is_end=False):
+ return 'data: ' + json.dumps(
+ {'chat_id': self.params['chat_id'], 'id': self.params['chat_record_id'], 'operate': True,
+ 'content': chunk, 'is_end': is_end}, ensure_ascii=False) + "\n\n"
+
+ def _has_next_node(self, current_node, node_result: NodeResult | None):
+ """
+ 是否有下一个可运行的节点
+ """
+ if node_result is not None and node_result.is_assertion_result():
+ for edge in self.flow.edges:
+ if (edge.sourceNodeId == current_node.id and
+ f"{edge.sourceNodeId}_{node_result.node_variable.get('branch_id')}_right" == edge.sourceAnchorId):
+ return True
+ else:
+ for edge in self.flow.edges:
+ if edge.sourceNodeId == current_node.id:
+ return True
+
+ def has_next_node(self, node_result: NodeResult | None):
+ """
+ 是否有下一个可运行的节点
+ """
+ return self._has_next_node(self.get_start_node() if self.current_node is None else self.current_node,
+ node_result)
+
+ def get_runtime_details(self):
+ details_result = {}
+ for index in range(len(self.node_context)):
+ node = self.node_context[index]
+ if self.chat_record is not None and self.chat_record.details is not None:
+ details = self.chat_record.details.get(node.runtime_node_id)
+ if details is not None and self.start_node.runtime_node_id != node.runtime_node_id:
+ details_result[node.runtime_node_id] = details
+ continue
+ details = node.get_details(index)
+ details['node_id'] = node.id
+ details['up_node_id_list'] = node.up_node_id_list
+ details['runtime_node_id'] = node.runtime_node_id
+ details_result[node.runtime_node_id] = details
+ return details_result
+
+ def get_answer_text_list(self):
+ result = []
+ answer_list = reduce(lambda x, y: [*x, *y],
+ [n.get_answer_list() for n in self.node_context if n.get_answer_list() is not None],
+ [])
+ up_node = None
+ for index in range(len(answer_list)):
+ current_answer = answer_list[index]
+ if len(current_answer.content) > 0:
+ if up_node is None or current_answer.view_type == 'single_view' or (
+ current_answer.view_type == 'many_view' and up_node.view_type == 'single_view'):
+ result.append([current_answer])
+ else:
+ if len(result) > 0:
+ exec_index = len(result) - 1
+ if isinstance(result[exec_index], list):
+ result[exec_index].append(current_answer)
+ else:
+ result.insert(0, [current_answer])
+ up_node = current_answer
+ if len(result) == 0:
+ # 如果没有响应 就响应一个空数据
+ return [[]]
+ return [[item.to_dict() for item in r] for r in result]
+
+ def get_next_node(self):
+ """
+ 获取下一个可运行的所有节点
+ """
+ if self.current_node is None:
+ node = self.get_start_node()
+ node_instance = get_node(node.type)(node, self.params, self)
+ return node_instance
+ if self.current_result is not None and self.current_result.is_assertion_result():
+ for edge in self.flow.edges:
+ if (edge.sourceNodeId == self.current_node.id and
+ f"{edge.sourceNodeId}_{self.current_result.node_variable.get('branch_id')}_right" == edge.sourceAnchorId):
+ return self.get_node_cls_by_id(edge.targetNodeId)
+ else:
+ for edge in self.flow.edges:
+ if edge.sourceNodeId == self.current_node.id:
+ return self.get_node_cls_by_id(edge.targetNodeId)
+
+ return None
+
+ @staticmethod
+ def dependent_node(up_node_id, node):
+ if not node.node_chunk.is_end():
+ return False
+ if node.id == up_node_id:
+ if node.type == 'form-node':
+ if node.context.get('form_data', None) is not None:
+ return True
+ return False
+ return True
+
+ def dependent_node_been_executed(self, node_id):
+ """
+ 判断依赖节点是否都已执行
+ @param node_id: 需要判断的节点id
+ @return:
+ """
+ up_node_id_list = [edge.sourceNodeId for edge in self.flow.edges if edge.targetNodeId == node_id]
+ return all([any([self.dependent_node(up_node_id, node) for node in self.node_context]) for up_node_id in
+ up_node_id_list])
+
+ def get_up_node_id_list(self, node_id):
+ up_node_id_list = [edge.sourceNodeId for edge in self.flow.edges if edge.targetNodeId == node_id]
+ return up_node_id_list
+
+ def get_next_node_list(self, current_node, current_node_result):
+ """
+ 获取下一个可执行节点列表
+ @param current_node: 当前可执行节点
+ @param current_node_result: 当前可执行节点结果
+ @return: 可执行节点列表
+ """
+ # 判断是否中断执行
+ if current_node_result.is_interrupt_exec(current_node):
+ return []
+ node_list = []
+ if current_node_result is not None and current_node_result.is_assertion_result():
+ for edge in self.flow.edges:
+ if (edge.sourceNodeId == current_node.id and
+ f"{edge.sourceNodeId}_{current_node_result.node_variable.get('branch_id')}_right" == edge.sourceAnchorId):
+ next_node = [node for node in self.flow.nodes if node.id == edge.targetNodeId]
+ if len(next_node) == 0:
+ continue
+ if next_node[0].properties.get('condition', "AND") == 'AND':
+ if self.dependent_node_been_executed(edge.targetNodeId):
+ node_list.append(
+ self.get_node_cls_by_id(edge.targetNodeId,
+ [*current_node.up_node_id_list, current_node.node.id]))
+ else:
+ node_list.append(
+ self.get_node_cls_by_id(edge.targetNodeId,
+ [*current_node.up_node_id_list, current_node.node.id]))
+ else:
+ for edge in self.flow.edges:
+ if edge.sourceNodeId == current_node.id:
+ next_node = [node for node in self.flow.nodes if node.id == edge.targetNodeId]
+ if len(next_node) == 0:
+ continue
+ if next_node[0].properties.get('condition', "AND") == 'AND':
+ if self.dependent_node_been_executed(edge.targetNodeId):
+ node_list.append(
+ self.get_node_cls_by_id(edge.targetNodeId,
+ [*current_node.up_node_id_list, current_node.node.id]))
+ else:
+ node_list.append(
+ self.get_node_cls_by_id(edge.targetNodeId,
+ [*current_node.up_node_id_list, current_node.node.id]))
+ return node_list
+
+ def get_reference_field(self, node_id: str, fields: List[str]):
+ """
+ @param node_id: 节点id
+ @param fields: 字段
+ @return:
+ """
+ if node_id == 'global':
+ return INode.get_field(self.context, fields)
+ else:
+ return self.get_node_by_id(node_id).get_reference_field(fields)
+
+ def get_workflow_content(self):
+ context = {
+ 'global': self.context,
+ }
+
+ for node in self.node_context:
+ context[node.id] = node.context
+ return context
+
+ def reset_prompt(self, prompt: str):
+ placeholder = "{}"
+ for field in self.field_list:
+ globeLabel = f"{field.get('node_name')}.{field.get('value')}"
+ globeValue = f"context.get('{field.get('node_id')}',{placeholder}).get('{field.get('value', '')}','')"
+ prompt = prompt.replace(globeLabel, globeValue)
+ for field in self.global_field_list:
+ globeLabel = f"全局变量.{field.get('value')}"
+ globeLabelNew = f"global.{field.get('value')}"
+ globeValue = f"context.get('global').get('{field.get('value', '')}','')"
+ prompt = prompt.replace(globeLabel, globeValue).replace(globeLabelNew, globeValue)
+ return prompt
+
+ def generate_prompt(self, prompt: str):
+ """
+ 格式化生成提示词
+ @param prompt: 提示词信息
+ @return: 格式化后的提示词
+ """
+ context = self.get_workflow_content()
+ prompt = self.reset_prompt(prompt)
+ prompt_template = PromptTemplate.from_template(prompt, template_format='jinja2')
+ value = prompt_template.format(context=context)
+ return value
+
+ def get_start_node(self):
+ """
+ 获取启动节点
+ @return:
+ """
+ start_node_list = [node for node in self.flow.nodes if node.type == 'start-node']
+ return start_node_list[0]
+
+ def get_base_node(self):
+ """
+ 获取基础节点
+ @return:
+ """
+ base_node_list = [node for node in self.flow.nodes if node.type == 'base-node']
+ return base_node_list[0]
+
+ def get_node_cls_by_id(self, node_id, up_node_id_list=None,
+ get_node_params=lambda node: node.properties.get('node_data')):
+ for node in self.flow.nodes:
+ if node.id == node_id:
+ node_instance = get_node(node.type)(node,
+ self.params, self, up_node_id_list, get_node_params)
+ return node_instance
+ return None
+
+ def get_node_by_id(self, node_id):
+ for node in self.node_context:
+ if node.id == node_id:
+ return node
+ return None
+
+ def get_node_reference(self, reference_address: Dict):
+ node = self.get_node_by_id(reference_address.get('node_id'))
+ return node.context[reference_address.get('node_field')]
diff --git a/apps/application/migrations/0001_initial.py b/apps/application/migrations/0001_initial.py
index 707f30322..52dadda82 100644
--- a/apps/application/migrations/0001_initial.py
+++ b/apps/application/migrations/0001_initial.py
@@ -1,4 +1,5 @@
-# Generated by Django 5.2.1 on 2025-05-27 06:42
+# Generated by Django 5.2 on 2025-05-27 07:50
+from django.db.models import QuerySet
import application.models.application
import django.contrib.postgres.fields
@@ -9,11 +10,17 @@ import uuid_utils.compat
from django.db import migrations, models
-class Migration(migrations.Migration):
+def insert_default_data(apps, schema_editor):
+ # 创建一个根模块(没有父节点)
+ QuerySet(application.models.application.ApplicationFolder).create(id='root', name='根目录',
+ user_id='f0dd8f71-e4ee-11ee-8c84-a8a1595801ab')
+
+class Migration(migrations.Migration):
initial = True
dependencies = [
+ ('knowledge', '0001_initial'),
('models_provider', '0001_initial'),
('users', '0002_alter_user_nick_name'),
]
@@ -24,22 +31,31 @@ class Migration(migrations.Migration):
fields=[
('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')),
('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')),
- ('id', models.UUIDField(default=uuid_utils.compat.uuid7, editable=False, primary_key=True, serialize=False, verbose_name='主键id')),
- ('workspace_id', models.CharField(db_index=True, default='default', max_length=64, verbose_name='工作空间id')),
+ ('id',
+ models.UUIDField(default=uuid_utils.compat.uuid7, editable=False, primary_key=True, serialize=False,
+ verbose_name='主键id')),
+ ('workspace_id',
+ models.CharField(db_index=True, default='default', max_length=64, verbose_name='工作空间id')),
('is_publish', models.BooleanField(default=False, verbose_name='是否发布')),
('name', models.CharField(max_length=128, verbose_name='应用名称')),
('desc', models.CharField(default='', max_length=512, verbose_name='引用描述')),
('prologue', models.CharField(default='', max_length=40960, verbose_name='开场白')),
('dialogue_number', models.IntegerField(default=0, verbose_name='会话数量')),
- ('dataset_setting', models.JSONField(default=application.models.application.get_dataset_setting_dict, verbose_name='数据集参数设置')),
- ('model_setting', models.JSONField(default=application.models.application.get_model_setting_dict, verbose_name='模型参数相关设置')),
+ ('knowledge_setting', models.JSONField(default=application.models.application.get_dataset_setting_dict,
+ verbose_name='数据集参数设置')),
+ ('model_setting', models.JSONField(default=application.models.application.get_model_setting_dict,
+ verbose_name='模型参数相关设置')),
('model_params_setting', models.JSONField(default=dict, verbose_name='模型参数相关设置')),
('tts_model_params_setting', models.JSONField(default=dict, verbose_name='模型参数相关设置')),
('problem_optimization', models.BooleanField(default=False, verbose_name='问题优化')),
('icon', models.CharField(default='/ui/favicon.ico', max_length=256, verbose_name='应用icon')),
('work_flow', models.JSONField(default=dict, verbose_name='工作流数据')),
- ('type', models.CharField(choices=[('SIMPLE', '简易'), ('WORK_FLOW', '工作流')], default='SIMPLE', max_length=256, verbose_name='应用类型')),
- ('problem_optimization_prompt', models.CharField(blank=True, default='()里面是用户问题,根据上下文回答揣测用户问题({question}) 要求: 输出一个补全问题,并且放在标签中', max_length=102400, null=True, verbose_name='问题优化提示词')),
+ ('type', models.CharField(choices=[('SIMPLE', '简易'), ('WORK_FLOW', '工作流')], default='SIMPLE',
+ max_length=256, verbose_name='应用类型')),
+ ('problem_optimization_prompt', models.CharField(blank=True,
+ default='()里面是用户问题,根据上下文回答揣测用户问题({question}) 要求: 输出一个补全问题,并且放在标签中',
+ max_length=102400, null=True,
+ verbose_name='问题优化提示词')),
('tts_model_enable', models.BooleanField(default=False, verbose_name='语音合成模型是否启用')),
('stt_model_enable', models.BooleanField(default=False, verbose_name='语音识别模型是否启用')),
('tts_type', models.CharField(default='BROWSER', max_length=20, verbose_name='语音播放类型')),
@@ -48,9 +64,14 @@ class Migration(migrations.Migration):
('clean_time', models.IntegerField(default=180, verbose_name='清理时间')),
('file_upload_enable', models.BooleanField(default=False, verbose_name='文件上传是否启用')),
('file_upload_setting', models.JSONField(default=dict, verbose_name='文件上传相关设置')),
- ('model', models.ForeignKey(blank=True, db_constraint=False, null=True, on_delete=django.db.models.deletion.SET_NULL, to='models_provider.model')),
- ('stt_model', models.ForeignKey(blank=True, db_constraint=False, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='stt_model_id', to='models_provider.model')),
- ('tts_model', models.ForeignKey(blank=True, db_constraint=False, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='tts_model_id', to='models_provider.model')),
+ ('model', models.ForeignKey(blank=True, db_constraint=False, null=True,
+ on_delete=django.db.models.deletion.SET_NULL, to='models_provider.model')),
+ ('stt_model', models.ForeignKey(blank=True, db_constraint=False, null=True,
+ on_delete=django.db.models.deletion.SET_NULL,
+ related_name='stt_model_id', to='models_provider.model')),
+ ('tts_model', models.ForeignKey(blank=True, db_constraint=False, null=True,
+ on_delete=django.db.models.deletion.SET_NULL,
+ related_name='tts_model_id', to='models_provider.model')),
('user', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='users.user')),
],
options={
@@ -62,12 +83,16 @@ class Migration(migrations.Migration):
fields=[
('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')),
('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')),
- ('application', models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, primary_key=True, serialize=False, to='application.application', verbose_name='应用id')),
+ ('application',
+ models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, primary_key=True, serialize=False,
+ to='application.application', verbose_name='应用id')),
('access_token', models.CharField(max_length=128, unique=True, verbose_name='用户公开访问 认证token')),
('is_active', models.BooleanField(default=True, verbose_name='是否开启公开访问')),
('access_num', models.IntegerField(default=100, verbose_name='访问次数')),
('white_active', models.BooleanField(default=False, verbose_name='是否开启白名单')),
- ('white_list', django.contrib.postgres.fields.ArrayField(base_field=models.CharField(blank=True, max_length=128), default=list, size=None, verbose_name='白名单列表')),
+ ('white_list',
+ django.contrib.postgres.fields.ArrayField(base_field=models.CharField(blank=True, max_length=128),
+ default=list, size=None, verbose_name='白名单列表')),
('show_source', models.BooleanField(default=False, verbose_name='是否显示知识来源')),
('language', models.CharField(default=None, max_length=10, null=True, verbose_name='语言')),
],
@@ -80,14 +105,21 @@ class Migration(migrations.Migration):
fields=[
('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')),
('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')),
- ('id', models.UUIDField(default=uuid.uuid1, editable=False, primary_key=True, serialize=False, verbose_name='主键id')),
+ ('id', models.UUIDField(default=uuid.uuid1, editable=False, primary_key=True, serialize=False,
+ verbose_name='主键id')),
('secret_key', models.CharField(max_length=1024, unique=True, verbose_name='秘钥')),
- ('workspace_id', models.CharField(db_index=True, default='default', max_length=64, verbose_name='工作空间id')),
+ ('workspace_id',
+ models.CharField(db_index=True, default='default', max_length=64, verbose_name='工作空间id')),
('is_active', models.BooleanField(default=True, verbose_name='是否开启')),
('allow_cross_domain', models.BooleanField(default=False, verbose_name='是否允许跨域')),
- ('cross_domain_list', django.contrib.postgres.fields.ArrayField(base_field=models.CharField(blank=True, max_length=128), default=list, size=None, verbose_name='跨域列表')),
- ('application', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='application.application', verbose_name='应用id')),
- ('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='users.user', verbose_name='用户id')),
+ ('cross_domain_list',
+ django.contrib.postgres.fields.ArrayField(base_field=models.CharField(blank=True, max_length=128),
+ default=list, size=None, verbose_name='跨域列表')),
+ ('application',
+ models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='application.application',
+ verbose_name='应用id')),
+ ('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='users.user',
+ verbose_name='用户id')),
],
options={
'db_table': 'application_api_key',
@@ -98,16 +130,21 @@ class Migration(migrations.Migration):
fields=[
('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')),
('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')),
- ('id', models.CharField(editable=False, max_length=64, primary_key=True, serialize=False, verbose_name='主键id')),
+ ('id', models.CharField(editable=False, max_length=64, primary_key=True, serialize=False,
+ verbose_name='主键id')),
('name', models.CharField(max_length=64, verbose_name='文件夹名称')),
('desc', models.CharField(blank=True, max_length=200, null=True, verbose_name='描述')),
- ('workspace_id', models.CharField(db_index=True, default='default', max_length=64, verbose_name='工作空间id')),
+ ('workspace_id',
+ models.CharField(db_index=True, default='default', max_length=64, verbose_name='工作空间id')),
('lft', models.PositiveIntegerField(editable=False)),
('rght', models.PositiveIntegerField(editable=False)),
('tree_id', models.PositiveIntegerField(db_index=True, editable=False)),
('level', models.PositiveIntegerField(editable=False)),
- ('parent', mptt.fields.TreeForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.DO_NOTHING, related_name='children', to='application.applicationfolder')),
- ('user', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='users.user', verbose_name='用户id')),
+ ('parent',
+ mptt.fields.TreeForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.DO_NOTHING,
+ related_name='children', to='application.applicationfolder')),
+ ('user', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='users.user',
+ verbose_name='用户id')),
],
options={
'db_table': 'application_folder',
@@ -116,18 +153,25 @@ class Migration(migrations.Migration):
migrations.AddField(
model_name='application',
name='folder',
- field=models.ForeignKey(default='root', on_delete=django.db.models.deletion.DO_NOTHING, to='application.applicationfolder', verbose_name='文件夹id'),
+ field=models.ForeignKey(default='root', on_delete=django.db.models.deletion.DO_NOTHING,
+ to='application.applicationfolder', verbose_name='文件夹id'),
),
migrations.CreateModel(
name='ApplicationKnowledgeMapping',
fields=[
('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')),
('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')),
- ('id', models.UUIDField(default=uuid_utils.compat.uuid7, editable=False, primary_key=True, serialize=False, verbose_name='主键id')),
- ('application', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='application.application')),
+ ('id',
+ models.UUIDField(default=uuid_utils.compat.uuid7, editable=False, primary_key=True, serialize=False,
+ verbose_name='主键id')),
+ ('application',
+ models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='application.application')),
+ ('knowledge',
+ models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='knowledge.knowledge')),
],
options={
'db_table': 'application_knowledge_mapping',
},
),
+ migrations.RunPython(insert_default_data)
]
diff --git a/apps/application/migrations/0002_initial.py b/apps/application/migrations/0002_initial.py
deleted file mode 100644
index fbec8994e..000000000
--- a/apps/application/migrations/0002_initial.py
+++ /dev/null
@@ -1,22 +0,0 @@
-# Generated by Django 5.2.1 on 2025-05-27 06:42
-
-import django.db.models.deletion
-from django.db import migrations, models
-
-
-class Migration(migrations.Migration):
-
- initial = True
-
- dependencies = [
- ('application', '0001_initial'),
- ('knowledge', '0001_initial'),
- ]
-
- operations = [
- migrations.AddField(
- model_name='applicationknowledgemapping',
- name='knowledge',
- field=models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='knowledge.knowledge'),
- ),
- ]
diff --git a/apps/application/models/application.py b/apps/application/models/application.py
index 50ee76087..243c14652 100644
--- a/apps/application/models/application.py
+++ b/apps/application/models/application.py
@@ -67,7 +67,7 @@ class Application(AppModelMixin):
dialogue_number = models.IntegerField(default=0, verbose_name="会话数量")
user = models.ForeignKey(User, on_delete=models.DO_NOTHING)
model = models.ForeignKey(Model, on_delete=models.SET_NULL, db_constraint=False, blank=True, null=True)
- dataset_setting = models.JSONField(verbose_name="数据集参数设置", default=get_dataset_setting_dict)
+ knowledge_setting = models.JSONField(verbose_name="数据集参数设置", default=get_dataset_setting_dict)
model_setting = models.JSONField(verbose_name="模型参数相关设置", default=get_model_setting_dict)
model_params_setting = models.JSONField(verbose_name="模型参数相关设置", default=dict)
tts_model_params_setting = models.JSONField(verbose_name="模型参数相关设置", default=dict)
diff --git a/apps/application/serializers/application.py b/apps/application/serializers/application.py
index 879962290..abb5a24f9 100644
--- a/apps/application/serializers/application.py
+++ b/apps/application/serializers/application.py
@@ -56,7 +56,7 @@ class ModelKnowledgeAssociation(serializers.Serializer):
user_id = serializers.UUIDField(required=True, label=_("User ID"))
model_id = serializers.CharField(required=False, allow_null=True, allow_blank=True,
label=_("Model id"))
- Knowledge_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True,
+ knowledge_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True,
label=_(
"Knowledge base id")),
label=_("Knowledge Base List"))
@@ -68,7 +68,7 @@ class ModelKnowledgeAssociation(serializers.Serializer):
if model_id is not None and len(model_id) > 0:
if not QuerySet(Model).filter(id=model_id).exists():
raise AppApiException(500, f'{_("Model does not exist")}【{model_id}】')
- knowledge_id_list = list(set(self.data.get('knowledge_id_list')))
+ knowledge_id_list = list(set(self.data.get('knowledge_id_list', [])))
exist_knowledge_id_list = [str(knowledge.id) for knowledge in
QuerySet(Knowledge).filter(id__in=knowledge_id_list, user_id=user_id)]
for knowledge_id in knowledge_id_list:
@@ -110,6 +110,7 @@ class ApplicationCreateSerializer(serializers.Serializer):
work_flow = serializers.DictField(required=True, label=_("Workflow Objects"))
prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True, max_length=102400,
label=_("Opening remarks"))
+ folder_id = serializers.CharField(required=True, label=_('folder id'))
@staticmethod
def to_application_model(user_id: str, workspace_id: str, application: Dict):
@@ -123,6 +124,7 @@ class ApplicationCreateSerializer(serializers.Serializer):
name=application.get('name'),
desc=application.get('desc'),
workspace_id=workspace_id,
+ folder_id=application.get('folder_id', 'root'),
prologue="",
dialogue_number=0,
user_id=user_id, model_id=None,
@@ -135,7 +137,7 @@ class ApplicationCreateSerializer(serializers.Serializer):
tts_model_id=application.get('tts_model', None),
tts_model_enable=application.get('tts_model_enable', False),
tts_model_params_setting=application.get('tts_model_params_setting', {}),
- tts_type=application.get('tts_type', None),
+ tts_type=application.get('tts_type', 'BROWSER'),
file_upload_enable=application.get('file_upload_enable', False),
file_upload_setting=application.get('file_upload_setting', {}),
work_flow=default_workflow
@@ -147,6 +149,7 @@ class ApplicationCreateSerializer(serializers.Serializer):
desc = serializers.CharField(required=False, allow_null=True, allow_blank=True,
max_length=256, min_length=1,
label=_("application describe"))
+ folder_id = serializers.CharField(required=True, label=_('folder id'))
model_id = serializers.CharField(required=False, allow_null=True, allow_blank=True,
label=_("Model"))
dialogue_number = serializers.IntegerField(required=True,
@@ -179,6 +182,20 @@ class ApplicationCreateSerializer(serializers.Serializer):
model_params_setting = serializers.DictField(required=False,
label=_('Model parameters'))
+ tts_model_enable = serializers.BooleanField(required=False, label=_('Voice playback enabled'))
+
+ tts_model_id = serializers.UUIDField(required=False, allow_null=True, label=_("Voice playback model ID"))
+
+ tts_type = serializers.CharField(required=False, label=_('Voice playback type'))
+
+ tts_autoplay = serializers.BooleanField(required=False, label=_('Voice playback autoplay'))
+
+ stt_model_enable = serializers.BooleanField(required=False, label=_('Voice recognition enabled'))
+
+ stt_model_id = serializers.UUIDField(required=False, allow_null=True, label=_('Speech recognition model ID'))
+
+ stt_autosend = serializers.BooleanField(required=False, label=_('Voice recognition automatic transmission'))
+
def is_valid(self, *, user_id=None, raise_exception=False):
super().is_valid(raise_exception=True)
ModelKnowledgeAssociation(data={'user_id': user_id, 'model_id': self.data.get('model_id'),
@@ -190,7 +207,8 @@ class ApplicationCreateSerializer(serializers.Serializer):
prologue=application.get('prologue'),
dialogue_number=application.get('dialogue_number', 0),
user_id=user_id, model_id=application.get('model_id'),
- dataset_setting=application.get('dataset_setting'),
+ folder_id=application.get('folder_id', 'root'),
+ knowledge_setting=application.get('knowledge_setting'),
model_setting=application.get('model_setting'),
problem_optimization=application.get('problem_optimization'),
type=ApplicationTypeChoices.SIMPLE,
@@ -198,10 +216,11 @@ class ApplicationCreateSerializer(serializers.Serializer):
problem_optimization_prompt=application.get('problem_optimization_prompt', None),
stt_model_enable=application.get('stt_model_enable', False),
stt_model_id=application.get('stt_model', None),
+ stt_autosend=application.get('stt_autosend', False),
tts_model_id=application.get('tts_model', None),
tts_model_enable=application.get('tts_model_enable', False),
tts_model_params_setting=application.get('tts_model_params_setting', {}),
- tts_type=application.get('tts_type', None),
+ tts_type=application.get('tts_type', 'BROWSER'),
file_upload_enable=application.get('file_upload_enable', False),
file_upload_setting=application.get('file_upload_setting', {}),
work_flow={}
@@ -222,8 +241,10 @@ class ApplicationSerializer(serializers.Serializer):
def insert_workflow(self, instance: Dict):
self.is_valid(raise_exception=True)
user_id = self.data.get('user_id')
- ApplicationCreateSerializer.WorkflowRequest(data=instance).is_valid(raise_exception=True)
- application_model = ApplicationCreateSerializer.WorkflowRequest.to_application_model(user_id, instance)
+ workspace_id = self.data.get('workspace_id')
+ wq = ApplicationCreateSerializer.WorkflowRequest(data=instance)
+ wq.is_valid(raise_exception=True)
+ application_model = wq.to_application_model(user_id, workspace_id, instance)
application_model.save()
# 插入认证信息
ApplicationAccessToken(application_id=application_model.id,
diff --git a/apps/application/serializers/application_folder.py b/apps/application/serializers/application_folder.py
new file mode 100644
index 000000000..d79015e74
--- /dev/null
+++ b/apps/application/serializers/application_folder.py
@@ -0,0 +1,21 @@
+from rest_framework import serializers
+
+from application.models import ApplicationFolder
+from knowledge.models import KnowledgeFolder
+
+
+class ApplicationFolderTreeSerializer(serializers.ModelSerializer):
+ children = serializers.SerializerMethodField()
+
+ class Meta:
+ model = ApplicationFolder
+ fields = ['id', 'name', 'desc', 'user_id', 'workspace_id', 'parent_id', 'children']
+
+ def get_children(self, obj):
+ return ApplicationFolderTreeSerializer(obj.get_children(), many=True).data
+
+
+class ApplicationFolderFlatSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = ApplicationFolder
+ fields = ['id', 'name', 'desc', 'user_id', 'workspace_id', 'parent_id']
diff --git a/apps/folders/serializers/folder.py b/apps/folders/serializers/folder.py
index 366b7e8c3..b2c86b83a 100644
--- a/apps/folders/serializers/folder.py
+++ b/apps/folders/serializers/folder.py
@@ -7,6 +7,7 @@ from django.utils.translation import gettext_lazy as _
from rest_framework import serializers
from application.models.application import Application, ApplicationFolder
+from application.serializers.application_folder import ApplicationFolderTreeSerializer
from common.constants.permission_constants import Group
from folders.api.folder import FolderCreateRequest
from knowledge.models import KnowledgeFolder, Knowledge
@@ -42,9 +43,7 @@ def get_folder_tree_serializer(source):
if source == Group.TOOL.name:
return ToolFolderTreeSerializer
elif source == Group.APPLICATION.name:
- # todo app folder
- return None
- # return ApplicationFolderTreeSerializer
+ return ApplicationFolderTreeSerializer
elif source == Group.KNOWLEDGE.name:
return KnowledgeFolderTreeSerializer
else: