feat: 工作编排

This commit is contained in:
shaohuzhang1 2024-06-11 17:53:58 +08:00
parent 6f14abcd56
commit c60fc829ae
40 changed files with 1051 additions and 64 deletions

View File

@ -0,0 +1,8 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file __init__.py.py
@date2024/6/7 14:43
@desc:
"""

View File

@ -41,15 +41,18 @@ class WorkFlowPostHandler:
answer,
workflow):
question = workflow.params['question']
details = workflow.get_runtime_details()
message_tokens = sum([row.get('message_tokens') for row in details.values() if 'message_tokens' in row])
answer_tokens = sum([row.get('answer_tokens') for row in details.values() if 'answer_tokens' in row])
chat_record = ChatRecord(id=chat_record_id,
chat_id=chat_id,
problem_text=question,
answer_text=answer,
details=workflow.get_details(),
message_tokens=workflow.context['message_tokens'],
answer_tokens=workflow.context['answer_tokens'],
run_time=workflow.context['run_time'],
index=len(self.chat_info.chat_record_list) + 1)
details=details,
message_tokens=message_tokens,
answer_tokens=answer_tokens,
run_time=time.time() - workflow.context['time'],
index=0)
self.chat_info.append_chat_record(chat_record, self.client_id)
# 重新设置缓存
chat_cache.set(chat_id,
@ -76,6 +79,9 @@ class NodeResult:
return self._to_response(chat_id, chat_record_id, self.node_variable, self.workflow_variable, node, workflow,
post_handler)
def is_assertion_result(self):
return 'branch_id' in self.node_variable
class ReferenceAddressSerializer(serializers.Serializer):
node_id = serializers.CharField(required=True, error_messages=ErrMessage.char("节点id"))
@ -103,15 +109,16 @@ class FlowParamsSerializer(serializers.Serializer):
class INode:
def __init__(self, _id, node_params, workflow_params, workflow_manage):
def __init__(self, node, workflow_params, workflow_manage):
# 当前步骤上下文,用于存储当前步骤信息
self.node_params = node_params
self.node = node
self.node_params = node.properties.get('node_data')
self.workflow_manage = workflow_manage
self.node_params_serializer = None
self.flow_params_serializer = None
self.context = {}
self.id = _id
self.valid_args(node_params, workflow_params)
self.id = node.id
self.valid_args(self.node_params, workflow_params)
def valid_args(self, node_params, flow_params):
flow_params_serializer_class = self.get_flow_params_serializer_class()
@ -160,9 +167,9 @@ class INode:
def execute(self, **kwargs) -> NodeResult:
pass
def get_details(self, **kwargs):
def get_details(self, index: int, **kwargs):
"""
运行详情
:return: 步骤详情
"""
return None
return {}

View File

@ -0,0 +1,23 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file __init__.py.py
@date2024/6/7 14:43
@desc:
"""
from .ai_chat_step_node import *
from .condition_node import *
from .question_node import *
from .search_dataset_node import *
from .start_node import *
from .direct_reply_node import *
node_list = [BaseStartStepNode, BaseChatNode, BaseSearchDatasetNode, BaseQuestionNode, BaseConditionNode, BaseReplyNode]
def get_node(node_type):
find_list = [node for node in node_list if node.type == node_type]
if len(find_list) > 0:
return find_list[0]
return None

View File

@ -0,0 +1,9 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file __init__.py
@date2024/6/11 15:29
@desc:
"""
from .impl import *

View File

@ -16,7 +16,8 @@ from common.util.field_message import ErrMessage
class ChatNodeSerializer(serializers.Serializer):
model_id = serializers.CharField(required=True, error_messages=ErrMessage.char("模型id"))
system = serializers.CharField(required=True, error_messages=ErrMessage.char("角色设定"))
system = serializers.CharField(required=False, allow_blank=True, allow_null=True,
error_messages=ErrMessage.char("角色设定"))
prompt = serializers.CharField(required=True, error_messages=ErrMessage.char("提示词"))
# 多轮对话数量
dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer("多轮对话数量"))

View File

@ -0,0 +1,9 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file __init__.py
@date2024/6/11 15:34
@desc:
"""
from .base_chat_node import BaseChatNode

View File

