From c60fc829aec9983d6ff08682b5be7cc36a0e2754 Mon Sep 17 00:00:00 2001 From: shaohuzhang1 Date: Tue, 11 Jun 2024 17:53:58 +0800 Subject: [PATCH] =?UTF-8?q?=20feat:=20=E5=B7=A5=E4=BD=9C=E7=BC=96=E6=8E=92?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/application/flow/__init__.py | 8 + apps/application/flow/i_step_node.py | 29 +-- apps/application/flow/step_node/__init__.py | 23 +++ .../step_node/ai_chat_step_node/__init__.py | 9 + .../ai_chat_step_node/i_chat_node.py | 3 +- .../ai_chat_step_node/impl/__init__.py | 9 + .../ai_chat_step_node/impl/base_chat_node.py | 47 ++++- .../flow/step_node/condition_node/__init__.py | 9 + .../condition_node/compare/__init__.py | 23 +++ .../condition_node/compare/compare.py | 20 ++ .../condition_node/compare/contain_compare.py | 21 +++ .../condition_node/compare/equal_compare.py | 21 +++ .../condition_node/compare/ge_compare.py | 24 +++ .../condition_node/compare/gt_compare.py | 24 +++ .../condition_node/compare/le_compare.py | 24 +++ .../compare/len_equal_compare.py | 24 +++ .../condition_node/compare/len_ge_compare.py | 24 +++ .../condition_node/compare/len_gt_compare.py | 24 +++ .../condition_node/compare/len_le_compare.py | 24 +++ .../condition_node/compare/len_lt_compare.py | 24 +++ .../condition_node/compare/lt_compare.py | 24 +++ .../condition_node/i_condition_node.py | 83 ++++++++ .../step_node/condition_node/impl/__init__.py | 9 + .../impl/base_condition_node.py | 47 +++++ .../step_node/direct_reply_node/__init__.py | 9 + .../direct_reply_node/i_reply_node.py | 45 +++++ .../direct_reply_node/impl/__init__.py | 9 + .../direct_reply_node/impl/base_reply_node.py | 87 +++++++++ .../flow/step_node/question_node/__init__.py | 9 + .../question_node/i_question_node.py | 37 ++++ .../step_node/question_node/impl/__init__.py | 9 + .../question_node/impl/base_question_node.py | 177 ++++++++++++++++++ .../step_node/search_dataset_node/__init__.py | 9 + .../i_search_dataset_node.py | 23 ++- .../search_dataset_node/impl/__init__.py | 9 + .../impl/base_search_dataset_node.py | 13 +- .../flow/step_node/start_node/__init__.py | 9 + .../step_node/start_node/impl/__init__.py | 9 + .../start_node/impl/base_start_node.py | 8 + apps/application/flow/workflow_manage.py | 76 ++++---- 40 files changed, 1051 insertions(+), 64 deletions(-) create mode 100644 apps/application/flow/__init__.py create mode 100644 apps/application/flow/step_node/__init__.py create mode 100644 apps/application/flow/step_node/ai_chat_step_node/__init__.py create mode 100644 apps/application/flow/step_node/ai_chat_step_node/impl/__init__.py create mode 100644 apps/application/flow/step_node/condition_node/__init__.py create mode 100644 apps/application/flow/step_node/condition_node/compare/__init__.py create mode 100644 apps/application/flow/step_node/condition_node/compare/compare.py create mode 100644 apps/application/flow/step_node/condition_node/compare/contain_compare.py create mode 100644 apps/application/flow/step_node/condition_node/compare/equal_compare.py create mode 100644 apps/application/flow/step_node/condition_node/compare/ge_compare.py create mode 100644 apps/application/flow/step_node/condition_node/compare/gt_compare.py create mode 100644 apps/application/flow/step_node/condition_node/compare/le_compare.py create mode 100644 apps/application/flow/step_node/condition_node/compare/len_equal_compare.py create mode 100644 apps/application/flow/step_node/condition_node/compare/len_ge_compare.py create mode 100644 apps/application/flow/step_node/condition_node/compare/len_gt_compare.py create mode 100644 apps/application/flow/step_node/condition_node/compare/len_le_compare.py create mode 100644 apps/application/flow/step_node/condition_node/compare/len_lt_compare.py create mode 100644 apps/application/flow/step_node/condition_node/compare/lt_compare.py create mode 100644 apps/application/flow/step_node/condition_node/i_condition_node.py create mode 100644 apps/application/flow/step_node/condition_node/impl/__init__.py create mode 100644 apps/application/flow/step_node/condition_node/impl/base_condition_node.py create mode 100644 apps/application/flow/step_node/direct_reply_node/__init__.py create mode 100644 apps/application/flow/step_node/direct_reply_node/i_reply_node.py create mode 100644 apps/application/flow/step_node/direct_reply_node/impl/__init__.py create mode 100644 apps/application/flow/step_node/direct_reply_node/impl/base_reply_node.py create mode 100644 apps/application/flow/step_node/question_node/__init__.py create mode 100644 apps/application/flow/step_node/question_node/i_question_node.py create mode 100644 apps/application/flow/step_node/question_node/impl/__init__.py create mode 100644 apps/application/flow/step_node/question_node/impl/base_question_node.py create mode 100644 apps/application/flow/step_node/search_dataset_node/__init__.py create mode 100644 apps/application/flow/step_node/search_dataset_node/impl/__init__.py create mode 100644 apps/application/flow/step_node/start_node/__init__.py create mode 100644 apps/application/flow/step_node/start_node/impl/__init__.py diff --git a/apps/application/flow/__init__.py b/apps/application/flow/__init__.py new file mode 100644 index 000000000..328e8f8ec --- /dev/null +++ b/apps/application/flow/__init__.py @@ -0,0 +1,8 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py.py + @date:2024/6/7 14:43 + @desc: +""" diff --git a/apps/application/flow/i_step_node.py b/apps/application/flow/i_step_node.py index 96b1b7c24..b499c96a6 100644 --- a/apps/application/flow/i_step_node.py +++ b/apps/application/flow/i_step_node.py @@ -41,15 +41,18 @@ class WorkFlowPostHandler: answer, workflow): question = workflow.params['question'] + details = workflow.get_runtime_details() + message_tokens = sum([row.get('message_tokens') for row in details.values() if 'message_tokens' in row]) + answer_tokens = sum([row.get('answer_tokens') for row in details.values() if 'answer_tokens' in row]) chat_record = ChatRecord(id=chat_record_id, chat_id=chat_id, problem_text=question, answer_text=answer, - details=workflow.get_details(), - message_tokens=workflow.context['message_tokens'], - answer_tokens=workflow.context['answer_tokens'], - run_time=workflow.context['run_time'], - index=len(self.chat_info.chat_record_list) + 1) + details=details, + message_tokens=message_tokens, + answer_tokens=answer_tokens, + run_time=time.time() - workflow.context['time'], + index=0) self.chat_info.append_chat_record(chat_record, self.client_id) # 重新设置缓存 chat_cache.set(chat_id, @@ -76,6 +79,9 @@ class NodeResult: return self._to_response(chat_id, chat_record_id, self.node_variable, self.workflow_variable, node, workflow, post_handler) + def is_assertion_result(self): + return 'branch_id' in self.node_variable + class ReferenceAddressSerializer(serializers.Serializer): node_id = serializers.CharField(required=True, error_messages=ErrMessage.char("节点id")) @@ -103,15 +109,16 @@ class FlowParamsSerializer(serializers.Serializer): class INode: - def __init__(self, _id, node_params, workflow_params, workflow_manage): + def __init__(self, node, workflow_params, workflow_manage): # 当前步骤上下文,用于存储当前步骤信息 - self.node_params = node_params + self.node = node + self.node_params = node.properties.get('node_data') self.workflow_manage = workflow_manage self.node_params_serializer = None self.flow_params_serializer = None self.context = {} - self.id = _id - self.valid_args(node_params, workflow_params) + self.id = node.id + self.valid_args(self.node_params, workflow_params) def valid_args(self, node_params, flow_params): flow_params_serializer_class = self.get_flow_params_serializer_class() @@ -160,9 +167,9 @@ class INode: def execute(self, **kwargs) -> NodeResult: pass - def get_details(self, **kwargs): + def get_details(self, index: int, **kwargs): """ 运行详情 :return: 步骤详情 """ - return None + return {} diff --git a/apps/application/flow/step_node/__init__.py b/apps/application/flow/step_node/__init__.py new file mode 100644 index 000000000..b3692fa8b --- /dev/null +++ b/apps/application/flow/step_node/__init__.py @@ -0,0 +1,23 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py.py + @date:2024/6/7 14:43 + @desc: +""" +from .ai_chat_step_node import * +from .condition_node import * +from .question_node import * +from .search_dataset_node import * +from .start_node import * +from .direct_reply_node import * + +node_list = [BaseStartStepNode, BaseChatNode, BaseSearchDatasetNode, BaseQuestionNode, BaseConditionNode, BaseReplyNode] + + +def get_node(node_type): + find_list = [node for node in node_list if node.type == node_type] + if len(find_list) > 0: + return find_list[0] + return None diff --git a/apps/application/flow/step_node/ai_chat_step_node/__init__.py b/apps/application/flow/step_node/ai_chat_step_node/__init__.py new file mode 100644 index 000000000..1929ae2af --- /dev/null +++ b/apps/application/flow/step_node/ai_chat_step_node/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py + @date:2024/6/11 15:29 + @desc: +""" +from .impl import * diff --git a/apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py b/apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py index 738072aa6..d0dfbaef9 100644 --- a/apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py +++ b/apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py @@ -16,7 +16,8 @@ from common.util.field_message import ErrMessage class ChatNodeSerializer(serializers.Serializer): model_id = serializers.CharField(required=True, error_messages=ErrMessage.char("模型id")) - system = serializers.CharField(required=True, error_messages=ErrMessage.char("角色设定")) + system = serializers.CharField(required=False, allow_blank=True, allow_null=True, + error_messages=ErrMessage.char("角色设定")) prompt = serializers.CharField(required=True, error_messages=ErrMessage.char("提示词")) # 多轮对话数量 dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer("多轮对话数量")) diff --git a/apps/application/flow/step_node/ai_chat_step_node/impl/__init__.py b/apps/application/flow/step_node/ai_chat_step_node/impl/__init__.py new file mode 100644 index 000000000..79051a999 --- /dev/null +++ b/apps/application/flow/step_node/ai_chat_step_node/impl/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py + @date:2024/6/11 15:34 + @desc: +""" +from .base_chat_node import BaseChatNode diff --git a/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py b/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py index e688903e0..c6e934582 100644 --- a/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py +++ b/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py @@ -2,7 +2,7 @@ """ @project: maxkb @Author:虎 - @file: base_chat_node.py + @file: base_question_node.py @date:2024/6/4 14:30 @desc: """ @@ -39,6 +39,8 @@ def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INo node.context['message_tokens'] = message_tokens node.context['answer_tokens'] = answer_tokens node.context['answer'] = answer + node.context['history_message'] = node_variable['history_message'] + node.context['question'] = node_variable['question'] def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow): @@ -57,6 +59,8 @@ def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wor node.context['message_tokens'] = message_tokens node.context['answer_tokens'] = answer_tokens node.context['answer'] = answer + node.context['history_message'] = node_variable['history_message'] + node.context['question'] = node_variable['question'] def get_to_response_write_context(node_variable: Dict, node: INode): @@ -67,6 +71,8 @@ def get_to_response_write_context(node_variable: Dict, node: INode): node.context['message_tokens'] = message_tokens node.context['answer_tokens'] = answer_tokens node.context['answer'] = answer + node.context['history_message'] = node_variable['history_message'] + node.context['question'] = node_variable['question'] return _write_context @@ -115,25 +121,38 @@ class BaseChatNode(IChatNode): json.loads( rsa_long_decrypt(model.credential)), streaming=True) - message_list = self.generate_message_list(system, prompt, history_chat_record, dialogue_number) + history_message = self.get_history_message(history_chat_record, dialogue_number) + question = self.generate_prompt_question(prompt) + message_list = self.generate_message_list(system, prompt, history_message) if stream: r = chat_model.stream(message_list) return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list, - 'get_to_response_write_context': get_to_response_write_context}, {}, + 'history_message': history_message, 'question': question}, {}, _write_context=write_context_stream, _to_response=to_stream_response) else: r = chat_model.invoke(message_list) - return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list}, {}, + return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list, + 'history_message': history_message, 'question': question}, {}, _write_context=write_context, _to_response=to_response) - def generate_message_list(self, system: str, prompt: str, history_chat_record, dialogue_number): + @staticmethod + def get_history_message(history_chat_record, dialogue_number): start_index = len(history_chat_record) - dialogue_number history_message = [[history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()] for index in range(start_index if start_index > 0 else 0, len(history_chat_record))] - return [SystemMessage(self.workflow_manage.generate_prompt(system)), *history_message, - HumanMessage(self.workflow_manage.generate_prompt(prompt))] + return history_message + + def generate_prompt_question(self, prompt): + return HumanMessage(self.workflow_manage.generate_prompt(prompt)) + + def generate_message_list(self, system: str, prompt: str, history_message): + if system is None or len(system) == 0: + return [SystemMessage(self.workflow_manage.generate_prompt(system)), *history_message, + HumanMessage(self.workflow_manage.generate_prompt(prompt))] + else: + return [*history_message, HumanMessage(self.workflow_manage.generate_prompt(prompt))] @staticmethod def reset_message_list(message_list: List[BaseMessage], answer_text): @@ -143,3 +162,17 @@ class BaseChatNode(IChatNode): message_list] result.append({'role': 'ai', 'content': answer_text}) return result + + def get_details(self, index: int, **kwargs): + return { + "index": index, + 'run_time': self.context.get('run_time'), + 'system': self.node_params.get('system'), + 'history_message': [{'content': message.content, 'role': message.type} for message in + self.context.get('history_message')], + 'question': self.context.get('question'), + 'answer': self.context.get('answer'), + 'type': self.node.type, + 'message_tokens': self.context['message_tokens'], + 'answer_tokens': self.context['answer_tokens'] + } diff --git a/apps/application/flow/step_node/condition_node/__init__.py b/apps/application/flow/step_node/condition_node/__init__.py new file mode 100644 index 000000000..57638504c --- /dev/null +++ b/apps/application/flow/step_node/condition_node/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py.py + @date:2024/6/7 14:43 + @desc: +""" +from .impl import * diff --git a/apps/application/flow/step_node/condition_node/compare/__init__.py b/apps/application/flow/step_node/condition_node/compare/__init__.py new file mode 100644 index 000000000..b2c464b41 --- /dev/null +++ b/apps/application/flow/step_node/condition_node/compare/__init__.py @@ -0,0 +1,23 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py.py + @date:2024/6/7 14:43 + @desc: +""" + +from .contain_compare import * +from .equal_compare import * +from .gt_compare import * +from .ge_compare import * +from .le_compare import * +from .lt_compare import * +from .len_ge_compare import * +from .len_gt_compare import * +from .len_le_compare import * +from .len_lt_compare import * +from .len_equal_compare import * + +compare_handle_list = [GECompare(), GTCompare(), ContainCompare(), EqualCompare(), LTCompare(), LECompare(), + LenLECompare(), LenGECompare(), LenEqualCompare(), LenGTCompare(), LenLTCompare()] diff --git a/apps/application/flow/step_node/condition_node/compare/compare.py b/apps/application/flow/step_node/condition_node/compare/compare.py new file mode 100644 index 000000000..6cbb4af07 --- /dev/null +++ b/apps/application/flow/step_node/condition_node/compare/compare.py @@ -0,0 +1,20 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: compare.py + @date:2024/6/7 14:37 + @desc: +""" +from abc import abstractmethod +from typing import List + + +class Compare: + @abstractmethod + def support(self, node_id, fields: List[str], source_value, compare, target_value): + pass + + @abstractmethod + def compare(self, source_value, compare, target_value): + pass diff --git a/apps/application/flow/step_node/condition_node/compare/contain_compare.py b/apps/application/flow/step_node/condition_node/compare/contain_compare.py new file mode 100644 index 000000000..fee2cddef --- /dev/null +++ b/apps/application/flow/step_node/condition_node/compare/contain_compare.py @@ -0,0 +1,21 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: contain_compare.py + @date:2024/6/11 10:02 + @desc: +""" +from typing import List + +from application.flow.step_node.condition_node.compare.compare import Compare + + +class ContainCompare(Compare): + + def support(self, node_id, fields: List[str], source_value, compare, target_value): + if compare == 'contain': + return True + + def compare(self, source_value, compare, target_value): + return any([str(item) == str(target_value) for item in source_value]) diff --git a/apps/application/flow/step_node/condition_node/compare/equal_compare.py b/apps/application/flow/step_node/condition_node/compare/equal_compare.py new file mode 100644 index 000000000..0061a82f6 --- /dev/null +++ b/apps/application/flow/step_node/condition_node/compare/equal_compare.py @@ -0,0 +1,21 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: equal_compare.py + @date:2024/6/7 14:44 + @desc: +""" +from typing import List + +from application.flow.step_node.condition_node.compare.compare import Compare + + +class EqualCompare(Compare): + + def support(self, node_id, fields: List[str], source_value, compare, target_value): + if compare == 'eq': + return True + + def compare(self, source_value, compare, target_value): + return str(source_value) == str(target_value) diff --git a/apps/application/flow/step_node/condition_node/compare/ge_compare.py b/apps/application/flow/step_node/condition_node/compare/ge_compare.py new file mode 100644 index 000000000..d4e22cbd6 --- /dev/null +++ b/apps/application/flow/step_node/condition_node/compare/ge_compare.py @@ -0,0 +1,24 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: lt_compare.py + @date:2024/6/11 9:52 + @desc: 大于比较器 +""" +from typing import List + +from application.flow.step_node.condition_node.compare.compare import Compare + + +class GECompare(Compare): + + def support(self, node_id, fields: List[str], source_value, compare, target_value): + if compare == 'ge': + return True + + def compare(self, source_value, compare, target_value): + try: + return float(source_value) >= float(target_value) + except Exception as e: + return False diff --git a/apps/application/flow/step_node/condition_node/compare/gt_compare.py b/apps/application/flow/step_node/condition_node/compare/gt_compare.py new file mode 100644 index 000000000..80942abb2 --- /dev/null +++ b/apps/application/flow/step_node/condition_node/compare/gt_compare.py @@ -0,0 +1,24 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: lt_compare.py + @date:2024/6/11 9:52 + @desc: 大于比较器 +""" +from typing import List + +from application.flow.step_node.condition_node.compare.compare import Compare + + +class GTCompare(Compare): + + def support(self, node_id, fields: List[str], source_value, compare, target_value): + if compare == 'gt': + return True + + def compare(self, source_value, compare, target_value): + try: + return float(source_value) > float(target_value) + except Exception as e: + return False diff --git a/apps/application/flow/step_node/condition_node/compare/le_compare.py b/apps/application/flow/step_node/condition_node/compare/le_compare.py new file mode 100644 index 000000000..77a0bca0f --- /dev/null +++ b/apps/application/flow/step_node/condition_node/compare/le_compare.py @@ -0,0 +1,24 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: lt_compare.py + @date:2024/6/11 9:52 + @desc: 小于比较器 +""" +from typing import List + +from application.flow.step_node.condition_node.compare.compare import Compare + + +class LECompare(Compare): + + def support(self, node_id, fields: List[str], source_value, compare, target_value): + if compare == 'le': + return True + + def compare(self, source_value, compare, target_value): + try: + return float(source_value) <= float(target_value) + except Exception as e: + return False diff --git a/apps/application/flow/step_node/condition_node/compare/len_equal_compare.py b/apps/application/flow/step_node/condition_node/compare/len_equal_compare.py new file mode 100644 index 000000000..f2b0764c5 --- /dev/null +++ b/apps/application/flow/step_node/condition_node/compare/len_equal_compare.py @@ -0,0 +1,24 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: equal_compare.py + @date:2024/6/7 14:44 + @desc: +""" +from typing import List + +from application.flow.step_node.condition_node.compare.compare import Compare + + +class LenEqualCompare(Compare): + + def support(self, node_id, fields: List[str], source_value, compare, target_value): + if compare == 'len_eq': + return True + + def compare(self, source_value, compare, target_value): + try: + return len(source_value) == int(target_value) + except Exception as e: + return False diff --git a/apps/application/flow/step_node/condition_node/compare/len_ge_compare.py b/apps/application/flow/step_node/condition_node/compare/len_ge_compare.py new file mode 100644 index 000000000..87f11eb2c --- /dev/null +++ b/apps/application/flow/step_node/condition_node/compare/len_ge_compare.py @@ -0,0 +1,24 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: lt_compare.py + @date:2024/6/11 9:52 + @desc: 大于比较器 +""" +from typing import List + +from application.flow.step_node.condition_node.compare.compare import Compare + + +class LenGECompare(Compare): + + def support(self, node_id, fields: List[str], source_value, compare, target_value): + if compare == 'len_ge': + return True + + def compare(self, source_value, compare, target_value): + try: + return len(source_value) >= int(target_value) + except Exception as e: + return False diff --git a/apps/application/flow/step_node/condition_node/compare/len_gt_compare.py b/apps/application/flow/step_node/condition_node/compare/len_gt_compare.py new file mode 100644 index 000000000..0532d353d --- /dev/null +++ b/apps/application/flow/step_node/condition_node/compare/len_gt_compare.py @@ -0,0 +1,24 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: lt_compare.py + @date:2024/6/11 9:52 + @desc: 大于比较器 +""" +from typing import List + +from application.flow.step_node.condition_node.compare.compare import Compare + + +class LenGTCompare(Compare): + + def support(self, node_id, fields: List[str], source_value, compare, target_value): + if compare == 'len_gt': + return True + + def compare(self, source_value, compare, target_value): + try: + return len(source_value) > int(target_value) + except Exception as e: + return False diff --git a/apps/application/flow/step_node/condition_node/compare/len_le_compare.py b/apps/application/flow/step_node/condition_node/compare/len_le_compare.py new file mode 100644 index 000000000..d315a754a --- /dev/null +++ b/apps/application/flow/step_node/condition_node/compare/len_le_compare.py @@ -0,0 +1,24 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: lt_compare.py + @date:2024/6/11 9:52 + @desc: 小于比较器 +""" +from typing import List + +from application.flow.step_node.condition_node.compare.compare import Compare + + +class LenLECompare(Compare): + + def support(self, node_id, fields: List[str], source_value, compare, target_value): + if compare == 'len_le': + return True + + def compare(self, source_value, compare, target_value): + try: + return len(source_value) <= int(target_value) + except Exception as e: + return False diff --git a/apps/application/flow/step_node/condition_node/compare/len_lt_compare.py b/apps/application/flow/step_node/condition_node/compare/len_lt_compare.py new file mode 100644 index 000000000..c89638cd7 --- /dev/null +++ b/apps/application/flow/step_node/condition_node/compare/len_lt_compare.py @@ -0,0 +1,24 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: lt_compare.py + @date:2024/6/11 9:52 + @desc: 小于比较器 +""" +from typing import List + +from application.flow.step_node.condition_node.compare.compare import Compare + + +class LenLTCompare(Compare): + + def support(self, node_id, fields: List[str], source_value, compare, target_value): + if compare == 'len_lt': + return True + + def compare(self, source_value, compare, target_value): + try: + return len(source_value) < int(target_value) + except Exception as e: + return False diff --git a/apps/application/flow/step_node/condition_node/compare/lt_compare.py b/apps/application/flow/step_node/condition_node/compare/lt_compare.py new file mode 100644 index 000000000..d2d5be748 --- /dev/null +++ b/apps/application/flow/step_node/condition_node/compare/lt_compare.py @@ -0,0 +1,24 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: lt_compare.py + @date:2024/6/11 9:52 + @desc: 小于比较器 +""" +from typing import List + +from application.flow.step_node.condition_node.compare.compare import Compare + + +class LTCompare(Compare): + + def support(self, node_id, fields: List[str], source_value, compare, target_value): + if compare == 'lt': + return True + + def compare(self, source_value, compare, target_value): + try: + return float(source_value) < float(target_value) + except Exception as e: + return False diff --git a/apps/application/flow/step_node/condition_node/i_condition_node.py b/apps/application/flow/step_node/condition_node/i_condition_node.py new file mode 100644 index 000000000..a828c30dd --- /dev/null +++ b/apps/application/flow/step_node/condition_node/i_condition_node.py @@ -0,0 +1,83 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: i_condition_node.py + @date:2024/6/7 9:54 + @desc: +""" +import json +from typing import Type + +from rest_framework import serializers + +from application.flow.i_step_node import INode +from common.util.field_message import ErrMessage + + +class ConditionSerializer(serializers.Serializer): + compare = serializers.CharField(required=True, error_messages=ErrMessage.char("比较器")) + value = serializers.CharField(required=True, error_messages=ErrMessage.char("")) + field = serializers.ListField(required=True, error_messages=ErrMessage.char("字段")) + + +class ConditionBranchSerializer(serializers.Serializer): + id = serializers.CharField(required=True, error_messages=ErrMessage.char("分支id")) + condition = serializers.CharField(required=True, error_messages=ErrMessage.char("条件or|and")) + conditions = ConditionSerializer(many=True) + + +class ConditionNodeParamsSerializer(serializers.Serializer): + branch = ConditionBranchSerializer(many=True) + + +j = """ + { "branch": [ + { + "conditions": [ + { + "field": [ + "34902d3d-a3ff-497f-b8e1-0c34a44d7dd5", + "paragraph_list" + ], + "compare": "len_eq", + "value": "0" + } + ], + "id": "2391", + "condition": "and" + }, + { + "conditions": [ + { + "field": [ + "34902d3d-a3ff-497f-b8e1-0c34a44d7dd5", + "paragraph_list" + ], + "compare": "len_eq", + "value": "1" + } + ], + "id": "1143", + "condition": "and" + }, + { + "conditions": [ + + ], + "id": "9208", + "condition": "and" + } + ]} +""" +a = json.loads(j) +c = ConditionNodeParamsSerializer(data=a) +c.is_valid(raise_exception=True) +print(c.data) + + +class IConditionNode(INode): + def get_node_params_serializer_class(self) -> Type[serializers.Serializer]: + return ConditionNodeParamsSerializer + + type = 'condition-node' diff --git a/apps/application/flow/step_node/condition_node/impl/__init__.py b/apps/application/flow/step_node/condition_node/impl/__init__.py new file mode 100644 index 000000000..c21cd3ebb --- /dev/null +++ b/apps/application/flow/step_node/condition_node/impl/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py + @date:2024/6/11 15:35 + @desc: +""" +from .base_condition_node import BaseConditionNode diff --git a/apps/application/flow/step_node/condition_node/impl/base_condition_node.py b/apps/application/flow/step_node/condition_node/impl/base_condition_node.py new file mode 100644 index 000000000..2920e00c1 --- /dev/null +++ b/apps/application/flow/step_node/condition_node/impl/base_condition_node.py @@ -0,0 +1,47 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: base_condition_node.py + @date:2024/6/7 11:29 + @desc: +""" +from typing import List + +from application.flow.i_step_node import NodeResult +from application.flow.step_node.condition_node.compare import compare_handle_list +from application.flow.step_node.condition_node.i_condition_node import IConditionNode + + +class BaseConditionNode(IConditionNode): + def execute(self, **kwargs) -> NodeResult: + branch_list = self.node_params_serializer.data['branch'] + branch = self._execute(branch_list) + r = NodeResult({'branch_id': branch.get('id')}, {}) + return r + + def _execute(self, branch_list: List): + for branch in branch_list: + if self.branch_assertion(branch): + return branch + + def branch_assertion(self, branch): + condition_list = [self.assertion(row.get('field'), row.get('compare'), row.get('value')) for row in + branch.get('conditions')] + condition = branch.get('condition') + return all(condition_list) if condition == 'and' else any(condition_list) + + def assertion(self, field_list: List[str], compare: str, value): + field_value = self.workflow_manage.get_reference_field(field_list[0], field_list[1:]) + for compare_handler in compare_handle_list: + if compare_handler.support(field_list[0], field_list[1:], field_value, compare, value): + return compare_handler.compare(field_value, compare, value) + + def get_details(self, index: int, **kwargs): + return { + "index": index, + 'run_time': self.context.get('run_time'), + 'branch_id': self.context.get('branch_id'), + 'branch_name': self.context.get('branch_name'), + 'type': self.node.type + } diff --git a/apps/application/flow/step_node/direct_reply_node/__init__.py b/apps/application/flow/step_node/direct_reply_node/__init__.py new file mode 100644 index 000000000..cf360f956 --- /dev/null +++ b/apps/application/flow/step_node/direct_reply_node/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py + @date:2024/6/11 17:50 + @desc: +""" +from .impl import * \ No newline at end of file diff --git a/apps/application/flow/step_node/direct_reply_node/i_reply_node.py b/apps/application/flow/step_node/direct_reply_node/i_reply_node.py new file mode 100644 index 000000000..45bbebc38 --- /dev/null +++ b/apps/application/flow/step_node/direct_reply_node/i_reply_node.py @@ -0,0 +1,45 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: i_reply_node.py + @date:2024/6/11 16:25 + @desc: +""" +from typing import Type + +from rest_framework import serializers + +from application.flow.i_step_node import INode, NodeResult +from common.exception.app_exception import AppApiException +from common.util.field_message import ErrMessage + + +class ReplyNodeParamsSerializer(serializers.Serializer): + reply_type = serializers.CharField(required=True, error_messages=ErrMessage.char("回复类型")) + fields = serializers.ListField(required=False, error_messages=ErrMessage.list("引用字段")) + content = serializers.CharField(required=False, error_messages=ErrMessage.char("直接回答内容")) + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + if self.data.get('reply_type') == 'referencing': + if 'fields' not in self.data: + raise AppApiException(500, "引用字段不能为空") + if len(self.data.get('fields')) < 2: + raise AppApiException(500, "引用字段错误") + else: + if 'content' not in self.data or self.data.get('content') is None: + raise AppApiException(500, "内容不能为空") + + +class IReplyNode(INode): + type = 'reply-node' + + def get_node_params_serializer_class(self) -> Type[serializers.Serializer]: + return ReplyNodeParamsSerializer + + def _run(self): + return self.execute(**self.flow_params_serializer.data) + + def execute(self, reply_type, stream, fields=None, content=None, **kwargs) -> NodeResult: + pass diff --git a/apps/application/flow/step_node/direct_reply_node/impl/__init__.py b/apps/application/flow/step_node/direct_reply_node/impl/__init__.py new file mode 100644 index 000000000..3307e9089 --- /dev/null +++ b/apps/application/flow/step_node/direct_reply_node/impl/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py + @date:2024/6/11 17:49 + @desc: +""" +from .base_reply_node import * \ No newline at end of file diff --git a/apps/application/flow/step_node/direct_reply_node/impl/base_reply_node.py b/apps/application/flow/step_node/direct_reply_node/impl/base_reply_node.py new file mode 100644 index 000000000..94a78a860 --- /dev/null +++ b/apps/application/flow/step_node/direct_reply_node/impl/base_reply_node.py @@ -0,0 +1,87 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: base_reply_node.py + @date:2024/6/11 17:25 + @desc: +""" +from typing import List, Dict + +from langchain_core.messages import AIMessage, AIMessageChunk + +from application.flow import tools +from application.flow.i_step_node import NodeResult, INode +from application.flow.step_node.direct_reply_node.i_reply_node import IReplyNode + + +def get_to_response_write_context(node_variable: Dict, node: INode): + def _write_context(answer): + node.context['answer'] = answer + + return _write_context + + +def to_stream_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow, + post_handler): + """ + 将流式数据 转换为 流式响应 + @param chat_id: 会话id + @param chat_record_id: 对话记录id + @param node_variable: 节点数据 + @param workflow_variable: 工作流数据 + @param node: 节点 + @param workflow: 工作流管理器 + @param post_handler: 后置处理器 输出结果后执行 + @return: 流式响应 + """ + response = node_variable.get('result') + _write_context = get_to_response_write_context(node_variable, node) + return tools.to_stream_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler) + + +def to_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow, + post_handler): + """ + 将结果转换 + @param chat_id: 会话id + @param chat_record_id: 对话记录id + @param node_variable: 节点数据 + @param workflow_variable: 工作流数据 + @param node: 节点 + @param workflow: 工作流管理器 + @param post_handler: 后置处理器 + @return: 响应 + """ + response = node_variable.get('result') + _write_context = get_to_response_write_context(node_variable, node) + return tools.to_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler) + + +class BaseReplyNode(IReplyNode): + def execute(self, reply_type, stream, fields=None, content=None, **kwargs) -> NodeResult: + if reply_type == 'referencing': + result = self.get_reference_content(fields) + else: + result = self.generate_reply_content(content) + if stream: + return NodeResult({'result': iter(AIMessageChunk(content=result))}, {}, + _to_response=to_stream_response) + else: + return NodeResult({'result': AIMessage(content=result)}, {}, _to_response=to_response) + + def generate_reply_content(self, prompt): + return self.workflow_manage.generate_prompt(prompt) + + def get_reference_content(self, fields: List[str]): + return self.workflow_manage.get_reference_field( + fields[0], + fields[1:]) + + def get_details(self, index: int, **kwargs): + return { + "index": index, + 'run_time': self.context.get('run_time'), + 'type': self.node.type, + 'answer': self.context.get('answer'), + } diff --git a/apps/application/flow/step_node/question_node/__init__.py b/apps/application/flow/step_node/question_node/__init__.py new file mode 100644 index 000000000..98a1afcd9 --- /dev/null +++ b/apps/application/flow/step_node/question_node/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py + @date:2024/6/11 15:30 + @desc: +""" +from .impl import * diff --git a/apps/application/flow/step_node/question_node/i_question_node.py b/apps/application/flow/step_node/question_node/i_question_node.py new file mode 100644 index 000000000..ede120def --- /dev/null +++ b/apps/application/flow/step_node/question_node/i_question_node.py @@ -0,0 +1,37 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: i_chat_node.py + @date:2024/6/4 13:58 + @desc: +""" +from typing import Type + +from rest_framework import serializers + +from application.flow.i_step_node import INode, NodeResult +from common.util.field_message import ErrMessage + + +class QuestionNodeSerializer(serializers.Serializer): + model_id = serializers.CharField(required=True, error_messages=ErrMessage.char("模型id")) + system = serializers.CharField(required=False, allow_blank=True, allow_null=True, + error_messages=ErrMessage.char("角色设定")) + prompt = serializers.CharField(required=True, error_messages=ErrMessage.char("提示词")) + # 多轮对话数量 + dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer("多轮对话数量")) + + +class IQuestionNode(INode): + type = 'question-node' + + def get_node_params_serializer_class(self) -> Type[serializers.Serializer]: + return QuestionNodeSerializer + + def _run(self): + return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data) + + def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id, + **kwargs) -> NodeResult: + pass diff --git a/apps/application/flow/step_node/question_node/impl/__init__.py b/apps/application/flow/step_node/question_node/impl/__init__.py new file mode 100644 index 000000000..d85aa8724 --- /dev/null +++ b/apps/application/flow/step_node/question_node/impl/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py + @date:2024/6/11 15:35 + @desc: +""" +from .base_question_node import BaseQuestionNode diff --git a/apps/application/flow/step_node/question_node/impl/base_question_node.py b/apps/application/flow/step_node/question_node/impl/base_question_node.py new file mode 100644 index 000000000..371e55ac0 --- /dev/null +++ b/apps/application/flow/step_node/question_node/impl/base_question_node.py @@ -0,0 +1,177 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: base_question_node.py + @date:2024/6/4 14:30 + @desc: +""" +import json +from typing import List, Dict + +from django.db.models import QuerySet +from langchain.schema import HumanMessage, SystemMessage +from langchain_core.messages import BaseMessage + +from application.flow import tools +from application.flow.i_step_node import NodeResult, INode +from application.flow.step_node.question_node.i_question_node import IQuestionNode +from common.util.rsa_util import rsa_long_decrypt +from setting.models import Model +from setting.models_provider.constants.model_provider_constants import ModelProvideConstants + + +def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow): + """ + 写入上下文数据 (流式) + @param node_variable: 节点数据 + @param workflow_variable: 全局数据 + @param node: 节点 + @param workflow: 工作流管理器 + """ + response = node_variable.get('result') + answer = '' + for chunk in response: + answer += chunk.content + chat_model = node_variable.get('chat_model') + message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list')) + answer_tokens = chat_model.get_num_tokens(answer) + node.context['message_tokens'] = message_tokens + node.context['answer_tokens'] = answer_tokens + node.context['answer'] = answer + node.context['history_message'] = node_variable['history_message'] + node.context['question'] = node_variable['question'] + + +def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow): + """ + 写入上下文数据 + @param node_variable: 节点数据 + @param workflow_variable: 全局数据 + @param node: 节点实例对象 + @param workflow: 工作流管理器 + """ + response = node_variable.get('result') + chat_model = node_variable.get('chat_model') + answer = response.content + message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list')) + answer_tokens = chat_model.get_num_tokens(answer) + node.context['message_tokens'] = message_tokens + node.context['answer_tokens'] = answer_tokens + node.context['answer'] = answer + node.context['history_message'] = node_variable['history_message'] + node.context['question'] = node_variable['question'] + + +def get_to_response_write_context(node_variable: Dict, node: INode): + def _write_context(answer): + chat_model = node_variable.get('chat_model') + message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list')) + answer_tokens = chat_model.get_num_tokens(answer) + node.context['message_tokens'] = message_tokens + node.context['answer_tokens'] = answer_tokens + node.context['answer'] = answer + + return _write_context + + +def to_stream_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow, + post_handler): + """ + 将流式数据 转换为 流式响应 + @param chat_id: 会话id + @param chat_record_id: 对话记录id + @param node_variable: 节点数据 + @param workflow_variable: 工作流数据 + @param node: 节点 + @param workflow: 工作流管理器 + @param post_handler: 后置处理器 输出结果后执行 + @return: 流式响应 + """ + response = node_variable.get('result') + _write_context = get_to_response_write_context(node_variable, node) + return tools.to_stream_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler) + + +def to_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow, + post_handler): + """ + 将结果转换 + @param chat_id: 会话id + @param chat_record_id: 对话记录id + @param node_variable: 节点数据 + @param workflow_variable: 工作流数据 + @param node: 节点 + @param workflow: 工作流管理器 + @param post_handler: 后置处理器 + @return: 响应 + """ + response = node_variable.get('result') + _write_context = get_to_response_write_context(node_variable, node) + return tools.to_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler) + + +class BaseQuestionNode(IQuestionNode): + def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id, + **kwargs) -> NodeResult: + model = QuerySet(Model).filter(id=model_id).first() + chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name, + json.loads( + rsa_long_decrypt(model.credential)), + streaming=True) + history_message = self.get_history_message(history_chat_record, dialogue_number) + question = self.generate_prompt_question(prompt) + message_list = self.generate_message_list(system, prompt, history_message) + if stream: + r = chat_model.stream(message_list) + return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list, + 'get_to_response_write_context': get_to_response_write_context, + 'history_message': history_message, 'question': question}, {}, + _write_context=write_context_stream, + _to_response=to_stream_response) + else: + r = chat_model.invoke(message_list) + return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list, + 'history_message': history_message, 'question': question}, {}, + _write_context=write_context, _to_response=to_response) + + @staticmethod + def get_history_message(history_chat_record, dialogue_number): + start_index = len(history_chat_record) - dialogue_number + history_message = [[history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()] + for index in + range(start_index if start_index > 0 else 0, len(history_chat_record))] + return history_message + + def generate_prompt_question(self, prompt): + return HumanMessage(self.workflow_manage.generate_prompt(prompt)) + + def generate_message_list(self, system: str, prompt: str, history_message): + if system is None or len(system) == 0: + return [SystemMessage(self.workflow_manage.generate_prompt(system)), *history_message, + HumanMessage(self.workflow_manage.generate_prompt(prompt))] + else: + return [*history_message, HumanMessage(self.workflow_manage.generate_prompt(prompt))] + + @staticmethod + def reset_message_list(message_list: List[BaseMessage], answer_text): + result = [{'role': 'user' if isinstance(message, HumanMessage) else 'ai', 'content': message.content} for + message + in + message_list] + result.append({'role': 'ai', 'content': answer_text}) + return result + + def get_details(self, index: int, **kwargs): + return { + "index": index, + 'run_time': self.context.get('run_time'), + 'system': self.node_params.get('system'), + 'history_message': [{'content': message.content, 'role': message.type} for message in + self.context.get('history_message')], + 'question': self.context.get('question'), + 'answer': self.context.get('answer'), + 'type': self.node.type, + 'message_tokens': self.context['message_tokens'], + 'answer_tokens': self.context['answer_tokens'] + } diff --git a/apps/application/flow/step_node/search_dataset_node/__init__.py b/apps/application/flow/step_node/search_dataset_node/__init__.py new file mode 100644 index 000000000..98a1afcd9 --- /dev/null +++ b/apps/application/flow/step_node/search_dataset_node/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py + @date:2024/6/11 15:30 + @desc: +""" +from .impl import * diff --git a/apps/application/flow/step_node/search_dataset_node/i_search_dataset_node.py b/apps/application/flow/step_node/search_dataset_node/i_search_dataset_node.py index b438e363d..7476ecd04 100644 --- a/apps/application/flow/step_node/search_dataset_node/i_search_dataset_node.py +++ b/apps/application/flow/step_node/search_dataset_node/i_search_dataset_node.py @@ -16,10 +16,7 @@ from application.flow.i_step_node import INode, ReferenceAddressSerializer, Node from common.util.field_message import ErrMessage -class SearchDatasetStepNodeSerializer(serializers.Serializer): - # 需要查询的数据集id列表 - dataset_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True), - error_messages=ErrMessage.list("数据集id列表")) +class DatasetSettingSerializer(serializers.Serializer): # 需要查询的条数 top_n = serializers.IntegerField(required=True, error_messages=ErrMessage.integer("引用分段数")) @@ -30,9 +27,17 @@ class SearchDatasetStepNodeSerializer(serializers.Serializer): validators.RegexValidator(regex=re.compile("^embedding|keywords|blend$"), message="类型只支持register|reset_password", code=500) ], error_messages=ErrMessage.char("检索模式")) + max_paragraph_char_number = serializers.IntegerField(required=True, + error_messages=ErrMessage.float("最大引用分段字数")) - question_reference_address = ReferenceAddressSerializer(required=False, - error_messages=ErrMessage.char("问题应用地址")) + +class SearchDatasetStepNodeSerializer(serializers.Serializer): + # 需要查询的数据集id列表 + dataset_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True), + error_messages=ErrMessage.list("数据集id列表")) + dataset_setting = DatasetSettingSerializer(required=True) + + question_reference_address = serializers.ListField(required=True, ) def is_valid(self, *, raise_exception=False): super().is_valid(raise_exception=True) @@ -46,11 +51,11 @@ class ISearchDatasetStepNode(INode): def _run(self): question = self.workflow_manage.get_reference_field( - self.node_params_serializer.data.get('question_reference_address').get('node_id'), - self.node_params_serializer.data.get('question_reference_address').get('fields')) + self.node_params_serializer.data.get('question_reference_address')[0], + self.node_params_serializer.data.get('question_reference_address')[1:]) return self.execute(**self.node_params_serializer.data, question=question, exclude_paragraph_id_list=[]) - def execute(self, dataset_id_list, top_n, similarity, search_mode, question_reference_address, question, + def execute(self, dataset_id_list, dataset_setting, question, exclude_paragraph_id_list=None, **kwargs) -> NodeResult: pass diff --git a/apps/application/flow/step_node/search_dataset_node/impl/__init__.py b/apps/application/flow/step_node/search_dataset_node/impl/__init__.py new file mode 100644 index 000000000..a9cff0d09 --- /dev/null +++ b/apps/application/flow/step_node/search_dataset_node/impl/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py + @date:2024/6/11 15:35 + @desc: +""" +from .base_search_dataset_node import BaseSearchDatasetNode diff --git a/apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py b/apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py index 28b0eb3c2..26975e4b8 100644 --- a/apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py +++ b/apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py @@ -22,7 +22,7 @@ from smartdoc.conf import PROJECT_DIR class BaseSearchDatasetNode(ISearchDatasetStepNode): - def execute(self, dataset_id_list, top_n, similarity, search_mode, question_reference_address, question, + def execute(self, dataset_id_list, dataset_setting, question, exclude_paragraph_id_list=None, **kwargs) -> NodeResult: embedding_model = EmbeddingModel.get_embedding_model() @@ -33,7 +33,8 @@ class BaseSearchDatasetNode(ISearchDatasetStepNode): dataset_id__in=dataset_id_list, is_active=False)] embedding_list = vector.query(question, embedding_value, dataset_id_list, exclude_document_id_list, - exclude_paragraph_id_list, True, top_n, similarity, SearchMode(search_mode)) + exclude_paragraph_id_list, True, dataset_setting.get('top_n'), + dataset_setting.get('similarity'), SearchMode(dataset_setting.get('search_mode'))) if embedding_list is None: return NodeResult({'paragraph_list': [], 'is_hit_handling_method': []}, {}) paragraph_list = self.list_paragraph(embedding_list, vector) @@ -71,3 +72,11 @@ class BaseSearchDatasetNode(ISearchDatasetStepNode): if not exist_paragraph_list.__contains__(paragraph_id): vector.delete_by_paragraph_id(paragraph_id) return paragraph_list + + def get_details(self, index: int, **kwargs): + return { + "index": index, + 'run_time': self.context.get('run_time'), + 'paragraph_list': self.context.get('paragraph_list'), + 'type': self.node.type + } diff --git a/apps/application/flow/step_node/start_node/__init__.py b/apps/application/flow/step_node/start_node/__init__.py new file mode 100644 index 000000000..98a1afcd9 --- /dev/null +++ b/apps/application/flow/step_node/start_node/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py + @date:2024/6/11 15:30 + @desc: +""" +from .impl import * diff --git a/apps/application/flow/step_node/start_node/impl/__init__.py b/apps/application/flow/step_node/start_node/impl/__init__.py new file mode 100644 index 000000000..b68a92d02 --- /dev/null +++ b/apps/application/flow/step_node/start_node/impl/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py + @date:2024/6/11 15:36 + @desc: +""" +from .base_start_node import BaseStartStepNode diff --git a/apps/application/flow/step_node/start_node/impl/base_start_node.py b/apps/application/flow/step_node/start_node/impl/base_start_node.py index 77a721ad8..f75a03e68 100644 --- a/apps/application/flow/step_node/start_node/impl/base_start_node.py +++ b/apps/application/flow/step_node/start_node/impl/base_start_node.py @@ -18,3 +18,11 @@ class BaseStartStepNode(IStarNode): 开始节点 初始化全局变量 """ return NodeResult({'question': question}, {'time': time.time()}) + + def get_details(self, index: int, **kwargs): + return { + "index": index, + "question": self.context.get('question'), + 'run_time': self.context.get('run_time'), + 'type': self.node.type + } diff --git a/apps/application/flow/workflow_manage.py b/apps/application/flow/workflow_manage.py index 69e8d3176..12f4a5a06 100644 --- a/apps/application/flow/workflow_manage.py +++ b/apps/application/flow/workflow_manage.py @@ -10,10 +10,8 @@ from typing import List, Dict from langchain_core.prompts import PromptTemplate -from application.flow.i_step_node import INode, WorkFlowPostHandler -from application.flow.step_node.ai_chat_step_node.impl.base_chat_node import BaseChatNode -from application.flow.step_node.search_dataset_node.impl.base_search_dataset_node import BaseSearchDatasetNode -from application.flow.step_node.start_node.impl.base_start_node import BaseStartStepNode +from application.flow.i_step_node import INode, WorkFlowPostHandler, NodeResult +from application.flow.step_node import get_node class Edge: @@ -52,41 +50,36 @@ class Flow: return Flow(nodes, edges) -flow_node_dict = { - 'start-node': BaseStartStepNode, - 'search-dataset-node': BaseSearchDatasetNode, - 'chat-node': BaseChatNode -} - - class WorkflowManage: def __init__(self, flow: Flow, params): self.params = params self.flow = flow self.context = {} - self.node_dict = {} - self.runtime_nodes = [] + self.node_context = [] self.current_node = None + self.current_result = None def run(self): """ 运行工作流 """ - while self.has_next_node(): + while self.has_next_node(self.current_result): self.current_node = self.get_next_node() - self.node_dict[self.current_node.id] = self.current_node - result = self.current_node.run() - if self.has_next_node(): - result.write_context(self.current_node, self) + self.node_context.append(self.current_node) + self.current_result = self.current_node.run() + if self.has_next_node(self.current_result): + self.current_result.write_context(self.current_node, self) else: - r = result.to_response(self.params['chat_id'], self.params['chat_record_id'], self.current_node, self, - WorkFlowPostHandler(client_id=self.params['client_id'], chat_info=None, - client_type='ss')) + r = self.current_result.to_response(self.params['chat_id'], self.params['chat_record_id'], + self.current_node, self, + WorkFlowPostHandler(client_id=self.params['client_id'], + chat_info=None, + client_type='APPLICATION_ACCESS_TOKEN')) for row in r: print(row) print(self) - def has_next_node(self): + def has_next_node(self, node_result: NodeResult | None): """ 是否有下一个可运行的节点 """ @@ -94,13 +87,24 @@ class WorkflowManage: if self.get_start_node() is not None: return True else: - for edge in self.flow.edges: - if edge.sourceNodeId == self.current_node.id: - return True + if node_result is not None and node_result.is_assertion_result(): + for edge in self.flow.edges: + if (edge.sourceNodeId == self.current_node.id and + f"{edge.sourceNodeId}_{node_result.node_variable.get('branch_id')}_right" == edge.sourceAnchorId): + return True + else: + for edge in self.flow.edges: + if edge.sourceNodeId == self.current_node.id: + return True return False def get_runtime_details(self): - return {} + details_result = {} + for index in range(len(self.node_context)): + node = self.node_context[index] + details = node.get_details({'index': index}) + details_result[node.id] = details + return details_result def get_next_node(self): """ @@ -108,8 +112,7 @@ class WorkflowManage: """ if self.current_node is None: node = self.get_start_node() - node_instance = flow_node_dict[node.type](node.id, node.properties.get('node_data'), - self.params, self.context) + node_instance = get_node(node.type)(node, self.params, self.context) return node_instance for edge in self.flow.edges: if edge.sourceNodeId == self.current_node.id: @@ -138,8 +141,9 @@ class WorkflowManage: context = { 'global': self.context, } - for key in self.node_dict: - context[key] = self.node_dict[key].context + + for node in self.node_context: + context[node.id] = node.context value = prompt_template.format(context=context) return value @@ -148,18 +152,22 @@ class WorkflowManage: 获取启动节点 @return: """ - return self.flow.nodes[0] + start_node_list = [node for node in self.flow.nodes if node.type == 'start-node'] + return start_node_list[0] def get_node_cls_by_id(self, node_id): for node in self.flow.nodes: if node.id == node_id: - node_instance = flow_node_dict[node.type](node.id, node.properties.get('node_data'), - self.params, self) + node_instance = get_node(node.type)(node, + self.params, self) return node_instance return None def get_node_by_id(self, node_id): - return self.node_dict[node_id] + for node in self.node_context: + if node.id == node_id: + return node + return None def get_node_reference(self, reference_address: Dict): node = self.get_node_by_id(reference_address.get('node_id'))