From 2f2f74fdab23ed297025e5704a28487b94d2dc39 Mon Sep 17 00:00:00 2001
From: shaohuzhang1 <80892890+shaohuzhang1@users.noreply.github.com>
Date: Mon, 1 Jul 2024 09:45:59 +0800
Subject: [PATCH] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81=E5=B7=A5=E4=BD=9C?=
=?UTF-8?q?=E6=B5=81=20(#671)?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
apps/application/flow/__init__.py | 8 +
apps/application/flow/default_workflow.json | 426 +++++++++++++++
apps/application/flow/i_step_node.py | 190 +++++++
apps/application/flow/step_node/__init__.py | 23 +
.../step_node/ai_chat_step_node/__init__.py | 9 +
.../ai_chat_step_node/i_chat_node.py | 37 ++
.../ai_chat_step_node/impl/__init__.py | 9 +
.../ai_chat_step_node/impl/base_chat_node.py | 195 +++++++
.../flow/step_node/condition_node/__init__.py | 9 +
.../condition_node/compare/__init__.py | 23 +
.../condition_node/compare/compare.py | 20 +
.../condition_node/compare/contain_compare.py | 23 +
.../condition_node/compare/equal_compare.py | 21 +
.../condition_node/compare/ge_compare.py | 24 +
.../condition_node/compare/gt_compare.py | 24 +
.../compare/is_not_null_compare.py | 21 +
.../condition_node/compare/is_null_compare.py | 21 +
.../condition_node/compare/le_compare.py | 24 +
.../compare/len_equal_compare.py | 24 +
.../condition_node/compare/len_ge_compare.py | 24 +
.../condition_node/compare/len_gt_compare.py | 24 +
.../condition_node/compare/len_le_compare.py | 24 +
.../condition_node/compare/len_lt_compare.py | 24 +
.../condition_node/compare/lt_compare.py | 24 +
.../compare/not_contain_compare.py | 23 +
.../condition_node/i_condition_node.py | 39 ++
.../step_node/condition_node/impl/__init__.py | 9 +
.../impl/base_condition_node.py | 50 ++
.../step_node/direct_reply_node/__init__.py | 9 +
.../direct_reply_node/i_reply_node.py | 46 ++
.../direct_reply_node/impl/__init__.py | 9 +
.../direct_reply_node/impl/base_reply_node.py | 90 ++++
.../flow/step_node/question_node/__init__.py | 9 +
.../question_node/i_question_node.py | 37 ++
.../step_node/question_node/impl/__init__.py | 9 +
.../question_node/impl/base_question_node.py | 196 +++++++
.../step_node/search_dataset_node/__init__.py | 9 +
.../i_search_dataset_node.py | 61 +++
.../search_dataset_node/impl/__init__.py | 9 +
.../impl/base_search_dataset_node.py | 93 ++++
.../flow/step_node/start_node/__init__.py | 9 +
.../flow/step_node/start_node/i_start_node.py | 26 +
.../step_node/start_node/impl/__init__.py | 9 +
.../start_node/impl/base_start_node.py | 33 ++
apps/application/flow/tools.py | 87 +++
apps/application/flow/workflow_manage.py | 282 ++++++++++
...ion_type_application_work_flow_and_more.py | 38 ++
apps/application/models/application.py | 32 +-
.../serializers/application_serializers.py | 101 +++-
.../serializers/chat_message_serializers.py | 82 ++-
.../serializers/chat_serializers.py | 56 +-
.../swagger_api/application_api.py | 35 +-
apps/application/swagger_api/chat_api.py | 11 +
apps/application/urls.py | 2 +
apps/application/views/application_views.py | 26 +-
apps/application/views/chat_views.py | 12 +
apps/common/db/compiler.py | 114 +++-
apps/common/response/result.py | 1 +
.../model/qian_fan_chat_model.py | 56 --
package-lock.json | 6 +
pyproject.toml | 4 +-
ui/index.html | 2 +-
ui/package.json | 16 +-
ui/src/App.vue | 4 +-
ui/src/api/application.ts | 29 +-
ui/src/api/type/application.ts | 2 +
ui/src/assets/MaxKB-logo.svg | 1 +
ui/src/assets/icon_condition.svg | 3 +
ui/src/assets/icon_globe_color.svg | 1 +
ui/src/assets/icon_hi.svg | 5 +
ui/src/assets/icon_reply.svg | 3 +
ui/src/assets/icon_setting.svg | 3 +
ui/src/assets/icon_start.svg | 4 +
.../ai-chat/ExecutionDetailDialog.vue | 209 +++++++
ui/src/components/ai-chat/KnowledgeSource.vue | 79 +++
.../ai-chat/ParagraphSourceDialog.vue | 54 +-
.../ai-chat/component/ParagraphCard.vue | 58 ++
ui/src/components/ai-chat/index.vue | 90 ++--
ui/src/components/app-avatar/index.vue | 2 +
ui/src/components/card-add/index.vue | 4 +-
.../items/complex/ArrayObjectCard.vue | 4 +-
.../items/table/ProgressTableItem.vue | 8 +-
ui/src/components/icons/index.ts | 160 +++++-
ui/src/components/index.ts | 6 +-
ui/src/components/layout-container/index.vue | 5 +-
ui/src/components/login-container/index.vue | 15 +-
ui/src/components/markdown-editor/index.vue | 56 --
ui/src/components/markdown-renderer/index.vue | 66 ---
ui/src/components/markdown/MdEditor.vue | 14 +
ui/src/components/markdown/MdPreview.vue | 8 +
.../MdRenderer.vue | 3 +-
.../assets/markdown-iconfont.js} | 0
ui/src/directives/clickoutside.ts | 7 +
ui/src/directives/resize.ts | 31 ++
ui/src/enums/application.ts | 5 +
ui/src/enums/workflow.ts | 9 +
ui/src/layout/components/breadcrumb/index.vue | 38 +-
.../layout/components/sidebar/SidebarItem.vue | 33 +-
ui/src/layout/components/sidebar/index.vue | 2 +-
.../components/top-bar/avatar/AboutDialog.vue | 24 +-
ui/src/layout/components/top-bar/index.vue | 18 +-
.../locales/lang/zh_CN/views/application.ts | 4 +-
ui/src/main.ts | 40 +-
ui/src/router/modules/application.ts | 16 +-
ui/src/router/routes.ts | 8 +
ui/src/styles/app.scss | 46 +-
ui/src/styles/element-plus.scss | 2 +-
ui/src/styles/index.scss | 3 +-
ui/src/utils/application.ts | 4 +
ui/src/utils/utils.ts | 2 +-
.../component/APIKeyDialog.vue | 52 +-
.../component/EmbedDialog.vue | 4 +-
.../component/LimitDialog.vue | 42 +-
ui/src/views/application-workflow/index.vue | 312 +++++++++++
.../views/application/ApplicationSetting.vue | 508 ++++++++++++++++++
ui/src/views/application/CreateAndSetting.vue | 21 +-
.../AddDatasetDialog.vue | 22 +-
.../component/CreateApplicationDialog.vue | 191 +++++++
.../ParamSettingDialog.vue | 106 ++--
ui/src/views/application/index.vue | 31 +-
ui/src/views/chat/base/index.vue | 3 +-
ui/src/views/chat/embed/index.vue | 14 +-
ui/src/views/chat/pc/index.vue | 3 +-
ui/src/views/dataset/DatasetSetting.vue | 67 ++-
.../dataset/component/UploadComponent.vue | 4 +-
ui/src/views/dataset/index.vue | 2 +-
ui/src/views/dataset/step/ResultSuccess.vue | 8 +-
ui/src/views/hit-test/index.vue | 7 +-
.../paragraph/component/ParagraphForm.vue | 6 +-
.../team/component/PermissionSetting.vue | 9 +-
.../template/component/CreateModelDialog.vue | 1 +
.../component/SelectProviderDialog.vue | 1 +
ui/src/workflow/common/CustomLine.vue | 37 ++
ui/src/workflow/common/NodeCascader.vue | 104 ++++
ui/src/workflow/common/NodeContainer.vue | 187 +++++++
ui/src/workflow/common/NodeControl.vue | 34 ++
ui/src/workflow/common/app-node.ts | 235 ++++++++
ui/src/workflow/common/data.ts | 172 ++++++
ui/src/workflow/common/edge.ts | 184 +++++++
ui/src/workflow/common/shortcut.ts | 136 +++++
ui/src/workflow/common/validate.ts | 137 +++++
ui/src/workflow/icons/ai-chat-node-icon.vue | 6 +
ui/src/workflow/icons/base-node-icon.vue | 6 +
ui/src/workflow/icons/condition-node-icon.vue | 6 +
ui/src/workflow/icons/global-icon.vue | 4 +
ui/src/workflow/icons/question-node-icon.vue | 6 +
ui/src/workflow/icons/reply-node-icon.vue | 6 +
.../icons/search-dataset-node-icon.vue | 6 +
ui/src/workflow/icons/start-node-icon.vue | 6 +
ui/src/workflow/icons/utils.ts | 5 +
ui/src/workflow/index.vue | 154 ++++++
ui/src/workflow/nodes/ai-chat-node/index.ts | 12 +
ui/src/workflow/nodes/ai-chat-node/index.vue | 241 +++++++++
ui/src/workflow/nodes/base-node/index.ts | 12 +
ui/src/workflow/nodes/base-node/index.vue | 92 ++++
ui/src/workflow/nodes/condition-node/index.ts | 67 +++
.../workflow/nodes/condition-node/index.vue | 296 ++++++++++
ui/src/workflow/nodes/question-node/index.ts | 12 +
ui/src/workflow/nodes/question-node/index.vue | 240 +++++++++
ui/src/workflow/nodes/reply-node/index.ts | 12 +
ui/src/workflow/nodes/reply-node/index.vue | 123 +++++
.../nodes/search-dataset-node/index.ts | 12 +
.../nodes/search-dataset-node/index.vue | 217 ++++++++
ui/src/workflow/nodes/start-node/index.ts | 12 +
ui/src/workflow/nodes/start-node/index.vue | 34 ++
165 files changed, 7791 insertions(+), 613 deletions(-)
create mode 100644 apps/application/flow/__init__.py
create mode 100644 apps/application/flow/default_workflow.json
create mode 100644 apps/application/flow/i_step_node.py
create mode 100644 apps/application/flow/step_node/__init__.py
create mode 100644 apps/application/flow/step_node/ai_chat_step_node/__init__.py
create mode 100644 apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py
create mode 100644 apps/application/flow/step_node/ai_chat_step_node/impl/__init__.py
create mode 100644 apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py
create mode 100644 apps/application/flow/step_node/condition_node/__init__.py
create mode 100644 apps/application/flow/step_node/condition_node/compare/__init__.py
create mode 100644 apps/application/flow/step_node/condition_node/compare/compare.py
create mode 100644 apps/application/flow/step_node/condition_node/compare/contain_compare.py
create mode 100644 apps/application/flow/step_node/condition_node/compare/equal_compare.py
create mode 100644 apps/application/flow/step_node/condition_node/compare/ge_compare.py
create mode 100644 apps/application/flow/step_node/condition_node/compare/gt_compare.py
create mode 100644 apps/application/flow/step_node/condition_node/compare/is_not_null_compare.py
create mode 100644 apps/application/flow/step_node/condition_node/compare/is_null_compare.py
create mode 100644 apps/application/flow/step_node/condition_node/compare/le_compare.py
create mode 100644 apps/application/flow/step_node/condition_node/compare/len_equal_compare.py
create mode 100644 apps/application/flow/step_node/condition_node/compare/len_ge_compare.py
create mode 100644 apps/application/flow/step_node/condition_node/compare/len_gt_compare.py
create mode 100644 apps/application/flow/step_node/condition_node/compare/len_le_compare.py
create mode 100644 apps/application/flow/step_node/condition_node/compare/len_lt_compare.py
create mode 100644 apps/application/flow/step_node/condition_node/compare/lt_compare.py
create mode 100644 apps/application/flow/step_node/condition_node/compare/not_contain_compare.py
create mode 100644 apps/application/flow/step_node/condition_node/i_condition_node.py
create mode 100644 apps/application/flow/step_node/condition_node/impl/__init__.py
create mode 100644 apps/application/flow/step_node/condition_node/impl/base_condition_node.py
create mode 100644 apps/application/flow/step_node/direct_reply_node/__init__.py
create mode 100644 apps/application/flow/step_node/direct_reply_node/i_reply_node.py
create mode 100644 apps/application/flow/step_node/direct_reply_node/impl/__init__.py
create mode 100644 apps/application/flow/step_node/direct_reply_node/impl/base_reply_node.py
create mode 100644 apps/application/flow/step_node/question_node/__init__.py
create mode 100644 apps/application/flow/step_node/question_node/i_question_node.py
create mode 100644 apps/application/flow/step_node/question_node/impl/__init__.py
create mode 100644 apps/application/flow/step_node/question_node/impl/base_question_node.py
create mode 100644 apps/application/flow/step_node/search_dataset_node/__init__.py
create mode 100644 apps/application/flow/step_node/search_dataset_node/i_search_dataset_node.py
create mode 100644 apps/application/flow/step_node/search_dataset_node/impl/__init__.py
create mode 100644 apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py
create mode 100644 apps/application/flow/step_node/start_node/__init__.py
create mode 100644 apps/application/flow/step_node/start_node/i_start_node.py
create mode 100644 apps/application/flow/step_node/start_node/impl/__init__.py
create mode 100644 apps/application/flow/step_node/start_node/impl/base_start_node.py
create mode 100644 apps/application/flow/tools.py
create mode 100644 apps/application/flow/workflow_manage.py
create mode 100644 apps/application/migrations/0009_application_type_application_work_flow_and_more.py
create mode 100644 package-lock.json
create mode 100644 ui/src/assets/MaxKB-logo.svg
create mode 100644 ui/src/assets/icon_condition.svg
create mode 100644 ui/src/assets/icon_globe_color.svg
create mode 100644 ui/src/assets/icon_hi.svg
create mode 100644 ui/src/assets/icon_reply.svg
create mode 100644 ui/src/assets/icon_setting.svg
create mode 100644 ui/src/assets/icon_start.svg
create mode 100644 ui/src/components/ai-chat/ExecutionDetailDialog.vue
create mode 100644 ui/src/components/ai-chat/KnowledgeSource.vue
create mode 100644 ui/src/components/ai-chat/component/ParagraphCard.vue
delete mode 100644 ui/src/components/markdown-editor/index.vue
delete mode 100644 ui/src/components/markdown-renderer/index.vue
create mode 100644 ui/src/components/markdown/MdEditor.vue
create mode 100644 ui/src/components/markdown/MdPreview.vue
rename ui/src/components/{markdown-renderer => markdown}/MdRenderer.vue (96%)
rename ui/src/components/{markdown-editor/assets/font_prouiefeic.js => markdown/assets/markdown-iconfont.js} (100%)
create mode 100644 ui/src/directives/clickoutside.ts
create mode 100644 ui/src/directives/resize.ts
create mode 100644 ui/src/enums/application.ts
create mode 100644 ui/src/enums/workflow.ts
create mode 100644 ui/src/views/application-workflow/index.vue
create mode 100644 ui/src/views/application/ApplicationSetting.vue
rename ui/src/views/application/{components => component}/AddDatasetDialog.vue (64%)
create mode 100644 ui/src/views/application/component/CreateApplicationDialog.vue
rename ui/src/views/application/{components => component}/ParamSettingDialog.vue (70%)
create mode 100644 ui/src/workflow/common/CustomLine.vue
create mode 100644 ui/src/workflow/common/NodeCascader.vue
create mode 100644 ui/src/workflow/common/NodeContainer.vue
create mode 100644 ui/src/workflow/common/NodeControl.vue
create mode 100644 ui/src/workflow/common/app-node.ts
create mode 100644 ui/src/workflow/common/data.ts
create mode 100644 ui/src/workflow/common/edge.ts
create mode 100644 ui/src/workflow/common/shortcut.ts
create mode 100644 ui/src/workflow/common/validate.ts
create mode 100644 ui/src/workflow/icons/ai-chat-node-icon.vue
create mode 100644 ui/src/workflow/icons/base-node-icon.vue
create mode 100644 ui/src/workflow/icons/condition-node-icon.vue
create mode 100644 ui/src/workflow/icons/global-icon.vue
create mode 100644 ui/src/workflow/icons/question-node-icon.vue
create mode 100644 ui/src/workflow/icons/reply-node-icon.vue
create mode 100644 ui/src/workflow/icons/search-dataset-node-icon.vue
create mode 100644 ui/src/workflow/icons/start-node-icon.vue
create mode 100644 ui/src/workflow/icons/utils.ts
create mode 100644 ui/src/workflow/index.vue
create mode 100644 ui/src/workflow/nodes/ai-chat-node/index.ts
create mode 100644 ui/src/workflow/nodes/ai-chat-node/index.vue
create mode 100644 ui/src/workflow/nodes/base-node/index.ts
create mode 100644 ui/src/workflow/nodes/base-node/index.vue
create mode 100644 ui/src/workflow/nodes/condition-node/index.ts
create mode 100644 ui/src/workflow/nodes/condition-node/index.vue
create mode 100644 ui/src/workflow/nodes/question-node/index.ts
create mode 100644 ui/src/workflow/nodes/question-node/index.vue
create mode 100644 ui/src/workflow/nodes/reply-node/index.ts
create mode 100644 ui/src/workflow/nodes/reply-node/index.vue
create mode 100644 ui/src/workflow/nodes/search-dataset-node/index.ts
create mode 100644 ui/src/workflow/nodes/search-dataset-node/index.vue
create mode 100644 ui/src/workflow/nodes/start-node/index.ts
create mode 100644 ui/src/workflow/nodes/start-node/index.vue
diff --git a/apps/application/flow/__init__.py b/apps/application/flow/__init__.py
new file mode 100644
index 000000000..328e8f8ec
--- /dev/null
+++ b/apps/application/flow/__init__.py
@@ -0,0 +1,8 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: __init__.py.py
+ @date:2024/6/7 14:43
+ @desc:
+"""
diff --git a/apps/application/flow/default_workflow.json b/apps/application/flow/default_workflow.json
new file mode 100644
index 000000000..9c460f549
--- /dev/null
+++ b/apps/application/flow/default_workflow.json
@@ -0,0 +1,426 @@
+{
+ "nodes": [
+ {
+ "id": "base-node",
+ "type": "base-node",
+ "x": 440,
+ "y": 3350,
+ "properties": {
+ "config": {},
+ "height": 517,
+ "stepName": "基本信息",
+ "node_data": {
+ "desc": "",
+ "name": "",
+ "prologue": "您好,我是 MaxKB 小助手,您可以向我提出 MaxKB 使用问题。\n- MaxKB 主要功能有什么?\n- MaxKB 支持哪些大语言模型?\n- MaxKB 支持哪些文档类型?"
+ }
+ }
+ },
+ {
+ "id": "start-node",
+ "type": "start-node",
+ "x": 440,
+ "y": 3710,
+ "properties": {
+ "config": {
+ "fields": [
+ {
+ "label": "用户问题",
+ "value": "question"
+ }
+ ],
+ "globalFields": [
+ {
+ "value": "time",
+ "label": "当前时间"
+ }
+ ]
+ },
+ "fields": [
+ {
+ "label": "用户问题",
+ "value": "question"
+ }
+ ],
+ "height": 268.533,
+ "stepName": "开始",
+ "globalFields": [
+ {
+ "label": "当前时间",
+ "value": "time"
+ }
+ ]
+ }
+ },
+ {
+ "id": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
+ "type": "search-dataset-node",
+ "x": 830,
+ "y": 3470,
+ "properties": {
+ "config": {
+ "fields": [
+ {
+ "label": "检索结果的分段列表",
+ "value": "paragraph_list"
+ },
+ {
+ "label": "满足直接回答的分段列表",
+ "value": "is_hit_handling_method_list"
+ },
+ {
+ "label": "检索结果",
+ "value": "data"
+ },
+ {
+ "label": "满足直接回答的分段内容",
+ "value": "directly_return"
+ }
+ ]
+ },
+ "height": 754.8,
+ "stepName": "知识库检索",
+ "node_data": {
+ "dataset_id_list": [],
+ "dataset_setting": {
+ "top_n": 3,
+ "similarity": 0.6,
+ "search_mode": "embedding",
+ "max_paragraph_char_number": 5000
+ },
+ "question_reference_address": [
+ "start-node",
+ "question"
+ ]
+ }
+ }
+ },
+ {
+ "id": "fc60863a-dec2-4854-9e5a-7a44b7187a2b",
+ "type": "condition-node",
+ "x": 1380,
+ "y": 3470,
+ "properties": {
+ "width": 600,
+ "config": {
+ "fields": [
+ {
+ "label": "分支名称",
+ "value": "branch_name"
+ }
+ ]
+ },
+ "height": 524.6669999999999,
+ "stepName": "判断器",
+ "node_data": {
+ "branch": [
+ {
+ "id": "1009",
+ "type": "IF",
+ "condition": "and",
+ "conditions": [
+ {
+ "field": [
+ "b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
+ "is_hit_handling_method_list"
+ ],
+ "value": "1",
+ "compare": "len_ge"
+ }
+ ]
+ },
+ {
+ "id": "4908",
+ "type": "ELSE IF 1",
+ "condition": "and",
+ "conditions": [
+ {
+ "field": [
+ "b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
+ "paragraph_list"
+ ],
+ "value": "1",
+ "compare": "len_ge"
+ }
+ ]
+ },
+ {
+ "id": "161",
+ "type": "ELSE",
+ "condition": "and",
+ "conditions": []
+ }
+ ]
+ },
+ "branch_condition_list": [
+ {
+ "index": 0,
+ "height": 116.133,
+ "id": "1009"
+ },
+ {
+ "index": 1,
+ "height": 116.133,
+ "id": "4908"
+ },
+ {
+ "index": 2,
+ "height": 40,
+ "id": "161"
+ }
+ ]
+ }
+ },
+ {
+ "id": "4ffe1086-25df-4c85-b168-979b5bbf0a26",
+ "type": "reply-node",
+ "x": 2090,
+ "y": 2820,
+ "properties": {
+ "config": {
+ "fields": [
+ {
+ "label": "内容",
+ "value": "answer"
+ }
+ ]
+ },
+ "height": 312.267,
+ "stepName": "指定回复",
+ "node_data": {
+ "fields": [
+ "b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
+ "directly_return"
+ ],
+ "content": "",
+ "reply_type": "referencing"
+ }
+ }
+ },
+ {
+ "id": "f1f1ee18-5a02-46f6-b4e6-226253cdffbb",
+ "type": "ai-chat-node",
+ "x": 2090,
+ "y": 3460,
+ "properties": {
+ "config": {
+ "fields": [
+ {
+ "label": "AI 回答内容",
+ "value": "answer"
+ }
+ ]
+ },
+ "height": 681.4,
+ "stepName": "AI 对话",
+ "node_data": {
+ "prompt": "已知信息:\n{{知识库检索.data}}\n问题:\n{{开始.question}}",
+ "system": "",
+ "model_id": "",
+ "dialogue_number": 0
+ }
+ }
+ },
+ {
+ "id": "309d0eef-c597-46b5-8d51-b9a28aaef4c7",
+ "type": "ai-chat-node",
+ "x": 2090,
+ "y": 4180,
+ "properties": {
+ "config": {
+ "fields": [
+ {
+ "label": "AI 回答内容",
+ "value": "answer"
+ }
+ ]
+ },
+ "height": 681.4,
+ "stepName": "AI 对话1",
+ "node_data": {
+ "prompt": "{{开始.question}}",
+ "system": "",
+ "model_id": "",
+ "dialogue_number": 0
+ }
+ }
+ }
+ ],
+ "edges": [
+ {
+ "id": "7d0f166f-c472-41b2-b9a2-c294f4c83d73",
+ "type": "app-edge",
+ "sourceNodeId": "start-node",
+ "targetNodeId": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
+ "startPoint": {
+ "x": 600,
+ "y": 3710
+ },
+ "endPoint": {
+ "x": 670,
+ "y": 3470
+ },
+ "properties": {},
+ "pointsList": [
+ {
+ "x": 600,
+ "y": 3710
+ },
+ {
+ "x": 710,
+ "y": 3710
+ },
+ {
+ "x": 560,
+ "y": 3470
+ },
+ {
+ "x": 670,
+ "y": 3470
+ }
+ ],
+ "sourceAnchorId": "start-node_right",
+ "targetAnchorId": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5_left"
+ },
+ {
+ "id": "35cb86dd-f328-429e-a973-12fd7218b696",
+ "type": "app-edge",
+ "sourceNodeId": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5",
+ "targetNodeId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b",
+ "startPoint": {
+ "x": 990,
+ "y": 3470
+ },
+ "endPoint": {
+ "x": 1090,
+ "y": 3470
+ },
+ "properties": {},
+ "pointsList": [
+ {
+ "x": 990,
+ "y": 3470
+ },
+ {
+ "x": 1100,
+ "y": 3470
+ },
+ {
+ "x": 980,
+ "y": 3470
+ },
+ {
+ "x": 1090,
+ "y": 3470
+ }
+ ],
+ "sourceAnchorId": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5_right",
+ "targetAnchorId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b_left"
+ },
+ {
+ "id": "e8f6cfe6-7e48-41cd-abd3-abfb5304d0d8",
+ "type": "app-edge",
+ "sourceNodeId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b",
+ "targetNodeId": "4ffe1086-25df-4c85-b168-979b5bbf0a26",
+ "startPoint": {
+ "x": 1670,
+ "y": 3340.733
+ },
+ "endPoint": {
+ "x": 1930,
+ "y": 2820
+ },
+ "properties": {},
+ "pointsList": [
+ {
+ "x": 1670,
+ "y": 3340.733
+ },
+ {
+ "x": 1780,
+ "y": 3340.733
+ },
+ {
+ "x": 1820,
+ "y": 2820
+ },
+ {
+ "x": 1930,
+ "y": 2820
+ }
+ ],
+ "sourceAnchorId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b_1009_right",
+ "targetAnchorId": "4ffe1086-25df-4c85-b168-979b5bbf0a26_left"
+ },
+ {
+ "id": "994ff325-6f7a-4ebc-b61b-10e15519d6d2",
+ "type": "app-edge",
+ "sourceNodeId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b",
+ "targetNodeId": "f1f1ee18-5a02-46f6-b4e6-226253cdffbb",
+ "startPoint": {
+ "x": 1670,
+ "y": 3464.866
+ },
+ "endPoint": {
+ "x": 1930,
+ "y": 3460
+ },
+ "properties": {},
+ "pointsList": [
+ {
+ "x": 1670,
+ "y": 3464.866
+ },
+ {
+ "x": 1780,
+ "y": 3464.866
+ },
+ {
+ "x": 1820,
+ "y": 3460
+ },
+ {
+ "x": 1930,
+ "y": 3460
+ }
+ ],
+ "sourceAnchorId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b_4908_right",
+ "targetAnchorId": "f1f1ee18-5a02-46f6-b4e6-226253cdffbb_left"
+ },
+ {
+ "id": "19270caf-bb9f-4ba7-9bf8-200aa70fecd5",
+ "type": "app-edge",
+ "sourceNodeId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b",
+ "targetNodeId": "309d0eef-c597-46b5-8d51-b9a28aaef4c7",
+ "startPoint": {
+ "x": 1670,
+ "y": 3550.9325000000003
+ },
+ "endPoint": {
+ "x": 1930,
+ "y": 4180
+ },
+ "properties": {},
+ "pointsList": [
+ {
+ "x": 1670,
+ "y": 3550.9325000000003
+ },
+ {
+ "x": 1780,
+ "y": 3550.9325000000003
+ },
+ {
+ "x": 1820,
+ "y": 4180
+ },
+ {
+ "x": 1930,
+ "y": 4180
+ }
+ ],
+ "sourceAnchorId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b_161_right",
+ "targetAnchorId": "309d0eef-c597-46b5-8d51-b9a28aaef4c7_left"
+ }
+ ]
+}
\ No newline at end of file
diff --git a/apps/application/flow/i_step_node.py b/apps/application/flow/i_step_node.py
new file mode 100644
index 000000000..0aa620e26
--- /dev/null
+++ b/apps/application/flow/i_step_node.py
@@ -0,0 +1,190 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: i_step_node.py
+ @date:2024/6/3 14:57
+ @desc:
+"""
+import time
+from abc import abstractmethod
+from typing import Type, Dict, List
+
+from django.db.models import QuerySet
+from rest_framework import serializers
+
+from application.models import ChatRecord
+from application.models.api_key_model import ApplicationPublicAccessClient
+from common.constants.authentication_type import AuthenticationType
+from common.field.common import InstanceField
+from common.util.field_message import ErrMessage
+from django.core import cache
+
+chat_cache = cache.caches['model_cache']
+
+
+def write_context(step_variable: Dict, global_variable: Dict, node, workflow):
+ if step_variable is not None:
+ for key in step_variable:
+ node.context[key] = step_variable[key]
+ if global_variable is not None:
+ for key in global_variable:
+ workflow.context[key] = global_variable[key]
+
+
+class WorkFlowPostHandler:
+ def __init__(self, chat_info, client_id, client_type):
+ self.chat_info = chat_info
+ self.client_id = client_id
+ self.client_type = client_type
+
+ def handler(self, chat_id,
+ chat_record_id,
+ answer,
+ workflow):
+ question = workflow.params['question']
+ details = workflow.get_runtime_details()
+ message_tokens = sum([row.get('message_tokens') for row in details.values() if
+ 'message_tokens' in row and row.get('message_tokens') is not None])
+ answer_tokens = sum([row.get('answer_tokens') for row in details.values() if
+ 'answer_tokens' in row and row.get('answer_tokens') is not None])
+ chat_record = ChatRecord(id=chat_record_id,
+ chat_id=chat_id,
+ problem_text=question,
+ answer_text=answer,
+ details=details,
+ message_tokens=message_tokens,
+ answer_tokens=answer_tokens,
+ run_time=time.time() - workflow.context['start_time'],
+ index=0)
+ self.chat_info.append_chat_record(chat_record, self.client_id)
+ # 重新设置缓存
+ chat_cache.set(chat_id,
+ self.chat_info, timeout=60 * 30)
+ if self.client_type == AuthenticationType.APPLICATION_ACCESS_TOKEN.value:
+ application_public_access_client = QuerySet(ApplicationPublicAccessClient).filter(id=self.client_id).first()
+ if application_public_access_client is not None:
+ application_public_access_client.access_num = application_public_access_client.access_num + 1
+ application_public_access_client.intraday_access_num = application_public_access_client.intraday_access_num + 1
+ application_public_access_client.save()
+
+
+class NodeResult:
+ def __init__(self, node_variable: Dict, workflow_variable: Dict, _to_response=None, _write_context=write_context):
+ self._write_context = _write_context
+ self.node_variable = node_variable
+ self.workflow_variable = workflow_variable
+ self._to_response = _to_response
+
+ def write_context(self, node, workflow):
+ self._write_context(self.node_variable, self.workflow_variable, node, workflow)
+
+ def to_response(self, chat_id, chat_record_id, node, workflow, post_handler: WorkFlowPostHandler):
+ return self._to_response(chat_id, chat_record_id, self.node_variable, self.workflow_variable, node, workflow,
+ post_handler)
+
+ def is_assertion_result(self):
+ return 'branch_id' in self.node_variable
+
+
+class ReferenceAddressSerializer(serializers.Serializer):
+ node_id = serializers.CharField(required=True, error_messages=ErrMessage.char("节点id"))
+ fields = serializers.ListField(
+ child=serializers.CharField(required=True, error_messages=ErrMessage.char("节点字段")), required=True,
+ error_messages=ErrMessage.list("节点字段数组"))
+
+
+class FlowParamsSerializer(serializers.Serializer):
+ # 历史对答
+ history_chat_record = serializers.ListField(child=InstanceField(model_type=ChatRecord, required=True),
+ error_messages=ErrMessage.list("历史对答"))
+
+ question = serializers.CharField(required=True, error_messages=ErrMessage.list("用户问题"))
+
+ chat_id = serializers.CharField(required=True, error_messages=ErrMessage.list("对话id"))
+
+ chat_record_id = serializers.CharField(required=True, error_messages=ErrMessage.char("对话记录id"))
+
+ stream = serializers.BooleanField(required=True, error_messages=ErrMessage.base("流式输出"))
+
+ client_id = serializers.CharField(required=False, error_messages=ErrMessage.char("客户端id"))
+
+ client_type = serializers.CharField(required=False, error_messages=ErrMessage.char("客户端类型"))
+
+
+class INode:
+ def __init__(self, node, workflow_params, workflow_manage):
+ # 当前步骤上下文,用于存储当前步骤信息
+ self.status = 200
+ self.err_message = ''
+ self.node = node
+ self.node_params = node.properties.get('node_data')
+ self.workflow_manage = workflow_manage
+ self.node_params_serializer = None
+ self.flow_params_serializer = None
+ self.context = {}
+ self.id = node.id
+ self.valid_args(self.node_params, workflow_params)
+
+ def valid_args(self, node_params, flow_params):
+ flow_params_serializer_class = self.get_flow_params_serializer_class()
+ node_params_serializer_class = self.get_node_params_serializer_class()
+ if flow_params_serializer_class is not None and flow_params is not None:
+ self.flow_params_serializer = flow_params_serializer_class(data=flow_params)
+ self.flow_params_serializer.is_valid(raise_exception=True)
+ if node_params_serializer_class is not None:
+ self.node_params_serializer = node_params_serializer_class(data=node_params)
+ self.node_params_serializer.is_valid(raise_exception=True)
+
+ def get_reference_field(self, fields: List[str]):
+ return self.get_field(self.context, fields)
+
+ @staticmethod
+ def get_field(obj, fields: List[str]):
+ for field in fields:
+ value = obj.get(field)
+ if value is None:
+ return None
+ else:
+ obj = value
+ return obj
+
+ @abstractmethod
+ def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
+ pass
+
+ def get_flow_params_serializer_class(self) -> Type[serializers.Serializer]:
+ return FlowParamsSerializer
+
+ def get_write_error_context(self, e):
+ self.status = 500
+ self.err_message = str(e)
+
+ def write_error_context(answer, status=200):
+ pass
+
+ return write_error_context
+
+ def run(self) -> NodeResult:
+ """
+ :return: 执行结果
+ """
+ start_time = time.time()
+ self.context['start_time'] = start_time
+ result = self._run()
+ self.context['run_time'] = time.time() - start_time
+ return result
+
+ def _run(self):
+ result = self.execute()
+ return result
+
+ def execute(self, **kwargs) -> NodeResult:
+ pass
+
+ def get_details(self, index: int, **kwargs):
+ """
+ 运行详情
+ :return: 步骤详情
+ """
+ return {}
diff --git a/apps/application/flow/step_node/__init__.py b/apps/application/flow/step_node/__init__.py
new file mode 100644
index 000000000..b3692fa8b
--- /dev/null
+++ b/apps/application/flow/step_node/__init__.py
@@ -0,0 +1,23 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: __init__.py.py
+ @date:2024/6/7 14:43
+ @desc:
+"""
+from .ai_chat_step_node import *
+from .condition_node import *
+from .question_node import *
+from .search_dataset_node import *
+from .start_node import *
+from .direct_reply_node import *
+
+node_list = [BaseStartStepNode, BaseChatNode, BaseSearchDatasetNode, BaseQuestionNode, BaseConditionNode, BaseReplyNode]
+
+
+def get_node(node_type):
+ find_list = [node for node in node_list if node.type == node_type]
+ if len(find_list) > 0:
+ return find_list[0]
+ return None
diff --git a/apps/application/flow/step_node/ai_chat_step_node/__init__.py b/apps/application/flow/step_node/ai_chat_step_node/__init__.py
new file mode 100644
index 000000000..1929ae2af
--- /dev/null
+++ b/apps/application/flow/step_node/ai_chat_step_node/__init__.py
@@ -0,0 +1,9 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: __init__.py
+ @date:2024/6/11 15:29
+ @desc:
+"""
+from .impl import *
diff --git a/apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py b/apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py
new file mode 100644
index 000000000..d0dfbaef9
--- /dev/null
+++ b/apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py
@@ -0,0 +1,37 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: i_chat_node.py
+ @date:2024/6/4 13:58
+ @desc:
+"""
+from typing import Type
+
+from rest_framework import serializers
+
+from application.flow.i_step_node import INode, NodeResult
+from common.util.field_message import ErrMessage
+
+
+class ChatNodeSerializer(serializers.Serializer):
+ model_id = serializers.CharField(required=True, error_messages=ErrMessage.char("模型id"))
+ system = serializers.CharField(required=False, allow_blank=True, allow_null=True,
+ error_messages=ErrMessage.char("角色设定"))
+ prompt = serializers.CharField(required=True, error_messages=ErrMessage.char("提示词"))
+ # 多轮对话数量
+ dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer("多轮对话数量"))
+
+
+class IChatNode(INode):
+ type = 'ai-chat-node'
+
+ def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
+ return ChatNodeSerializer
+
+ def _run(self):
+ return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data)
+
+ def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id,
+ **kwargs) -> NodeResult:
+ pass
diff --git a/apps/application/flow/step_node/ai_chat_step_node/impl/__init__.py b/apps/application/flow/step_node/ai_chat_step_node/impl/__init__.py
new file mode 100644
index 000000000..79051a999
--- /dev/null
+++ b/apps/application/flow/step_node/ai_chat_step_node/impl/__init__.py
@@ -0,0 +1,9 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: __init__.py
+ @date:2024/6/11 15:34
+ @desc:
+"""
+from .base_chat_node import BaseChatNode
diff --git a/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py b/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py
new file mode 100644
index 000000000..48df8a463
--- /dev/null
+++ b/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py
@@ -0,0 +1,195 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: base_question_node.py
+ @date:2024/6/4 14:30
+ @desc:
+"""
+import json
+import time
+from functools import reduce
+from typing import List, Dict
+
+from django.db.models import QuerySet
+from langchain.schema import HumanMessage, SystemMessage
+from langchain_core.messages import BaseMessage
+
+from application.flow import tools
+from application.flow.i_step_node import NodeResult, INode
+from application.flow.step_node.ai_chat_step_node.i_chat_node import IChatNode
+from common.util.rsa_util import rsa_long_decrypt
+from setting.models import Model
+from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
+
+
+def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
+ """
+ 写入上下文数据 (流式)
+ @param node_variable: 节点数据
+ @param workflow_variable: 全局数据
+ @param node: 节点
+ @param workflow: 工作流管理器
+ """
+ response = node_variable.get('result')
+ answer = ''
+ for chunk in response:
+ answer += chunk.content
+ chat_model = node_variable.get('chat_model')
+ message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list'))
+ answer_tokens = chat_model.get_num_tokens(answer)
+ node.context['message_tokens'] = message_tokens
+ node.context['answer_tokens'] = answer_tokens
+ node.context['answer'] = answer
+ node.context['history_message'] = node_variable['history_message']
+ node.context['question'] = node_variable['question']
+ node.context['run_time'] = time.time() - node.context['start_time']
+
+
+def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
+ """
+ 写入上下文数据
+ @param node_variable: 节点数据
+ @param workflow_variable: 全局数据
+ @param node: 节点实例对象
+ @param workflow: 工作流管理器
+ """
+ response = node_variable.get('result')
+ chat_model = node_variable.get('chat_model')
+ answer = response.content
+ message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list'))
+ answer_tokens = chat_model.get_num_tokens(answer)
+ node.context['message_tokens'] = message_tokens
+ node.context['answer_tokens'] = answer_tokens
+ node.context['answer'] = answer
+ node.context['history_message'] = node_variable['history_message']
+ node.context['question'] = node_variable['question']
+
+
+def get_to_response_write_context(node_variable: Dict, node: INode):
+ def _write_context(answer, status=200):
+ chat_model = node_variable.get('chat_model')
+
+ if status == 200:
+ answer_tokens = chat_model.get_num_tokens(answer)
+ message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list'))
+ else:
+ answer_tokens = 0
+ message_tokens = 0
+ node.err_message = answer
+ node.status = status
+ node.context['message_tokens'] = message_tokens
+ node.context['answer_tokens'] = answer_tokens
+ node.context['answer'] = answer
+ node.context['run_time'] = time.time() - node.context['start_time']
+
+ return _write_context
+
+
+def to_stream_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow,
+ post_handler):
+ """
+ 将流式数据 转换为 流式响应
+ @param chat_id: 会话id
+ @param chat_record_id: 对话记录id
+ @param node_variable: 节点数据
+ @param workflow_variable: 工作流数据
+ @param node: 节点
+ @param workflow: 工作流管理器
+ @param post_handler: 后置处理器 输出结果后执行
+ @return: 流式响应
+ """
+ response = node_variable.get('result')
+ _write_context = get_to_response_write_context(node_variable, node)
+ return tools.to_stream_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler)
+
+
+def to_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow,
+ post_handler):
+ """
+ 将结果转换
+ @param chat_id: 会话id
+ @param chat_record_id: 对话记录id
+ @param node_variable: 节点数据
+ @param workflow_variable: 工作流数据
+ @param node: 节点
+ @param workflow: 工作流管理器
+ @param post_handler: 后置处理器
+ @return: 响应
+ """
+ response = node_variable.get('result')
+ _write_context = get_to_response_write_context(node_variable, node)
+ return tools.to_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler)
+
+
+class BaseChatNode(IChatNode):
+ def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id,
+ **kwargs) -> NodeResult:
+ model = QuerySet(Model).filter(id=model_id).first()
+ chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name,
+ json.loads(
+ rsa_long_decrypt(model.credential)),
+ streaming=True)
+ history_message = self.get_history_message(history_chat_record, dialogue_number)
+ self.context['history_message'] = history_message
+ question = self.generate_prompt_question(prompt)
+ self.context['question'] = question.content
+ message_list = self.generate_message_list(system, prompt, history_message)
+ self.context['message_list'] = message_list
+ if stream:
+ r = chat_model.stream(message_list)
+ return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list,
+ 'history_message': history_message, 'question': question.content}, {},
+ _write_context=write_context_stream,
+ _to_response=to_stream_response)
+ else:
+ r = chat_model.invoke(message_list)
+ return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list,
+ 'history_message': history_message, 'question': question.content}, {},
+ _write_context=write_context, _to_response=to_response)
+
+ @staticmethod
+ def get_history_message(history_chat_record, dialogue_number):
+ start_index = len(history_chat_record) - dialogue_number
+ history_message = reduce(lambda x, y: [*x, *y], [
+ [history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()]
+ for index in
+ range(start_index if start_index > 0 else 0, len(history_chat_record))], [])
+ return history_message
+
+ def generate_prompt_question(self, prompt):
+ return HumanMessage(self.workflow_manage.generate_prompt(prompt))
+
+ def generate_message_list(self, system: str, prompt: str, history_message):
+ if system is not None and len(system) > 0:
+ return [SystemMessage(self.workflow_manage.generate_prompt(system)), *history_message,
+ HumanMessage(self.workflow_manage.generate_prompt(prompt))]
+ else:
+ return [*history_message, HumanMessage(self.workflow_manage.generate_prompt(prompt))]
+
+ @staticmethod
+ def reset_message_list(message_list: List[BaseMessage], answer_text):
+ result = [{'role': 'user' if isinstance(message, HumanMessage) else 'ai', 'content': message.content} for
+ message
+ in
+ message_list]
+ result.append({'role': 'ai', 'content': answer_text})
+ return result
+
+ def get_details(self, index: int, **kwargs):
+ return {
+ 'name': self.node.properties.get('stepName'),
+ "index": index,
+ 'run_time': self.context.get('run_time'),
+ 'system': self.node_params.get('system'),
+ 'history_message': [{'content': message.content, 'role': message.type} for message in
+ (self.context.get('history_message') if self.context.get(
+ 'history_message') is not None else [])],
+ 'question': self.context.get('question'),
+ 'answer': self.context.get('answer'),
+ 'type': self.node.type,
+ 'message_tokens': self.context.get('message_tokens'),
+ 'answer_tokens': self.context.get('answer_tokens'),
+ 'status': self.status,
+ 'err_message': self.err_message
+ }
diff --git a/apps/application/flow/step_node/condition_node/__init__.py b/apps/application/flow/step_node/condition_node/__init__.py
new file mode 100644
index 000000000..57638504c
--- /dev/null
+++ b/apps/application/flow/step_node/condition_node/__init__.py
@@ -0,0 +1,9 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: __init__.py.py
+ @date:2024/6/7 14:43
+ @desc:
+"""
+from .impl import *
diff --git a/apps/application/flow/step_node/condition_node/compare/__init__.py b/apps/application/flow/step_node/condition_node/compare/__init__.py
new file mode 100644
index 000000000..b2c464b41
--- /dev/null
+++ b/apps/application/flow/step_node/condition_node/compare/__init__.py
@@ -0,0 +1,23 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: __init__.py.py
+ @date:2024/6/7 14:43
+ @desc:
+"""
+
+from .contain_compare import *
+from .equal_compare import *
+from .gt_compare import *
+from .ge_compare import *
+from .le_compare import *
+from .lt_compare import *
+from .len_ge_compare import *
+from .len_gt_compare import *
+from .len_le_compare import *
+from .len_lt_compare import *
+from .len_equal_compare import *
+
+compare_handle_list = [GECompare(), GTCompare(), ContainCompare(), EqualCompare(), LTCompare(), LECompare(),
+ LenLECompare(), LenGECompare(), LenEqualCompare(), LenGTCompare(), LenLTCompare()]
diff --git a/apps/application/flow/step_node/condition_node/compare/compare.py b/apps/application/flow/step_node/condition_node/compare/compare.py
new file mode 100644
index 000000000..6cbb4af07
--- /dev/null
+++ b/apps/application/flow/step_node/condition_node/compare/compare.py
@@ -0,0 +1,20 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: compare.py
+ @date:2024/6/7 14:37
+ @desc:
+"""
+from abc import abstractmethod
+from typing import List
+
+
+class Compare:
+ @abstractmethod
+ def support(self, node_id, fields: List[str], source_value, compare, target_value):
+ pass
+
+ @abstractmethod
+ def compare(self, source_value, compare, target_value):
+ pass
diff --git a/apps/application/flow/step_node/condition_node/compare/contain_compare.py b/apps/application/flow/step_node/condition_node/compare/contain_compare.py
new file mode 100644
index 000000000..6073131a5
--- /dev/null
+++ b/apps/application/flow/step_node/condition_node/compare/contain_compare.py
@@ -0,0 +1,23 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: contain_compare.py
+ @date:2024/6/11 10:02
+ @desc:
+"""
+from typing import List
+
+from application.flow.step_node.condition_node.compare.compare import Compare
+
+
+class ContainCompare(Compare):
+
+ def support(self, node_id, fields: List[str], source_value, compare, target_value):
+ if compare == 'contain':
+ return True
+
+ def compare(self, source_value, compare, target_value):
+ if isinstance(source_value, str):
+ return str(target_value) in source_value
+ return any([str(item) == str(target_value) for item in source_value])
diff --git a/apps/application/flow/step_node/condition_node/compare/equal_compare.py b/apps/application/flow/step_node/condition_node/compare/equal_compare.py
new file mode 100644
index 000000000..0061a82f6
--- /dev/null
+++ b/apps/application/flow/step_node/condition_node/compare/equal_compare.py
@@ -0,0 +1,21 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: equal_compare.py
+ @date:2024/6/7 14:44
+ @desc:
+"""
+from typing import List
+
+from application.flow.step_node.condition_node.compare.compare import Compare
+
+
+class EqualCompare(Compare):
+
+ def support(self, node_id, fields: List[str], source_value, compare, target_value):
+ if compare == 'eq':
+ return True
+
+ def compare(self, source_value, compare, target_value):
+ return str(source_value) == str(target_value)
diff --git a/apps/application/flow/step_node/condition_node/compare/ge_compare.py b/apps/application/flow/step_node/condition_node/compare/ge_compare.py
new file mode 100644
index 000000000..d4e22cbd6
--- /dev/null
+++ b/apps/application/flow/step_node/condition_node/compare/ge_compare.py
@@ -0,0 +1,24 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: lt_compare.py
+ @date:2024/6/11 9:52
+ @desc: 大于比较器
+"""
+from typing import List
+
+from application.flow.step_node.condition_node.compare.compare import Compare
+
+
+class GECompare(Compare):
+
+ def support(self, node_id, fields: List[str], source_value, compare, target_value):
+ if compare == 'ge':
+ return True
+
+ def compare(self, source_value, compare, target_value):
+ try:
+ return float(source_value) >= float(target_value)
+ except Exception as e:
+ return False
diff --git a/apps/application/flow/step_node/condition_node/compare/gt_compare.py b/apps/application/flow/step_node/condition_node/compare/gt_compare.py
new file mode 100644
index 000000000..80942abb2
--- /dev/null
+++ b/apps/application/flow/step_node/condition_node/compare/gt_compare.py
@@ -0,0 +1,24 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: lt_compare.py
+ @date:2024/6/11 9:52
+ @desc: 大于比较器
+"""
+from typing import List
+
+from application.flow.step_node.condition_node.compare.compare import Compare
+
+
+class GTCompare(Compare):
+
+ def support(self, node_id, fields: List[str], source_value, compare, target_value):
+ if compare == 'gt':
+ return True
+
+ def compare(self, source_value, compare, target_value):
+ try:
+ return float(source_value) > float(target_value)
+ except Exception as e:
+ return False
diff --git a/apps/application/flow/step_node/condition_node/compare/is_not_null_compare.py b/apps/application/flow/step_node/condition_node/compare/is_not_null_compare.py
new file mode 100644
index 000000000..9c281e381
--- /dev/null
+++ b/apps/application/flow/step_node/condition_node/compare/is_not_null_compare.py
@@ -0,0 +1,21 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: is_not_null_compare.py
+ @date:2024/6/28 10:45
+ @desc:
+"""
+from typing import List
+
+from application.flow.step_node.condition_node.compare import Compare
+
+
+class IsNotNullCompare(Compare):
+
+ def support(self, node_id, fields: List[str], source_value, compare, target_value):
+ if compare == 'is_not_null':
+ return True
+
+ def compare(self, source_value, compare, target_value=None):
+ return source_value is not None
diff --git a/apps/application/flow/step_node/condition_node/compare/is_null_compare.py b/apps/application/flow/step_node/condition_node/compare/is_null_compare.py
new file mode 100644
index 000000000..6d49de605
--- /dev/null
+++ b/apps/application/flow/step_node/condition_node/compare/is_null_compare.py
@@ -0,0 +1,21 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: is_null_compare.py
+ @date:2024/6/28 10:45
+ @desc:
+"""
+from typing import List
+
+from application.flow.step_node.condition_node.compare import Compare
+
+
+class IsNullCompare(Compare):
+
+ def support(self, node_id, fields: List[str], source_value, compare, target_value):
+ if compare == 'is_null':
+ return True
+
+ def compare(self, source_value, compare, target_value=None):
+ return source_value is None
diff --git a/apps/application/flow/step_node/condition_node/compare/le_compare.py b/apps/application/flow/step_node/condition_node/compare/le_compare.py
new file mode 100644
index 000000000..77a0bca0f
--- /dev/null
+++ b/apps/application/flow/step_node/condition_node/compare/le_compare.py
@@ -0,0 +1,24 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: lt_compare.py
+ @date:2024/6/11 9:52
+ @desc: 小于比较器
+"""
+from typing import List
+
+from application.flow.step_node.condition_node.compare.compare import Compare
+
+
+class LECompare(Compare):
+
+ def support(self, node_id, fields: List[str], source_value, compare, target_value):
+ if compare == 'le':
+ return True
+
+ def compare(self, source_value, compare, target_value):
+ try:
+ return float(source_value) <= float(target_value)
+ except Exception as e:
+ return False
diff --git a/apps/application/flow/step_node/condition_node/compare/len_equal_compare.py b/apps/application/flow/step_node/condition_node/compare/len_equal_compare.py
new file mode 100644
index 000000000..f2b0764c5
--- /dev/null
+++ b/apps/application/flow/step_node/condition_node/compare/len_equal_compare.py
@@ -0,0 +1,24 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: equal_compare.py
+ @date:2024/6/7 14:44
+ @desc:
+"""
+from typing import List
+
+from application.flow.step_node.condition_node.compare.compare import Compare
+
+
+class LenEqualCompare(Compare):
+
+ def support(self, node_id, fields: List[str], source_value, compare, target_value):
+ if compare == 'len_eq':
+ return True
+
+ def compare(self, source_value, compare, target_value):
+ try:
+ return len(source_value) == int(target_value)
+ except Exception as e:
+ return False
diff --git a/apps/application/flow/step_node/condition_node/compare/len_ge_compare.py b/apps/application/flow/step_node/condition_node/compare/len_ge_compare.py
new file mode 100644
index 000000000..87f11eb2c
--- /dev/null
+++ b/apps/application/flow/step_node/condition_node/compare/len_ge_compare.py
@@ -0,0 +1,24 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: lt_compare.py
+ @date:2024/6/11 9:52
+ @desc: 大于比较器
+"""
+from typing import List
+
+from application.flow.step_node.condition_node.compare.compare import Compare
+
+
+class LenGECompare(Compare):
+
+ def support(self, node_id, fields: List[str], source_value, compare, target_value):
+ if compare == 'len_ge':
+ return True
+
+ def compare(self, source_value, compare, target_value):
+ try:
+ return len(source_value) >= int(target_value)
+ except Exception as e:
+ return False
diff --git a/apps/application/flow/step_node/condition_node/compare/len_gt_compare.py b/apps/application/flow/step_node/condition_node/compare/len_gt_compare.py
new file mode 100644
index 000000000..0532d353d
--- /dev/null
+++ b/apps/application/flow/step_node/condition_node/compare/len_gt_compare.py
@@ -0,0 +1,24 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: lt_compare.py
+ @date:2024/6/11 9:52
+ @desc: 大于比较器
+"""
+from typing import List
+
+from application.flow.step_node.condition_node.compare.compare import Compare
+
+
+class LenGTCompare(Compare):
+
+ def support(self, node_id, fields: List[str], source_value, compare, target_value):
+ if compare == 'len_gt':
+ return True
+
+ def compare(self, source_value, compare, target_value):
+ try:
+ return len(source_value) > int(target_value)
+ except Exception as e:
+ return False
diff --git a/apps/application/flow/step_node/condition_node/compare/len_le_compare.py b/apps/application/flow/step_node/condition_node/compare/len_le_compare.py
new file mode 100644
index 000000000..d315a754a
--- /dev/null
+++ b/apps/application/flow/step_node/condition_node/compare/len_le_compare.py
@@ -0,0 +1,24 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: lt_compare.py
+ @date:2024/6/11 9:52
+ @desc: 小于比较器
+"""
+from typing import List
+
+from application.flow.step_node.condition_node.compare.compare import Compare
+
+
+class LenLECompare(Compare):
+
+ def support(self, node_id, fields: List[str], source_value, compare, target_value):
+ if compare == 'len_le':
+ return True
+
+ def compare(self, source_value, compare, target_value):
+ try:
+ return len(source_value) <= int(target_value)
+ except Exception as e:
+ return False
diff --git a/apps/application/flow/step_node/condition_node/compare/len_lt_compare.py b/apps/application/flow/step_node/condition_node/compare/len_lt_compare.py
new file mode 100644
index 000000000..c89638cd7
--- /dev/null
+++ b/apps/application/flow/step_node/condition_node/compare/len_lt_compare.py
@@ -0,0 +1,24 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: lt_compare.py
+ @date:2024/6/11 9:52
+ @desc: 小于比较器
+"""
+from typing import List
+
+from application.flow.step_node.condition_node.compare.compare import Compare
+
+
+class LenLTCompare(Compare):
+
+ def support(self, node_id, fields: List[str], source_value, compare, target_value):
+ if compare == 'len_lt':
+ return True
+
+ def compare(self, source_value, compare, target_value):
+ try:
+ return len(source_value) < int(target_value)
+ except Exception as e:
+ return False
diff --git a/apps/application/flow/step_node/condition_node/compare/lt_compare.py b/apps/application/flow/step_node/condition_node/compare/lt_compare.py
new file mode 100644
index 000000000..d2d5be748
--- /dev/null
+++ b/apps/application/flow/step_node/condition_node/compare/lt_compare.py
@@ -0,0 +1,24 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: lt_compare.py
+ @date:2024/6/11 9:52
+ @desc: 小于比较器
+"""
+from typing import List
+
+from application.flow.step_node.condition_node.compare.compare import Compare
+
+
+class LTCompare(Compare):
+
+ def support(self, node_id, fields: List[str], source_value, compare, target_value):
+ if compare == 'lt':
+ return True
+
+ def compare(self, source_value, compare, target_value):
+ try:
+ return float(source_value) < float(target_value)
+ except Exception as e:
+ return False
diff --git a/apps/application/flow/step_node/condition_node/compare/not_contain_compare.py b/apps/application/flow/step_node/condition_node/compare/not_contain_compare.py
new file mode 100644
index 000000000..cfa0063a5
--- /dev/null
+++ b/apps/application/flow/step_node/condition_node/compare/not_contain_compare.py
@@ -0,0 +1,23 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: contain_compare.py
+ @date:2024/6/11 10:02
+ @desc:
+"""
+from typing import List
+
+from application.flow.step_node.condition_node.compare.compare import Compare
+
+
+class ContainCompare(Compare):
+
+ def support(self, node_id, fields: List[str], source_value, compare, target_value):
+ if compare == 'not_contain':
+ return True
+
+ def compare(self, source_value, compare, target_value):
+ if isinstance(source_value, str):
+ return str(target_value) not in source_value
+ return not any([str(item) == str(target_value) for item in source_value])
diff --git a/apps/application/flow/step_node/condition_node/i_condition_node.py b/apps/application/flow/step_node/condition_node/i_condition_node.py
new file mode 100644
index 000000000..ffb975a98
--- /dev/null
+++ b/apps/application/flow/step_node/condition_node/i_condition_node.py
@@ -0,0 +1,39 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: i_condition_node.py
+ @date:2024/6/7 9:54
+ @desc:
+"""
+import json
+from typing import Type
+
+from rest_framework import serializers
+
+from application.flow.i_step_node import INode
+from common.util.field_message import ErrMessage
+
+
+class ConditionSerializer(serializers.Serializer):
+ compare = serializers.CharField(required=True, error_messages=ErrMessage.char("比较器"))
+ value = serializers.CharField(required=True, error_messages=ErrMessage.char(""))
+ field = serializers.ListField(required=True, error_messages=ErrMessage.char("字段"))
+
+
+class ConditionBranchSerializer(serializers.Serializer):
+ id = serializers.CharField(required=True, error_messages=ErrMessage.char("分支id"))
+ type = serializers.CharField(required=True, error_messages=ErrMessage.char("分支类型"))
+ condition = serializers.CharField(required=True, error_messages=ErrMessage.char("条件or|and"))
+ conditions = ConditionSerializer(many=True)
+
+
+class ConditionNodeParamsSerializer(serializers.Serializer):
+ branch = ConditionBranchSerializer(many=True)
+
+
+class IConditionNode(INode):
+ def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
+ return ConditionNodeParamsSerializer
+
+ type = 'condition-node'
diff --git a/apps/application/flow/step_node/condition_node/impl/__init__.py b/apps/application/flow/step_node/condition_node/impl/__init__.py
new file mode 100644
index 000000000..c21cd3ebb
--- /dev/null
+++ b/apps/application/flow/step_node/condition_node/impl/__init__.py
@@ -0,0 +1,9 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: __init__.py
+ @date:2024/6/11 15:35
+ @desc:
+"""
+from .base_condition_node import BaseConditionNode
diff --git a/apps/application/flow/step_node/condition_node/impl/base_condition_node.py b/apps/application/flow/step_node/condition_node/impl/base_condition_node.py
new file mode 100644
index 000000000..3164bb9fe
--- /dev/null
+++ b/apps/application/flow/step_node/condition_node/impl/base_condition_node.py
@@ -0,0 +1,50 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: base_condition_node.py
+ @date:2024/6/7 11:29
+ @desc:
+"""
+from typing import List
+
+from application.flow.i_step_node import NodeResult
+from application.flow.step_node.condition_node.compare import compare_handle_list
+from application.flow.step_node.condition_node.i_condition_node import IConditionNode
+
+
+class BaseConditionNode(IConditionNode):
+ def execute(self, **kwargs) -> NodeResult:
+ branch_list = self.node_params_serializer.data['branch']
+ branch = self._execute(branch_list)
+ r = NodeResult({'branch_id': branch.get('id'), 'branch_name': branch.get('type')}, {})
+ return r
+
+ def _execute(self, branch_list: List):
+ for branch in branch_list:
+ if self.branch_assertion(branch):
+ return branch
+
+ def branch_assertion(self, branch):
+ condition_list = [self.assertion(row.get('field'), row.get('compare'), row.get('value')) for row in
+ branch.get('conditions')]
+ condition = branch.get('condition')
+ return all(condition_list) if condition == 'and' else any(condition_list)
+
+ def assertion(self, field_list: List[str], compare: str, value):
+ field_value = self.workflow_manage.get_reference_field(field_list[0], field_list[1:])
+ for compare_handler in compare_handle_list:
+ if compare_handler.support(field_list[0], field_list[1:], field_value, compare, value):
+ return compare_handler.compare(field_value, compare, value)
+
+ def get_details(self, index: int, **kwargs):
+ return {
+ 'name': self.node.properties.get('stepName'),
+ "index": index,
+ 'run_time': self.context.get('run_time'),
+ 'branch_id': self.context.get('branch_id'),
+ 'branch_name': self.context.get('branch_name'),
+ 'type': self.node.type,
+ 'status': self.status,
+ 'err_message': self.err_message
+ }
diff --git a/apps/application/flow/step_node/direct_reply_node/__init__.py b/apps/application/flow/step_node/direct_reply_node/__init__.py
new file mode 100644
index 000000000..cf360f956
--- /dev/null
+++ b/apps/application/flow/step_node/direct_reply_node/__init__.py
@@ -0,0 +1,9 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: __init__.py
+ @date:2024/6/11 17:50
+ @desc:
+"""
+from .impl import *
\ No newline at end of file
diff --git a/apps/application/flow/step_node/direct_reply_node/i_reply_node.py b/apps/application/flow/step_node/direct_reply_node/i_reply_node.py
new file mode 100644
index 000000000..1d5256ac5
--- /dev/null
+++ b/apps/application/flow/step_node/direct_reply_node/i_reply_node.py
@@ -0,0 +1,46 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: i_reply_node.py
+ @date:2024/6/11 16:25
+ @desc:
+"""
+from typing import Type
+
+from rest_framework import serializers
+
+from application.flow.i_step_node import INode, NodeResult
+from common.exception.app_exception import AppApiException
+from common.util.field_message import ErrMessage
+
+
+class ReplyNodeParamsSerializer(serializers.Serializer):
+ reply_type = serializers.CharField(required=True, error_messages=ErrMessage.char("回复类型"))
+ fields = serializers.ListField(required=False, error_messages=ErrMessage.list("引用字段"))
+ content = serializers.CharField(required=False, allow_blank=True, allow_null=True,
+ error_messages=ErrMessage.char("直接回答内容"))
+
+ def is_valid(self, *, raise_exception=False):
+ super().is_valid(raise_exception=True)
+ if self.data.get('reply_type') == 'referencing':
+ if 'fields' not in self.data:
+ raise AppApiException(500, "引用字段不能为空")
+ if len(self.data.get('fields')) < 2:
+ raise AppApiException(500, "引用字段错误")
+ else:
+ if 'content' not in self.data or self.data.get('content') is None:
+ raise AppApiException(500, "内容不能为空")
+
+
+class IReplyNode(INode):
+ type = 'reply-node'
+
+ def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
+ return ReplyNodeParamsSerializer
+
+ def _run(self):
+ return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data)
+
+ def execute(self, reply_type, stream, fields=None, content=None, **kwargs) -> NodeResult:
+ pass
diff --git a/apps/application/flow/step_node/direct_reply_node/impl/__init__.py b/apps/application/flow/step_node/direct_reply_node/impl/__init__.py
new file mode 100644
index 000000000..3307e9089
--- /dev/null
+++ b/apps/application/flow/step_node/direct_reply_node/impl/__init__.py
@@ -0,0 +1,9 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: __init__.py
+ @date:2024/6/11 17:49
+ @desc:
+"""
+from .base_reply_node import *
\ No newline at end of file
diff --git a/apps/application/flow/step_node/direct_reply_node/impl/base_reply_node.py b/apps/application/flow/step_node/direct_reply_node/impl/base_reply_node.py
new file mode 100644
index 000000000..d266265be
--- /dev/null
+++ b/apps/application/flow/step_node/direct_reply_node/impl/base_reply_node.py
@@ -0,0 +1,90 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: base_reply_node.py
+ @date:2024/6/11 17:25
+ @desc:
+"""
+from typing import List, Dict
+
+from langchain_core.messages import AIMessage, AIMessageChunk
+
+from application.flow import tools
+from application.flow.i_step_node import NodeResult, INode
+from application.flow.step_node.direct_reply_node.i_reply_node import IReplyNode
+
+
+def get_to_response_write_context(node_variable: Dict, node: INode):
+ def _write_context(answer, status=200):
+ node.context['answer'] = answer
+
+ return _write_context
+
+
+def to_stream_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow,
+ post_handler):
+ """
+ 将流式数据 转换为 流式响应
+ @param chat_id: 会话id
+ @param chat_record_id: 对话记录id
+ @param node_variable: 节点数据
+ @param workflow_variable: 工作流数据
+ @param node: 节点
+ @param workflow: 工作流管理器
+ @param post_handler: 后置处理器 输出结果后执行
+ @return: 流式响应
+ """
+ response = node_variable.get('result')
+ _write_context = get_to_response_write_context(node_variable, node)
+ return tools.to_stream_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler)
+
+
+def to_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow,
+ post_handler):
+ """
+ 将结果转换
+ @param chat_id: 会话id
+ @param chat_record_id: 对话记录id
+ @param node_variable: 节点数据
+ @param workflow_variable: 工作流数据
+ @param node: 节点
+ @param workflow: 工作流管理器
+ @param post_handler: 后置处理器
+ @return: 响应
+ """
+ response = node_variable.get('result')
+ _write_context = get_to_response_write_context(node_variable, node)
+ return tools.to_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler)
+
+
+class BaseReplyNode(IReplyNode):
+ def execute(self, reply_type, stream, fields=None, content=None, **kwargs) -> NodeResult:
+ if reply_type == 'referencing':
+ result = self.get_reference_content(fields)
+ else:
+ result = self.generate_reply_content(content)
+ if stream:
+ return NodeResult({'result': iter([AIMessageChunk(content=result)])}, {},
+ _to_response=to_stream_response)
+ else:
+ return NodeResult({'result': AIMessage(content=result)}, {}, _to_response=to_response)
+
+ def generate_reply_content(self, prompt):
+ return self.workflow_manage.generate_prompt(prompt)
+
+ def get_reference_content(self, fields: List[str]):
+ return str(self.workflow_manage.get_reference_field(
+ fields[0],
+ fields[1:]))
+
+ def get_details(self, index: int, **kwargs):
+ return {
+ 'name': self.node.properties.get('stepName'),
+ "index": index,
+ 'run_time': self.context.get('run_time'),
+ 'type': self.node.type,
+ 'answer': self.context.get('answer'),
+ 'status': self.status,
+ 'err_message': self.err_message
+ }
diff --git a/apps/application/flow/step_node/question_node/__init__.py b/apps/application/flow/step_node/question_node/__init__.py
new file mode 100644
index 000000000..98a1afcd9
--- /dev/null
+++ b/apps/application/flow/step_node/question_node/__init__.py
@@ -0,0 +1,9 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: __init__.py
+ @date:2024/6/11 15:30
+ @desc:
+"""
+from .impl import *
diff --git a/apps/application/flow/step_node/question_node/i_question_node.py b/apps/application/flow/step_node/question_node/i_question_node.py
new file mode 100644
index 000000000..ede120def
--- /dev/null
+++ b/apps/application/flow/step_node/question_node/i_question_node.py
@@ -0,0 +1,37 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: i_chat_node.py
+ @date:2024/6/4 13:58
+ @desc:
+"""
+from typing import Type
+
+from rest_framework import serializers
+
+from application.flow.i_step_node import INode, NodeResult
+from common.util.field_message import ErrMessage
+
+
+class QuestionNodeSerializer(serializers.Serializer):
+ model_id = serializers.CharField(required=True, error_messages=ErrMessage.char("模型id"))
+ system = serializers.CharField(required=False, allow_blank=True, allow_null=True,
+ error_messages=ErrMessage.char("角色设定"))
+ prompt = serializers.CharField(required=True, error_messages=ErrMessage.char("提示词"))
+ # 多轮对话数量
+ dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer("多轮对话数量"))
+
+
+class IQuestionNode(INode):
+ type = 'question-node'
+
+ def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
+ return QuestionNodeSerializer
+
+ def _run(self):
+ return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data)
+
+ def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id,
+ **kwargs) -> NodeResult:
+ pass
diff --git a/apps/application/flow/step_node/question_node/impl/__init__.py b/apps/application/flow/step_node/question_node/impl/__init__.py
new file mode 100644
index 000000000..d85aa8724
--- /dev/null
+++ b/apps/application/flow/step_node/question_node/impl/__init__.py
@@ -0,0 +1,9 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: __init__.py
+ @date:2024/6/11 15:35
+ @desc:
+"""
+from .base_question_node import BaseQuestionNode
diff --git a/apps/application/flow/step_node/question_node/impl/base_question_node.py b/apps/application/flow/step_node/question_node/impl/base_question_node.py
new file mode 100644
index 000000000..65fc52c32
--- /dev/null
+++ b/apps/application/flow/step_node/question_node/impl/base_question_node.py
@@ -0,0 +1,196 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: base_question_node.py
+ @date:2024/6/4 14:30
+ @desc:
+"""
+import json
+import time
+from functools import reduce
+from typing import List, Dict
+
+from django.db.models import QuerySet
+from langchain.schema import HumanMessage, SystemMessage
+from langchain_core.messages import BaseMessage
+
+from application.flow import tools
+from application.flow.i_step_node import NodeResult, INode
+from application.flow.step_node.question_node.i_question_node import IQuestionNode
+from common.util.rsa_util import rsa_long_decrypt
+from setting.models import Model
+from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
+
+
+def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
+ """
+ 写入上下文数据 (流式)
+ @param node_variable: 节点数据
+ @param workflow_variable: 全局数据
+ @param node: 节点
+ @param workflow: 工作流管理器
+ """
+ response = node_variable.get('result')
+ answer = ''
+ for chunk in response:
+ answer += chunk.content
+ chat_model = node_variable.get('chat_model')
+ message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list'))
+ answer_tokens = chat_model.get_num_tokens(answer)
+ node.context['message_tokens'] = message_tokens
+ node.context['answer_tokens'] = answer_tokens
+ node.context['answer'] = answer
+ node.context['history_message'] = node_variable['history_message']
+ node.context['question'] = node_variable['question']
+ node.context['run_time'] = time.time() - node.context['start_time']
+
+
+def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
+ """
+ 写入上下文数据
+ @param node_variable: 节点数据
+ @param workflow_variable: 全局数据
+ @param node: 节点实例对象
+ @param workflow: 工作流管理器
+ """
+ response = node_variable.get('result')
+ chat_model = node_variable.get('chat_model')
+ answer = response.content
+ message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list'))
+ answer_tokens = chat_model.get_num_tokens(answer)
+ node.context['message_tokens'] = message_tokens
+ node.context['answer_tokens'] = answer_tokens
+ node.context['answer'] = answer
+ node.context['history_message'] = node_variable['history_message']
+ node.context['question'] = node_variable['question']
+
+
+def get_to_response_write_context(node_variable: Dict, node: INode):
+ def _write_context(answer, status=200):
+ chat_model = node_variable.get('chat_model')
+
+ if status == 200:
+ answer_tokens = chat_model.get_num_tokens(answer)
+ message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list'))
+ else:
+ answer_tokens = 0
+ message_tokens = 0
+ node.err_message = answer
+ node.status = status
+ node.context['message_tokens'] = message_tokens
+ node.context['answer_tokens'] = answer_tokens
+ node.context['answer'] = answer
+ node.context['run_time'] = time.time() - node.context['start_time']
+
+ return _write_context
+
+
+def to_stream_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow,
+ post_handler):
+ """
+ 将流式数据 转换为 流式响应
+ @param chat_id: 会话id
+ @param chat_record_id: 对话记录id
+ @param node_variable: 节点数据
+ @param workflow_variable: 工作流数据
+ @param node: 节点
+ @param workflow: 工作流管理器
+ @param post_handler: 后置处理器 输出结果后执行
+ @return: 流式响应
+ """
+ response = node_variable.get('result')
+ _write_context = get_to_response_write_context(node_variable, node)
+ return tools.to_stream_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler)
+
+
+def to_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow,
+ post_handler):
+ """
+ 将结果转换
+ @param chat_id: 会话id
+ @param chat_record_id: 对话记录id
+ @param node_variable: 节点数据
+ @param workflow_variable: 工作流数据
+ @param node: 节点
+ @param workflow: 工作流管理器
+ @param post_handler: 后置处理器
+ @return: 响应
+ """
+ response = node_variable.get('result')
+ _write_context = get_to_response_write_context(node_variable, node)
+ return tools.to_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler)
+
+
+class BaseQuestionNode(IQuestionNode):
+ def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id,
+ **kwargs) -> NodeResult:
+ model = QuerySet(Model).filter(id=model_id).first()
+ chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name,
+ json.loads(
+ rsa_long_decrypt(model.credential)),
+ streaming=True)
+ history_message = self.get_history_message(history_chat_record, dialogue_number)
+ self.context['history_message'] = history_message
+ question = self.generate_prompt_question(prompt)
+ self.context['question'] = question.content
+ message_list = self.generate_message_list(system, prompt, history_message)
+ self.context['message_list'] = message_list
+ if stream:
+ r = chat_model.stream(message_list)
+ return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list,
+ 'get_to_response_write_context': get_to_response_write_context,
+ 'history_message': history_message, 'question': question.content}, {},
+ _write_context=write_context_stream,
+ _to_response=to_stream_response)
+ else:
+ r = chat_model.invoke(message_list)
+ return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list,
+ 'history_message': history_message, 'question': question.content}, {},
+ _write_context=write_context, _to_response=to_response)
+
+ @staticmethod
+ def get_history_message(history_chat_record, dialogue_number):
+ start_index = len(history_chat_record) - dialogue_number
+ history_message = reduce(lambda x, y: [*x, *y], [
+ [history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()]
+ for index in
+ range(start_index if start_index > 0 else 0, len(history_chat_record))], [])
+ return history_message
+
+ def generate_prompt_question(self, prompt):
+ return HumanMessage(self.workflow_manage.generate_prompt(prompt))
+
+ def generate_message_list(self, system: str, prompt: str, history_message):
+ if system is None or len(system) == 0:
+ return [SystemMessage(self.workflow_manage.generate_prompt(system)), *history_message,
+ HumanMessage(self.workflow_manage.generate_prompt(prompt))]
+ else:
+ return [*history_message, HumanMessage(self.workflow_manage.generate_prompt(prompt))]
+
+ @staticmethod
+ def reset_message_list(message_list: List[BaseMessage], answer_text):
+ result = [{'role': 'user' if isinstance(message, HumanMessage) else 'ai', 'content': message.content} for
+ message
+ in
+ message_list]
+ result.append({'role': 'ai', 'content': answer_text})
+ return result
+
+ def get_details(self, index: int, **kwargs):
+ return {
+ 'name': self.node.properties.get('stepName'),
+ "index": index,
+ 'run_time': self.context.get('run_time'),
+ 'system': self.node_params.get('system'),
+ 'history_message': [{'content': message.content, 'role': message.type} for message in
+ (self.context.get('history_message') if self.context.get(
+ 'history_message') is not None else [])],
+ 'question': self.context.get('question'),
+ 'answer': self.context.get('answer'),
+ 'type': self.node.type,
+ 'message_tokens': self.context.get('message_tokens'),
+ 'answer_tokens': self.context.get('answer_tokens'),
+ 'status': self.status,
+ 'err_message': self.err_message
+ }
diff --git a/apps/application/flow/step_node/search_dataset_node/__init__.py b/apps/application/flow/step_node/search_dataset_node/__init__.py
new file mode 100644
index 000000000..98a1afcd9
--- /dev/null
+++ b/apps/application/flow/step_node/search_dataset_node/__init__.py
@@ -0,0 +1,9 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: __init__.py
+ @date:2024/6/11 15:30
+ @desc:
+"""
+from .impl import *
diff --git a/apps/application/flow/step_node/search_dataset_node/i_search_dataset_node.py b/apps/application/flow/step_node/search_dataset_node/i_search_dataset_node.py
new file mode 100644
index 000000000..0a134527c
--- /dev/null
+++ b/apps/application/flow/step_node/search_dataset_node/i_search_dataset_node.py
@@ -0,0 +1,61 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: i_search_dataset_node.py
+ @date:2024/6/3 17:52
+ @desc:
+"""
+import re
+from typing import Type
+
+from django.core import validators
+from rest_framework import serializers
+
+from application.flow.i_step_node import INode, NodeResult
+from common.util.field_message import ErrMessage
+
+
+class DatasetSettingSerializer(serializers.Serializer):
+ # 需要查询的条数
+ top_n = serializers.IntegerField(required=True,
+ error_messages=ErrMessage.integer("引用分段数"))
+ # 相似度 0-1之间
+ similarity = serializers.FloatField(required=True, max_value=2, min_value=0,
+ error_messages=ErrMessage.float("引用分段数"))
+ search_mode = serializers.CharField(required=True, validators=[
+ validators.RegexValidator(regex=re.compile("^embedding|keywords|blend$"),
+ message="类型只支持register|reset_password", code=500)
+ ], error_messages=ErrMessage.char("检索模式"))
+ max_paragraph_char_number = serializers.IntegerField(required=True,
+ error_messages=ErrMessage.float("最大引用分段字数"))
+
+
+class SearchDatasetStepNodeSerializer(serializers.Serializer):
+ # 需要查询的数据集id列表
+ dataset_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True),
+ error_messages=ErrMessage.list("数据集id列表"))
+ dataset_setting = DatasetSettingSerializer(required=True)
+
+ question_reference_address = serializers.ListField(required=True, )
+
+ def is_valid(self, *, raise_exception=False):
+ super().is_valid(raise_exception=True)
+
+
+class ISearchDatasetStepNode(INode):
+ type = 'search-dataset-node'
+
+ def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
+ return SearchDatasetStepNodeSerializer
+
+ def _run(self):
+ question = self.workflow_manage.get_reference_field(
+ self.node_params_serializer.data.get('question_reference_address')[0],
+ self.node_params_serializer.data.get('question_reference_address')[1:])
+ return self.execute(**self.node_params_serializer.data, question=str(question), exclude_paragraph_id_list=[])
+
+ def execute(self, dataset_id_list, dataset_setting, question,
+ exclude_paragraph_id_list=None,
+ **kwargs) -> NodeResult:
+ pass
diff --git a/apps/application/flow/step_node/search_dataset_node/impl/__init__.py b/apps/application/flow/step_node/search_dataset_node/impl/__init__.py
new file mode 100644
index 000000000..a9cff0d09
--- /dev/null
+++ b/apps/application/flow/step_node/search_dataset_node/impl/__init__.py
@@ -0,0 +1,9 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: __init__.py
+ @date:2024/6/11 15:35
+ @desc:
+"""
+from .base_search_dataset_node import BaseSearchDatasetNode
diff --git a/apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py b/apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py
new file mode 100644
index 000000000..20e0af9fc
--- /dev/null
+++ b/apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py
@@ -0,0 +1,93 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: base_search_dataset_node.py
+ @date:2024/6/4 11:56
+ @desc:
+"""
+import os
+from typing import List, Dict
+
+from django.db.models import QuerySet
+
+from application.flow.i_step_node import NodeResult
+from application.flow.step_node.search_dataset_node.i_search_dataset_node import ISearchDatasetStepNode
+from common.config.embedding_config import EmbeddingModel, VectorStore
+from common.db.search import native_search
+from common.util.file_util import get_file_content
+from dataset.models import Document, Paragraph
+from embedding.models import SearchMode
+from smartdoc.conf import PROJECT_DIR
+
+
+class BaseSearchDatasetNode(ISearchDatasetStepNode):
+ def execute(self, dataset_id_list, dataset_setting, question,
+ exclude_paragraph_id_list=None,
+ **kwargs) -> NodeResult:
+ self.context['question'] = question
+ embedding_model = EmbeddingModel.get_embedding_model()
+ embedding_value = embedding_model.embed_query(question)
+ vector = VectorStore.get_embedding_vector()
+ exclude_document_id_list = [str(document.id) for document in
+ QuerySet(Document).filter(
+ dataset_id__in=dataset_id_list,
+ is_active=False)]
+ embedding_list = vector.query(question, embedding_value, dataset_id_list, exclude_document_id_list,
+ exclude_paragraph_id_list, True, dataset_setting.get('top_n'),
+ dataset_setting.get('similarity'), SearchMode(dataset_setting.get('search_mode')))
+ if embedding_list is None:
+ return NodeResult({'paragraph_list': [], 'is_hit_handling_method': []}, {})
+ paragraph_list = self.list_paragraph(embedding_list, vector)
+ result = [self.reset_paragraph(paragraph, embedding_list) for paragraph in paragraph_list]
+ return NodeResult({'paragraph_list': result,
+ 'is_hit_handling_method_list': [row for row in result if row.get('is_hit_handling_method')],
+ 'data': '\n'.join([paragraph.get('content') for paragraph in paragraph_list]),
+ 'directly_return': '\n'.join([paragraph.get('content') for paragraph in result if
+ paragraph.get('is_hit_handling_method')]),
+ 'question': question},
+
+ {})
+
+ @staticmethod
+ def reset_paragraph(paragraph: Dict, embedding_list: List):
+ filter_embedding_list = [embedding for embedding in embedding_list if
+ str(embedding.get('paragraph_id')) == str(paragraph.get('id'))]
+ if filter_embedding_list is not None and len(filter_embedding_list) > 0:
+ find_embedding = filter_embedding_list[-1]
+ return {
+ **paragraph,
+ 'similarity': find_embedding.get('similarity'),
+ 'is_hit_handling_method': find_embedding.get('similarity') > paragraph.get(
+ 'directly_return_similarity') and paragraph.get('hit_handling_method') == 'directly_return'
+ }
+
+ @staticmethod
+ def list_paragraph(embedding_list: List, vector):
+ paragraph_id_list = [row.get('paragraph_id') for row in embedding_list]
+ if paragraph_id_list is None or len(paragraph_id_list) == 0:
+ return []
+ paragraph_list = native_search(QuerySet(Paragraph).filter(id__in=paragraph_id_list),
+ get_file_content(
+ os.path.join(PROJECT_DIR, "apps", "application", 'sql',
+ 'list_dataset_paragraph_by_paragraph_id.sql')),
+ with_table_name=True)
+ # 如果向量库中存在脏数据 直接删除
+ if len(paragraph_list) != len(paragraph_id_list):
+ exist_paragraph_list = [row.get('id') for row in paragraph_list]
+ for paragraph_id in paragraph_id_list:
+ if not exist_paragraph_list.__contains__(paragraph_id):
+ vector.delete_by_paragraph_id(paragraph_id)
+ return paragraph_list
+
+ def get_details(self, index: int, **kwargs):
+ return {
+ 'name': self.node.properties.get('stepName'),
+ 'question': self.context.get('question'),
+ "index": index,
+ 'run_time': self.context.get('run_time'),
+ 'paragraph_list': self.context.get('paragraph_list'),
+ 'type': self.node.type,
+ 'status': self.status,
+ 'err_message': self.err_message
+ }
diff --git a/apps/application/flow/step_node/start_node/__init__.py b/apps/application/flow/step_node/start_node/__init__.py
new file mode 100644
index 000000000..98a1afcd9
--- /dev/null
+++ b/apps/application/flow/step_node/start_node/__init__.py
@@ -0,0 +1,9 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: __init__.py
+ @date:2024/6/11 15:30
+ @desc:
+"""
+from .impl import *
diff --git a/apps/application/flow/step_node/start_node/i_start_node.py b/apps/application/flow/step_node/start_node/i_start_node.py
new file mode 100644
index 000000000..4c1ecfd2a
--- /dev/null
+++ b/apps/application/flow/step_node/start_node/i_start_node.py
@@ -0,0 +1,26 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: i_start_node.py
+ @date:2024/6/3 16:54
+ @desc:
+"""
+from typing import Type
+
+from rest_framework import serializers
+
+from application.flow.i_step_node import INode, NodeResult
+
+
+class IStarNode(INode):
+ type = 'start-node'
+
+ def get_node_params_serializer_class(self) -> Type[serializers.Serializer] | None:
+ return None
+
+ def _run(self):
+ return self.execute(**self.flow_params_serializer.data)
+
+ def execute(self, question, **kwargs) -> NodeResult:
+ pass
diff --git a/apps/application/flow/step_node/start_node/impl/__init__.py b/apps/application/flow/step_node/start_node/impl/__init__.py
new file mode 100644
index 000000000..b68a92d02
--- /dev/null
+++ b/apps/application/flow/step_node/start_node/impl/__init__.py
@@ -0,0 +1,9 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: __init__.py
+ @date:2024/6/11 15:36
+ @desc:
+"""
+from .base_start_node import BaseStartStepNode
diff --git a/apps/application/flow/step_node/start_node/impl/base_start_node.py b/apps/application/flow/step_node/start_node/impl/base_start_node.py
new file mode 100644
index 000000000..7043e42eb
--- /dev/null
+++ b/apps/application/flow/step_node/start_node/impl/base_start_node.py
@@ -0,0 +1,33 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: base_start_node.py
+ @date:2024/6/3 17:17
+ @desc:
+"""
+import time
+from datetime import datetime
+
+from application.flow.i_step_node import NodeResult
+from application.flow.step_node.start_node.i_start_node import IStarNode
+
+
+class BaseStartStepNode(IStarNode):
+ def execute(self, question, **kwargs) -> NodeResult:
+ """
+ 开始节点 初始化全局变量
+ """
+ return NodeResult({'question': question},
+ {'time': datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 'start_time': time.time()})
+
+ def get_details(self, index: int, **kwargs):
+ return {
+ 'name': self.node.properties.get('stepName'),
+ "index": index,
+ "question": self.context.get('question'),
+ 'run_time': self.context.get('run_time'),
+ 'type': self.node.type,
+ 'status': self.status,
+ 'err_message': self.err_message
+ }
diff --git a/apps/application/flow/tools.py b/apps/application/flow/tools.py
new file mode 100644
index 000000000..839aae8da
--- /dev/null
+++ b/apps/application/flow/tools.py
@@ -0,0 +1,87 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: utils.py
+ @date:2024/6/6 15:15
+ @desc:
+"""
+import json
+from typing import Iterator
+
+from django.http import StreamingHttpResponse
+from langchain_core.messages import BaseMessageChunk, BaseMessage
+
+from application.flow.i_step_node import WorkFlowPostHandler
+from common.response import result
+
+
+def event_content(chat_id, chat_record_id, response, workflow,
+ write_context,
+ post_handler: WorkFlowPostHandler):
+ """
+ 用于处理流式输出
+ @param chat_id: 会话id
+ @param chat_record_id: 对话记录id
+ @param response: 响应数据
+ @param workflow: 工作流管理器
+ @param write_context 写入节点上下文
+ @param post_handler: 后置处理器
+ """
+ answer = ''
+ try:
+ for chunk in response:
+ answer += chunk.content
+ yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
+ 'content': chunk.content, 'is_end': False}, ensure_ascii=False) + "\n\n"
+ write_context(answer, 200)
+ post_handler.handler(chat_id, chat_record_id, answer, workflow)
+ yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
+ 'content': '', 'is_end': True}, ensure_ascii=False) + "\n\n"
+ except Exception as e:
+ answer = str(e)
+ write_context(answer, 500)
+ post_handler.handler(chat_id, chat_record_id, answer, workflow)
+ yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
+ 'content': answer, 'is_end': True}, ensure_ascii=False) + "\n\n"
+
+
+def to_stream_response(chat_id, chat_record_id, response: Iterator[BaseMessageChunk], workflow, write_context,
+ post_handler):
+ """
+ 将结果转换为服务流输出
+ @param chat_id: 会话id
+ @param chat_record_id: 对话记录id
+ @param response: 响应数据
+ @param workflow: 工作流管理器
+ @param write_context 写入节点上下文
+ @param post_handler: 后置处理器
+ @return: 响应
+ """
+ r = StreamingHttpResponse(
+ streaming_content=event_content(chat_id, chat_record_id, response, workflow, write_context, post_handler),
+ content_type='text/event-stream;charset=utf-8',
+ charset='utf-8')
+
+ r['Cache-Control'] = 'no-cache'
+ return r
+
+
+def to_response(chat_id, chat_record_id, response: BaseMessage, workflow, write_context,
+ post_handler: WorkFlowPostHandler):
+ """
+ 将结果转换为服务输出
+
+ @param chat_id: 会话id
+ @param chat_record_id: 对话记录id
+ @param response: 响应数据
+ @param workflow: 工作流管理器
+ @param write_context 写入节点上下文
+ @param post_handler: 后置处理器
+ @return: 响应
+ """
+ answer = response.content
+ write_context(answer)
+ post_handler.handler(chat_id, chat_record_id, answer, workflow)
+ return result.success({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
+ 'content': answer, 'is_end': True})
diff --git a/apps/application/flow/workflow_manage.py b/apps/application/flow/workflow_manage.py
new file mode 100644
index 000000000..f99d9351b
--- /dev/null
+++ b/apps/application/flow/workflow_manage.py
@@ -0,0 +1,282 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: workflow_manage.py
+ @date:2024/1/9 17:40
+ @desc:
+"""
+from functools import reduce
+from typing import List, Dict
+
+from langchain_core.messages import AIMessageChunk, AIMessage
+from langchain_core.prompts import PromptTemplate
+
+from application.flow import tools
+from application.flow.i_step_node import INode, WorkFlowPostHandler, NodeResult
+from application.flow.step_node import get_node
+from common.exception.app_exception import AppApiException
+
+
+class Edge:
+ def __init__(self, _id: str, _type: str, sourceNodeId: str, targetNodeId: str, **keywords):
+ self.id = _id
+ self.type = _type
+ self.sourceNodeId = sourceNodeId
+ self.targetNodeId = targetNodeId
+ for keyword in keywords:
+ self.__setattr__(keyword, keywords.get(keyword))
+
+
+class Node:
+ def __init__(self, _id: str, _type: str, x: int, y: int, properties: dict, **kwargs):
+ self.id = _id
+ self.type = _type
+ self.x = x
+ self.y = y
+ self.properties = properties
+ for keyword in kwargs:
+ self.__setattr__(keyword, kwargs.get(keyword))
+
+
+end_nodes = ['ai-chat-node', 'reply-node']
+
+
+class Flow:
+ def __init__(self, nodes: List[Node], edges: List[Edge]):
+ self.nodes = nodes
+ self.edges = edges
+
+ @staticmethod
+ def new_instance(flow_obj: Dict):
+ nodes = flow_obj.get('nodes')
+ edges = flow_obj.get('edges')
+ nodes = [Node(node.get('id'), node.get('type'), **node)
+ for node in nodes]
+ edges = [Edge(edge.get('id'), edge.get('type'), **edge) for edge in edges]
+ return Flow(nodes, edges)
+
+ def get_start_node(self):
+ start_node_list = [node for node in self.nodes if node.id == 'start-node']
+ return start_node_list[0]
+
+ def is_valid(self):
+ """
+ 校验工作流数据
+ """
+ self.is_valid_start_node()
+ self.is_valid_base_node()
+ self.is_valid_work_flow()
+
+ @staticmethod
+ def is_valid_node_params(node: Node):
+ get_node(node.type)(node, None, None)
+
+ def is_valid_node(self, node: Node):
+ self.is_valid_node_params(node)
+ if node.type == 'condition-node':
+ branch_list = node.properties.get('node_data').get('branch')
+ for branch in branch_list:
+ source_anchor_id = f"{node.id}_{branch.get('id')}_right"
+ edge_list = [edge for edge in self.edges if edge.sourceAnchorId == source_anchor_id]
+ if len(edge_list) == 0:
+ raise AppApiException(500,
+ f'{node.properties.get("stepName")} 节点的{branch.get("type")}分支需要连接')
+ elif len(edge_list) > 1:
+ raise AppApiException(500,
+ f'{node.properties.get("stepName")} 节点的{branch.get("type")}分支不能连接俩个节点')
+
+ else:
+ edge_list = [edge for edge in self.edges if edge.sourceNodeId == node.id]
+ if len(edge_list) == 0 and not end_nodes.__contains__(node.type):
+ raise AppApiException(500, f'{node.properties.get("stepName")} 节点不能当做结束节点')
+ elif len(edge_list) > 1:
+ raise AppApiException(500,
+ f'{node.properties.get("stepName")} 节点不能连接俩个节点')
+
+ def get_next_nodes(self, node: Node):
+ edge_list = [edge for edge in self.edges if edge.sourceNodeId == node.id]
+ node_list = reduce(lambda x, y: [*x, *y],
+ [[node for node in self.nodes if node.id == edge.targetNodeId] for edge in edge_list],
+ [])
+ if len(node_list) == 0 and not end_nodes.__contains__(node.type):
+ raise AppApiException(500,
+ f'不存在的下一个节点')
+ return node_list
+
+ def is_valid_work_flow(self, up_node=None):
+ if up_node is None:
+ up_node = self.get_start_node()
+ self.is_valid_node(up_node)
+ next_nodes = self.get_next_nodes(up_node)
+ for next_node in next_nodes:
+ self.is_valid_work_flow(next_node)
+
+ def is_valid_start_node(self):
+ start_node_list = [node for node in self.nodes if node.id == 'start-node']
+ if len(start_node_list) == 0:
+ raise AppApiException(500, '开始节点必填')
+ if len(start_node_list) > 1:
+ raise AppApiException(500, '开始节点只能有一个')
+
+ def is_valid_base_node(self):
+ base_node_list = [node for node in self.nodes if node.id == 'base-node']
+ if len(base_node_list) == 0:
+ raise AppApiException(500, '基本信息节点必填')
+ if len(base_node_list) > 1:
+ raise AppApiException(500, '基本信息节点只能有一个')
+
+
+class WorkflowManage:
+ def __init__(self, flow: Flow, params, work_flow_post_handler: WorkFlowPostHandler):
+ self.params = params
+ self.flow = flow
+ self.context = {}
+ self.node_context = []
+ self.work_flow_post_handler = work_flow_post_handler
+ self.current_node = None
+ self.current_result = None
+
+ def run(self):
+ """
+ 运行工作流
+ """
+ try:
+ while self.has_next_node(self.current_result):
+ self.current_node = self.get_next_node()
+ self.node_context.append(self.current_node)
+ self.current_result = self.current_node.run()
+ if self.has_next_node(self.current_result):
+ self.current_result.write_context(self.current_node, self)
+ else:
+ r = self.current_result.to_response(self.params['chat_id'], self.params['chat_record_id'],
+ self.current_node, self,
+ self.work_flow_post_handler)
+ return r
+ except Exception as e:
+ if self.params.get('stream'):
+ return tools.to_stream_response(self.params['chat_id'], self.params['chat_record_id'],
+ iter([AIMessageChunk(str(e))]), self,
+ self.current_node.get_write_error_context(e),
+ self.work_flow_post_handler)
+ else:
+ return tools.to_response(self.params['chat_id'], self.params['chat_record_id'],
+ AIMessage(str(e)), self, self.current_node.get_write_error_context(e),
+ self.work_flow_post_handler)
+
+ def has_next_node(self, node_result: NodeResult | None):
+ """
+ 是否有下一个可运行的节点
+ """
+ if self.current_node is None:
+ if self.get_start_node() is not None:
+ return True
+ else:
+ if node_result is not None and node_result.is_assertion_result():
+ for edge in self.flow.edges:
+ if (edge.sourceNodeId == self.current_node.id and
+ f"{edge.sourceNodeId}_{node_result.node_variable.get('branch_id')}_right" == edge.sourceAnchorId):
+ return True
+ else:
+ for edge in self.flow.edges:
+ if edge.sourceNodeId == self.current_node.id:
+ return True
+ return False
+
+ def get_runtime_details(self):
+ details_result = {}
+ for index in range(len(self.node_context)):
+ node = self.node_context[index]
+ details = node.get_details(index)
+ details_result[node.id] = details
+ return details_result
+
+ def get_next_node(self):
+ """
+ 获取下一个可运行的所有节点
+ """
+ if self.current_node is None:
+ node = self.get_start_node()
+ node_instance = get_node(node.type)(node, self.params, self.context)
+ return node_instance
+ if self.current_result is not None and self.current_result.is_assertion_result():
+ for edge in self.flow.edges:
+ if (edge.sourceNodeId == self.current_node.id and
+ f"{edge.sourceNodeId}_{self.current_result.node_variable.get('branch_id')}_right" == edge.sourceAnchorId):
+ return self.get_node_cls_by_id(edge.targetNodeId)
+ else:
+ for edge in self.flow.edges:
+ if edge.sourceNodeId == self.current_node.id:
+ return self.get_node_cls_by_id(edge.targetNodeId)
+
+ return None
+
+ def get_reference_field(self, node_id: str, fields: List[str]):
+ """
+
+ @param node_id: 节点id
+ @param fields: 字段
+ @return:
+ """
+ if node_id == 'global':
+ return INode.get_field(self.context, fields)
+ else:
+ return self.get_node_by_id(node_id).get_reference_field(fields)
+
+ def generate_prompt(self, prompt: str):
+ """
+ 格式化生成提示词
+ @param prompt: 提示词信息
+ @return: 格式化后的提示词
+ """
+ context = {
+ 'global': self.context,
+ }
+
+ for node in self.node_context:
+ properties = node.node.properties
+ node_config = properties.get('config')
+ if node_config is not None:
+ fields = node_config.get('fields')
+ if fields is not None:
+ for field in fields:
+ globeLabel = f"{properties.get('stepName')}.{field.get('value')}"
+ globeValue = f"context['{node.id}'].{field.get('value')}"
+ prompt = prompt.replace(globeLabel, globeValue)
+ global_fields = node_config.get('globalFields')
+ if global_fields is not None:
+ for field in global_fields:
+ globeLabel = f"全局变量.{field.get('value')}"
+ globeValue = f"context['global'].{field.get('value')}"
+ prompt = prompt.replace(globeLabel, globeValue)
+ context[node.id] = node.context
+ prompt_template = PromptTemplate.from_template(prompt, template_format='jinja2')
+
+ value = prompt_template.format(context=context)
+ return value
+
+ def get_start_node(self):
+ """
+ 获取启动节点
+ @return:
+ """
+ start_node_list = [node for node in self.flow.nodes if node.type == 'start-node']
+ return start_node_list[0]
+
+ def get_node_cls_by_id(self, node_id):
+ for node in self.flow.nodes:
+ if node.id == node_id:
+ node_instance = get_node(node.type)(node,
+ self.params, self)
+ return node_instance
+ return None
+
+ def get_node_by_id(self, node_id):
+ for node in self.node_context:
+ if node.id == node_id:
+ return node
+ return None
+
+ def get_node_reference(self, reference_address: Dict):
+ node = self.get_node_by_id(reference_address.get('node_id'))
+ return node.context[reference_address.get('node_field')]
diff --git a/apps/application/migrations/0009_application_type_application_work_flow_and_more.py b/apps/application/migrations/0009_application_type_application_work_flow_and_more.py
new file mode 100644
index 000000000..5d0bf0c9f
--- /dev/null
+++ b/apps/application/migrations/0009_application_type_application_work_flow_and_more.py
@@ -0,0 +1,38 @@
+# Generated by Django 4.1.13 on 2024-06-25 16:30
+
+from django.db import migrations, models
+import django.db.models.deletion
+import uuid
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ('application', '0008_chat_is_deleted'),
+ ]
+
+ operations = [
+ migrations.AddField(
+ model_name='application',
+ name='type',
+ field=models.CharField(choices=[('SIMPLE', '简易'), ('WORK_FLOW', '工作流')], default='SIMPLE', max_length=256, verbose_name='应用类型'),
+ ),
+ migrations.AddField(
+ model_name='application',
+ name='work_flow',
+ field=models.JSONField(default=dict, verbose_name='工作流数据'),
+ ),
+ migrations.CreateModel(
+ name='WorkFlowVersion',
+ fields=[
+ ('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')),
+ ('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')),
+ ('id', models.UUIDField(default=uuid.uuid1, editable=False, primary_key=True, serialize=False, verbose_name='主键id')),
+ ('work_flow', models.JSONField(default=dict, verbose_name='工作流数据')),
+ ('application', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='application.application')),
+ ],
+ options={
+ 'db_table': 'application_work_flow_version',
+ },
+ ),
+ ]
diff --git a/apps/application/models/application.py b/apps/application/models/application.py
index 073e980c7..bdd0a672e 100644
--- a/apps/application/models/application.py
+++ b/apps/application/models/application.py
@@ -6,6 +6,8 @@
@date:2023/9/25 14:24
@desc:
"""
+import datetime
+import json
import uuid
from django.contrib.postgres.fields import ArrayField
@@ -18,6 +20,12 @@ from setting.models.model_management import Model
from users.models import User
+class ApplicationTypeChoices(models.TextChoices):
+ """订单类型"""
+ SIMPLE = 'SIMPLE', '简易'
+ WORK_FLOW = 'WORK_FLOW', '工作流'
+
+
def get_dataset_setting_dict():
return {'top_n': 3, 'similarity': 0.6, 'max_paragraph_char_number': 5000, 'search_mode': 'embedding',
'no_references_setting': {
@@ -42,6 +50,9 @@ class Application(AppModelMixin):
model_setting = models.JSONField(verbose_name="模型参数相关设置", default=get_model_setting_dict)
problem_optimization = models.BooleanField(verbose_name="问题优化", default=False)
icon = models.CharField(max_length=256, verbose_name="应用icon", default="/ui/favicon.ico")
+ work_flow = models.JSONField(verbose_name="工作流数据", default=dict)
+ type = models.CharField(verbose_name="应用类型", choices=ApplicationTypeChoices.choices,
+ default=ApplicationTypeChoices.SIMPLE, max_length=256)
@staticmethod
def get_default_model_prompt():
@@ -61,6 +72,15 @@ class Application(AppModelMixin):
db_table = "application"
+class WorkFlowVersion(AppModelMixin):
+ id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id")
+ application = models.ForeignKey(Application, on_delete=models.CASCADE)
+ work_flow = models.JSONField(verbose_name="工作流数据", default=dict)
+
+ class Meta:
+ db_table = "application_work_flow_version"
+
+
class ApplicationDatasetMapping(AppModelMixin):
id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id")
application = models.ForeignKey(Application, on_delete=models.CASCADE)
@@ -88,6 +108,16 @@ class VoteChoices(models.TextChoices):
TRAMPLE = 1, '反对'
+class DateEncoder(json.JSONEncoder):
+ def default(self, obj):
+ if isinstance(obj, uuid.UUID):
+ return str(obj)
+ if isinstance(obj, datetime.datetime):
+ return obj.strftime("%Y-%m-%d %H:%M:%S")
+ else:
+ return json.JSONEncoder.default(self, obj)
+
+
class ChatRecord(AppModelMixin):
"""
对话日志 详情
@@ -101,7 +131,7 @@ class ChatRecord(AppModelMixin):
message_tokens = models.IntegerField(verbose_name="请求token数量", default=0)
answer_tokens = models.IntegerField(verbose_name="响应token数量", default=0)
const = models.IntegerField(verbose_name="总费用", default=0)
- details = models.JSONField(verbose_name="对话详情", default=dict)
+ details = models.JSONField(verbose_name="对话详情", default=dict, encoder=DateEncoder)
improve_paragraph_id_list = ArrayField(verbose_name="改进标注列表",
base_field=models.UUIDField(max_length=128, blank=True)
, default=list)
diff --git a/apps/application/serializers/application_serializers.py b/apps/application/serializers/application_serializers.py
index 47ecb4446..40e447d48 100644
--- a/apps/application/serializers/application_serializers.py
+++ b/apps/application/serializers/application_serializers.py
@@ -7,6 +7,7 @@
@desc:
"""
import hashlib
+import json
import os
import re
import uuid
@@ -22,7 +23,8 @@ from django.http import HttpResponse
from django.template import Template, Context
from rest_framework import serializers
-from application.models import Application, ApplicationDatasetMapping
+from application.flow.workflow_manage import Flow
+from application.models import Application, ApplicationDatasetMapping, ApplicationTypeChoices, WorkFlowVersion
from application.models.api_key_model import ApplicationAccessToken, ApplicationApiKey
from common.config.embedding_config import VectorStore, EmbeddingModel
from common.constants.authentication_type import AuthenticationType
@@ -105,6 +107,47 @@ class ModelSettingSerializer(serializers.Serializer):
prompt = serializers.CharField(required=True, max_length=2048, error_messages=ErrMessage.char("提示词"))
+class ApplicationWorkflowSerializer(serializers.Serializer):
+ name = serializers.CharField(required=True, max_length=64, min_length=1, error_messages=ErrMessage.char("应用名称"))
+ desc = serializers.CharField(required=False, allow_null=True, allow_blank=True,
+ max_length=256, min_length=1,
+ error_messages=ErrMessage.char("应用描述"))
+ prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True, max_length=4096,
+ error_messages=ErrMessage.char("开场白"))
+
+ @staticmethod
+ def to_application_model(user_id: str, application: Dict):
+
+ default_workflow_json = get_file_content(
+ os.path.join(PROJECT_DIR, "apps", "application", 'flow', 'default_workflow.json'))
+ default_workflow = json.loads(default_workflow_json)
+ for node in default_workflow.get('nodes'):
+ if node.get('id') == 'base-node':
+ node.get('properties')['node_data'] = {"desc": application.get('desc'),
+ "name": application.get('name'),
+ "prologue": application.get('prologue')}
+ return Application(id=uuid.uuid1(),
+ name=application.get('name'),
+ desc=application.get('desc'),
+ prologue="",
+ dialogue_number=0,
+ user_id=user_id, model_id=None,
+ dataset_setting={},
+ model_setting={},
+ problem_optimization=False,
+ type=ApplicationTypeChoices.WORK_FLOW,
+ work_flow=default_workflow
+ )
+
+
+def get_base_node_work_flow(work_flow):
+ node_list = work_flow.get('nodes')
+ base_node_list = [node for node in node_list if node.get('id') == 'base-node']
+ if len(base_node_list) > 0:
+ return base_node_list[-1]
+ return None
+
+
class ApplicationSerializer(serializers.Serializer):
name = serializers.CharField(required=True, max_length=64, min_length=1, error_messages=ErrMessage.char("应用名称"))
desc = serializers.CharField(required=False, allow_null=True, allow_blank=True,
@@ -123,6 +166,13 @@ class ApplicationSerializer(serializers.Serializer):
model_setting = ModelSettingSerializer(required=True)
# 问题补全
problem_optimization = serializers.BooleanField(required=True, error_messages=ErrMessage.boolean("问题补全"))
+ # 应用类型
+ type = serializers.CharField(required=True, error_messages=ErrMessage.char("应用类型"),
+ validators=[
+ validators.RegexValidator(regex=re.compile("^SIMPLE|WORK_FLOW$"),
+ message="应用类型只支持SIMPLE|WORK_FLOW", code=500)
+ ]
+ )
def is_valid(self, *, user_id=None, raise_exception=False):
super().is_valid(raise_exception=True)
@@ -281,6 +331,24 @@ class ApplicationSerializer(serializers.Serializer):
@transaction.atomic
def insert(self, application: Dict):
+ application_type = application.get('type')
+ if 'WORK_FLOW' == application_type:
+ return self.insert_workflow(application)
+ else:
+ return self.insert_simple(application)
+
+ def insert_workflow(self, application: Dict):
+ self.is_valid(raise_exception=True)
+ user_id = self.data.get('user_id')
+ ApplicationWorkflowSerializer(data=application).is_valid(raise_exception=True)
+ application_model = ApplicationWorkflowSerializer.to_application_model(user_id, application)
+ application_model.save()
+ # 插入认证信息
+ ApplicationAccessToken(application_id=application_model.id,
+ access_token=hashlib.md5(str(uuid.uuid1()).encode()).hexdigest()[8:24]).save()
+ return ApplicationSerializerModel(application_model).data
+
+ def insert_simple(self, application: Dict):
self.is_valid(raise_exception=True)
user_id = self.data.get('user_id')
ApplicationSerializer(data=application).is_valid(user_id=user_id, raise_exception=True)
@@ -296,7 +364,7 @@ class ApplicationSerializer(serializers.Serializer):
access_token=hashlib.md5(str(uuid.uuid1()).encode()).hexdigest()[8:24]).save()
# 插入关联数据
QuerySet(ApplicationDatasetMapping).bulk_create(application_dataset_mapping_model_list)
- return True
+ return ApplicationSerializerModel(application_model).data
@staticmethod
def to_application_model(user_id: str, application: Dict):
@@ -306,7 +374,9 @@ class ApplicationSerializer(serializers.Serializer):
user_id=user_id, model_id=application.get('model_id'),
dataset_setting=application.get('dataset_setting'),
model_setting=application.get('model_setting'),
- problem_optimization=application.get('problem_optimization')
+ problem_optimization=application.get('problem_optimization'),
+ type=ApplicationTypeChoices.SIMPLE,
+ work_flow={}
)
@staticmethod
@@ -420,7 +490,7 @@ class ApplicationSerializer(serializers.Serializer):
class ApplicationModel(serializers.ModelSerializer):
class Meta:
model = Application
- fields = ['id', 'name', 'desc', 'prologue', 'dialogue_number', 'icon']
+ fields = ['id', 'name', 'desc', 'prologue', 'dialogue_number', 'icon', 'type']
class IconOperate(serializers.Serializer):
application_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("应用id"))
@@ -463,6 +533,27 @@ class ApplicationSerializer(serializers.Serializer):
QuerySet(Application).filter(id=self.data.get('application_id')).delete()
return True
+ def publish(self, instance, with_valid=True):
+ if with_valid:
+ self.is_valid()
+ application = QuerySet(Application).filter(id=self.data.get("application_id")).first()
+ work_flow = instance.get('work_flow')
+ if work_flow is None:
+ raise AppApiException(500, "work_flow是必填字段")
+ Flow.new_instance(work_flow).is_valid()
+ base_node = get_base_node_work_flow(work_flow)
+ if base_node is not None:
+ node_data = base_node.get('properties').get('node_data')
+ if node_data is not None:
+ application.name = node_data.get('name')
+ application.desc = node_data.get('desc')
+ application.prologue = node_data.get('prologue')
+ application.work_flow = work_flow
+ application.save()
+ work_flow_version = WorkFlowVersion(work_flow=work_flow, application=application)
+ work_flow_version.save()
+ return True
+
def one(self, with_valid=True):
if with_valid:
self.is_valid()
@@ -507,7 +598,7 @@ class ApplicationSerializer(serializers.Serializer):
raise AppApiException(500, "模型不存在")
update_keys = ['name', 'desc', 'model_id', 'multiple_rounds_dialogue', 'prologue', 'status',
'dataset_setting', 'model_setting', 'problem_optimization',
- 'api_key_is_active', 'icon']
+ 'api_key_is_active', 'icon', 'work_flow']
for update_key in update_keys:
if update_key in instance and instance.get(update_key) is not None:
if update_key == 'multiple_rounds_dialogue':
diff --git a/apps/application/serializers/chat_message_serializers.py b/apps/application/serializers/chat_message_serializers.py
index cc658f71d..41f19bc0d 100644
--- a/apps/application/serializers/chat_message_serializers.py
+++ b/apps/application/serializers/chat_message_serializers.py
@@ -7,6 +7,7 @@
@desc:
"""
import json
+import uuid
from typing import List
from uuid import UUID
@@ -22,7 +23,10 @@ from application.chat_pipeline.step.generate_human_message_step.impl.base_genera
BaseGenerateHumanMessageStep
from application.chat_pipeline.step.reset_problem_step.impl.base_reset_problem_step import BaseResetProblemStep
from application.chat_pipeline.step.search_dataset_step.impl.base_search_dataset_step import BaseSearchDatasetStep
-from application.models import ChatRecord, Chat, Application, ApplicationDatasetMapping
+from application.flow.i_step_node import WorkFlowPostHandler
+from application.flow.workflow_manage import WorkflowManage, Flow
+from application.models import ChatRecord, Chat, Application, ApplicationDatasetMapping, ApplicationTypeChoices, \
+ WorkFlowVersion
from application.models.api_key_model import ApplicationPublicAccessClient, ApplicationAccessToken
from common.constants.authentication_type import AuthenticationType
from common.exception.app_exception import AppApiException, AppChatNumOutOfBoundsFailed
@@ -39,10 +43,11 @@ chat_cache = caches['model_cache']
class ChatInfo:
def __init__(self,
chat_id: str,
- chat_model: BaseChatModel,
+ chat_model: BaseChatModel | None,
dataset_id_list: List[str],
exclude_document_id_list: list[str],
- application: Application):
+ application: Application,
+ work_flow_version: WorkFlowVersion = None):
"""
:param chat_id: 对话id
:param chat_model: 对话模型
@@ -56,6 +61,7 @@ class ChatInfo:
self.dataset_id_list = dataset_id_list
self.exclude_document_id_list = exclude_document_id_list
self.chat_record_list: List[ChatRecord] = []
+ self.work_flow_version = work_flow_version
def to_base_pipeline_manage_params(self):
dataset_setting = self.application.dataset_setting
@@ -146,8 +152,10 @@ class ChatMessageSerializer(serializers.Serializer):
client_id = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端id"))
client_type = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端类型"))
- def is_valid(self, *, raise_exception=False):
- super().is_valid(raise_exception=True)
+ def is_valid_application_workflow(self, *, raise_exception=False):
+ self.is_valid_intraday_access_num()
+
+ def is_valid_intraday_access_num(self):
if self.data.get('client_type') == AuthenticationType.APPLICATION_ACCESS_TOKEN.value:
access_client = QuerySet(ApplicationPublicAccessClient).filter(id=self.data.get('client_id')).first()
if access_client is None:
@@ -161,12 +169,9 @@ class ChatMessageSerializer(serializers.Serializer):
application_id=self.data.get('application_id')).first()
if application_access_token.access_num <= access_client.intraday_access_num:
raise AppChatNumOutOfBoundsFailed(1002, "访问次数超过今日访问量")
- chat_id = self.data.get('chat_id')
- chat_info: ChatInfo = chat_cache.get(chat_id)
- if chat_info is None:
- chat_info = self.re_open_chat(chat_id)
- chat_cache.set(chat_id,
- chat_info, timeout=60 * 30)
+
+ def is_valid_application_simple(self, *, chat_info: ChatInfo, raise_exception=False):
+ self.is_valid_intraday_access_num()
model = chat_info.application.model
if model is None:
return chat_info
@@ -179,8 +184,7 @@ class ChatMessageSerializer(serializers.Serializer):
raise AppApiException(500, "模型正在下载中,请稍后再发起对话")
return chat_info
- def chat(self):
- chat_info = self.is_valid(raise_exception=True)
+ def chat_simple(self, chat_info: ChatInfo):
message = self.data.get('message')
re_chat = self.data.get('re_chat')
stream = self.data.get('stream')
@@ -211,14 +215,54 @@ class ChatMessageSerializer(serializers.Serializer):
pipeline_message.run(params)
return pipeline_message.context['chat_result']
- @staticmethod
- def re_open_chat(chat_id: str):
+ def chat_work_flow(self, chat_info: ChatInfo):
+ message = self.data.get('message')
+ re_chat = self.data.get('re_chat')
+ stream = self.data.get('stream')
+ client_id = self.data.get('client_id')
+ client_type = self.data.get('client_type')
+ work_flow_manage = WorkflowManage(Flow.new_instance(chat_info.work_flow_version.work_flow),
+ {'history_chat_record': chat_info.chat_record_list, 'question': message,
+ 'chat_id': chat_info.chat_id, 'chat_record_id': str(uuid.uuid1()),
+ 'stream': stream,
+ 're_chat': re_chat}, WorkFlowPostHandler(chat_info, client_id, client_type))
+ r = work_flow_manage.run()
+ return r
+
+ def chat(self):
+ super().is_valid(raise_exception=True)
+ chat_info = self.get_chat_info()
+ if chat_info.application.type == ApplicationTypeChoices.SIMPLE:
+ self.is_valid_application_simple(raise_exception=True, chat_info=chat_info),
+ return self.chat_simple(chat_info)
+ else:
+ self.is_valid_application_workflow(raise_exception=True)
+ return self.chat_work_flow(chat_info)
+
+ def get_chat_info(self):
+ self.is_valid(raise_exception=True)
+ chat_id = self.data.get('chat_id')
+ chat_info: ChatInfo = chat_cache.get(chat_id)
+ if chat_info is None:
+ chat_info: ChatInfo = self.re_open_chat(chat_id)
+ chat_cache.set(chat_id,
+ chat_info, timeout=60 * 30)
+ return chat_info
+
+ def re_open_chat(self, chat_id: str):
chat = QuerySet(Chat).filter(id=chat_id).first()
if chat is None:
raise AppApiException(500, "会话不存在")
application = QuerySet(Application).filter(id=chat.application_id).first()
if application is None:
raise AppApiException(500, "应用不存在")
+ if application.type == ApplicationTypeChoices.SIMPLE:
+ return self.re_open_chat_simple(chat_id, application)
+ else:
+ return self.re_open_chat_work_flow(chat_id, application)
+
+ @staticmethod
+ def re_open_chat_simple(chat_id, application):
model = QuerySet(Model).filter(id=application.model_id).first()
chat_model = None
if model is not None:
@@ -238,3 +282,11 @@ class ChatMessageSerializer(serializers.Serializer):
dataset_id__in=dataset_id_list,
is_active=False)]
return ChatInfo(chat_id, chat_model, dataset_id_list, exclude_document_id_list, application)
+
+ @staticmethod
+ def re_open_chat_work_flow(chat_id, application):
+ work_flow_version = QuerySet(WorkFlowVersion).filter(application_id=application.id).order_by(
+ '-create_time')[0:1].first()
+ if work_flow_version is None:
+ raise AppApiException(500, "应用未发布,请发布后再使用")
+ return ChatInfo(chat_id, None, [], [], application, work_flow_version)
diff --git a/apps/application/serializers/chat_serializers.py b/apps/application/serializers/chat_serializers.py
index a0d0b7690..402eff691 100644
--- a/apps/application/serializers/chat_serializers.py
+++ b/apps/application/serializers/chat_serializers.py
@@ -22,7 +22,9 @@ from django.db.models import QuerySet, Q
from django.http import HttpResponse
from rest_framework import serializers
-from application.models import Chat, Application, ApplicationDatasetMapping, VoteChoices, ChatRecord
+from application.flow.workflow_manage import Flow
+from application.models import Chat, Application, ApplicationDatasetMapping, VoteChoices, ChatRecord, WorkFlowVersion, \
+ ApplicationTypeChoices
from application.models.api_key_model import ApplicationAccessToken
from application.serializers.application_serializers import ModelDatasetAssociation, DatasetSettingSerializer, \
ModelSettingSerializer
@@ -45,6 +47,11 @@ from smartdoc.conf import PROJECT_DIR
chat_cache = caches['model_cache']
+class WorkFlowSerializers(serializers.Serializer):
+ nodes = serializers.ListSerializer(child=serializers.DictField(), error_messages=ErrMessage.uuid("节点"))
+ edges = serializers.ListSerializer(child=serializers.DictField(), error_messages=ErrMessage.uuid("连线"))
+
+
class ChatSerializers(serializers.Serializer):
class Operate(serializers.Serializer):
chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("对话id"))
@@ -207,6 +214,27 @@ class ChatSerializers(serializers.Serializer):
self.is_valid(raise_exception=True)
application_id = self.data.get('application_id')
application = QuerySet(Application).get(id=application_id)
+ if application.type == ApplicationTypeChoices.SIMPLE:
+ return self.open_simple(application)
+ else:
+ return self.open_work_flow(application)
+
+ def open_work_flow(self, application):
+ self.is_valid(raise_exception=True)
+ application_id = self.data.get('application_id')
+ chat_id = str(uuid.uuid1())
+ work_flow_version = QuerySet(WorkFlowVersion).filter(application_id=application_id).order_by(
+ '-create_time')[0:1].first()
+ if work_flow_version is None:
+ raise AppApiException(500, "应用未发布,请发布后再使用")
+ chat_cache.set(chat_id,
+ ChatInfo(chat_id, None, [],
+ [],
+ application, work_flow_version), timeout=60 * 30)
+ return chat_id
+
+ def open_simple(self, application):
+ application_id = self.data.get('application_id')
model = QuerySet(Model).filter(id=application.model_id).first()
dataset_id_list = [str(row.dataset_id) for row in
QuerySet(ApplicationDatasetMapping).filter(
@@ -229,6 +257,27 @@ class ChatSerializers(serializers.Serializer):
application), timeout=60 * 30)
return chat_id
+ class OpenWorkFlowChat(serializers.Serializer):
+ work_flow = WorkFlowSerializers(error_messages=ErrMessage.uuid("工作流"))
+
+ def open(self):
+ self.is_valid(raise_exception=True)
+ work_flow = self.data.get('work_flow')
+ Flow.new_instance(work_flow).is_valid()
+ chat_id = str(uuid.uuid1())
+ application = Application(id=None, dialogue_number=3, model=None,
+ dataset_setting={},
+ model_setting={},
+ problem_optimization=None,
+ type=ApplicationTypeChoices.WORK_FLOW
+ )
+ work_flow_version = WorkFlowVersion(work_flow=work_flow)
+ chat_cache.set(chat_id,
+ ChatInfo(chat_id, None, [],
+ [],
+ application, work_flow_version), timeout=60 * 30)
+ return chat_id
+
class OpenTempChat(serializers.Serializer):
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))
@@ -329,7 +378,7 @@ class ChatRecordSerializer(serializers.Serializer):
chat_info: ChatInfo = chat_cache.get(chat_id)
if chat_info is not None:
chat_record_list = [chat_record for chat_record in chat_info.chat_record_list if
- chat_record.id == uuid.UUID(chat_record_id)]
+ str(chat_record.id) == str(chat_record_id)]
if chat_record_list is not None and len(chat_record_list):
return chat_record_list[-1]
return QuerySet(ChatRecord).filter(id=chat_record_id, chat_id=chat_id).first()
@@ -377,7 +426,8 @@ class ChatRecordSerializer(serializers.Serializer):
'padding_problem_text': chat_record.details.get('problem_padding').get(
'padding_problem_text') if 'problem_padding' in chat_record.details else None,
'dataset_list': dataset_list,
- 'paragraph_list': paragraph_list
+ 'paragraph_list': paragraph_list,
+ 'execution_details': [chat_record.details[key] for key in chat_record.details]
}
def page(self, current_page: int, page_size: int, with_valid=True):
diff --git a/apps/application/swagger_api/application_api.py b/apps/application/swagger_api/application_api.py
index 4bacc5831..6e46931a6 100644
--- a/apps/application/swagger_api/application_api.py
+++ b/apps/application/swagger_api/application_api.py
@@ -161,7 +161,25 @@ class ApplicationApi(ApiMixin):
'problem_optimization': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="问题优化",
description="是否开启问题优化", default=True),
'icon': openapi.Schema(type=openapi.TYPE_STRING, title="icon",
- description="icon", default="/ui/favicon.ico")
+ description="icon", default="/ui/favicon.ico"),
+ 'work_flow': ApplicationApi.WorkFlow.get_request_body_api()
+
+ }
+ )
+
+ class WorkFlow(ApiMixin):
+ @staticmethod
+ def get_request_body_api():
+ return openapi.Schema(
+ type=openapi.TYPE_OBJECT,
+ required=[''],
+ properties={
+ 'nodes': openapi.Schema(type=openapi.TYPE_ARRAY, items=openapi.Schema(type=openapi.TYPE_OBJECT),
+ title="节点列表", description="节点列表",
+ default=[]),
+ 'edges': openapi.Schema(type=openapi.TYPE_ARRAY, items=openapi.Schema(type=openapi.TYPE_OBJECT),
+ title='连线列表', description="连线列表",
+ default={}),
}
)
@@ -219,6 +237,17 @@ class ApplicationApi(ApiMixin):
}
)
+ class Publish(ApiMixin):
+ @staticmethod
+ def get_request_body_api():
+ return openapi.Schema(
+ type=openapi.TYPE_OBJECT,
+ required=[],
+ properties={
+ 'work_flow': ApplicationApi.WorkFlow.get_request_body_api()
+ }
+ )
+
class Create(ApiMixin):
@staticmethod
def get_request_body_api():
@@ -239,7 +268,9 @@ class ApplicationApi(ApiMixin):
'dataset_setting': ApplicationApi.DatasetSetting.get_request_body_api(),
'model_setting': ApplicationApi.ModelSetting.get_request_body_api(),
'problem_optimization': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="问题优化",
- description="是否开启问题优化", default=True)
+ description="是否开启问题优化", default=True),
+ 'type': openapi.Schema(type=openapi.TYPE_STRING, title="应用类型",
+ description="应用类型 简易:SIMPLE|工作流:WORK_FLOW")
}
)
diff --git a/apps/application/swagger_api/chat_api.py b/apps/application/swagger_api/chat_api.py
index 9c56cd21e..2ff8f8ac5 100644
--- a/apps/application/swagger_api/chat_api.py
+++ b/apps/application/swagger_api/chat_api.py
@@ -82,6 +82,17 @@ class ChatApi(ApiMixin):
]
+ class OpenWorkFlowTemp(ApiMixin):
+ @staticmethod
+ def get_request_body_api():
+ return openapi.Schema(
+ type=openapi.TYPE_OBJECT,
+ required=[],
+ properties={
+ 'work_flow': ApplicationApi.WorkFlow.get_request_body_api()
+ }
+ )
+
class OpenTempChat(ApiMixin):
@staticmethod
def get_request_body_api():
diff --git a/apps/application/urls.py b/apps/application/urls.py
index 4fcbbbf0c..335205d37 100644
--- a/apps/application/urls.py
+++ b/apps/application/urls.py
@@ -8,6 +8,7 @@ urlpatterns = [
path('application/profile', views.Application.Profile.as_view(), name='application/profile'),
path('application/embed', views.Application.Embed.as_view()),
path('application/authentication', views.Application.Authentication.as_view()),
+ path('application/
+ {{ history.role }}:{{ history.content }}
+ {{ item.name }}
+ 参数输入
+ 检索内容
+ 检索结果
+ 判断结果
+ 角色设定 (System)
+ 历史记录
+ 本次对话
+ AI 回答
+ 回复内容
+ 错误日志
+
- 简单配置
+高级编排
+{{$t('views.application.applicationForm.dialogues.vectorSearch')}}
-+ {{ $t('views.application.applicationForm.dialogues.vectorSearch') }} +
+{{$t('views.application.applicationForm.dialogues.fullTextSearch')}}
-+ {{ $t('views.application.applicationForm.dialogues.fullTextSearch') }} +
+{{$t('views.application.applicationForm.dialogues.hybridSearch')}}
-+ {{ $t('views.application.applicationForm.dialogues.hybridSearch') }} +
+{{$t('views.application.applicationForm.dialogues.continueQuestioning')}}
++ {{ $t('views.application.applicationForm.dialogues.continueQuestioning') }} +
{{$t('views.application.applicationForm.dialogues.provideAnswer')}}
+{{ $t('views.application.applicationForm.dialogues.provideAnswer') }}