@ -2,7 +2,7 @@
"""
@project: maxkb
@Author
@file base_chat_node.py
@file base_question_node.py
@date2024/6/4 14:30
@desc:
"""
@ -39,6 +39,8 @@ def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INo
node.context['message_tokens'] = message_tokens
node.context['answer_tokens'] = answer_tokens
node.context['answer'] = answer
node.context['history_message'] = node_variable['history_message']
node.context['question'] = node_variable['question']
def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
@ -57,6 +59,8 @@ def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wor
node.context['message_tokens'] = message_tokens
node.context['answer_tokens'] = answer_tokens
node.context['answer'] = answer
node.context['history_message'] = node_variable['history_message']
node.context['question'] = node_variable['question']
def get_to_response_write_context(node_variable: Dict, node: INode):
@ -67,6 +71,8 @@ def get_to_response_write_context(node_variable: Dict, node: INode):
node.context['message_tokens'] = message_tokens
node.context['answer_tokens'] = answer_tokens
node.context['answer'] = answer
node.context['history_message'] = node_variable['history_message']
node.context['question'] = node_variable['question']
return _write_context
@ -115,25 +121,38 @@ class BaseChatNode(IChatNode):
json.loads(
rsa_long_decrypt(model.credential)),
streaming=True)
message_list = self.generate_message_list(system, prompt, history_chat_record, dialogue_number)
history_message = self.get_history_message(history_chat_record, dialogue_number)
question = self.generate_prompt_question(prompt)
message_list = self.generate_message_list(system, prompt, history_message)
if stream:
r = chat_model.stream(message_list)
return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list,
'get_to_response_write_context': get_to_response_write_context}, {},
'history_message': history_message, 'question': question}, {},
_write_context=write_context_stream,
_to_response=to_stream_response)
else:
r = chat_model.invoke(message_list)
return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list}, {},
return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list,
'history_message': history_message, 'question': question}, {},
_write_context=write_context, _to_response=to_response)
def generate_message_list(self, system: str, prompt: str, history_chat_record, dialogue_number):
@staticmethod
def get_history_message(history_chat_record, dialogue_number):
start_index = len(history_chat_record) - dialogue_number
history_message = [[history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()]
for index in
range(start_index if start_index > 0 else 0, len(history_chat_record))]
return [SystemMessage(self.workflow_manage.generate_prompt(system)), *history_message,
HumanMessage(self.workflow_manage.generate_prompt(prompt))]
return history_message
def generate_prompt_question(self, prompt):
return HumanMessage(self.workflow_manage.generate_prompt(prompt))
def generate_message_list(self, system: str, prompt: str, history_message):
if system is None or len(system) == 0:
return [SystemMessage(self.workflow_manage.generate_prompt(system)), *history_message,
HumanMessage(self.workflow_manage.generate_prompt(prompt))]
else:
return [*history_message, HumanMessage(self.workflow_manage.generate_prompt(prompt))]
@staticmethod
def reset_message_list(message_list: List[BaseMessage], answer_text):
@ -143,3 +162,17 @@ class BaseChatNode(IChatNode):
message_list]
result.append({'role': 'ai', 'content': answer_text})
return result
def get_details(self, index: int, **kwargs):
return {
"index": index,
'run_time': self.context.get('run_time'),
'system': self.node_params.get('system'),
'history_message': [{'content': message.content, 'role': message.type} for message in
self.context.get('history_message')],
'question': self.context.get('question'),
'answer': self.context.get('answer'),
'type': self.node.type,
'message_tokens': self.context['message_tokens'],
'answer_tokens': self.context['answer_tokens']
}

View File

@ -0,0 +1,9 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file __init__.py.py
@date2024/6/7 14:43
@desc:
"""
from .impl import *

View File

@ -0,0 +1,23 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file __init__.py.py
@date2024/6/7 14:43
@desc:
"""
from .contain_compare import *
from .equal_compare import *
from .gt_compare import *
from .ge_compare import *
from .le_compare import *
from .lt_compare import *
from .len_ge_compare import *
from .len_gt_compare import *
from .len_le_compare import *
from .len_lt_compare import *
from .len_equal_compare import *
compare_handle_list = [GECompare(), GTCompare(), ContainCompare(), EqualCompare(), LTCompare(), LECompare(),
LenLECompare(), LenGECompare(), LenEqualCompare(), LenGTCompare(), LenLTCompare()]

View File

@ -0,0 +1,20 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file compare.py
@date2024/6/7 14:37
@desc:
"""
from abc import abstractmethod
from typing import List
class Compare:
@abstractmethod
def support(self, node_id, fields: List[str], source_value, compare, target_value):
pass
@abstractmethod
def compare(self, source_value, compare, target_value):
pass

View File

@ -0,0 +1,21 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file contain_compare.py
@date2024/6/11 10:02
@desc:
"""
from typing import List
from application.flow.step_node.condition_node.compare.compare import Compare
class ContainCompare(Compare):
def support(self, node_id, fields: List[str], source_value, compare, target_value):
if compare == 'contain':
return True
def compare(self, source_value, compare, target_value):
return any([str(item) == str(target_value) for item in source_value])

