mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 10:12:51 +00:00
feat: 工作编排
This commit is contained in:
parent
6f14abcd56
commit
c60fc829ae
|
|
@ -0,0 +1,8 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: __init__.py.py
|
||||
@date:2024/6/7 14:43
|
||||
@desc:
|
||||
"""
|
||||
|
|
@ -41,15 +41,18 @@ class WorkFlowPostHandler:
|
|||
answer,
|
||||
workflow):
|
||||
question = workflow.params['question']
|
||||
details = workflow.get_runtime_details()
|
||||
message_tokens = sum([row.get('message_tokens') for row in details.values() if 'message_tokens' in row])
|
||||
answer_tokens = sum([row.get('answer_tokens') for row in details.values() if 'answer_tokens' in row])
|
||||
chat_record = ChatRecord(id=chat_record_id,
|
||||
chat_id=chat_id,
|
||||
problem_text=question,
|
||||
answer_text=answer,
|
||||
details=workflow.get_details(),
|
||||
message_tokens=workflow.context['message_tokens'],
|
||||
answer_tokens=workflow.context['answer_tokens'],
|
||||
run_time=workflow.context['run_time'],
|
||||
index=len(self.chat_info.chat_record_list) + 1)
|
||||
details=details,
|
||||
message_tokens=message_tokens,
|
||||
answer_tokens=answer_tokens,
|
||||
run_time=time.time() - workflow.context['time'],
|
||||
index=0)
|
||||
self.chat_info.append_chat_record(chat_record, self.client_id)
|
||||
# 重新设置缓存
|
||||
chat_cache.set(chat_id,
|
||||
|
|
@ -76,6 +79,9 @@ class NodeResult:
|
|||
return self._to_response(chat_id, chat_record_id, self.node_variable, self.workflow_variable, node, workflow,
|
||||
post_handler)
|
||||
|
||||
def is_assertion_result(self):
|
||||
return 'branch_id' in self.node_variable
|
||||
|
||||
|
||||
class ReferenceAddressSerializer(serializers.Serializer):
|
||||
node_id = serializers.CharField(required=True, error_messages=ErrMessage.char("节点id"))
|
||||
|
|
@ -103,15 +109,16 @@ class FlowParamsSerializer(serializers.Serializer):
|
|||
|
||||
|
||||
class INode:
|
||||
def __init__(self, _id, node_params, workflow_params, workflow_manage):
|
||||
def __init__(self, node, workflow_params, workflow_manage):
|
||||
# 当前步骤上下文,用于存储当前步骤信息
|
||||
self.node_params = node_params
|
||||
self.node = node
|
||||
self.node_params = node.properties.get('node_data')
|
||||
self.workflow_manage = workflow_manage
|
||||
self.node_params_serializer = None
|
||||
self.flow_params_serializer = None
|
||||
self.context = {}
|
||||
self.id = _id
|
||||
self.valid_args(node_params, workflow_params)
|
||||
self.id = node.id
|
||||
self.valid_args(self.node_params, workflow_params)
|
||||
|
||||
def valid_args(self, node_params, flow_params):
|
||||
flow_params_serializer_class = self.get_flow_params_serializer_class()
|
||||
|
|
@ -160,9 +167,9 @@ class INode:
|
|||
def execute(self, **kwargs) -> NodeResult:
|
||||
pass
|
||||
|
||||
def get_details(self, **kwargs):
|
||||
def get_details(self, index: int, **kwargs):
|
||||
"""
|
||||
运行详情
|
||||
:return: 步骤详情
|
||||
"""
|
||||
return None
|
||||
return {}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,23 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: __init__.py.py
|
||||
@date:2024/6/7 14:43
|
||||
@desc:
|
||||
"""
|
||||
from .ai_chat_step_node import *
|
||||
from .condition_node import *
|
||||
from .question_node import *
|
||||
from .search_dataset_node import *
|
||||
from .start_node import *
|
||||
from .direct_reply_node import *
|
||||
|
||||
node_list = [BaseStartStepNode, BaseChatNode, BaseSearchDatasetNode, BaseQuestionNode, BaseConditionNode, BaseReplyNode]
|
||||
|
||||
|
||||
def get_node(node_type):
|
||||
find_list = [node for node in node_list if node.type == node_type]
|
||||
if len(find_list) > 0:
|
||||
return find_list[0]
|
||||
return None
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: __init__.py
|
||||
@date:2024/6/11 15:29
|
||||
@desc:
|
||||
"""
|
||||
from .impl import *
|
||||
|
|
@ -16,7 +16,8 @@ from common.util.field_message import ErrMessage
|
|||
|
||||
class ChatNodeSerializer(serializers.Serializer):
|
||||
model_id = serializers.CharField(required=True, error_messages=ErrMessage.char("模型id"))
|
||||
system = serializers.CharField(required=True, error_messages=ErrMessage.char("角色设定"))
|
||||
system = serializers.CharField(required=False, allow_blank=True, allow_null=True,
|
||||
error_messages=ErrMessage.char("角色设定"))
|
||||
prompt = serializers.CharField(required=True, error_messages=ErrMessage.char("提示词"))
|
||||
# 多轮对话数量
|
||||
dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer("多轮对话数量"))
|
||||
|
|
|
|||
|
|
@ -0,0 +1,9 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: __init__.py
|
||||
@date:2024/6/11 15:34
|
||||
@desc:
|
||||
"""
|
||||
from .base_chat_node import BaseChatNode
|
||||
|
|
@ -2,7 +2,7 @@
|
|||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: base_chat_node.py
|
||||
@file: base_question_node.py
|
||||
@date:2024/6/4 14:30
|
||||
@desc:
|
||||
"""
|
||||
|
|
@ -39,6 +39,8 @@ def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INo
|
|||
node.context['message_tokens'] = message_tokens
|
||||
node.context['answer_tokens'] = answer_tokens
|
||||
node.context['answer'] = answer
|
||||
node.context['history_message'] = node_variable['history_message']
|
||||
node.context['question'] = node_variable['question']
|
||||
|
||||
|
||||
def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
|
||||
|
|
@ -57,6 +59,8 @@ def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wor
|
|||
node.context['message_tokens'] = message_tokens
|
||||
node.context['answer_tokens'] = answer_tokens
|
||||
node.context['answer'] = answer
|
||||
node.context['history_message'] = node_variable['history_message']
|
||||
node.context['question'] = node_variable['question']
|
||||
|
||||
|
||||
def get_to_response_write_context(node_variable: Dict, node: INode):
|
||||
|
|
@ -67,6 +71,8 @@ def get_to_response_write_context(node_variable: Dict, node: INode):
|
|||
node.context['message_tokens'] = message_tokens
|
||||
node.context['answer_tokens'] = answer_tokens
|
||||
node.context['answer'] = answer
|
||||
node.context['history_message'] = node_variable['history_message']
|
||||
node.context['question'] = node_variable['question']
|
||||
|
||||
return _write_context
|
||||
|
||||
|
|
@ -115,25 +121,38 @@ class BaseChatNode(IChatNode):
|
|||
json.loads(
|
||||
rsa_long_decrypt(model.credential)),
|
||||
streaming=True)
|
||||
message_list = self.generate_message_list(system, prompt, history_chat_record, dialogue_number)
|
||||
history_message = self.get_history_message(history_chat_record, dialogue_number)
|
||||
question = self.generate_prompt_question(prompt)
|
||||
message_list = self.generate_message_list(system, prompt, history_message)
|
||||
if stream:
|
||||
r = chat_model.stream(message_list)
|
||||
return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list,
|
||||
'get_to_response_write_context': get_to_response_write_context}, {},
|
||||
'history_message': history_message, 'question': question}, {},
|
||||
_write_context=write_context_stream,
|
||||
_to_response=to_stream_response)
|
||||
else:
|
||||
r = chat_model.invoke(message_list)
|
||||
return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list}, {},
|
||||
return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list,
|
||||
'history_message': history_message, 'question': question}, {},
|
||||
_write_context=write_context, _to_response=to_response)
|
||||
|
||||
def generate_message_list(self, system: str, prompt: str, history_chat_record, dialogue_number):
|
||||
@staticmethod
|
||||
def get_history_message(history_chat_record, dialogue_number):
|
||||
start_index = len(history_chat_record) - dialogue_number
|
||||
history_message = [[history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()]
|
||||
for index in
|
||||
range(start_index if start_index > 0 else 0, len(history_chat_record))]
|
||||
return [SystemMessage(self.workflow_manage.generate_prompt(system)), *history_message,
|
||||
HumanMessage(self.workflow_manage.generate_prompt(prompt))]
|
||||
return history_message
|
||||
|
||||
def generate_prompt_question(self, prompt):
|
||||
return HumanMessage(self.workflow_manage.generate_prompt(prompt))
|
||||
|
||||
def generate_message_list(self, system: str, prompt: str, history_message):
|
||||
if system is None or len(system) == 0:
|
||||
return [SystemMessage(self.workflow_manage.generate_prompt(system)), *history_message,
|
||||
HumanMessage(self.workflow_manage.generate_prompt(prompt))]
|
||||
else:
|
||||
return [*history_message, HumanMessage(self.workflow_manage.generate_prompt(prompt))]
|
||||
|
||||
@staticmethod
|
||||
def reset_message_list(message_list: List[BaseMessage], answer_text):
|
||||
|
|
@ -143,3 +162,17 @@ class BaseChatNode(IChatNode):
|
|||
message_list]
|
||||
result.append({'role': 'ai', 'content': answer_text})
|
||||
return result
|
||||
|
||||
def get_details(self, index: int, **kwargs):
|
||||
return {
|
||||
"index": index,
|
||||
'run_time': self.context.get('run_time'),
|
||||
'system': self.node_params.get('system'),
|
||||
'history_message': [{'content': message.content, 'role': message.type} for message in
|
||||
self.context.get('history_message')],
|
||||
'question': self.context.get('question'),
|
||||
'answer': self.context.get('answer'),
|
||||
'type': self.node.type,
|
||||
'message_tokens': self.context['message_tokens'],
|
||||
'answer_tokens': self.context['answer_tokens']
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,9 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: __init__.py.py
|
||||
@date:2024/6/7 14:43
|
||||
@desc:
|
||||
"""
|
||||
from .impl import *
|
||||
|
|
@ -0,0 +1,23 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: __init__.py.py
|
||||
@date:2024/6/7 14:43
|
||||
@desc:
|
||||
"""
|
||||
|
||||
from .contain_compare import *
|
||||
from .equal_compare import *
|
||||
from .gt_compare import *
|
||||
from .ge_compare import *
|
||||
from .le_compare import *
|
||||
from .lt_compare import *
|
||||
from .len_ge_compare import *
|
||||
from .len_gt_compare import *
|
||||
from .len_le_compare import *
|
||||
from .len_lt_compare import *
|
||||
from .len_equal_compare import *
|
||||
|
||||
compare_handle_list = [GECompare(), GTCompare(), ContainCompare(), EqualCompare(), LTCompare(), LECompare(),
|
||||
LenLECompare(), LenGECompare(), LenEqualCompare(), LenGTCompare(), LenLTCompare()]
|
||||
|
|
@ -0,0 +1,20 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: compare.py
|
||||
@date:2024/6/7 14:37
|
||||
@desc:
|
||||
"""
|
||||
from abc import abstractmethod
|
||||
from typing import List
|
||||
|
||||
|
||||
class Compare:
|
||||
@abstractmethod
|
||||
def support(self, node_id, fields: List[str], source_value, compare, target_value):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def compare(self, source_value, compare, target_value):
|
||||
pass
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: contain_compare.py
|
||||
@date:2024/6/11 10:02
|
||||
@desc:
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
from application.flow.step_node.condition_node.compare.compare import Compare
|
||||
|
||||
|
||||
class ContainCompare(Compare):
|
||||
|
||||
def support(self, node_id, fields: List[str], source_value, compare, target_value):
|
||||
if compare == 'contain':
|
||||
return True
|
||||
|
||||
def compare(self, source_value, compare, target_value):
|
||||
return any([str(item) == str(target_value) for item in source_value])
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: equal_compare.py
|
||||
@date:2024/6/7 14:44
|
||||
@desc:
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
from application.flow.step_node.condition_node.compare.compare import Compare
|
||||
|
||||
|
||||
class EqualCompare(Compare):
|
||||
|
||||
def support(self, node_id, fields: List[str], source_value, compare, target_value):
|
||||
if compare == 'eq':
|
||||
return True
|
||||
|
||||
def compare(self, source_value, compare, target_value):
|
||||
return str(source_value) == str(target_value)
|
||||
|
|
@ -0,0 +1,24 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: lt_compare.py
|
||||
@date:2024/6/11 9:52
|
||||
@desc: 大于比较器
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
from application.flow.step_node.condition_node.compare.compare import Compare
|
||||
|
||||
|
||||
class GECompare(Compare):
|
||||
|
||||
def support(self, node_id, fields: List[str], source_value, compare, target_value):
|
||||
if compare == 'ge':
|
||||
return True
|
||||
|
||||
def compare(self, source_value, compare, target_value):
|
||||
try:
|
||||
return float(source_value) >= float(target_value)
|
||||
except Exception as e:
|
||||
return False
|
||||
|
|
@ -0,0 +1,24 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: lt_compare.py
|
||||
@date:2024/6/11 9:52
|
||||
@desc: 大于比较器
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
from application.flow.step_node.condition_node.compare.compare import Compare
|
||||
|
||||
|
||||
class GTCompare(Compare):
|
||||
|
||||
def support(self, node_id, fields: List[str], source_value, compare, target_value):
|
||||
if compare == 'gt':
|
||||
return True
|
||||
|
||||
def compare(self, source_value, compare, target_value):
|
||||
try:
|
||||
return float(source_value) > float(target_value)
|
||||
except Exception as e:
|
||||
return False
|
||||
|
|
@ -0,0 +1,24 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: lt_compare.py
|
||||
@date:2024/6/11 9:52
|
||||
@desc: 小于比较器
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
from application.flow.step_node.condition_node.compare.compare import Compare
|
||||
|
||||
|
||||
class LECompare(Compare):
|
||||
|
||||
def support(self, node_id, fields: List[str], source_value, compare, target_value):
|
||||
if compare == 'le':
|
||||
return True
|
||||
|
||||
def compare(self, source_value, compare, target_value):
|
||||
try:
|
||||
return float(source_value) <= float(target_value)
|
||||
except Exception as e:
|
||||
return False
|
||||
|
|
@ -0,0 +1,24 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: equal_compare.py
|
||||
@date:2024/6/7 14:44
|
||||
@desc:
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
from application.flow.step_node.condition_node.compare.compare import Compare
|
||||
|
||||
|
||||
class LenEqualCompare(Compare):
|
||||
|
||||
def support(self, node_id, fields: List[str], source_value, compare, target_value):
|
||||
if compare == 'len_eq':
|
||||
return True
|
||||
|
||||
def compare(self, source_value, compare, target_value):
|
||||
try:
|
||||
return len(source_value) == int(target_value)
|
||||
except Exception as e:
|
||||
return False
|
||||
|
|
@ -0,0 +1,24 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: lt_compare.py
|
||||
@date:2024/6/11 9:52
|
||||
@desc: 大于比较器
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
from application.flow.step_node.condition_node.compare.compare import Compare
|
||||
|
||||
|
||||
class LenGECompare(Compare):
|
||||
|
||||
def support(self, node_id, fields: List[str], source_value, compare, target_value):
|
||||
if compare == 'len_ge':
|
||||
return True
|
||||
|
||||
def compare(self, source_value, compare, target_value):
|
||||
try:
|
||||
return len(source_value) >= int(target_value)
|
||||
except Exception as e:
|
||||
return False
|
||||
|
|
@ -0,0 +1,24 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: lt_compare.py
|
||||
@date:2024/6/11 9:52
|
||||
@desc: 大于比较器
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
from application.flow.step_node.condition_node.compare.compare import Compare
|
||||
|
||||
|
||||
class LenGTCompare(Compare):
|
||||
|
||||
def support(self, node_id, fields: List[str], source_value, compare, target_value):
|
||||
if compare == 'len_gt':
|
||||
return True
|
||||
|
||||
def compare(self, source_value, compare, target_value):
|
||||
try:
|
||||
return len(source_value) > int(target_value)
|
||||
except Exception as e:
|
||||
return False
|
||||
|
|
@ -0,0 +1,24 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: lt_compare.py
|
||||
@date:2024/6/11 9:52
|
||||
@desc: 小于比较器
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
from application.flow.step_node.condition_node.compare.compare import Compare
|
||||
|
||||
|
||||
class LenLECompare(Compare):
|
||||
|
||||
def support(self, node_id, fields: List[str], source_value, compare, target_value):
|
||||
if compare == 'len_le':
|
||||
return True
|
||||
|
||||
def compare(self, source_value, compare, target_value):
|
||||
try:
|
||||
return len(source_value) <= int(target_value)
|
||||
except Exception as e:
|
||||
return False
|
||||
|
|
@ -0,0 +1,24 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: lt_compare.py
|
||||
@date:2024/6/11 9:52
|
||||
@desc: 小于比较器
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
from application.flow.step_node.condition_node.compare.compare import Compare
|
||||
|
||||
|
||||
class LenLTCompare(Compare):
|
||||
|
||||
def support(self, node_id, fields: List[str], source_value, compare, target_value):
|
||||
if compare == 'len_lt':
|
||||
return True
|
||||
|
||||
def compare(self, source_value, compare, target_value):
|
||||
try:
|
||||
return len(source_value) < int(target_value)
|
||||
except Exception as e:
|
||||
return False
|
||||
|
|
@ -0,0 +1,24 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: lt_compare.py
|
||||
@date:2024/6/11 9:52
|
||||
@desc: 小于比较器
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
from application.flow.step_node.condition_node.compare.compare import Compare
|
||||
|
||||
|
||||
class LTCompare(Compare):
|
||||
|
||||
def support(self, node_id, fields: List[str], source_value, compare, target_value):
|
||||
if compare == 'lt':
|
||||
return True
|
||||
|
||||
def compare(self, source_value, compare, target_value):
|
||||
try:
|
||||
return float(source_value) < float(target_value)
|
||||
except Exception as e:
|
||||
return False
|
||||
|
|
@ -0,0 +1,83 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: i_condition_node.py
|
||||
@date:2024/6/7 9:54
|
||||
@desc:
|
||||
"""
|
||||
import json
|
||||
from typing import Type
|
||||
|
||||
from rest_framework import serializers
|
||||
|
||||
from application.flow.i_step_node import INode
|
||||
from common.util.field_message import ErrMessage
|
||||
|
||||
|
||||
class ConditionSerializer(serializers.Serializer):
|
||||
compare = serializers.CharField(required=True, error_messages=ErrMessage.char("比较器"))
|
||||
value = serializers.CharField(required=True, error_messages=ErrMessage.char(""))
|
||||
field = serializers.ListField(required=True, error_messages=ErrMessage.char("字段"))
|
||||
|
||||
|
||||
class ConditionBranchSerializer(serializers.Serializer):
|
||||
id = serializers.CharField(required=True, error_messages=ErrMessage.char("分支id"))
|
||||
condition = serializers.CharField(required=True, error_messages=ErrMessage.char("条件or|and"))
|
||||
conditions = ConditionSerializer(many=True)
|
||||
|
||||
|
||||
class ConditionNodeParamsSerializer(serializers.Serializer):
|
||||
branch = ConditionBranchSerializer(many=True)
|
||||
|
||||
|
||||
j = """
|
||||
{ "branch": [
|
||||
{
|
||||
"conditions": [
|
||||
{
|
||||
"field": [
|
||||
"34902d3d-a3ff-497f-b8e1-0c34a44d7dd5",
|
||||
"paragraph_list"
|
||||
],
|
||||
"compare": "len_eq",
|
||||
"value": "0"
|
||||
}
|
||||
],
|
||||
"id": "2391",
|
||||
"condition": "and"
|
||||
},
|
||||
{
|
||||
"conditions": [
|
||||
{
|
||||
"field": [
|
||||
"34902d3d-a3ff-497f-b8e1-0c34a44d7dd5",
|
||||
"paragraph_list"
|
||||
],
|
||||
"compare": "len_eq",
|
||||
"value": "1"
|
||||
}
|
||||
],
|
||||
"id": "1143",
|
||||
"condition": "and"
|
||||
},
|
||||
{
|
||||
"conditions": [
|
||||
|
||||
],
|
||||
"id": "9208",
|
||||
"condition": "and"
|
||||
}
|
||||
]}
|
||||
"""
|
||||
a = json.loads(j)
|
||||
c = ConditionNodeParamsSerializer(data=a)
|
||||
c.is_valid(raise_exception=True)
|
||||
print(c.data)
|
||||
|
||||
|
||||
class IConditionNode(INode):
|
||||
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
|
||||
return ConditionNodeParamsSerializer
|
||||
|
||||
type = 'condition-node'
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: __init__.py
|
||||
@date:2024/6/11 15:35
|
||||
@desc:
|
||||
"""
|
||||
from .base_condition_node import BaseConditionNode
|
||||
|
|
@ -0,0 +1,47 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: base_condition_node.py
|
||||
@date:2024/6/7 11:29
|
||||
@desc:
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
from application.flow.i_step_node import NodeResult
|
||||
from application.flow.step_node.condition_node.compare import compare_handle_list
|
||||
from application.flow.step_node.condition_node.i_condition_node import IConditionNode
|
||||
|
||||
|
||||
class BaseConditionNode(IConditionNode):
|
||||
def execute(self, **kwargs) -> NodeResult:
|
||||
branch_list = self.node_params_serializer.data['branch']
|
||||
branch = self._execute(branch_list)
|
||||
r = NodeResult({'branch_id': branch.get('id')}, {})
|
||||
return r
|
||||
|
||||
def _execute(self, branch_list: List):
|
||||
for branch in branch_list:
|
||||
if self.branch_assertion(branch):
|
||||
return branch
|
||||
|
||||
def branch_assertion(self, branch):
|
||||
condition_list = [self.assertion(row.get('field'), row.get('compare'), row.get('value')) for row in
|
||||
branch.get('conditions')]
|
||||
condition = branch.get('condition')
|
||||
return all(condition_list) if condition == 'and' else any(condition_list)
|
||||
|
||||
def assertion(self, field_list: List[str], compare: str, value):
|
||||
field_value = self.workflow_manage.get_reference_field(field_list[0], field_list[1:])
|
||||
for compare_handler in compare_handle_list:
|
||||
if compare_handler.support(field_list[0], field_list[1:], field_value, compare, value):
|
||||
return compare_handler.compare(field_value, compare, value)
|
||||
|
||||
def get_details(self, index: int, **kwargs):
|
||||
return {
|
||||
"index": index,
|
||||
'run_time': self.context.get('run_time'),
|
||||
'branch_id': self.context.get('branch_id'),
|
||||
'branch_name': self.context.get('branch_name'),
|
||||
'type': self.node.type
|
||||
}
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: __init__.py
|
||||
@date:2024/6/11 17:50
|
||||
@desc:
|
||||
"""
|
||||
from .impl import *
|
||||
|
|
@ -0,0 +1,45 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: i_reply_node.py
|
||||
@date:2024/6/11 16:25
|
||||
@desc:
|
||||
"""
|
||||
from typing import Type
|
||||
|
||||
from rest_framework import serializers
|
||||
|
||||
from application.flow.i_step_node import INode, NodeResult
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.util.field_message import ErrMessage
|
||||
|
||||
|
||||
class ReplyNodeParamsSerializer(serializers.Serializer):
|
||||
reply_type = serializers.CharField(required=True, error_messages=ErrMessage.char("回复类型"))
|
||||
fields = serializers.ListField(required=False, error_messages=ErrMessage.list("引用字段"))
|
||||
content = serializers.CharField(required=False, error_messages=ErrMessage.char("直接回答内容"))
|
||||
|
||||
def is_valid(self, *, raise_exception=False):
|
||||
super().is_valid(raise_exception=True)
|
||||
if self.data.get('reply_type') == 'referencing':
|
||||
if 'fields' not in self.data:
|
||||
raise AppApiException(500, "引用字段不能为空")
|
||||
if len(self.data.get('fields')) < 2:
|
||||
raise AppApiException(500, "引用字段错误")
|
||||
else:
|
||||
if 'content' not in self.data or self.data.get('content') is None:
|
||||
raise AppApiException(500, "内容不能为空")
|
||||
|
||||
|
||||
class IReplyNode(INode):
|
||||
type = 'reply-node'
|
||||
|
||||
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
|
||||
return ReplyNodeParamsSerializer
|
||||
|
||||
def _run(self):
|
||||
return self.execute(**self.flow_params_serializer.data)
|
||||
|
||||
def execute(self, reply_type, stream, fields=None, content=None, **kwargs) -> NodeResult:
|
||||
pass
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: __init__.py
|
||||
@date:2024/6/11 17:49
|
||||
@desc:
|
||||
"""
|
||||
from .base_reply_node import *
|
||||
|
|
@ -0,0 +1,87 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: base_reply_node.py
|
||||
@date:2024/6/11 17:25
|
||||
@desc:
|
||||
"""
|
||||
from typing import List, Dict
|
||||
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk
|
||||
|
||||
from application.flow import tools
|
||||
from application.flow.i_step_node import NodeResult, INode
|
||||
from application.flow.step_node.direct_reply_node.i_reply_node import IReplyNode
|
||||
|
||||
|
||||
def get_to_response_write_context(node_variable: Dict, node: INode):
|
||||
def _write_context(answer):
|
||||
node.context['answer'] = answer
|
||||
|
||||
return _write_context
|
||||
|
||||
|
||||
def to_stream_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow,
|
||||
post_handler):
|
||||
"""
|
||||
将流式数据 转换为 流式响应
|
||||
@param chat_id: 会话id
|
||||
@param chat_record_id: 对话记录id
|
||||
@param node_variable: 节点数据
|
||||
@param workflow_variable: 工作流数据
|
||||
@param node: 节点
|
||||
@param workflow: 工作流管理器
|
||||
@param post_handler: 后置处理器 输出结果后执行
|
||||
@return: 流式响应
|
||||
"""
|
||||
response = node_variable.get('result')
|
||||
_write_context = get_to_response_write_context(node_variable, node)
|
||||
return tools.to_stream_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler)
|
||||
|
||||
|
||||
def to_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow,
|
||||
post_handler):
|
||||
"""
|
||||
将结果转换
|
||||
@param chat_id: 会话id
|
||||
@param chat_record_id: 对话记录id
|
||||
@param node_variable: 节点数据
|
||||
@param workflow_variable: 工作流数据
|
||||
@param node: 节点
|
||||
@param workflow: 工作流管理器
|
||||
@param post_handler: 后置处理器
|
||||
@return: 响应
|
||||
"""
|
||||
response = node_variable.get('result')
|
||||
_write_context = get_to_response_write_context(node_variable, node)
|
||||
return tools.to_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler)
|
||||
|
||||
|
||||
class BaseReplyNode(IReplyNode):
|
||||
def execute(self, reply_type, stream, fields=None, content=None, **kwargs) -> NodeResult:
|
||||
if reply_type == 'referencing':
|
||||
result = self.get_reference_content(fields)
|
||||
else:
|
||||
result = self.generate_reply_content(content)
|
||||
if stream:
|
||||
return NodeResult({'result': iter(AIMessageChunk(content=result))}, {},
|
||||
_to_response=to_stream_response)
|
||||
else:
|
||||
return NodeResult({'result': AIMessage(content=result)}, {}, _to_response=to_response)
|
||||
|
||||
def generate_reply_content(self, prompt):
|
||||
return self.workflow_manage.generate_prompt(prompt)
|
||||
|
||||
def get_reference_content(self, fields: List[str]):
|
||||
return self.workflow_manage.get_reference_field(
|
||||
fields[0],
|
||||
fields[1:])
|
||||
|
||||
def get_details(self, index: int, **kwargs):
|
||||
return {
|
||||
"index": index,
|
||||
'run_time': self.context.get('run_time'),
|
||||
'type': self.node.type,
|
||||
'answer': self.context.get('answer'),
|
||||
}
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: __init__.py
|
||||
@date:2024/6/11 15:30
|
||||
@desc:
|
||||
"""
|
||||
from .impl import *
|
||||
|
|
@ -0,0 +1,37 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: i_chat_node.py
|
||||
@date:2024/6/4 13:58
|
||||
@desc:
|
||||
"""
|
||||
from typing import Type
|
||||
|
||||
from rest_framework import serializers
|
||||
|
||||
from application.flow.i_step_node import INode, NodeResult
|
||||
from common.util.field_message import ErrMessage
|
||||
|
||||
|
||||
class QuestionNodeSerializer(serializers.Serializer):
|
||||
model_id = serializers.CharField(required=True, error_messages=ErrMessage.char("模型id"))
|
||||
system = serializers.CharField(required=False, allow_blank=True, allow_null=True,
|
||||
error_messages=ErrMessage.char("角色设定"))
|
||||
prompt = serializers.CharField(required=True, error_messages=ErrMessage.char("提示词"))
|
||||
# 多轮对话数量
|
||||
dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer("多轮对话数量"))
|
||||
|
||||
|
||||
class IQuestionNode(INode):
|
||||
type = 'question-node'
|
||||
|
||||
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
|
||||
return QuestionNodeSerializer
|
||||
|
||||
def _run(self):
|
||||
return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data)
|
||||
|
||||
def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id,
|
||||
**kwargs) -> NodeResult:
|
||||
pass
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: __init__.py
|
||||
@date:2024/6/11 15:35
|
||||
@desc:
|
||||
"""
|
||||
from .base_question_node import BaseQuestionNode
|
||||
|
|
@ -0,0 +1,177 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: base_question_node.py
|
||||
@date:2024/6/4 14:30
|
||||
@desc:
|
||||
"""
|
||||
import json
|
||||
from typing import List, Dict
|
||||
|
||||
from django.db.models import QuerySet
|
||||
from langchain.schema import HumanMessage, SystemMessage
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
from application.flow import tools
|
||||
from application.flow.i_step_node import NodeResult, INode
|
||||
from application.flow.step_node.question_node.i_question_node import IQuestionNode
|
||||
from common.util.rsa_util import rsa_long_decrypt
|
||||
from setting.models import Model
|
||||
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
|
||||
|
||||
|
||||
def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
|
||||
"""
|
||||
写入上下文数据 (流式)
|
||||
@param node_variable: 节点数据
|
||||
@param workflow_variable: 全局数据
|
||||
@param node: 节点
|
||||
@param workflow: 工作流管理器
|
||||
"""
|
||||
response = node_variable.get('result')
|
||||
answer = ''
|
||||
for chunk in response:
|
||||
answer += chunk.content
|
||||
chat_model = node_variable.get('chat_model')
|
||||
message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list'))
|
||||
answer_tokens = chat_model.get_num_tokens(answer)
|
||||
node.context['message_tokens'] = message_tokens
|
||||
node.context['answer_tokens'] = answer_tokens
|
||||
node.context['answer'] = answer
|
||||
node.context['history_message'] = node_variable['history_message']
|
||||
node.context['question'] = node_variable['question']
|
||||
|
||||
|
||||
def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
|
||||
"""
|
||||
写入上下文数据
|
||||
@param node_variable: 节点数据
|
||||
@param workflow_variable: 全局数据
|
||||
@param node: 节点实例对象
|
||||
@param workflow: 工作流管理器
|
||||
"""
|
||||
response = node_variable.get('result')
|
||||
chat_model = node_variable.get('chat_model')
|
||||
answer = response.content
|
||||
message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list'))
|
||||
answer_tokens = chat_model.get_num_tokens(answer)
|
||||
node.context['message_tokens'] = message_tokens
|
||||
node.context['answer_tokens'] = answer_tokens
|
||||
node.context['answer'] = answer
|
||||
node.context['history_message'] = node_variable['history_message']
|
||||
node.context['question'] = node_variable['question']
|
||||
|
||||
|
||||
def get_to_response_write_context(node_variable: Dict, node: INode):
|
||||
def _write_context(answer):
|
||||
chat_model = node_variable.get('chat_model')
|
||||
message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list'))
|
||||
answer_tokens = chat_model.get_num_tokens(answer)
|
||||
node.context['message_tokens'] = message_tokens
|
||||
node.context['answer_tokens'] = answer_tokens
|
||||
node.context['answer'] = answer
|
||||
|
||||
return _write_context
|
||||
|
||||
|
||||
def to_stream_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow,
|
||||
post_handler):
|
||||
"""
|
||||
将流式数据 转换为 流式响应
|
||||
@param chat_id: 会话id
|
||||
@param chat_record_id: 对话记录id
|
||||
@param node_variable: 节点数据
|
||||
@param workflow_variable: 工作流数据
|
||||
@param node: 节点
|
||||
@param workflow: 工作流管理器
|
||||
@param post_handler: 后置处理器 输出结果后执行
|
||||
@return: 流式响应
|
||||
"""
|
||||
response = node_variable.get('result')
|
||||
_write_context = get_to_response_write_context(node_variable, node)
|
||||
return tools.to_stream_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler)
|
||||
|
||||
|
||||
def to_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow,
|
||||
post_handler):
|
||||
"""
|
||||
将结果转换
|
||||
@param chat_id: 会话id
|
||||
@param chat_record_id: 对话记录id
|
||||
@param node_variable: 节点数据
|
||||
@param workflow_variable: 工作流数据
|
||||
@param node: 节点
|
||||
@param workflow: 工作流管理器
|
||||
@param post_handler: 后置处理器
|
||||
@return: 响应
|
||||
"""
|
||||
response = node_variable.get('result')
|
||||
_write_context = get_to_response_write_context(node_variable, node)
|
||||
return tools.to_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler)
|
||||
|
||||
|
||||
class BaseQuestionNode(IQuestionNode):
|
||||
def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id,
|
||||
**kwargs) -> NodeResult:
|
||||
model = QuerySet(Model).filter(id=model_id).first()
|
||||
chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name,
|
||||
json.loads(
|
||||
rsa_long_decrypt(model.credential)),
|
||||
streaming=True)
|
||||
history_message = self.get_history_message(history_chat_record, dialogue_number)
|
||||
question = self.generate_prompt_question(prompt)
|
||||
message_list = self.generate_message_list(system, prompt, history_message)
|
||||
if stream:
|
||||
r = chat_model.stream(message_list)
|
||||
return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list,
|
||||
'get_to_response_write_context': get_to_response_write_context,
|
||||
'history_message': history_message, 'question': question}, {},
|
||||
_write_context=write_context_stream,
|
||||
_to_response=to_stream_response)
|
||||
else:
|
||||
r = chat_model.invoke(message_list)
|
||||
return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list,
|
||||
'history_message': history_message, 'question': question}, {},
|
||||
_write_context=write_context, _to_response=to_response)
|
||||
|
||||
@staticmethod
|
||||
def get_history_message(history_chat_record, dialogue_number):
|
||||
start_index = len(history_chat_record) - dialogue_number
|
||||
history_message = [[history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()]
|
||||
for index in
|
||||
range(start_index if start_index > 0 else 0, len(history_chat_record))]
|
||||
return history_message
|
||||
|
||||
def generate_prompt_question(self, prompt):
|
||||
return HumanMessage(self.workflow_manage.generate_prompt(prompt))
|
||||
|
||||
def generate_message_list(self, system: str, prompt: str, history_message):
|
||||
if system is None or len(system) == 0:
|
||||
return [SystemMessage(self.workflow_manage.generate_prompt(system)), *history_message,
|
||||
HumanMessage(self.workflow_manage.generate_prompt(prompt))]
|
||||
else:
|
||||
return [*history_message, HumanMessage(self.workflow_manage.generate_prompt(prompt))]
|
||||
|
||||
@staticmethod
|
||||
def reset_message_list(message_list: List[BaseMessage], answer_text):
|
||||
result = [{'role': 'user' if isinstance(message, HumanMessage) else 'ai', 'content': message.content} for
|
||||
message
|
||||
in
|
||||
message_list]
|
||||
result.append({'role': 'ai', 'content': answer_text})
|
||||
return result
|
||||
|
||||
def get_details(self, index: int, **kwargs):
|
||||
return {
|
||||
"index": index,
|
||||
'run_time': self.context.get('run_time'),
|
||||
'system': self.node_params.get('system'),
|
||||
'history_message': [{'content': message.content, 'role': message.type} for message in
|
||||
self.context.get('history_message')],
|
||||
'question': self.context.get('question'),
|
||||
'answer': self.context.get('answer'),
|
||||
'type': self.node.type,
|
||||
'message_tokens': self.context['message_tokens'],
|
||||
'answer_tokens': self.context['answer_tokens']
|
||||
}
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: __init__.py
|
||||
@date:2024/6/11 15:30
|
||||
@desc:
|
||||
"""
|
||||
from .impl import *
|
||||
|
|
@ -16,10 +16,7 @@ from application.flow.i_step_node import INode, ReferenceAddressSerializer, Node
|
|||
from common.util.field_message import ErrMessage
|
||||
|
||||
|
||||
class SearchDatasetStepNodeSerializer(serializers.Serializer):
|
||||
# 需要查询的数据集id列表
|
||||
dataset_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True),
|
||||
error_messages=ErrMessage.list("数据集id列表"))
|
||||
class DatasetSettingSerializer(serializers.Serializer):
|
||||
# 需要查询的条数
|
||||
top_n = serializers.IntegerField(required=True,
|
||||
error_messages=ErrMessage.integer("引用分段数"))
|
||||
|
|
@ -30,9 +27,17 @@ class SearchDatasetStepNodeSerializer(serializers.Serializer):
|
|||
validators.RegexValidator(regex=re.compile("^embedding|keywords|blend$"),
|
||||
message="类型只支持register|reset_password", code=500)
|
||||
], error_messages=ErrMessage.char("检索模式"))
|
||||
max_paragraph_char_number = serializers.IntegerField(required=True,
|
||||
error_messages=ErrMessage.float("最大引用分段字数"))
|
||||
|
||||
question_reference_address = ReferenceAddressSerializer(required=False,
|
||||
error_messages=ErrMessage.char("问题应用地址"))
|
||||
|
||||
class SearchDatasetStepNodeSerializer(serializers.Serializer):
|
||||
# 需要查询的数据集id列表
|
||||
dataset_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True),
|
||||
error_messages=ErrMessage.list("数据集id列表"))
|
||||
dataset_setting = DatasetSettingSerializer(required=True)
|
||||
|
||||
question_reference_address = serializers.ListField(required=True, )
|
||||
|
||||
def is_valid(self, *, raise_exception=False):
|
||||
super().is_valid(raise_exception=True)
|
||||
|
|
@ -46,11 +51,11 @@ class ISearchDatasetStepNode(INode):
|
|||
|
||||
def _run(self):
|
||||
question = self.workflow_manage.get_reference_field(
|
||||
self.node_params_serializer.data.get('question_reference_address').get('node_id'),
|
||||
self.node_params_serializer.data.get('question_reference_address').get('fields'))
|
||||
self.node_params_serializer.data.get('question_reference_address')[0],
|
||||
self.node_params_serializer.data.get('question_reference_address')[1:])
|
||||
return self.execute(**self.node_params_serializer.data, question=question, exclude_paragraph_id_list=[])
|
||||
|
||||
def execute(self, dataset_id_list, top_n, similarity, search_mode, question_reference_address, question,
|
||||
def execute(self, dataset_id_list, dataset_setting, question,
|
||||
exclude_paragraph_id_list=None,
|
||||
**kwargs) -> NodeResult:
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -0,0 +1,9 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: __init__.py
|
||||
@date:2024/6/11 15:35
|
||||
@desc:
|
||||
"""
|
||||
from .base_search_dataset_node import BaseSearchDatasetNode
|
||||
|
|
@ -22,7 +22,7 @@ from smartdoc.conf import PROJECT_DIR
|
|||
|
||||
|
||||
class BaseSearchDatasetNode(ISearchDatasetStepNode):
|
||||
def execute(self, dataset_id_list, top_n, similarity, search_mode, question_reference_address, question,
|
||||
def execute(self, dataset_id_list, dataset_setting, question,
|
||||
exclude_paragraph_id_list=None,
|
||||
**kwargs) -> NodeResult:
|
||||
embedding_model = EmbeddingModel.get_embedding_model()
|
||||
|
|
@ -33,7 +33,8 @@ class BaseSearchDatasetNode(ISearchDatasetStepNode):
|
|||
dataset_id__in=dataset_id_list,
|
||||
is_active=False)]
|
||||
embedding_list = vector.query(question, embedding_value, dataset_id_list, exclude_document_id_list,
|
||||
exclude_paragraph_id_list, True, top_n, similarity, SearchMode(search_mode))
|
||||
exclude_paragraph_id_list, True, dataset_setting.get('top_n'),
|
||||
dataset_setting.get('similarity'), SearchMode(dataset_setting.get('search_mode')))
|
||||
if embedding_list is None:
|
||||
return NodeResult({'paragraph_list': [], 'is_hit_handling_method': []}, {})
|
||||
paragraph_list = self.list_paragraph(embedding_list, vector)
|
||||
|
|
@ -71,3 +72,11 @@ class BaseSearchDatasetNode(ISearchDatasetStepNode):
|
|||
if not exist_paragraph_list.__contains__(paragraph_id):
|
||||
vector.delete_by_paragraph_id(paragraph_id)
|
||||
return paragraph_list
|
||||
|
||||
def get_details(self, index: int, **kwargs):
|
||||
return {
|
||||
"index": index,
|
||||
'run_time': self.context.get('run_time'),
|
||||
'paragraph_list': self.context.get('paragraph_list'),
|
||||
'type': self.node.type
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,9 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: __init__.py
|
||||
@date:2024/6/11 15:30
|
||||
@desc:
|
||||
"""
|
||||
from .impl import *
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: __init__.py
|
||||
@date:2024/6/11 15:36
|
||||
@desc:
|
||||
"""
|
||||
from .base_start_node import BaseStartStepNode
|
||||
|
|
@ -18,3 +18,11 @@ class BaseStartStepNode(IStarNode):
|
|||
开始节点 初始化全局变量
|
||||
"""
|
||||
return NodeResult({'question': question}, {'time': time.time()})
|
||||
|
||||
def get_details(self, index: int, **kwargs):
|
||||
return {
|
||||
"index": index,
|
||||
"question": self.context.get('question'),
|
||||
'run_time': self.context.get('run_time'),
|
||||
'type': self.node.type
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,10 +10,8 @@ from typing import List, Dict
|
|||
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
|
||||
from application.flow.i_step_node import INode, WorkFlowPostHandler
|
||||
from application.flow.step_node.ai_chat_step_node.impl.base_chat_node import BaseChatNode
|
||||
from application.flow.step_node.search_dataset_node.impl.base_search_dataset_node import BaseSearchDatasetNode
|
||||
from application.flow.step_node.start_node.impl.base_start_node import BaseStartStepNode
|
||||
from application.flow.i_step_node import INode, WorkFlowPostHandler, NodeResult
|
||||
from application.flow.step_node import get_node
|
||||
|
||||
|
||||
class Edge:
|
||||
|
|
@ -52,41 +50,36 @@ class Flow:
|
|||
return Flow(nodes, edges)
|
||||
|
||||
|
||||
flow_node_dict = {
|
||||
'start-node': BaseStartStepNode,
|
||||
'search-dataset-node': BaseSearchDatasetNode,
|
||||
'chat-node': BaseChatNode
|
||||
}
|
||||
|
||||
|
||||
class WorkflowManage:
|
||||
def __init__(self, flow: Flow, params):
|
||||
self.params = params
|
||||
self.flow = flow
|
||||
self.context = {}
|
||||
self.node_dict = {}
|
||||
self.runtime_nodes = []
|
||||
self.node_context = []
|
||||
self.current_node = None
|
||||
self.current_result = None
|
||||
|
||||
def run(self):
|
||||
"""
|
||||
运行工作流
|
||||
"""
|
||||
while self.has_next_node():
|
||||
while self.has_next_node(self.current_result):
|
||||
self.current_node = self.get_next_node()
|
||||
self.node_dict[self.current_node.id] = self.current_node
|
||||
result = self.current_node.run()
|
||||
if self.has_next_node():
|
||||
result.write_context(self.current_node, self)
|
||||
self.node_context.append(self.current_node)
|
||||
self.current_result = self.current_node.run()
|
||||
if self.has_next_node(self.current_result):
|
||||
self.current_result.write_context(self.current_node, self)
|
||||
else:
|
||||
r = result.to_response(self.params['chat_id'], self.params['chat_record_id'], self.current_node, self,
|
||||
WorkFlowPostHandler(client_id=self.params['client_id'], chat_info=None,
|
||||
client_type='ss'))
|
||||
r = self.current_result.to_response(self.params['chat_id'], self.params['chat_record_id'],
|
||||
self.current_node, self,
|
||||
WorkFlowPostHandler(client_id=self.params['client_id'],
|
||||
chat_info=None,
|
||||
client_type='APPLICATION_ACCESS_TOKEN'))
|
||||
for row in r:
|
||||
print(row)
|
||||
print(self)
|
||||
|
||||
def has_next_node(self):
|
||||
def has_next_node(self, node_result: NodeResult | None):
|
||||
"""
|
||||
是否有下一个可运行的节点
|
||||
"""
|
||||
|
|
@ -94,13 +87,24 @@ class WorkflowManage:
|
|||
if self.get_start_node() is not None:
|
||||
return True
|
||||
else:
|
||||
for edge in self.flow.edges:
|
||||
if edge.sourceNodeId == self.current_node.id:
|
||||
return True
|
||||
if node_result is not None and node_result.is_assertion_result():
|
||||
for edge in self.flow.edges:
|
||||
if (edge.sourceNodeId == self.current_node.id and
|
||||
f"{edge.sourceNodeId}_{node_result.node_variable.get('branch_id')}_right" == edge.sourceAnchorId):
|
||||
return True
|
||||
else:
|
||||
for edge in self.flow.edges:
|
||||
if edge.sourceNodeId == self.current_node.id:
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_runtime_details(self):
|
||||
return {}
|
||||
details_result = {}
|
||||
for index in range(len(self.node_context)):
|
||||
node = self.node_context[index]
|
||||
details = node.get_details({'index': index})
|
||||
details_result[node.id] = details
|
||||
return details_result
|
||||
|
||||
def get_next_node(self):
|
||||
"""
|
||||
|
|
@ -108,8 +112,7 @@ class WorkflowManage:
|
|||
"""
|
||||
if self.current_node is None:
|
||||
node = self.get_start_node()
|
||||
node_instance = flow_node_dict[node.type](node.id, node.properties.get('node_data'),
|
||||
self.params, self.context)
|
||||
node_instance = get_node(node.type)(node, self.params, self.context)
|
||||
return node_instance
|
||||
for edge in self.flow.edges:
|
||||
if edge.sourceNodeId == self.current_node.id:
|
||||
|
|
@ -138,8 +141,9 @@ class WorkflowManage:
|
|||
context = {
|
||||
'global': self.context,
|
||||
}
|
||||
for key in self.node_dict:
|
||||
context[key] = self.node_dict[key].context
|
||||
|
||||
for node in self.node_context:
|
||||
context[node.id] = node.context
|
||||
value = prompt_template.format(context=context)
|
||||
return value
|
||||
|
||||
|
|
@ -148,18 +152,22 @@ class WorkflowManage:
|
|||
获取启动节点
|
||||
@return:
|
||||
"""
|
||||
return self.flow.nodes[0]
|
||||
start_node_list = [node for node in self.flow.nodes if node.type == 'start-node']
|
||||
return start_node_list[0]
|
||||
|
||||
def get_node_cls_by_id(self, node_id):
|
||||
for node in self.flow.nodes:
|
||||
if node.id == node_id:
|
||||
node_instance = flow_node_dict[node.type](node.id, node.properties.get('node_data'),
|
||||
self.params, self)
|
||||
node_instance = get_node(node.type)(node,
|
||||
self.params, self)
|
||||
return node_instance
|
||||
return None
|
||||
|
||||
def get_node_by_id(self, node_id):
|
||||
return self.node_dict[node_id]
|
||||
for node in self.node_context:
|
||||
if node.id == node_id:
|
||||
return node
|
||||
return None
|
||||
|
||||
def get_node_reference(self, reference_address: Dict):
|
||||
node = self.get_node_by_id(reference_address.get('node_id'))
|
||||
|
|
|
|||
Loading…
Reference in New Issue