View File

@ -0,0 +1,21 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file equal_compare.py
@date2024/6/7 14:44
@desc:
"""
from typing import List
from application.flow.step_node.condition_node.compare.compare import Compare
class EqualCompare(Compare):
def support(self, node_id, fields: List[str], source_value, compare, target_value):
if compare == 'eq':
return True
def compare(self, source_value, compare, target_value):
return str(source_value) == str(target_value)

View File

@ -0,0 +1,24 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file lt_compare.py
@date2024/6/11 9:52
@desc: 大于比较器
"""
from typing import List
from application.flow.step_node.condition_node.compare.compare import Compare
class GECompare(Compare):
def support(self, node_id, fields: List[str], source_value, compare, target_value):
if compare == 'ge':
return True
def compare(self, source_value, compare, target_value):
try:
return float(source_value) >= float(target_value)
except Exception as e:
return False

View File

@ -0,0 +1,24 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file lt_compare.py
@date2024/6/11 9:52
@desc: 大于比较器
"""
from typing import List
from application.flow.step_node.condition_node.compare.compare import Compare
class GTCompare(Compare):
def support(self, node_id, fields: List[str], source_value, compare, target_value):
if compare == 'gt':
return True
def compare(self, source_value, compare, target_value):
try:
return float(source_value) > float(target_value)
except Exception as e:
return False

View File

@ -0,0 +1,24 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file lt_compare.py
@date2024/6/11 9:52
@desc: 小于比较器
"""
from typing import List
from application.flow.step_node.condition_node.compare.compare import Compare
class LECompare(Compare):
def support(self, node_id, fields: List[str], source_value, compare, target_value):
if compare == 'le':
return True
def compare(self, source_value, compare, target_value):
try:
return float(source_value) <= float(target_value)
except Exception as e:
return False

View File

@ -0,0 +1,24 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file equal_compare.py
@date2024/6/7 14:44
@desc:
"""
from typing import List
from application.flow.step_node.condition_node.compare.compare import Compare
class LenEqualCompare(Compare):
def support(self, node_id, fields: List[str], source_value, compare, target_value):
if compare == 'len_eq':
return True
def compare(self, source_value, compare, target_value):
try:
return len(source_value) == int(target_value)
except Exception as e:
return False

View File

@ -0,0 +1,24 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file lt_compare.py
@date2024/6/11 9:52
@desc: 大于比较器
"""
from typing import List
from application.flow.step_node.condition_node.compare.compare import Compare
class LenGECompare(Compare):
def support(self, node_id, fields: List[str], source_value, compare, target_value):
if compare == 'len_ge':
return True
def compare(self, source_value, compare, target_value):
try:
return len(source_value) >= int(target_value)
except Exception as e:
return False

View File

@ -0,0 +1,24 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file lt_compare.py
@date2024/6/11 9:52
@desc: 大于比较器
"""
from typing import List
from application.flow.step_node.condition_node.compare.compare import Compare
class LenGTCompare(Compare):
def support(self, node_id, fields: List[str], source_value, compare, target_value):
if compare == 'len_gt':
return True
def compare(self, source_value, compare, target_value):
try:
return len(source_value) > int(target_value)
except Exception as e:
return False

View File

@ -0,0 +1,24 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file lt_compare.py
@date2024/6/11 9:52
@desc: 小于比较器
"""
from typing import List
from application.flow.step_node.condition_node.compare.compare import Compare
class LenLECompare(Compare):
def support(self, node_id, fields: List[str], source_value, compare, target_value):
if compare == 'len_le':
return True
def compare(self, source_value, compare, target_value):
try:
return len(source_value) <= int(target_value)
except Exception as e:
return False

View File

@ -0,0 +1,24 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file lt_compare.py
@date2024/6/11 9:52
@desc: 小于比较器
"""
from typing import List
from application.flow.step_node.condition_node.compare.compare import Compare
class LenLTCompare(Compare):
def support(self, node_id, fields: List[str], source_value, compare, target_value):
if compare == 'len_lt':
return True
def compare(self, source_value, compare, target_value):
try:
return len(source_value) < int(target_value)
except Exception as e:
return False

View File

@ -0,0 +1,24 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file lt_compare.py
@date2024/6/11 9:52
@desc: 小于比较器
"""
from typing import List
from application.flow.step_node.condition_node.compare.compare import Compare
class LTCompare(Compare):
def support(self, node_id, fields: List[str], source_value, compare, target_value):
if compare == 'lt':
return True
def compare(self, source_value, compare, target_value):
try:
return float(source_value) < float(target_value)
except Exception as e:
return False

View File

@ -0,0 +1,83 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file i_condition_node.py
@date2024/6/7 9:54
@desc:
"""
import json
from typing import Type
from rest_framework import serializers
from application.flow.i_step_node import INode
from common.util.field_message import ErrMessage
class ConditionSerializer(serializers.Serializer):
compare = serializers.CharField(required=True, error_messages=ErrMessage.char("比较器"))
value = serializers.CharField(required=True, error_messages=ErrMessage.char(""))
field = serializers.ListField(required=True, error_messages=ErrMessage.char("字段"))
class ConditionBranchSerializer(serializers.Serializer):
id = serializers.CharField(required=True, error_messages=ErrMessage.char("分支id"))
condition = serializers.CharField(required=True, error_messages=ErrMessage.char("条件or|and"))
conditions = ConditionSerializer(many=True)
class ConditionNodeParamsSerializer(serializers.Serializer):
branch = ConditionBranchSerializer(many=True)
j = """
{ "branch": [
{
"conditions": [
{
"field": [
"34902d3d-a3ff-497f-b8e1-0c34a44d7dd5",
"paragraph_list"
],
"compare": "len_eq",
"value": "0"
}
],
"id": "2391",
"condition": "and"
},
{
"conditions": [
{
"field": [
"34902d3d-a3ff-497f-b8e1-0c34a44d7dd5",
"paragraph_list"
],
"compare": "len_eq",
"value": "1"
}
],
"id": "1143",
"condition": "and"
},
{
"conditions": [
],
"id": "9208",
"condition": "and"
}
]}
"""
a = json.loads(j)
c = ConditionNodeParamsSerializer(data=a)
c.is_valid(raise_exception=True)
print(c.data)
class IConditionNode(INode):
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
return ConditionNodeParamsSerializer
type = 'condition-node'

View File

@ -0,0 +1,9 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file __init__.py
@date2024/6/11 15:35
@desc:
"""
from .base_condition_node import BaseConditionNode

View File

@ -0,0 +1,47 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file base_condition_node.py
@date2024/6/7 11:29
@desc:
"""
from typing import List
from application.flow.i_step_node import NodeResult
from application.flow.step_node.condition_node.compare import compare_handle_list
from application.flow.step_node.condition_node.i_condition_node import IConditionNode
class BaseConditionNode(IConditionNode):
def execute(self, **kwargs) -> NodeResult:
branch_list = self.node_params_serializer.data['branch']
branch = self._execute(branch_list)
r = NodeResult({'branch_id': branch.get('id')}, {})
return r
def _execute(self, branch_list: List):
for branch in branch_list:
if self.branch_assertion(branch):
return branch
def branch_assertion(self, branch):
condition_list = [self.assertion(row.get('field'), row.get('compare'), row.get('value')) for row in
branch.get('conditions')]
condition = branch.get('condition')
return all(condition_list) if condition == 'and' else any(condition_list)
def assertion(self, field_list: List[str], compare: str, value):
field_value = self.workflow_manage.get_reference_field(field_list[0], field_list[1:])
for compare_handler in compare_handle_list:
if compare_handler.support(field_list[0], field_list[1:], field_value, compare, value):
return compare_handler.compare(field_value, compare, value)
def get_details(self, index: int, **kwargs):
return {
"index": index,
'run_time': self.context.get('run_time'),
'branch_id': self.context.get('branch_id'),
'branch_name': self.context.get('branch_name'),
'type': self.node.type
}

View File

@ -0,0 +1,9 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file __init__.py
@date2024/6/11 17:50
@desc:
"""
from .impl import *

View File

@ -0,0 +1,45 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file i_reply_node.py
@date2024/6/11 16:25
@desc:
"""
from typing import Type
from rest_framework import serializers
from application.flow.i_step_node import INode, NodeResult
from common.exception.app_exception import AppApiException
from common.util.field_message import ErrMessage
class ReplyNodeParamsSerializer(serializers.Serializer):
reply_type = serializers.CharField(required=True, error_messages=ErrMessage.char("回复类型"))
fields = serializers.ListField(required=False, error_messages=ErrMessage.list("引用字段"))
content = serializers.CharField(required=False, error_messages=ErrMessage.char("直接回答内容"))
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
if self.data.get('reply_type') == 'referencing':
if 'fields' not in self.data:
raise AppApiException(500, "引用字段不能为空")
if len(self.data.get('fields')) < 2:
raise AppApiException(500, "引用字段错误")
else:
if 'content' not in self.data or self.data.get('content') is None:
raise AppApiException(500, "内容不能为空")
class IReplyNode(INode):
type = 'reply-node'
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
return ReplyNodeParamsSerializer
def _run(self):
return self.execute(**self.flow_params_serializer.data)
def execute(self, reply_type, stream, fields=None, content=None, **kwargs) -> NodeResult:
pass

View File

@ -0,0 +1,9 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file __init__.py
@date2024/6/11 17:49
@desc:
"""
from .base_reply_node import *

View File

@ -0,0 +1,87 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file base_reply_node.py
@date2024/6/11 17:25
@desc:
"""
from typing import List, Dict
from langchain_core.messages import AIMessage, AIMessageChunk
from application.flow import tools
from application.flow.i_step_node import NodeResult, INode
from application.flow.step_node.direct_reply_node.i_reply_node import IReplyNode
def get_to_response_write_context(node_variable: Dict, node: INode):
def _write_context(answer):
node.context['answer'] = answer
return _write_context
def to_stream_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow,
post_handler):
"""
将流式数据 转换为 流式响应
@param chat_id: 会话id
@param chat_record_id: 对话记录id
@param node_variable: 节点数据
@param workflow_variable: 工作流数据
@param node: 节点
@param workflow: 工作流管理器
@param post_handler: 后置处理器 输出结果后执行
@return: 流式响应
"""
response = node_variable.get('result')
_write_context = get_to_response_write_context(node_variable, node)
return tools.to_stream_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler)
def to_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow,
post_handler):
"""
将结果转换
@param chat_id: 会话id
@param chat_record_id: 对话记录id
@param node_variable: 节点数据
@param workflow_variable: 工作流数据
@param node: 节点
@param workflow: 工作流管理器
@param post_handler: 后置处理器
@return: 响应
"""
response = node_variable.get('result')
_write_context = get_to_response_write_context(node_variable, node)
return tools.to_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler)
class BaseReplyNode(IReplyNode):
def execute(self, reply_type, stream, fields=None, content=None, **kwargs) -> NodeResult:
if reply_type == 'referencing':
result = self.get_reference_content(fields)
else:
result = self.generate_reply_content(content)
if stream:
return NodeResult({'result': iter(AIMessageChunk(content=result))}, {},
_to_response=to_stream_response)
else:
return NodeResult({'result': AIMessage(content=result)}, {}, _to_response=to_response)
def generate_reply_content(self, prompt):
return self.workflow_manage.generate_prompt(prompt)
def get_reference_content(self, fields: List[str]):
return self.workflow_manage.get_reference_field(
fields[0],
fields[1:])
def get_details(self, index: int, **kwargs):
return {
"index": index,
'run_time': self.context.get('run_time'),
'type': self.node.type,
'answer': self.context.get('answer'),
}

View File

@ -0,0 +1,9 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file __init__.py
@date2024/6/11 15:30
@desc:
"""
from .impl import *

View File

@ -0,0 +1,37 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file i_chat_node.py
@date2024/6/4 13:58
@desc:
"""
from typing import Type
from rest_framework import serializers
from application.flow.i_step_node import INode, NodeResult
from common.util.field_message import ErrMessage
class QuestionNodeSerializer(serializers.Serializer):
model_id = serializers.CharField(required=True, error_messages=ErrMessage.char("模型id"))
system = serializers.CharField(required=False, allow_blank=True, allow_null=True,
error_messages=ErrMessage.char("角色设定"))
prompt = serializers.CharField(required=True, error_messages=ErrMessage.char("提示词"))
# 多轮对话数量
dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer("多轮对话数量"))
class IQuestionNode(INode):
type = 'question-node'
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
return QuestionNodeSerializer
def _run(self):
return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data)
def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id,
**kwargs) -> NodeResult:
pass

View File

@ -0,0 +1,9 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file __init__.py
@date2024/6/11 15:35
@desc:
"""
from .base_question_node import BaseQuestionNode

View File

@ -0,0 +1,177 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file base_question_node.py
@date2024/6/4 14:30
@desc:
"""
import json
from typing import List, Dict
from django.db.models import QuerySet
from langchain.schema import HumanMessage, SystemMessage
from langchain_core.messages import BaseMessage
from application.flow import tools
from application.flow.i_step_node import NodeResult, INode
from application.flow.step_node.question_node.i_question_node import IQuestionNode
from common.util.rsa_util import rsa_long_decrypt
from setting.models import Model
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
"""
写入上下文数据 (流式)
@param node_variable: 节点数据
@param workflow_variable: 全局数据
@param node: 节点
@param workflow: 工作流管理器
"""
response = node_variable.get('result')
answer = ''
for chunk in response:
answer += chunk.content
chat_model = node_variable.get('chat_model')
message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list'))
answer_tokens = chat_model.get_num_tokens(answer)
node.context['message_tokens'] = message_tokens
node.context['answer_tokens'] = answer_tokens
node.context['answer'] = answer
node.context['history_message'] = node_variable['history_message']
node.context['question'] = node_variable['question']
def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
"""
写入上下文数据
@param node_variable: 节点数据
@param workflow_variable: 全局数据
@param node: 节点实例对象
@param workflow: 工作流管理器
"""
response = node_variable.get('result')
chat_model = node_variable.get('chat_model')
answer = response.content
message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list'))
answer_tokens = chat_model.get_num_tokens(answer)
node.context['message_tokens'] = message_tokens
node.context['answer_tokens'] = answer_tokens
node.context['answer'] = answer
node.context['history_message'] = node_variable['history_message']
node.context['question'] = node_variable['question']
def get_to_response_write_context(node_variable: Dict, node: INode):
def _write_context(answer):
chat_model = node_variable.get('chat_model')
message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list'))
answer_tokens = chat_model.get_num_tokens(answer)
node.context['message_tokens'] = message_tokens
node.context['answer_tokens'] = answer_tokens
node.context['answer'] = answer
return _write_context
def to_stream_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow,
post_handler):
"""
将流式数据 转换为 流式响应
@param chat_id: 会话id
@param chat_record_id: 对话记录id
@param node_variable: 节点数据
@param workflow_variable: 工作流数据
@param node: 节点
@param workflow: 工作流管理器
@param post_handler: 后置处理器 输出结果后执行
@return: 流式响应
"""
response = node_variable.get('result')
_write_context = get_to_response_write_context(node_variable, node)
return tools.to_stream_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler)
def to_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow,
post_handler):
"""
将结果转换
@param chat_id: 会话id
@param chat_record_id: 对话记录id
@param node_variable: 节点数据
@param workflow_variable: 工作流数据
@param node: 节点
@param workflow: 工作流管理器
@param post_handler: 后置处理器
@return: 响应
"""
response = node_variable.get('result')
_write_context = get_to_response_write_context(node_variable, node)
return tools.to_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler)
class BaseQuestionNode(IQuestionNode):
def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id,
**kwargs) -> NodeResult:
model = QuerySet(Model).filter(id=model_id).first()
chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name,
json.loads(
rsa_long_decrypt(model.credential)),
streaming=True)
history_message = self.get_history_message(history_chat_record, dialogue_number)
question = self.generate_prompt_question(prompt)
message_list = self.generate_message_list(system, prompt, history_message)
if stream:
r = chat_model.stream(message_list)
return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list,
'get_to_response_write_context': get_to_response_write_context,
'history_message': history_message, 'question': question}, {},
_write_context=write_context_stream,
_to_response=to_stream_response)
else:
r = chat_model.invoke(message_list)
return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list,
'history_message': history_message, 'question': question}, {},
_write_context=write_context, _to_response=to_response)
@staticmethod
def get_history_message(history_chat_record, dialogue_number):
start_index = len(history_chat_record) - dialogue_number
history_message = [[history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()]
for index in
range(start_index if start_index > 0 else 0, len(history_chat_record))]
return history_message
def generate_prompt_question(self, prompt):
return HumanMessage(self.workflow_manage.generate_prompt(prompt))
def generate_message_list(self, system: str, prompt: str, history_message):
if system is None or len(system) == 0:
return [SystemMessage(self.workflow_manage.generate_prompt(system)), *history_message,
HumanMessage(self.workflow_manage.generate_prompt(prompt))]
else:
return [*history_message, HumanMessage(self.workflow_manage.generate_prompt(prompt))]
@staticmethod
def reset_message_list(message_list: List[BaseMessage], answer_text):
result = [{'role': 'user' if isinstance(message, HumanMessage) else 'ai', 'content': message.content} for
message
in
message_list]
result.append({'role': 'ai', 'content': answer_text})
return result
def get_details(self, index: int, **kwargs):
return {
"index": index,
'run_time': self.context.get('run_time'),
'system': self.node_params.get('system'),
'history_message': [{'content': message.content, 'role': message.type} for message in
self.context.get('history_message')],
'question': self.context.get('question'),
'answer': self.context.get('answer'),
'type': self.node.type,
'message_tokens': self.context['message_tokens'],
'answer_tokens': self.context['answer_tokens']
}

View File

@ -0,0 +1,9 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file __init__.py
@date2024/6/11 15:30
@desc:
"""
from .impl import *

View File

@ -16,10 +16,7 @@ from application.flow.i_step_node import INode, ReferenceAddressSerializer, Node
from common.util.field_message import ErrMessage
class SearchDatasetStepNodeSerializer(serializers.Serializer):
# 需要查询的数据集id列表
dataset_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True),
error_messages=ErrMessage.list("数据集id列表"))
class DatasetSettingSerializer(serializers.Serializer):
# 需要查询的条数
top_n = serializers.IntegerField(required=True,
error_messages=ErrMessage.integer("引用分段数"))
@ -30,9 +27,17 @@ class SearchDatasetStepNodeSerializer(serializers.Serializer):
validators.RegexValidator(regex=re.compile("^embedding|keywords|blend$"),
message="类型只支持register|reset_password", code=500)
], error_messages=ErrMessage.char("检索模式"))
max_paragraph_char_number = serializers.IntegerField(required=True,
error_messages=ErrMessage.float("最大引用分段字数"))
question_reference_address = ReferenceAddressSerializer(required=False,
error_messages=ErrMessage.char("问题应用地址"))
class SearchDatasetStepNodeSerializer(serializers.Serializer):
# 需要查询的数据集id列表
dataset_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True),
error_messages=ErrMessage.list("数据集id列表"))
dataset_setting = DatasetSettingSerializer(required=True)
question_reference_address = serializers.ListField(required=True, )
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
@ -46,11 +51,11 @@ class ISearchDatasetStepNode(INode):
def _run(self):
question = self.workflow_manage.get_reference_field(
self.node_params_serializer.data.get('question_reference_address').get('node_id'),
self.node_params_serializer.data.get('question_reference_address').get('fields'))
self.node_params_serializer.data.get('question_reference_address')[0],
self.node_params_serializer.data.get('question_reference_address')[1:])
return self.execute(**self.node_params_serializer.data, question=question, exclude_paragraph_id_list=[])
def execute(self, dataset_id_list, top_n, similarity, search_mode, question_reference_address, question,
def execute(self, dataset_id_list, dataset_setting, question,
exclude_paragraph_id_list=None,
**kwargs) -> NodeResult:
pass

View File

@ -0,0 +1,9 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file __init__.py
@date2024/6/11 15:35
@desc:
"""
from .base_search_dataset_node import BaseSearchDatasetNode

View File

@ -22,7 +22,7 @@ from smartdoc.conf import PROJECT_DIR
class BaseSearchDatasetNode(ISearchDatasetStepNode):
def execute(self, dataset_id_list, top_n, similarity, search_mode, question_reference_address, question,
def execute(self, dataset_id_list, dataset_setting, question,
exclude_paragraph_id_list=None,
**kwargs) -> NodeResult:
embedding_model = EmbeddingModel.get_embedding_model()
@ -33,7 +33,8 @@ class BaseSearchDatasetNode(ISearchDatasetStepNode):
dataset_id__in=dataset_id_list,
is_active=False)]
embedding_list = vector.query(question, embedding_value, dataset_id_list, exclude_document_id_list,
exclude_paragraph_id_list, True, top_n, similarity, SearchMode(search_mode))
exclude_paragraph_id_list, True, dataset_setting.get('top_n'),
dataset_setting.get('similarity'), SearchMode(dataset_setting.get('search_mode')))
if embedding_list is None:
return NodeResult({'paragraph_list': [], 'is_hit_handling_method': []}, {})
paragraph_list = self.list_paragraph(embedding_list, vector)
@ -71,3 +72,11 @@ class BaseSearchDatasetNode(ISearchDatasetStepNode):
if not exist_paragraph_list.__contains__(paragraph_id):
vector.delete_by_paragraph_id(paragraph_id)
return paragraph_list
def get_details(self, index: int, **kwargs):
return {
"index": index,
'run_time': self.context.get('run_time'),
'paragraph_list': self.context.get('paragraph_list'),
'type': self.node.type
}

View File

@ -0,0 +1,9 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file __init__.py
@date2024/6/11 15:30
@desc:
"""
from .impl import *

View File

@ -0,0 +1,9 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file __init__.py
@date2024/6/11 15:36
@desc:
"""
from .base_start_node import BaseStartStepNode

View File

@ -18,3 +18,11 @@ class BaseStartStepNode(IStarNode):
开始节点 初始化全局变量
"""
return NodeResult({'question': question}, {'time': time.time()})
def get_details(self, index: int, **kwargs):
return {
"index": index,
"question": self.context.get('question'),
'run_time': self.context.get('run_time'),
'type': self.node.type
}

View File

@ -10,10 +10,8 @@ from typing import List, Dict
from langchain_core.prompts import PromptTemplate
from application.flow.i_step_node import INode, WorkFlowPostHandler
from application.flow.step_node.ai_chat_step_node.impl.base_chat_node import BaseChatNode
from application.flow.step_node.search_dataset_node.impl.base_search_dataset_node import BaseSearchDatasetNode
from application.flow.step_node.start_node.impl.base_start_node import BaseStartStepNode
from application.flow.i_step_node import INode, WorkFlowPostHandler, NodeResult
from application.flow.step_node import get_node
class Edge:
@ -52,41 +50,36 @@ class Flow:
return Flow(nodes, edges)
flow_node_dict = {
'start-node': BaseStartStepNode,
'search-dataset-node': BaseSearchDatasetNode,
'chat-node': BaseChatNode
}
class WorkflowManage:
def __init__(self, flow: Flow, params):
self.params = params
self.flow = flow
self.context = {}
self.node_dict = {}
self.runtime_nodes = []
self.node_context = []
self.current_node = None
self.current_result = None
def run(self):
"""
运行工作流
"""
while self.has_next_node():
while self.has_next_node(self.current_result):
self.current_node = self.get_next_node()
self.node_dict[self.current_node.id] = self.current_node
result = self.current_node.run()
if self.has_next_node():
result.write_context(self.current_node, self)
self.node_context.append(self.current_node)
self.current_result = self.current_node.run()
if self.has_next_node(self.current_result):
self.current_result.write_context(self.current_node, self)
else:
r = result.to_response(self.params['chat_id'], self.params['chat_record_id'], self.current_node, self,
WorkFlowPostHandler(client_id=self.params['client_id'], chat_info=None,
client_type='ss'))
r = self.current_result.to_response(self.params['chat_id'], self.params['chat_record_id'],
self.current_node, self,
WorkFlowPostHandler(client_id=self.params['client_id'],
chat_info=None,
client_type='APPLICATION_ACCESS_TOKEN'))
for row in r:
print(row)
print(self)
def has_next_node(self):
def has_next_node(self, node_result: NodeResult | None):
"""
是否有下一个可运行的节点
"""
@ -94,13 +87,24 @@ class WorkflowManage:
if self.get_start_node() is not None:
return True
else:
for edge in self.flow.edges:
if edge.sourceNodeId == self.current_node.id:
return True
if node_result is not None and node_result.is_assertion_result():
for edge in self.flow.edges:
if (edge.sourceNodeId == self.current_node.id and
f"{edge.sourceNodeId}_{node_result.node_variable.get('branch_id')}_right" == edge.sourceAnchorId):
return True
else:
for edge in self.flow.edges:
if edge.sourceNodeId == self.current_node.id:
return True
return False
def get_runtime_details(self):
return {}
details_result = {}
for index in range(len(self.node_context)):
node = self.node_context[index]
details = node.get_details({'index': index})
details_result[node.id] = details
return details_result
def get_next_node(self):
"""
@ -108,8 +112,7 @@ class WorkflowManage:
"""
if self.current_node is None:
node = self.get_start_node()
node_instance = flow_node_dict[node.type](node.id, node.properties.get('node_data'),
self.params, self.context)
node_instance = get_node(node.type)(node, self.params, self.context)
return node_instance
for edge in self.flow.edges:
if edge.sourceNodeId == self.current_node.id:
@ -138,8 +141,9 @@ class WorkflowManage:
context = {
'global': self.context,
}
for key in self.node_dict:
context[key] = self.node_dict[key].context
for node in self.node_context:
context[node.id] = node.context
value = prompt_template.format(context=context)
return value
@ -148,18 +152,22 @@ class WorkflowManage:
获取启动节点
@return:
"""
return self.flow.nodes[0]
start_node_list = [node for node in self.flow.nodes if node.type == 'start-node']
return start_node_list[0]
def get_node_cls_by_id(self, node_id):
for node in self.flow.nodes:
if node.id == node_id:
node_instance = flow_node_dict[node.type](node.id, node.properties.get('node_data'),
self.params, self)
node_instance = get_node(node.type)(node,
self.params, self)
return node_instance
return None
def get_node_by_id(self, node_id):
return self.node_dict[node_id]
for node in self.node_context:
if node.id == node_id:
return node
return None
def get_node_reference(self, reference_address: Dict):
node = self.get_node_by_id(reference_address.get('node_id'))