mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
Compare commits
156 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
847755b1c2 | ||
|
|
b57455d0ee | ||
|
|
2a257edff9 | ||
|
|
d47699331c | ||
|
|
90c64d77dd | ||
|
|
e1ada3ffe2 | ||
|
|
b62c79fda6 | ||
|
|
3d9e7dd4b1 | ||
|
|
8ff15865a7 | ||
|
|
48899d55d1 | ||
|
|
1cc4107bfe | ||
|
|
b13cd03706 | ||
|
|
69f024492b | ||
|
|
a9c46cd7e0 | ||
|
|
a9e9f5b085 | ||
|
|
e12b1fe14e | ||
|
|
7ce66a7bf3 | ||
|
|
decd3395db | ||
|
|
9d7a383348 | ||
|
|
8f7d91798b | ||
|
|
81a3af2c8b | ||
|
|
2ec0d22b14 | ||
|
|
27a77dc657 | ||
|
|
187e9c1e4e | ||
|
|
e5bab10824 | ||
|
|
347f4a0b03 | ||
|
|
289ebf42a6 | ||
|
|
71fdce08d7 | ||
|
|
adc5af9cef | ||
|
|
ce2ab322f6 | ||
|
|
a7e31b94c7 | ||
|
|
8498687794 | ||
|
|
190ca3e198 | ||
|
|
c1ddec1a61 | ||
|
|
1ba8077e95 | ||
|
|
9a42bd2302 | ||
|
|
35b662a52d | ||
|
|
a4faf52261 | ||
|
|
a30316d87a | ||
|
|
a4d10cbe3b | ||
|
|
ceccf9f1fa | ||
|
|
3964db20dc | ||
|
|
57ada0708f | ||
|
|
a1a92a833a | ||
|
|
5e0d8048f9 | ||
|
|
8903b35aec | ||
|
|
fa4f7e99fd | ||
|
|
b0630b3ddd | ||
|
|
b2bf69740c | ||
|
|
8cf66b9eca | ||
|
|
1db8577ca6 | ||
|
|
949e4dea9e | ||
|
|
c12988bc8a | ||
|
|
2728453f6c | ||
|
|
ccf43bbcd9 | ||
|
|
8d8de53e38 | ||
|
|
7faf556771 | ||
|
|
76ec8ad6f6 | ||
|
|
0cf05a76a0 | ||
|
|
357edbfbe0 | ||
|
|
0609a9afc8 | ||
|
|
96e59a018f | ||
|
|
79b2de8893 | ||
|
|
0c7cca035e | ||
|
|
00a3e5ddc3 | ||
|
|
a6533c0db7 | ||
|
|
704077d066 | ||
|
|
59ee0c1270 | ||
|
|
b37cc3ba1c | ||
|
|
dfe6d0a91b | ||
|
|
5813eedd4f | ||
|
|
363150380d | ||
|
|
47f9c04664 | ||
|
|
17cd88edda | ||
|
|
d33b620dc8 | ||
|
|
b95cb20704 | ||
|
|
4ff1944b60 | ||
|
|
1c6b0f8a86 | ||
|
|
d85801fe58 | ||
|
|
e79e7d505d | ||
|
|
33b1cd65b0 | ||
|
|
8d503c8bf8 | ||
|
|
df7f922013 | ||
|
|
a0541203e4 | ||
|
|
fb64731cd8 | ||
|
|
d960a18711 | ||
|
|
ee35cc21e9 | ||
|
|
e4a60daa17 | ||
|
|
7bcb770ee5 | ||
|
|
b5b09dc8b4 | ||
|
|
b1f6092620 | ||
|
|
c8441cfd73 | ||
|
|
7e4b147576 | ||
|
|
9cd082089a | ||
|
|
d4541e23f9 | ||
|
|
5e7e91cced | ||
|
|
d9787bb548 | ||
|
|
131b5b3bbe | ||
|
|
e58f95832b | ||
|
|
f24337d5f3 | ||
|
|
d6f1d25b59 | ||
|
|
0ec198fa43 | ||
|
|
77e96624ee | ||
|
|
5e02809db2 | ||
|
|
6fe001fcf8 | ||
|
|
f646102262 | ||
|
|
b97f4e16ba | ||
|
|
0c14306889 | ||
|
|
f1d043f67b | ||
|
|
c39a6e81d7 | ||
|
|
9c56c7e198 | ||
|
|
6484fef8ea | ||
|
|
8a194481ac | ||
|
|
072b817792 | ||
|
|
d32f7d36a6 | ||
|
|
54c9d4e725 | ||
|
|
d2637c3de2 | ||
|
|
2550324003 | ||
|
|
bf52dd8174 | ||
|
|
2ecec57d2f | ||
|
|
b5fda0e020 | ||
|
|
45a60cd9a7 | ||
|
|
2b82675853 | ||
|
|
1d3bf1ca73 | ||
|
|
04cb9c96fe | ||
|
|
45d8ac2eee | ||
|
|
eb60331c88 | ||
|
|
8fc9b0a22d | ||
|
|
ec5fd9e343 | ||
|
|
1e49939c38 | ||
|
|
5a7b23aa00 | ||
|
|
791505b7b8 | ||
|
|
da10649adb | ||
|
|
a025e3960d | ||
|
|
98ed348de9 | ||
|
|
b9dcc36b31 | ||
|
|
7a3d3844ae | ||
|
|
6ea2cf149a | ||
|
|
c30677d8b0 | ||
|
|
a1a2fb5628 | ||
|
|
f9cb0e24d6 | ||
|
|
2dc42183cb | ||
|
|
c781c11d26 | ||
|
|
e178cfe5c0 | ||
|
|
3b24373cd0 | ||
|
|
125ed8aa7a | ||
|
|
0b60a03e5d | ||
|
|
bb3f17ebfe | ||
|
|
1695710cbe | ||
|
|
ebe8506c67 | ||
|
|
3c3bd9884f | ||
|
|
f188383fea | ||
|
|
bbab359813 | ||
|
|
8381ca5287 | ||
|
|
f5282bf1e7 | ||
|
|
c0ffc0aaf5 |
|
|
@ -1,4 +1,2 @@
|
|||
.git*
|
||||
.idea*
|
||||
*.md
|
||||
.venv/
|
||||
|
|
@ -6,4 +6,12 @@ updates:
|
|||
interval: "weekly"
|
||||
timezone: "Asia/Shanghai"
|
||||
day: "friday"
|
||||
target-branch: "v3"
|
||||
target-branch: "v2"
|
||||
groups:
|
||||
python-dependencies:
|
||||
patterns:
|
||||
- "*"
|
||||
# ignore:
|
||||
# - dependency-name: "pymupdf"
|
||||
# versions: ["*"]
|
||||
|
||||
|
|
|
|||
|
|
@ -14,19 +14,33 @@ on:
|
|||
- linux/amd64,linux/arm64
|
||||
jobs:
|
||||
build-and-push-python-pg-to-ghcr:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Check Disk Space
|
||||
run: df -h
|
||||
- name: Free Disk Space (Ubuntu)
|
||||
uses: jlumbroso/free-disk-space@main
|
||||
with:
|
||||
tool-cache: true
|
||||
android: true
|
||||
dotnet: true
|
||||
haskell: true
|
||||
large-packages: true
|
||||
docker-images: true
|
||||
swap-storage: true
|
||||
- name: Check Disk Space
|
||||
run: df -h
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.ref_name }}
|
||||
ref: main
|
||||
- name: Prepare
|
||||
id: prepare
|
||||
run: |
|
||||
DOCKER_IMAGE=ghcr.io/1panel-dev/maxkb-base
|
||||
DOCKER_IMAGE=ghcr.io/1panel-dev/maxkb-python-pg
|
||||
DOCKER_PLATFORMS=${{ github.event.inputs.architecture }}
|
||||
TAG_NAME=python3.11-pg17.6
|
||||
DOCKER_IMAGE_TAGS="--tag ${DOCKER_IMAGE}:${TAG_NAME}"
|
||||
TAG_NAME=python3.11-pg15.8
|
||||
DOCKER_IMAGE_TAGS="--tag ${DOCKER_IMAGE}:${TAG_NAME} --tag ${DOCKER_IMAGE}:latest"
|
||||
echo ::set-output name=docker_image::${DOCKER_IMAGE}
|
||||
echo ::set-output name=version::${TAG_NAME}
|
||||
echo ::set-output name=buildx_args::--platform ${DOCKER_PLATFORMS} --no-cache \
|
||||
|
|
@ -37,7 +51,8 @@ jobs:
|
|||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
with:
|
||||
cache-image: false
|
||||
# Until https://github.com/tonistiigi/binfmt/issues/215
|
||||
image: tonistiigi/binfmt:qemu-v7.0.0-28
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Login to GitHub Container Registry
|
||||
|
|
@ -48,4 +63,4 @@ jobs:
|
|||
password: ${{ secrets.GH_TOKEN }}
|
||||
- name: Docker Buildx (build-and-push)
|
||||
run: |
|
||||
docker buildx build --output "type=image,push=true" ${{ steps.prepare.outputs.buildx_args }} -f installer/Dockerfile-base
|
||||
docker buildx build --output "type=image,push=true" ${{ steps.prepare.outputs.buildx_args }} -f installer/Dockerfile-python-pg
|
||||
|
|
@ -5,7 +5,7 @@ on:
|
|||
inputs:
|
||||
dockerImageTag:
|
||||
description: 'Docker Image Tag'
|
||||
default: 'v2.0.3'
|
||||
default: 'v1.0.1'
|
||||
required: true
|
||||
architecture:
|
||||
description: 'Architecture'
|
||||
|
|
@ -19,12 +19,26 @@ on:
|
|||
|
||||
jobs:
|
||||
build-and-push-vector-model-to-ghcr:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Check Disk Space
|
||||
run: df -h
|
||||
- name: Free Disk Space (Ubuntu)
|
||||
uses: jlumbroso/free-disk-space@main
|
||||
with:
|
||||
tool-cache: true
|
||||
android: true
|
||||
dotnet: true
|
||||
haskell: true
|
||||
large-packages: true
|
||||
docker-images: true
|
||||
swap-storage: true
|
||||
- name: Check Disk Space
|
||||
run: df -h
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.ref_name }}
|
||||
ref: main
|
||||
- name: Prepare
|
||||
id: prepare
|
||||
run: |
|
||||
|
|
@ -42,7 +56,8 @@ jobs:
|
|||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
with:
|
||||
cache-image: false
|
||||
# Until https://github.com/tonistiigi/binfmt/issues/215
|
||||
image: tonistiigi/binfmt:qemu-v7.0.0-28
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Login to GitHub Container Registry
|
||||
|
|
|
|||
|
|
@ -1,13 +1,13 @@
|
|||
name: build-and-push
|
||||
|
||||
run-name: 构建镜像并推送仓库 ${{ github.event.inputs.dockerImageTag }} (${{ github.event.inputs.registry }}) (${{ github.event.inputs.architecture }})
|
||||
run-name: 构建镜像并推送仓库 ${{ github.event.inputs.dockerImageTag }} (${{ github.event.inputs.registry }})
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
dockerImageTag:
|
||||
description: 'Image Tag'
|
||||
default: 'v2.3.0-dev'
|
||||
default: 'v1.10.7-dev'
|
||||
required: true
|
||||
dockerImageTagWithLatest:
|
||||
description: '是否发布latest tag(正式发版时选择,测试版本切勿选择)'
|
||||
|
|
@ -38,10 +38,20 @@ jobs:
|
|||
if: ${{ contains(github.event.inputs.registry, 'fit2cloud') }}
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Clear Work Dir
|
||||
run: |
|
||||
ls -la
|
||||
rm -rf -- ./* ./.??*
|
||||
- name: Check Disk Space
|
||||
run: df -h
|
||||
- name: Free Disk Space (Ubuntu)
|
||||
uses: jlumbroso/free-disk-space@main
|
||||
with:
|
||||
tool-cache: true
|
||||
android: true
|
||||
dotnet: true
|
||||
haskell: true
|
||||
large-packages: true
|
||||
docker-images: true
|
||||
swap-storage: true
|
||||
- name: Check Disk Space
|
||||
run: df -h
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
|
|
@ -54,17 +64,15 @@ jobs:
|
|||
TAG_NAME=${{ github.event.inputs.dockerImageTag }}
|
||||
TAG_NAME_WITH_LATEST=${{ github.event.inputs.dockerImageTagWithLatest }}
|
||||
if [[ ${TAG_NAME_WITH_LATEST} == 'true' ]]; then
|
||||
DOCKER_IMAGE_TAGS="--tag ${DOCKER_IMAGE}:${TAG_NAME} --tag ${DOCKER_IMAGE}:${TAG_NAME%%.*} --tag ${DOCKER_IMAGE}:latest"
|
||||
DOCKER_IMAGE_TAGS="--tag ${DOCKER_IMAGE}:${TAG_NAME} --tag ${DOCKER_IMAGE}:${TAG_NAME%%.*}"
|
||||
else
|
||||
DOCKER_IMAGE_TAGS="--tag ${DOCKER_IMAGE}:${TAG_NAME}"
|
||||
fi
|
||||
echo "buildx_args=--platform ${DOCKER_PLATFORMS} --memory-swap -1 \
|
||||
echo ::set-output name=buildx_args::--platform ${DOCKER_PLATFORMS} --memory-swap -1 \
|
||||
--build-arg DOCKER_IMAGE_TAG=${{ github.event.inputs.dockerImageTag }} --build-arg BUILD_AT=$(TZ=Asia/Shanghai date +'%Y-%m-%dT%H:%M') --build-arg GITHUB_COMMIT=`git rev-parse --short HEAD` --no-cache \
|
||||
${DOCKER_IMAGE_TAGS} ." >> $GITHUB_OUTPUT
|
||||
${DOCKER_IMAGE_TAGS} .
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
with:
|
||||
cache-image: false
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Login to GitHub Container Registry
|
||||
|
|
@ -79,12 +87,6 @@ jobs:
|
|||
registry: ${{ secrets.FIT2CLOUD_REGISTRY_HOST }}
|
||||
username: ${{ secrets.FIT2CLOUD_REGISTRY_USERNAME }}
|
||||
password: ${{ secrets.FIT2CLOUD_REGISTRY_PASSWORD }}
|
||||
- name: Build Web
|
||||
run: |
|
||||
docker buildx build --no-cache --target web-build --output type=local,dest=./web-build-output . -f installer/Dockerfile
|
||||
rm -rf ./ui
|
||||
cp -r ./web-build-output/ui ./
|
||||
rm -rf ./web-build-output
|
||||
- name: Docker Buildx (build-and-push)
|
||||
run: |
|
||||
sudo sync && echo 3 | sudo tee /proc/sys/vm/drop_caches && free -m
|
||||
|
|
@ -94,10 +96,20 @@ jobs:
|
|||
if: ${{ contains(github.event.inputs.registry, 'dockerhub') }}
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Clear Work Dir
|
||||
run: |
|
||||
ls -la
|
||||
rm -rf -- ./* ./.??*
|
||||
- name: Check Disk Space
|
||||
run: df -h
|
||||
- name: Free Disk Space (Ubuntu)
|
||||
uses: jlumbroso/free-disk-space@main
|
||||
with:
|
||||
tool-cache: true
|
||||
android: true
|
||||
dotnet: true
|
||||
haskell: true
|
||||
large-packages: true
|
||||
docker-images: true
|
||||
swap-storage: true
|
||||
- name: Check Disk Space
|
||||
run: df -h
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
|
|
@ -110,17 +122,15 @@ jobs:
|
|||
TAG_NAME=${{ github.event.inputs.dockerImageTag }}
|
||||
TAG_NAME_WITH_LATEST=${{ github.event.inputs.dockerImageTagWithLatest }}
|
||||
if [[ ${TAG_NAME_WITH_LATEST} == 'true' ]]; then
|
||||
DOCKER_IMAGE_TAGS="--tag ${DOCKER_IMAGE}:${TAG_NAME} --tag ${DOCKER_IMAGE}:${TAG_NAME%%.*} --tag ${DOCKER_IMAGE}:latest"
|
||||
DOCKER_IMAGE_TAGS="--tag ${DOCKER_IMAGE}:${TAG_NAME} --tag ${DOCKER_IMAGE}:${TAG_NAME%%.*}"
|
||||
else
|
||||
DOCKER_IMAGE_TAGS="--tag ${DOCKER_IMAGE}:${TAG_NAME}"
|
||||
fi
|
||||
echo "buildx_args=--platform ${DOCKER_PLATFORMS} --memory-swap -1 \
|
||||
echo ::set-output name=buildx_args::--platform ${DOCKER_PLATFORMS} --memory-swap -1 \
|
||||
--build-arg DOCKER_IMAGE_TAG=${{ github.event.inputs.dockerImageTag }} --build-arg BUILD_AT=$(TZ=Asia/Shanghai date +'%Y-%m-%dT%H:%M') --build-arg GITHUB_COMMIT=`git rev-parse --short HEAD` --no-cache \
|
||||
${DOCKER_IMAGE_TAGS} ." >> $GITHUB_OUTPUT
|
||||
${DOCKER_IMAGE_TAGS} .
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
with:
|
||||
cache-image: false
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Login to GitHub Container Registry
|
||||
|
|
@ -134,12 +144,6 @@ jobs:
|
|||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
- name: Build Web
|
||||
run: |
|
||||
docker buildx build --no-cache --target web-build --output type=local,dest=./web-build-output . -f installer/Dockerfile
|
||||
rm -rf ./ui
|
||||
cp -r ./web-build-output/ui ./
|
||||
rm -rf ./web-build-output
|
||||
- name: Docker Buildx (build-and-push)
|
||||
run: |
|
||||
sudo sync && echo 3 | sudo tee /proc/sys/vm/drop_caches && free -m
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
name: Typos Check
|
||||
on:
|
||||
workflow_dispatch:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened]
|
||||
|
||||
|
|
@ -11,19 +12,7 @@ jobs:
|
|||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout Actions Repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.ref_name }}
|
||||
- name: Create config file
|
||||
run: |
|
||||
cat <<EOF > typo-check-config.toml
|
||||
[files]
|
||||
extend-exclude = [
|
||||
"**/*_svg",
|
||||
"**/migrations/**"
|
||||
]
|
||||
EOF
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Check spelling
|
||||
uses: crate-ci/typos@master
|
||||
with:
|
||||
config: ./typo-check-config.toml
|
||||
|
|
|
|||
|
|
@ -137,9 +137,9 @@ celerybeat.pid
|
|||
# Environments
|
||||
.env
|
||||
.venv
|
||||
# env/
|
||||
env/
|
||||
venv/
|
||||
# ENV/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
|
|
@ -183,10 +183,5 @@ apps/xpack
|
|||
data
|
||||
.dev
|
||||
poetry.lock
|
||||
uv.lock
|
||||
apps/models_provider/impl/*/icon/
|
||||
apps/models_provider/impl/tencent_model_provider/credential/stt.py
|
||||
apps/models_provider/impl/tencent_model_provider/model/stt.py
|
||||
tmp/
|
||||
config.yml
|
||||
.SANDBOX_BANNED_HOSTS
|
||||
apps/setting/models_provider/impl/*/icon/
|
||||
tmp/
|
||||
|
|
@ -0,0 +1,4 @@
|
|||
[files]
|
||||
extend-exclude = [
|
||||
'apps/setting/models_provider/impl/*/icon/*'
|
||||
]
|
||||
|
|
@ -27,4 +27,4 @@ When reporting issues, always include:
|
|||
* Snapshots or log files if needed
|
||||
|
||||
Because the issues are open to the public, when submitting files, be sure to remove any sensitive information, e.g. user name, password, IP address, and company name. You can
|
||||
replace those parts with "REDACTED" or other strings like "****".
|
||||
replace those parts with "REDACTED" or other strings like "****".
|
||||
|
|
|
|||
2
LICENSE
2
LICENSE
|
|
@ -671,4 +671,4 @@ into proprietary programs. If your program is a subroutine library, you
|
|||
may consider it more useful to permit linking proprietary applications with
|
||||
the library. If this is what you want to do, use the GNU Lesser General
|
||||
Public License instead of this License. But first, please read
|
||||
<https://www.gnu.org/licenses/why-not-lgpl.html>.
|
||||
<https://www.gnu.org/licenses/why-not-lgpl.html>.
|
||||
|
|
|
|||
75
README.md
75
README.md
|
|
@ -14,7 +14,7 @@
|
|||
MaxKB = Max Knowledge Brain, it is an open-source platform for building enterprise-grade agents. MaxKB integrates Retrieval-Augmented Generation (RAG) pipelines, supports robust workflows, and provides advanced MCP tool-use capabilities. MaxKB is widely applied in scenarios such as intelligent customer service, corporate internal knowledge bases, academic research, and education.
|
||||
|
||||
- **RAG Pipeline**: Supports direct uploading of documents / automatic crawling of online documents, with features for automatic text splitting, vectorization. This effectively reduces hallucinations in large models, providing a superior smart Q&A interaction experience.
|
||||
- **Agentic Workflow**: Equipped with a powerful workflow engine, function library and MCP tool-use, enabling the orchestration of AI processes to meet the needs of complex business scenarios.
|
||||
- **Agentic Workflow**: Equipped with a powerful workflow engine, function library and MCP tool-use, enabling the orchestration of AI processes to meet the needs of complex business scenarios.
|
||||
- **Seamless Integration**: Facilitates zero-coding rapid integration into third-party business systems, quickly equipping existing systems with intelligent Q&A capabilities to enhance user satisfaction.
|
||||
- **Model-Agnostic**: Supports various large models, including private models (such as DeepSeek, Llama, Qwen, etc.) and public models (like OpenAI, Claude, Gemini, etc.).
|
||||
- **Multi Modal**: Native support for input and output text, image, audio and video.
|
||||
|
|
@ -24,7 +24,7 @@ MaxKB = Max Knowledge Brain, it is an open-source platform for building enterpri
|
|||
Execute the script below to start a MaxKB container using Docker:
|
||||
|
||||
```bash
|
||||
docker run -d --name=maxkb --restart=always -p 8080:8080 -v ~/.maxkb:/opt/maxkb 1panel/maxkb
|
||||
docker run -d --name=maxkb --restart=always -p 8080:8080 -v ~/.maxkb:/var/lib/postgresql/data -v ~/.python-packages:/opt/maxkb/app/sandbox/python-packages 1panel/maxkb
|
||||
```
|
||||
|
||||
Access MaxKB web interface at `http://your_server_ip:8080` with default admin credentials:
|
||||
|
|
@ -32,18 +32,18 @@ Access MaxKB web interface at `http://your_server_ip:8080` with default admin cr
|
|||
- username: admin
|
||||
- password: MaxKB@123..
|
||||
|
||||
中国用户如遇到 Docker 镜像 Pull 失败问题,请参照该 [离线安装文档](https://maxkb.cn/docs/v2/installation/offline_installtion/) 进行安装。
|
||||
中国用户如遇到 Docker 镜像 Pull 失败问题,请参照该 [离线安装文档](https://maxkb.cn/docs/installation/offline_installtion/) 进行安装。
|
||||
|
||||
## Screenshots
|
||||
|
||||
<table style="border-collapse: collapse; border: 1px solid black;">
|
||||
<tr>
|
||||
<td style="padding: 5px;background-color:#fff;"><img src= "https://github.com/user-attachments/assets/eb285512-a66a-4752-8941-c65ed1592238" alt="MaxKB Demo1" /></td>
|
||||
<td style="padding: 5px;background-color:#fff;"><img src= "https://github.com/user-attachments/assets/f732f1f5-472c-4fd2-93c1-a277eda83d04" alt="MaxKB Demo2" /></td>
|
||||
<td style="padding: 5px;background-color:#fff;"><img src= "https://maxkb.hk/images/overview.png" alt="MaxKB Demo1" /></td>
|
||||
<td style="padding: 5px;background-color:#fff;"><img src= "https://maxkb.hk/images/screenshot-models.png" alt="MaxKB Demo2" /></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="padding: 5px;background-color:#fff;"><img src= "https://github.com/user-attachments/assets/c927474a-9a23-4830-822f-5db26025c9b2" alt="MaxKB Demo3" /></td>
|
||||
<td style="padding: 5px;background-color:#fff;"><img src= "https://github.com/user-attachments/assets/e6268996-a46d-4e58-9f30-31139df78ad2" alt="MaxKB Demo4" /></td>
|
||||
<td style="padding: 5px;background-color:#fff;"><img src= "https://maxkb.hk/images/screenshot-knowledge.png" alt="MaxKB Demo3" /></td>
|
||||
<td style="padding: 5px;background-color:#fff;"><img src= "https://maxkb.hk/images/screenshot-function.png" alt="MaxKB Demo4" /></td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
|
|
@ -54,6 +54,67 @@ Access MaxKB web interface at `http://your_server_ip:8080` with default admin cr
|
|||
- LLM Framework:[LangChain](https://www.langchain.com/)
|
||||
- Database:[PostgreSQL + pgvector](https://www.postgresql.org/)
|
||||
|
||||
## Feature Comparison
|
||||
|
||||
<table style="width: 100%;">
|
||||
<tr>
|
||||
<th align="center">Feature</th>
|
||||
<th align="center">LangChain</th>
|
||||
<th align="center">Dify.AI</th>
|
||||
<th align="center">Flowise</th>
|
||||
<th align="center">MaxKB <br>(Built upon LangChain)</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">Supported LLMs</td>
|
||||
<td align="center">Rich Variety</td>
|
||||
<td align="center">Rich Variety</td>
|
||||
<td align="center">Rich Variety</td>
|
||||
<td align="center">Rich Variety</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">RAG Engine</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">Agent</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">❌</td>
|
||||
<td align="center">✅</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">Workflow</td>
|
||||
<td align="center">❌</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">Observability</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">❌</td>
|
||||
<td align="center">✅</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">SSO/Access control</td>
|
||||
<td align="center">❌</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">❌</td>
|
||||
<td align="center">✅ (Pro)</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">On-premise Deployment</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
## Star History
|
||||
|
||||
[](https://star-history.com/#1Panel-dev/MaxKB&Date)
|
||||
|
|
|
|||
20
README_CN.md
20
README_CN.md
|
|
@ -14,12 +14,12 @@
|
|||
</p>
|
||||
<hr/>
|
||||
|
||||
MaxKB = Max Knowledge Brain,是一个强大易用的企业级智能体平台,致力于解决企业 AI 落地面临的技术门槛高、部署成本高、迭代周期长等问题,助力企业在人工智能时代赢得先机。秉承“开箱即用,伴随成长”的设计理念,MaxKB 支持企业快速接入主流大模型,高效构建专属知识库,并提供从基础问答(RAG)、复杂流程自动化(工作流)到智能体(Agent)的渐进式升级路径,全面赋能智能客服、智能办公助手等多种应用场景。
|
||||
MaxKB = Max Knowledge Brain,是一款强大易用的企业级智能体平台,支持 RAG 检索增强生成、工作流编排、MCP 工具调用能力。MaxKB 支持对接各种主流大语言模型,广泛应用于智能客服、企业内部知识库问答、员工助手、学术研究与教育等场景。
|
||||
|
||||
- **RAG 检索增强生成**:高效搭建本地 AI 知识库,支持直接上传文档 / 自动爬取在线文档,支持文本自动拆分、向量化,有效减少大模型幻觉,提升问答效果;
|
||||
- **灵活编排**:内置强大的工作流引擎、函数库和 MCP 工具调用能力,支持编排 AI 工作过程,满足复杂业务场景下的需求;
|
||||
- **无缝嵌入**:支持零编码快速嵌入到第三方业务系统,让已有系统快速拥有智能问答能力,提高用户满意度;
|
||||
- **模型中立**:支持对接各种大模型,包括本地私有大模型(DeepSeek R1 / Qwen 3 等)、国内公共大模型(通义千问 / 腾讯混元 / 字节豆包 / 百度千帆 / 智谱 AI / Kimi 等)和国外公共大模型(OpenAI / Claude / Gemini 等)。
|
||||
- **模型中立**:支持对接各种大模型,包括本地私有大模型(DeepSeek R1 / Llama 3 / Qwen 2 等)、国内公共大模型(通义千问 / 腾讯混元 / 字节豆包 / 百度千帆 / 智谱 AI / Kimi 等)和国外公共大模型(OpenAI / Claude / Gemini 等)。
|
||||
|
||||
MaxKB 三分钟视频介绍:https://www.bilibili.com/video/BV18JypYeEkj/
|
||||
|
||||
|
|
@ -27,10 +27,10 @@ MaxKB 三分钟视频介绍:https://www.bilibili.com/video/BV18JypYeEkj/
|
|||
|
||||
```
|
||||
# Linux 机器
|
||||
docker run -d --name=maxkb --restart=always -p 8080:8080 -v ~/.maxkb:/opt/maxkb registry.fit2cloud.com/maxkb/maxkb
|
||||
docker run -d --name=maxkb --restart=always -p 8080:8080 -v ~/.maxkb:/var/lib/postgresql/data -v ~/.python-packages:/opt/maxkb/app/sandbox/python-packages registry.fit2cloud.com/maxkb/maxkb
|
||||
|
||||
# Windows 机器
|
||||
docker run -d --name=maxkb --restart=always -p 8080:8080 -v C:/maxkb:/opt/maxkb registry.fit2cloud.com/maxkb/maxkb
|
||||
docker run -d --name=maxkb --restart=always -p 8080:8080 -v C:/maxkb:/var/lib/postgresql/data -v C:/python-packages:/opt/maxkb/app/sandbox/python-packages registry.fit2cloud.com/maxkb/maxkb
|
||||
|
||||
# 用户名: admin
|
||||
# 密码: MaxKB@123..
|
||||
|
|
@ -38,8 +38,8 @@ docker run -d --name=maxkb --restart=always -p 8080:8080 -v C:/maxkb:/opt/maxkb
|
|||
|
||||
- 你也可以通过 [1Panel 应用商店](https://apps.fit2cloud.com/1panel) 快速部署 MaxKB;
|
||||
- 如果是内网环境,推荐使用 [离线安装包](https://community.fit2cloud.com/#/products/maxkb/downloads) 进行安装部署;
|
||||
- MaxKB 不同产品产品版本的对比请参见:[MaxKB 产品版本对比](https://maxkb.cn/price);
|
||||
- 如果您需要向团队介绍 MaxKB,可以使用这个 [官方 PPT 材料](https://fit2cloud.com/maxkb/download/introduce-maxkb_202507.pdf)。
|
||||
- MaxKB 产品版本分为社区版和专业版,详情请参见:[MaxKB 产品版本对比](https://maxkb.cn/pricing.html);
|
||||
- 如果您需要向团队介绍 MaxKB,可以使用这个 [官方 PPT 材料](https://maxkb.cn/download/introduce-maxkb_202503.pdf)。
|
||||
|
||||
如你有更多问题,可以查看使用手册,或者通过论坛与我们交流。
|
||||
|
||||
|
|
@ -54,12 +54,12 @@ docker run -d --name=maxkb --restart=always -p 8080:8080 -v C:/maxkb:/opt/maxkb
|
|||
|
||||
<table style="border-collapse: collapse; border: 1px solid black;">
|
||||
<tr>
|
||||
<td style="padding: 5px;background-color:#fff;"><img src= "https://github.com/user-attachments/assets/eb285512-a66a-4752-8941-c65ed1592238" alt="MaxKB Demo1" /></td>
|
||||
<td style="padding: 5px;background-color:#fff;"><img src= "https://github.com/user-attachments/assets/f732f1f5-472c-4fd2-93c1-a277eda83d04" alt="MaxKB Demo2" /></td>
|
||||
<td style="padding: 5px;background-color:#fff;"><img src= "https://github.com/1Panel-dev/MaxKB/assets/52996290/d87395fa-a8d7-401c-82bf-c6e475d10ae9" alt="MaxKB Demo1" /></td>
|
||||
<td style="padding: 5px;background-color:#fff;"><img src= "https://github.com/1Panel-dev/MaxKB/assets/52996290/47c35ee4-3a3b-4bd4-9f4f-ee20788b2b9a" alt="MaxKB Demo2" /></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="padding: 5px;background-color:#fff;"><img src= "https://github.com/user-attachments/assets/c927474a-9a23-4830-822f-5db26025c9b2" alt="MaxKB Demo3" /></td>
|
||||
<td style="padding: 5px;background-color:#fff;"><img src= "https://github.com/user-attachments/assets/e6268996-a46d-4e58-9f30-31139df78ad2" alt="MaxKB Demo4" /></td>
|
||||
<td style="padding: 5px;background-color:#fff;"><img src= "https://github.com/user-attachments/assets/9a1043cb-fa62-4f71-b9a3-0b46fa59a70e" alt="MaxKB Demo3" /></td>
|
||||
<td style="padding: 5px;background-color:#fff;"><img src= "https://github.com/user-attachments/assets/3407ce9a-779c-4eb4-858e-9441a2ddc664" alt="MaxKB Demo4" /></td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
|
|
|
|||
|
|
@ -36,4 +36,4 @@
|
|||
- [MaxKB 应用案例:重磅!陕西广电网络“秦岭云”平台实现DeepSeek本地化部署](https://mp.weixin.qq.com/s/ZKmEU_wWShK1YDomKJHQeA)
|
||||
- [MaxKB 应用案例:粤海集团完成DeepSeek私有化部署,助力集团智能化管理](https://mp.weixin.qq.com/s/2JbVp0-kr9Hfp-0whH4cvg)
|
||||
- [MaxKB 应用案例:建筑材料工业信息中心完成DeepSeek本地化部署,推动行业数智化转型新发展](https://mp.weixin.qq.com/s/HThGSnND3qDF8ySEqiM4jw)
|
||||
- [MaxKB 应用案例:一起DeepSeek!福建设计以AI大模型开启新篇章](https://mp.weixin.qq.com/s/m67e-H7iQBg3d24NM82UjA)
|
||||
- [MaxKB 应用案例:一起DeepSeek!福建设计以AI大模型开启新篇章](https://mp.weixin.qq.com/s/m67e-H7iQBg3d24NM82UjA)
|
||||
|
|
|
|||
|
|
@ -1,35 +0,0 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎虎
|
||||
@file: application_access_token.py
|
||||
@date:2025/6/9 17:46
|
||||
@desc:
|
||||
"""
|
||||
from drf_spectacular.types import OpenApiTypes
|
||||
from drf_spectacular.utils import OpenApiParameter
|
||||
|
||||
from application.serializers.application_access_token import AccessTokenEditSerializer
|
||||
from common.mixins.api_mixin import APIMixin
|
||||
|
||||
|
||||
class ApplicationAccessTokenAPI(APIMixin):
|
||||
@staticmethod
|
||||
def get_parameters():
|
||||
return [OpenApiParameter(
|
||||
name="workspace_id",
|
||||
description="工作空间id",
|
||||
type=OpenApiTypes.STR,
|
||||
location='path',
|
||||
required=True,
|
||||
), OpenApiParameter(
|
||||
name="application_id",
|
||||
description="应用id",
|
||||
type=OpenApiTypes.STR,
|
||||
location='path',
|
||||
required=True,
|
||||
)]
|
||||
|
||||
@staticmethod
|
||||
def get_request():
|
||||
return AccessTokenEditSerializer
|
||||
|
|
@ -1,218 +0,0 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎虎
|
||||
@file: application.py
|
||||
@date:2025/5/26 16:59
|
||||
@desc:
|
||||
"""
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from drf_spectacular.types import OpenApiTypes
|
||||
from drf_spectacular.utils import OpenApiParameter
|
||||
from rest_framework import serializers
|
||||
|
||||
from application.serializers.application import ApplicationCreateSerializer, ApplicationListResponse, \
|
||||
ApplicationImportRequest, ApplicationEditSerializer, TextToSpeechRequest, SpeechToTextRequest, PlayDemoTextRequest
|
||||
from common.mixins.api_mixin import APIMixin
|
||||
from common.result import ResultSerializer, ResultPageSerializer, DefaultResultSerializer
|
||||
|
||||
|
||||
class ApplicationCreateRequest(ApplicationCreateSerializer.SimplateRequest):
|
||||
work_flow = serializers.DictField(required=True, label=_("Workflow Objects"))
|
||||
|
||||
|
||||
class ApplicationCreateResponse(ResultSerializer):
|
||||
def get_data(self):
|
||||
return ApplicationCreateSerializer.ApplicationResponse()
|
||||
|
||||
|
||||
class ApplicationListResult(ResultSerializer):
|
||||
def get_data(self):
|
||||
return ApplicationListResponse(many=True)
|
||||
|
||||
|
||||
class ApplicationPageResult(ResultPageSerializer):
|
||||
def get_data(self):
|
||||
return ApplicationListResponse(many=True)
|
||||
|
||||
|
||||
class ApplicationQueryAPI(APIMixin):
|
||||
@staticmethod
|
||||
def get_parameters():
|
||||
return [
|
||||
OpenApiParameter(
|
||||
name="workspace_id",
|
||||
description="工作空间id",
|
||||
type=OpenApiTypes.STR,
|
||||
location='path',
|
||||
required=True,
|
||||
),
|
||||
OpenApiParameter(
|
||||
name="current_page",
|
||||
description=_("Current page"),
|
||||
type=OpenApiTypes.INT,
|
||||
location='path',
|
||||
required=True,
|
||||
),
|
||||
OpenApiParameter(
|
||||
name="page_size",
|
||||
description=_("Page size"),
|
||||
type=OpenApiTypes.INT,
|
||||
location='path',
|
||||
required=True,
|
||||
),
|
||||
OpenApiParameter(
|
||||
name="folder_id",
|
||||
description=_("folder id"),
|
||||
type=OpenApiTypes.STR,
|
||||
location='query',
|
||||
required=False,
|
||||
),
|
||||
OpenApiParameter(
|
||||
name="name",
|
||||
description=_("Application Name"),
|
||||
type=OpenApiTypes.STR,
|
||||
location='query',
|
||||
required=False,
|
||||
),
|
||||
OpenApiParameter(
|
||||
name="desc",
|
||||
description=_("Application Description"),
|
||||
type=OpenApiTypes.STR,
|
||||
location='query',
|
||||
required=False,
|
||||
),
|
||||
OpenApiParameter(
|
||||
name="user_id",
|
||||
description=_("User ID"),
|
||||
type=OpenApiTypes.STR,
|
||||
location='query',
|
||||
required=False,
|
||||
),
|
||||
OpenApiParameter(
|
||||
name="publish_status",
|
||||
description=_("Publish status") + '(published|unpublished)',
|
||||
type=OpenApiTypes.STR,
|
||||
location='query',
|
||||
required=False,
|
||||
)
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_response():
|
||||
return ApplicationListResult
|
||||
|
||||
@staticmethod
|
||||
def get_page_response():
|
||||
return ApplicationPageResult
|
||||
|
||||
|
||||
class ApplicationCreateAPI(APIMixin):
|
||||
@staticmethod
|
||||
def get_parameters():
|
||||
return [
|
||||
OpenApiParameter(
|
||||
name="workspace_id",
|
||||
description="工作空间id",
|
||||
type=OpenApiTypes.STR,
|
||||
location='path',
|
||||
required=True,
|
||||
)
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_request():
|
||||
return ApplicationCreateRequest
|
||||
|
||||
@staticmethod
|
||||
def get_response():
|
||||
return ApplicationCreateResponse
|
||||
|
||||
|
||||
class ApplicationImportAPI(APIMixin):
|
||||
@staticmethod
|
||||
def get_parameters():
|
||||
ApplicationCreateAPI.get_parameters()
|
||||
|
||||
@staticmethod
|
||||
def get_request():
|
||||
return ApplicationImportRequest
|
||||
|
||||
|
||||
class ApplicationOperateAPI(APIMixin):
|
||||
@staticmethod
|
||||
def get_parameters():
|
||||
return [
|
||||
OpenApiParameter(
|
||||
name="workspace_id",
|
||||
description="工作空间id",
|
||||
type=OpenApiTypes.STR,
|
||||
location='path',
|
||||
required=True,
|
||||
),
|
||||
OpenApiParameter(
|
||||
name="application_id",
|
||||
description="应用id",
|
||||
type=OpenApiTypes.STR,
|
||||
location='path',
|
||||
required=True,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
class ApplicationExportAPI(APIMixin):
|
||||
@staticmethod
|
||||
def get_parameters():
|
||||
return ApplicationOperateAPI.get_parameters()
|
||||
|
||||
@staticmethod
|
||||
def get_response():
|
||||
return DefaultResultSerializer
|
||||
|
||||
|
||||
class ApplicationEditAPI(APIMixin):
|
||||
@staticmethod
|
||||
def get_request():
|
||||
return ApplicationEditSerializer
|
||||
|
||||
|
||||
class TextToSpeechAPI(APIMixin):
|
||||
@staticmethod
|
||||
def get_parameters():
|
||||
return ApplicationOperateAPI.get_parameters()
|
||||
|
||||
@staticmethod
|
||||
def get_request():
|
||||
return TextToSpeechRequest
|
||||
|
||||
@staticmethod
|
||||
def get_response():
|
||||
return DefaultResultSerializer
|
||||
|
||||
|
||||
class SpeechToTextAPI(APIMixin):
|
||||
@staticmethod
|
||||
def get_parameters():
|
||||
return ApplicationOperateAPI.get_parameters()
|
||||
|
||||
@staticmethod
|
||||
def get_request():
|
||||
return SpeechToTextRequest
|
||||
|
||||
@staticmethod
|
||||
def get_response():
|
||||
return DefaultResultSerializer
|
||||
|
||||
|
||||
class PlayDemoTextAPI(APIMixin):
|
||||
@staticmethod
|
||||
def get_parameters():
|
||||
return ApplicationOperateAPI.get_parameters()
|
||||
|
||||
@staticmethod
|
||||
def get_request():
|
||||
return PlayDemoTextRequest
|
||||
|
||||
@staticmethod
|
||||
def get_response():
|
||||
return DefaultResultSerializer
|
||||
|
|
@ -1,61 +0,0 @@
|
|||
from drf_spectacular.types import OpenApiTypes
|
||||
from drf_spectacular.utils import OpenApiParameter
|
||||
|
||||
from application.serializers.application_api_key import EditApplicationKeySerializer, ApplicationKeySerializerModel
|
||||
from common.mixins.api_mixin import APIMixin
|
||||
from common.result import ResultSerializer
|
||||
|
||||
|
||||
class ApplicationKeyListResult(ResultSerializer):
|
||||
def get_data(self):
|
||||
return ApplicationKeySerializerModel(many=True)
|
||||
|
||||
|
||||
class ApplicationKeyResult(ResultSerializer):
|
||||
def get_data(self):
|
||||
return ApplicationKeySerializerModel()
|
||||
|
||||
|
||||
class ApplicationKeyAPI(APIMixin):
|
||||
@staticmethod
|
||||
def get_parameters():
|
||||
return [
|
||||
OpenApiParameter(
|
||||
name="workspace_id",
|
||||
description="工作空间id",
|
||||
type=OpenApiTypes.STR,
|
||||
location='path',
|
||||
required=True,
|
||||
),
|
||||
OpenApiParameter(
|
||||
name="application_id",
|
||||
description="application ID",
|
||||
type=OpenApiTypes.STR,
|
||||
location='path',
|
||||
required=True,
|
||||
)
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_response():
|
||||
return ApplicationKeyResult
|
||||
|
||||
class List(APIMixin):
|
||||
@staticmethod
|
||||
def get_response():
|
||||
return ApplicationKeyListResult
|
||||
|
||||
class Operate(APIMixin):
|
||||
@staticmethod
|
||||
def get_parameters():
|
||||
return [*ApplicationKeyAPI.get_parameters(), OpenApiParameter(
|
||||
name="api_key_id",
|
||||
description="ApiKeyId",
|
||||
type=OpenApiTypes.STR,
|
||||
location='path',
|
||||
required=True,
|
||||
)]
|
||||
|
||||
@staticmethod
|
||||
def get_request():
|
||||
return EditApplicationKeySerializer
|
||||
|
|
@ -1,141 +0,0 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎虎
|
||||
@file: application_chat.py
|
||||
@date:2025/6/10 13:54
|
||||
@desc:
|
||||
"""
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from drf_spectacular.types import OpenApiTypes
|
||||
from drf_spectacular.utils import OpenApiParameter
|
||||
|
||||
from application.serializers.application_chat import ApplicationChatQuerySerializers, \
|
||||
ApplicationChatResponseSerializers, ApplicationChatRecordExportRequest
|
||||
from common.mixins.api_mixin import APIMixin
|
||||
from common.result import ResultSerializer, ResultPageSerializer
|
||||
|
||||
|
||||
class ApplicationChatListResponseSerializers(ResultSerializer):
|
||||
def get_data(self):
|
||||
return ApplicationChatResponseSerializers(many=True)
|
||||
|
||||
|
||||
class ApplicationChatPageResponseSerializers(ResultPageSerializer):
|
||||
def get_data(self):
|
||||
return ApplicationChatResponseSerializers(many=True)
|
||||
|
||||
|
||||
class ApplicationChatQueryAPI(APIMixin):
|
||||
@staticmethod
|
||||
def get_request():
|
||||
return ApplicationChatQuerySerializers
|
||||
|
||||
@staticmethod
|
||||
def get_parameters():
|
||||
return [
|
||||
OpenApiParameter(
|
||||
name="workspace_id",
|
||||
description="工作空间id",
|
||||
type=OpenApiTypes.STR,
|
||||
location='path',
|
||||
required=True,
|
||||
),
|
||||
OpenApiParameter(
|
||||
name="application_id",
|
||||
description="application ID",
|
||||
type=OpenApiTypes.STR,
|
||||
location='path',
|
||||
required=True,
|
||||
), OpenApiParameter(
|
||||
name="start_time",
|
||||
description="start Time",
|
||||
type=OpenApiTypes.STR,
|
||||
required=True,
|
||||
),
|
||||
OpenApiParameter(
|
||||
name="end_time",
|
||||
description="end Time",
|
||||
type=OpenApiTypes.STR,
|
||||
required=True,
|
||||
),
|
||||
OpenApiParameter(
|
||||
name="abstract",
|
||||
description="summary",
|
||||
type=OpenApiTypes.STR,
|
||||
required=False,
|
||||
),
|
||||
OpenApiParameter(
|
||||
name="username",
|
||||
description="username",
|
||||
type=OpenApiTypes.STR,
|
||||
required=False,
|
||||
),
|
||||
OpenApiParameter(
|
||||
name="min_star",
|
||||
description=_("Minimum number of likes"),
|
||||
type=OpenApiTypes.INT,
|
||||
required=False,
|
||||
),
|
||||
OpenApiParameter(
|
||||
name="min_trample",
|
||||
description=_("Minimum number of clicks"),
|
||||
type=OpenApiTypes.INT,
|
||||
required=False,
|
||||
),
|
||||
OpenApiParameter(
|
||||
name="comparer",
|
||||
description=_("Comparator"),
|
||||
type=OpenApiTypes.STR,
|
||||
required=False,
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_response():
|
||||
return ApplicationChatListResponseSerializers
|
||||
|
||||
|
||||
class ApplicationChatQueryPageAPI(APIMixin):
|
||||
@staticmethod
|
||||
def get_request():
|
||||
return ApplicationChatQueryAPI.get_request()
|
||||
|
||||
@staticmethod
|
||||
def get_parameters():
|
||||
return [
|
||||
*ApplicationChatQueryAPI.get_parameters(),
|
||||
OpenApiParameter(
|
||||
name="current_page",
|
||||
description=_("Current page"),
|
||||
type=OpenApiTypes.INT,
|
||||
location='path',
|
||||
required=True,
|
||||
),
|
||||
OpenApiParameter(
|
||||
name="page_size",
|
||||
description=_("Page size"),
|
||||
type=OpenApiTypes.INT,
|
||||
location='path',
|
||||
required=True,
|
||||
),
|
||||
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_response():
|
||||
return ApplicationChatPageResponseSerializers
|
||||
|
||||
|
||||
class ApplicationChatExportAPI(APIMixin):
|
||||
@staticmethod
|
||||
def get_request():
|
||||
return ApplicationChatRecordExportRequest
|
||||
|
||||
@staticmethod
|
||||
def get_parameters():
|
||||
return ApplicationChatQueryAPI.get_parameters()
|
||||
|
||||
@staticmethod
|
||||
def get_response():
|
||||
return None
|
||||
|
|
@ -1,180 +0,0 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎虎
|
||||
@file: application_chat_record.py
|
||||
@date:2025/6/10 15:19
|
||||
@desc:
|
||||
"""
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from drf_spectacular.types import OpenApiTypes
|
||||
from drf_spectacular.utils import OpenApiParameter
|
||||
|
||||
from application.serializers.application_chat_record import ApplicationChatRecordAddKnowledgeSerializer, \
|
||||
ApplicationChatRecordImproveInstanceSerializer
|
||||
from common.mixins.api_mixin import APIMixin
|
||||
|
||||
|
||||
class ApplicationChatRecordQueryAPI(APIMixin):
|
||||
@staticmethod
|
||||
def get_response():
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def get_request():
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def get_parameters():
|
||||
return [
|
||||
OpenApiParameter(
|
||||
name="workspace_id",
|
||||
description="工作空间id",
|
||||
type=OpenApiTypes.STR,
|
||||
location='path',
|
||||
required=True,
|
||||
),
|
||||
OpenApiParameter(
|
||||
name="application_id",
|
||||
description="Application ID",
|
||||
type=OpenApiTypes.STR,
|
||||
location='path',
|
||||
required=True,
|
||||
),
|
||||
OpenApiParameter(
|
||||
name="chat_id",
|
||||
description=_("Chat ID"),
|
||||
type=OpenApiTypes.STR,
|
||||
location='path',
|
||||
required=True,
|
||||
),
|
||||
OpenApiParameter(
|
||||
name="order_asc",
|
||||
description=_("Is it in order"),
|
||||
type=OpenApiTypes.BOOL,
|
||||
required=True,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
class ApplicationChatRecordPageQueryAPI(APIMixin):
|
||||
@staticmethod
|
||||
def get_response():
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def get_request():
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def get_parameters():
|
||||
return [*ApplicationChatRecordQueryAPI.get_parameters(),
|
||||
OpenApiParameter(
|
||||
name="current_page",
|
||||
description=_("Current page"),
|
||||
type=OpenApiTypes.INT,
|
||||
location='path',
|
||||
required=True,
|
||||
),
|
||||
OpenApiParameter(
|
||||
name="page_size",
|
||||
description=_("Page size"),
|
||||
type=OpenApiTypes.INT,
|
||||
location='path',
|
||||
required=True,
|
||||
)]
|
||||
|
||||
|
||||
class ApplicationChatRecordImproveParagraphAPI(APIMixin):
|
||||
@staticmethod
|
||||
def get_response():
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def get_request():
|
||||
return ApplicationChatRecordImproveInstanceSerializer
|
||||
|
||||
@staticmethod
|
||||
def get_parameters():
|
||||
return [OpenApiParameter(
|
||||
name="workspace_id",
|
||||
description="工作空间id",
|
||||
type=OpenApiTypes.STR,
|
||||
location='path',
|
||||
required=True,
|
||||
),
|
||||
OpenApiParameter(
|
||||
name="application_id",
|
||||
description="Application ID",
|
||||
type=OpenApiTypes.STR,
|
||||
location='path',
|
||||
required=True,
|
||||
),
|
||||
OpenApiParameter(
|
||||
name="chat_id",
|
||||
description=_("Chat ID"),
|
||||
type=OpenApiTypes.STR,
|
||||
location='path',
|
||||
required=True,
|
||||
),
|
||||
OpenApiParameter(
|
||||
name="chat_record_id",
|
||||
description=_("Chat Record ID"),
|
||||
type=OpenApiTypes.STR,
|
||||
location='path',
|
||||
required=True,
|
||||
),
|
||||
OpenApiParameter(
|
||||
name="knowledge_id",
|
||||
description=_("Knowledge ID"),
|
||||
type=OpenApiTypes.STR,
|
||||
location='path',
|
||||
required=True,
|
||||
),
|
||||
OpenApiParameter(
|
||||
name="document_id",
|
||||
description=_("Document ID"),
|
||||
type=OpenApiTypes.STR,
|
||||
location='path',
|
||||
required=True,
|
||||
)
|
||||
]
|
||||
|
||||
class Operate(APIMixin):
|
||||
@staticmethod
|
||||
def get_parameters():
|
||||
return [*ApplicationChatRecordImproveParagraphAPI.get_parameters(), OpenApiParameter(
|
||||
name="paragraph_id",
|
||||
description=_("Paragraph ID"),
|
||||
type=OpenApiTypes.STR,
|
||||
location='path',
|
||||
required=True,
|
||||
)]
|
||||
|
||||
|
||||
class ApplicationChatRecordAddKnowledgeAPI(APIMixin):
|
||||
@staticmethod
|
||||
def get_request():
|
||||
return ApplicationChatRecordAddKnowledgeSerializer
|
||||
|
||||
@staticmethod
|
||||
def get_response():
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_parameters():
|
||||
return [
|
||||
OpenApiParameter(
|
||||
name="workspace_id",
|
||||
description="工作空间id",
|
||||
type=OpenApiTypes.STR,
|
||||
location='path',
|
||||
required=True,
|
||||
),
|
||||
OpenApiParameter(
|
||||
name="application_id",
|
||||
description="Application ID",
|
||||
type=OpenApiTypes.STR,
|
||||
location='path',
|
||||
required=True,
|
||||
)]
|
||||
|
|
@ -1,55 +0,0 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎虎
|
||||
@file: application_stats.py
|
||||
@date:2025/6/9 20:45
|
||||
@desc:
|
||||
"""
|
||||
from drf_spectacular.types import OpenApiTypes
|
||||
from drf_spectacular.utils import OpenApiParameter
|
||||
|
||||
from application.serializers.application_stats import ApplicationStatsSerializer
|
||||
from common.mixins.api_mixin import APIMixin
|
||||
from common.result import ResultSerializer
|
||||
|
||||
|
||||
class ApplicationStatsResult(ResultSerializer):
|
||||
def get_data(self):
|
||||
return ApplicationStatsSerializer(many=True)
|
||||
|
||||
|
||||
class ApplicationStatsAPI(APIMixin):
|
||||
@staticmethod
|
||||
def get_parameters():
|
||||
return [OpenApiParameter(
|
||||
name="workspace_id",
|
||||
description="工作空间id",
|
||||
type=OpenApiTypes.STR,
|
||||
location='path',
|
||||
required=True,
|
||||
),
|
||||
OpenApiParameter(
|
||||
name="application_id",
|
||||
description="application ID",
|
||||
type=OpenApiTypes.STR,
|
||||
location='path',
|
||||
required=True,
|
||||
),
|
||||
OpenApiParameter(
|
||||
name="start_time",
|
||||
description="start Time",
|
||||
type=OpenApiTypes.STR,
|
||||
required=True,
|
||||
),
|
||||
OpenApiParameter(
|
||||
name="end_time",
|
||||
description="end Time",
|
||||
type=OpenApiTypes.STR,
|
||||
required=True,
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_response():
|
||||
return ApplicationStatsResult
|
||||
|
|
@ -1,96 +0,0 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎虎
|
||||
@file: application_version.py
|
||||
@date:2025/6/4 17:33
|
||||
@desc:
|
||||
"""
|
||||
from drf_spectacular.types import OpenApiTypes
|
||||
from drf_spectacular.utils import OpenApiParameter
|
||||
|
||||
from application.serializers.application_version import ApplicationVersionModelSerializer
|
||||
from common.mixins.api_mixin import APIMixin
|
||||
from common.result import ResultSerializer, PageDataResponse, ResultPageSerializer
|
||||
|
||||
|
||||
class ApplicationListVersionResult(ResultSerializer):
|
||||
def get_data(self):
|
||||
return ApplicationVersionModelSerializer(many=True)
|
||||
|
||||
|
||||
class ApplicationPageVersionResult(ResultPageSerializer):
|
||||
def get_data(self):
|
||||
return ApplicationVersionModelSerializer(many=True)
|
||||
|
||||
|
||||
class ApplicationWorkflowVersionResult(ResultSerializer):
|
||||
def get_data(self):
|
||||
return ApplicationVersionModelSerializer()
|
||||
|
||||
|
||||
class ApplicationVersionAPI(APIMixin):
|
||||
@staticmethod
|
||||
def get_parameters():
|
||||
return [
|
||||
OpenApiParameter(
|
||||
name="workspace_id",
|
||||
description="工作空间id",
|
||||
type=OpenApiTypes.STR,
|
||||
location='path',
|
||||
required=True,
|
||||
),
|
||||
OpenApiParameter(
|
||||
name="application_id",
|
||||
description="application ID",
|
||||
type=OpenApiTypes.STR,
|
||||
location='path',
|
||||
required=True,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
class ApplicationVersionOperateAPI(APIMixin):
|
||||
@staticmethod
|
||||
def get_parameters():
|
||||
return [
|
||||
OpenApiParameter(
|
||||
name="application_version_id",
|
||||
description="工作流版本id",
|
||||
type=OpenApiTypes.STR,
|
||||
location='path',
|
||||
required=True,
|
||||
)
|
||||
, *ApplicationVersionAPI.get_parameters()
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_response():
|
||||
return ApplicationWorkflowVersionResult
|
||||
|
||||
|
||||
class ApplicationVersionListAPI(APIMixin):
|
||||
@staticmethod
|
||||
def get_parameters():
|
||||
return [
|
||||
OpenApiParameter(
|
||||
name="name",
|
||||
description="Version Name",
|
||||
type=OpenApiTypes.STR,
|
||||
required=False,
|
||||
)
|
||||
, *ApplicationVersionAPI.get_parameters()]
|
||||
|
||||
@staticmethod
|
||||
def get_response():
|
||||
return ApplicationListVersionResult
|
||||
|
||||
|
||||
class ApplicationVersionPageAPI(APIMixin):
|
||||
@staticmethod
|
||||
def get_parameters():
|
||||
return ApplicationVersionListAPI.get_parameters()
|
||||
|
||||
@staticmethod
|
||||
def get_response():
|
||||
return ApplicationPageVersionResult
|
||||
|
|
@ -12,45 +12,42 @@ from typing import Type
|
|||
|
||||
from rest_framework import serializers
|
||||
|
||||
from knowledge.models import Paragraph
|
||||
from dataset.models import Paragraph
|
||||
|
||||
|
||||
class ParagraphPipelineModel:
|
||||
|
||||
def __init__(self, _id: str, document_id: str, knowledge_id: str, content: str, title: str, status: str,
|
||||
is_active: bool, comprehensive_score: float, similarity: float, knowledge_name: str,
|
||||
document_name: str,
|
||||
hit_handling_method: str, directly_return_similarity: float, knowledge_type, meta: dict = None):
|
||||
def __init__(self, _id: str, document_id: str, dataset_id: str, content: str, title: str, status: str,
|
||||
is_active: bool, comprehensive_score: float, similarity: float, dataset_name: str, document_name: str,
|
||||
hit_handling_method: str, directly_return_similarity: float, meta: dict = None):
|
||||
self.id = _id
|
||||
self.document_id = document_id
|
||||
self.knowledge_id = knowledge_id
|
||||
self.dataset_id = dataset_id
|
||||
self.content = content
|
||||
self.title = title
|
||||
self.status = status,
|
||||
self.is_active = is_active
|
||||
self.comprehensive_score = comprehensive_score
|
||||
self.similarity = similarity
|
||||
self.knowledge_name = knowledge_name
|
||||
self.dataset_name = dataset_name
|
||||
self.document_name = document_name
|
||||
self.hit_handling_method = hit_handling_method
|
||||
self.directly_return_similarity = directly_return_similarity
|
||||
self.meta = meta
|
||||
self.knowledge_type = knowledge_type
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
'id': self.id,
|
||||
'document_id': self.document_id,
|
||||
'knowledge_id': self.knowledge_id,
|
||||
'dataset_id': self.dataset_id,
|
||||
'content': self.content,
|
||||
'title': self.title,
|
||||
'status': self.status,
|
||||
'is_active': self.is_active,
|
||||
'comprehensive_score': self.comprehensive_score,
|
||||
'similarity': self.similarity,
|
||||
'knowledge_name': self.knowledge_name,
|
||||
'dataset_name': self.dataset_name,
|
||||
'document_name': self.document_name,
|
||||
'knowledge_type': self.knowledge_type,
|
||||
'meta': self.meta,
|
||||
}
|
||||
|
||||
|
|
@ -60,8 +57,7 @@ class ParagraphPipelineModel:
|
|||
self.paragraph = {}
|
||||
self.comprehensive_score = None
|
||||
self.document_name = None
|
||||
self.knowledge_name = None
|
||||
self.knowledge_type = None
|
||||
self.dataset_name = None
|
||||
self.hit_handling_method = None
|
||||
self.directly_return_similarity = 0.9
|
||||
self.meta = {}
|
||||
|
|
@ -70,7 +66,7 @@ class ParagraphPipelineModel:
|
|||
if isinstance(paragraph, Paragraph):
|
||||
self.paragraph = {'id': paragraph.id,
|
||||
'document_id': paragraph.document_id,
|
||||
'knowledge_id': paragraph.knowledge_id,
|
||||
'dataset_id': paragraph.dataset_id,
|
||||
'content': paragraph.content,
|
||||
'title': paragraph.title,
|
||||
'status': paragraph.status,
|
||||
|
|
@ -80,12 +76,8 @@ class ParagraphPipelineModel:
|
|||
self.paragraph = paragraph
|
||||
return self
|
||||
|
||||
def add_knowledge_name(self, knowledge_name):
|
||||
self.knowledge_name = knowledge_name
|
||||
return self
|
||||
|
||||
def add_knowledge_type(self, knowledge_type):
|
||||
self.knowledge_type = knowledge_type
|
||||
def add_dataset_name(self, dataset_name):
|
||||
self.dataset_name = dataset_name
|
||||
return self
|
||||
|
||||
def add_document_name(self, document_name):
|
||||
|
|
@ -114,13 +106,12 @@ class ParagraphPipelineModel:
|
|||
|
||||
def build(self):
|
||||
return ParagraphPipelineModel(str(self.paragraph.get('id')), str(self.paragraph.get('document_id')),
|
||||
str(self.paragraph.get('knowledge_id')),
|
||||
str(self.paragraph.get('dataset_id')),
|
||||
self.paragraph.get('content'), self.paragraph.get('title'),
|
||||
self.paragraph.get('status'),
|
||||
self.paragraph.get('is_active'),
|
||||
self.comprehensive_score, self.similarity, self.knowledge_name,
|
||||
self.comprehensive_score, self.similarity, self.dataset_name,
|
||||
self.document_name, self.hit_handling_method, self.directly_return_similarity,
|
||||
self.knowledge_type,
|
||||
self.meta)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -17,14 +17,12 @@ from common.handle.impl.response.system_to_response import SystemToResponse
|
|||
|
||||
class PipelineManage:
|
||||
def __init__(self, step_list: List[Type[IBaseChatPipelineStep]],
|
||||
base_to_response: BaseToResponse = SystemToResponse(),
|
||||
debug=False):
|
||||
base_to_response: BaseToResponse = SystemToResponse()):
|
||||
# 步骤执行器
|
||||
self.step_list = [step() for step in step_list]
|
||||
# 上下文
|
||||
self.context = {'message_tokens': 0, 'answer_tokens': 0}
|
||||
self.base_to_response = base_to_response
|
||||
self.debug = debug
|
||||
|
||||
def run(self, context: Dict = None):
|
||||
self.context['start_time'] = time.time()
|
||||
|
|
@ -46,7 +44,6 @@ class PipelineManage:
|
|||
def __init__(self):
|
||||
self.step_list: List[Type[IBaseChatPipelineStep]] = []
|
||||
self.base_to_response = SystemToResponse()
|
||||
self.debug = False
|
||||
|
||||
def append_step(self, step: Type[IBaseChatPipelineStep]):
|
||||
self.step_list.append(step)
|
||||
|
|
@ -56,9 +53,5 @@ class PipelineManage:
|
|||
self.base_to_response = base_to_response
|
||||
return self
|
||||
|
||||
def add_debug(self, debug):
|
||||
self.debug = debug
|
||||
return self
|
||||
|
||||
def build(self):
|
||||
return PipelineManage(step_list=self.step_list, base_to_response=self.base_to_response, debug=self.debug)
|
||||
return PipelineManage(step_list=self.step_list, base_to_response=self.base_to_response)
|
||||
|
|
|
|||
|
|
@ -16,8 +16,9 @@ from rest_framework import serializers
|
|||
|
||||
from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep, ParagraphPipelineModel
|
||||
from application.chat_pipeline.pipeline_manage import PipelineManage
|
||||
from application.serializers.application import NoReferencesSetting
|
||||
from application.serializers.application_serializers import NoReferencesSetting
|
||||
from common.field.common import InstanceField
|
||||
from common.util.field_message import ErrMessage
|
||||
|
||||
|
||||
class ModelField(serializers.Field):
|
||||
|
|
@ -44,7 +45,7 @@ class PostResponseHandler:
|
|||
@abstractmethod
|
||||
def handler(self, chat_id, chat_record_id, paragraph_list: List[ParagraphPipelineModel], problem_text: str,
|
||||
answer_text,
|
||||
manage, step, padding_problem_text: str = None, **kwargs):
|
||||
manage, step, padding_problem_text: str = None, client_id=None, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
|
|
@ -52,43 +53,35 @@ class IChatStep(IBaseChatPipelineStep):
|
|||
class InstanceSerializer(serializers.Serializer):
|
||||
# 对话列表
|
||||
message_list = serializers.ListField(required=True, child=MessageField(required=True),
|
||||
label=_("Conversation list"))
|
||||
model_id = serializers.UUIDField(required=False, allow_null=True, label=_("Model id"))
|
||||
error_messages=ErrMessage.list(_("Conversation list")))
|
||||
model_id = serializers.UUIDField(required=False, allow_null=True, error_messages=ErrMessage.uuid(_("Model id")))
|
||||
# 段落列表
|
||||
paragraph_list = serializers.ListField(label=_("Paragraph List"))
|
||||
paragraph_list = serializers.ListField(error_messages=ErrMessage.list(_("Paragraph List")))
|
||||
# 对话id
|
||||
chat_id = serializers.UUIDField(required=True, label=_("Conversation ID"))
|
||||
chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid(_("Conversation ID")))
|
||||
# 用户问题
|
||||
problem_text = serializers.CharField(required=True, label=_("User Questions"))
|
||||
problem_text = serializers.CharField(required=True, error_messages=ErrMessage.uuid(_("User Questions")))
|
||||
# 后置处理器
|
||||
post_response_handler = InstanceField(model_type=PostResponseHandler,
|
||||
label=_("Post-processor"))
|
||||
error_messages=ErrMessage.base(_("Post-processor")))
|
||||
# 补全问题
|
||||
padding_problem_text = serializers.CharField(required=False,
|
||||
label=_("Completion Question"))
|
||||
error_messages=ErrMessage.base(_("Completion Question")))
|
||||
# 是否使用流的形式输出
|
||||
stream = serializers.BooleanField(required=False, label=_("Streaming Output"))
|
||||
chat_user_id = serializers.CharField(required=True, label=_("Chat user id"))
|
||||
|
||||
chat_user_type = serializers.CharField(required=True, label=_("Chat user Type"))
|
||||
stream = serializers.BooleanField(required=False, error_messages=ErrMessage.base(_("Streaming Output")))
|
||||
client_id = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Client id")))
|
||||
client_type = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Client Type")))
|
||||
# 未查询到引用分段
|
||||
no_references_setting = NoReferencesSetting(required=True,
|
||||
label=_("No reference segment settings"))
|
||||
error_messages=ErrMessage.base(_("No reference segment settings")))
|
||||
|
||||
workspace_id = serializers.CharField(required=True, label=_("Workspace ID"))
|
||||
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid(_("User ID")))
|
||||
|
||||
model_setting = serializers.DictField(required=True, allow_null=True,
|
||||
label=_("Model settings"))
|
||||
error_messages=ErrMessage.dict(_("Model settings")))
|
||||
|
||||
model_params_setting = serializers.DictField(required=False, allow_null=True,
|
||||
label=_("Model parameter settings"))
|
||||
mcp_enable = serializers.BooleanField(label="MCP否启用", required=False, default=False)
|
||||
mcp_tool_ids = serializers.JSONField(label="MCP工具ID列表", required=False, default=list)
|
||||
mcp_servers = serializers.JSONField(label="MCP服务列表", required=False, default=dict)
|
||||
mcp_source = serializers.CharField(label="MCP Source", required=False, default="referencing")
|
||||
tool_enable = serializers.BooleanField(label="工具是否启用", required=False, default=False)
|
||||
tool_ids = serializers.JSONField(label="工具ID列表", required=False, default=list)
|
||||
mcp_output_enable = serializers.BooleanField(label="MCP输出是否启用", required=False, default=True)
|
||||
error_messages=ErrMessage.dict(_("Model parameter settings")))
|
||||
|
||||
def is_valid(self, *, raise_exception=False):
|
||||
super().is_valid(raise_exception=True)
|
||||
|
|
@ -109,12 +102,9 @@ class IChatStep(IBaseChatPipelineStep):
|
|||
chat_id, problem_text,
|
||||
post_response_handler: PostResponseHandler,
|
||||
model_id: str = None,
|
||||
workspace_id: str = None,
|
||||
user_id: str = None,
|
||||
paragraph_list=None,
|
||||
manage: PipelineManage = None,
|
||||
padding_problem_text: str = None, stream: bool = True, chat_user_id=None, chat_user_type=None,
|
||||
no_references_setting=None, model_params_setting=None, model_setting=None,
|
||||
mcp_enable=False, mcp_tool_ids=None, mcp_servers='', mcp_source="referencing",
|
||||
tool_enable=False, tool_ids=None, mcp_output_enable=True,
|
||||
**kwargs):
|
||||
padding_problem_text: str = None, stream: bool = True, client_id=None, client_type=None,
|
||||
no_references_setting=None, model_params_setting=None, model_setting=None, **kwargs):
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -6,11 +6,10 @@
|
|||
@date:2024/1/9 18:25
|
||||
@desc: 对话step Base实现
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
import logging
|
||||
import time
|
||||
import traceback
|
||||
import uuid_utils.compat as uuid
|
||||
import uuid
|
||||
from typing import List
|
||||
|
||||
from django.db.models import QuerySet
|
||||
|
|
@ -19,28 +18,22 @@ from django.utils.translation import gettext as _
|
|||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.schema import BaseMessage
|
||||
from langchain.schema.messages import HumanMessage, AIMessage
|
||||
from langchain_core.messages import AIMessageChunk, SystemMessage
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from rest_framework import status
|
||||
|
||||
from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel
|
||||
from application.chat_pipeline.pipeline_manage import PipelineManage
|
||||
from application.chat_pipeline.step.chat_step.i_chat_step import IChatStep, PostResponseHandler
|
||||
from application.flow.tools import Reasoning, mcp_response_generator
|
||||
from application.models import ApplicationChatUserStats, ChatUserType
|
||||
from common.utils.logger import maxkb_logger
|
||||
from common.utils.rsa_util import rsa_long_decrypt
|
||||
from common.utils.tool_code import ToolExecutor
|
||||
from maxkb.const import CONFIG
|
||||
from models_provider.tools import get_model_instance_by_model_workspace_id
|
||||
from tools.models import Tool
|
||||
from application.flow.tools import Reasoning
|
||||
from application.models.api_key_model import ApplicationPublicAccessClient
|
||||
from common.constants.authentication_type import AuthenticationType
|
||||
from setting.models_provider.tools import get_model_instance_by_model_user_id
|
||||
|
||||
|
||||
def add_access_num(chat_user_id=None, chat_user_type=None, application_id=None):
|
||||
if [ChatUserType.ANONYMOUS_USER.value, ChatUserType.CHAT_USER.value].__contains__(
|
||||
chat_user_type) and application_id is not None:
|
||||
application_public_access_client = (QuerySet(ApplicationChatUserStats).filter(chat_user_id=chat_user_id,
|
||||
chat_user_type=chat_user_type,
|
||||
application_id=application_id)
|
||||
def add_access_num(client_id=None, client_type=None, application_id=None):
|
||||
if client_type == AuthenticationType.APPLICATION_ACCESS_TOKEN.value and application_id is not None:
|
||||
application_public_access_client = (QuerySet(ApplicationPublicAccessClient).filter(client_id=client_id,
|
||||
application_id=application_id)
|
||||
.first())
|
||||
if application_public_access_client is not None:
|
||||
application_public_access_client.access_num = application_public_access_client.access_num + 1
|
||||
|
|
@ -59,7 +52,6 @@ def write_context(step, manage, request_token, response_token, all_text):
|
|||
manage.context['answer_tokens'] = manage.context['answer_tokens'] + response_token
|
||||
|
||||
|
||||
|
||||
def event_content(response,
|
||||
chat_id,
|
||||
chat_record_id,
|
||||
|
|
@ -71,7 +63,7 @@ def event_content(response,
|
|||
message_list: List[BaseMessage],
|
||||
problem_text: str,
|
||||
padding_problem_text: str = None,
|
||||
chat_user_id=None, chat_user_type=None,
|
||||
client_id=None, client_type=None,
|
||||
is_ai_chat: bool = None,
|
||||
model_setting=None):
|
||||
if model_setting is None:
|
||||
|
|
@ -132,24 +124,26 @@ def event_content(response,
|
|||
request_token = 0
|
||||
response_token = 0
|
||||
write_context(step, manage, request_token, response_token, all_text)
|
||||
asker = manage.context.get('form_data', {}).get('asker', None)
|
||||
post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text,
|
||||
all_text, manage, step, padding_problem_text,
|
||||
reasoning_content=reasoning_content if reasoning_content_enable else '')
|
||||
all_text, manage, step, padding_problem_text, client_id,
|
||||
reasoning_content=reasoning_content if reasoning_content_enable else ''
|
||||
, asker=asker)
|
||||
yield manage.get_base_to_response().to_stream_chunk_response(chat_id, str(chat_record_id), 'ai-chat-node',
|
||||
[], '', True,
|
||||
request_token, response_token,
|
||||
{'node_is_end': True, 'view_type': 'many_view',
|
||||
'node_type': 'ai-chat-node'})
|
||||
if not manage.debug:
|
||||
add_access_num(chat_user_id, chat_user_type, manage.context.get('application_id'))
|
||||
add_access_num(client_id, client_type, manage.context.get('application_id'))
|
||||
except Exception as e:
|
||||
maxkb_logger.error(f'{str(e)}:{traceback.format_exc()}')
|
||||
logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}')
|
||||
all_text = 'Exception:' + str(e)
|
||||
write_context(step, manage, 0, 0, all_text)
|
||||
asker = manage.context.get('form_data', {}).get('asker', None)
|
||||
post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text,
|
||||
all_text, manage, step, padding_problem_text, reasoning_content='')
|
||||
if not manage.debug:
|
||||
add_access_num(chat_user_id, chat_user_type, manage.context.get('application_id'))
|
||||
all_text, manage, step, padding_problem_text, client_id, reasoning_content='',
|
||||
asker=asker)
|
||||
add_access_num(client_id, client_type, manage.context.get('application_id'))
|
||||
yield manage.get_base_to_response().to_stream_chunk_response(chat_id, str(chat_record_id), 'ai-chat-node',
|
||||
[], all_text,
|
||||
False,
|
||||
|
|
@ -166,48 +160,30 @@ class BaseChatStep(IChatStep):
|
|||
problem_text,
|
||||
post_response_handler: PostResponseHandler,
|
||||
model_id: str = None,
|
||||
workspace_id: str = None,
|
||||
user_id: str = None,
|
||||
paragraph_list=None,
|
||||
manage: PipelineManage = None,
|
||||
padding_problem_text: str = None,
|
||||
stream: bool = True,
|
||||
chat_user_id=None, chat_user_type=None,
|
||||
client_id=None, client_type=None,
|
||||
no_references_setting=None,
|
||||
model_params_setting=None,
|
||||
model_setting=None,
|
||||
mcp_enable=False,
|
||||
mcp_tool_ids=None,
|
||||
mcp_servers='',
|
||||
mcp_source="referencing",
|
||||
tool_enable=False,
|
||||
tool_ids=None,
|
||||
mcp_output_enable=True,
|
||||
**kwargs):
|
||||
chat_model = get_model_instance_by_model_workspace_id(model_id, workspace_id,
|
||||
**model_params_setting) if model_id is not None else None
|
||||
chat_model = get_model_instance_by_model_user_id(model_id, user_id,
|
||||
**model_params_setting) if model_id is not None else None
|
||||
if stream:
|
||||
return self.execute_stream(message_list, chat_id, problem_text, post_response_handler, chat_model,
|
||||
paragraph_list,
|
||||
manage, padding_problem_text, chat_user_id, chat_user_type,
|
||||
no_references_setting,
|
||||
model_setting,
|
||||
mcp_enable, mcp_tool_ids, mcp_servers, mcp_source, tool_enable, tool_ids, mcp_output_enable)
|
||||
manage, padding_problem_text, client_id, client_type, no_references_setting,
|
||||
model_setting)
|
||||
else:
|
||||
return self.execute_block(message_list, chat_id, problem_text, post_response_handler, chat_model,
|
||||
paragraph_list,
|
||||
manage, padding_problem_text, chat_user_id, chat_user_type, no_references_setting,
|
||||
model_setting,
|
||||
mcp_enable, mcp_tool_ids, mcp_servers, mcp_source, tool_enable, tool_ids, mcp_output_enable)
|
||||
manage, padding_problem_text, client_id, client_type, no_references_setting,
|
||||
model_setting)
|
||||
|
||||
def get_details(self, manage, **kwargs):
|
||||
# 删除临时生成的MCP代码文件
|
||||
if self.context.get('execute_ids'):
|
||||
executor = ToolExecutor(CONFIG.get('SANDBOX'))
|
||||
# 清理工具代码文件,延时删除,避免文件被占用
|
||||
for tool_id in self.context.get('execute_ids'):
|
||||
code_path = f'{executor.sandbox_path}/execute/{tool_id}.py'
|
||||
if os.path.exists(code_path):
|
||||
os.remove(code_path)
|
||||
return {
|
||||
'step_type': 'chat_step',
|
||||
'run_time': self.context['run_time'],
|
||||
|
|
@ -221,72 +197,19 @@ class BaseChatStep(IChatStep):
|
|||
|
||||
@staticmethod
|
||||
def reset_message_list(message_list: List[BaseMessage], answer_text):
|
||||
result = [{'role': 'user' if isinstance(message, HumanMessage) else (
|
||||
'system' if isinstance(message, SystemMessage) else 'ai'), 'content': message.content} for
|
||||
result = [{'role': 'user' if isinstance(message, HumanMessage) else 'ai', 'content': message.content} for
|
||||
message
|
||||
in
|
||||
message_list]
|
||||
result.append({'role': 'ai', 'content': answer_text})
|
||||
return result
|
||||
|
||||
def _handle_mcp_request(self, mcp_enable, tool_enable, mcp_source, mcp_servers, mcp_tool_ids, tool_ids,
|
||||
mcp_output_enable, chat_model, message_list):
|
||||
if not mcp_enable and not tool_enable:
|
||||
return None
|
||||
|
||||
mcp_servers_config = {}
|
||||
|
||||
# 迁移过来mcp_source是None
|
||||
if mcp_source is None:
|
||||
mcp_source = 'custom'
|
||||
if mcp_enable:
|
||||
# 兼容老数据
|
||||
if not mcp_tool_ids:
|
||||
mcp_tool_ids = []
|
||||
if mcp_source == 'custom' and mcp_servers is not None and '"stdio"' not in mcp_servers:
|
||||
mcp_servers_config = json.loads(mcp_servers)
|
||||
elif mcp_tool_ids:
|
||||
mcp_tools = QuerySet(Tool).filter(id__in=mcp_tool_ids).values()
|
||||
for mcp_tool in mcp_tools:
|
||||
if mcp_tool and mcp_tool['is_active']:
|
||||
mcp_servers_config = {**mcp_servers_config, **json.loads(mcp_tool['code'])}
|
||||
|
||||
if tool_enable:
|
||||
if tool_ids and len(tool_ids) > 0: # 如果有工具ID,则将其转换为MCP
|
||||
self.context['tool_ids'] = tool_ids
|
||||
self.context['execute_ids'] = []
|
||||
for tool_id in tool_ids:
|
||||
tool = QuerySet(Tool).filter(id=tool_id).first()
|
||||
if tool is None or tool.is_active is False:
|
||||
continue
|
||||
executor = ToolExecutor(CONFIG.get('SANDBOX'))
|
||||
if tool.init_params is not None:
|
||||
params = json.loads(rsa_long_decrypt(tool.init_params))
|
||||
else:
|
||||
params = {}
|
||||
_id, tool_config = executor.get_tool_mcp_config(tool.code, params)
|
||||
|
||||
self.context['execute_ids'].append(_id)
|
||||
mcp_servers_config[str(tool.id)] = tool_config
|
||||
|
||||
if len(mcp_servers_config) > 0:
|
||||
return mcp_response_generator(chat_model, message_list, json.dumps(mcp_servers_config), mcp_output_enable)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_stream_result(self, message_list: List[BaseMessage],
|
||||
@staticmethod
|
||||
def get_stream_result(message_list: List[BaseMessage],
|
||||
chat_model: BaseChatModel = None,
|
||||
paragraph_list=None,
|
||||
no_references_setting=None,
|
||||
problem_text=None,
|
||||
mcp_enable=False,
|
||||
mcp_tool_ids=None,
|
||||
mcp_servers='',
|
||||
mcp_source="referencing",
|
||||
tool_enable=False,
|
||||
tool_ids=None,
|
||||
mcp_output_enable=True):
|
||||
problem_text=None):
|
||||
if paragraph_list is None:
|
||||
paragraph_list = []
|
||||
directly_return_chunk_list = [AIMessageChunk(content=paragraph.content)
|
||||
|
|
@ -302,12 +225,6 @@ class BaseChatStep(IChatStep):
|
|||
return iter([AIMessageChunk(
|
||||
_('Sorry, the AI model is not configured. Please go to the application to set up the AI model first.'))]), False
|
||||
else:
|
||||
# 处理 MCP 请求
|
||||
mcp_result = self._handle_mcp_request(
|
||||
mcp_enable, tool_enable, mcp_source, mcp_servers, mcp_tool_ids, tool_ids, mcp_output_enable, chat_model, message_list,
|
||||
)
|
||||
if mcp_result:
|
||||
return mcp_result, True
|
||||
return chat_model.stream(message_list), True
|
||||
|
||||
def execute_stream(self, message_list: List[BaseMessage],
|
||||
|
|
@ -318,43 +235,27 @@ class BaseChatStep(IChatStep):
|
|||
paragraph_list=None,
|
||||
manage: PipelineManage = None,
|
||||
padding_problem_text: str = None,
|
||||
chat_user_id=None, chat_user_type=None,
|
||||
client_id=None, client_type=None,
|
||||
no_references_setting=None,
|
||||
model_setting=None,
|
||||
mcp_enable=False,
|
||||
mcp_tool_ids=None,
|
||||
mcp_servers='',
|
||||
mcp_source="referencing",
|
||||
tool_enable=False,
|
||||
tool_ids=None,
|
||||
mcp_output_enable=True):
|
||||
model_setting=None):
|
||||
chat_result, is_ai_chat = self.get_stream_result(message_list, chat_model, paragraph_list,
|
||||
no_references_setting, problem_text, mcp_enable, mcp_tool_ids, mcp_servers, mcp_source, tool_enable, tool_ids,
|
||||
mcp_output_enable)
|
||||
chat_record_id = uuid.uuid7()
|
||||
no_references_setting, problem_text)
|
||||
chat_record_id = uuid.uuid1()
|
||||
r = StreamingHttpResponse(
|
||||
streaming_content=event_content(chat_result, chat_id, chat_record_id, paragraph_list,
|
||||
post_response_handler, manage, self, chat_model, message_list, problem_text,
|
||||
padding_problem_text, chat_user_id, chat_user_type, is_ai_chat,
|
||||
model_setting),
|
||||
padding_problem_text, client_id, client_type, is_ai_chat, model_setting),
|
||||
content_type='text/event-stream;charset=utf-8')
|
||||
|
||||
r['Cache-Control'] = 'no-cache'
|
||||
return r
|
||||
|
||||
def get_block_result(self, message_list: List[BaseMessage],
|
||||
@staticmethod
|
||||
def get_block_result(message_list: List[BaseMessage],
|
||||
chat_model: BaseChatModel = None,
|
||||
paragraph_list=None,
|
||||
no_references_setting=None,
|
||||
problem_text=None,
|
||||
mcp_enable=False,
|
||||
mcp_tool_ids=None,
|
||||
mcp_servers='',
|
||||
mcp_source="referencing",
|
||||
tool_enable=False,
|
||||
tool_ids=None,
|
||||
mcp_output_enable=True
|
||||
):
|
||||
problem_text=None):
|
||||
if paragraph_list is None:
|
||||
paragraph_list = []
|
||||
directly_return_chunk_list = [AIMessageChunk(content=paragraph.content)
|
||||
|
|
@ -369,13 +270,6 @@ class BaseChatStep(IChatStep):
|
|||
return AIMessage(
|
||||
_('Sorry, the AI model is not configured. Please go to the application to set up the AI model first.')), False
|
||||
else:
|
||||
# 处理 MCP 请求
|
||||
mcp_result = self._handle_mcp_request(
|
||||
mcp_enable, tool_enable, mcp_source, mcp_servers, mcp_tool_ids, tool_ids, mcp_output_enable,
|
||||
chat_model, message_list,
|
||||
)
|
||||
if mcp_result:
|
||||
return mcp_result, True
|
||||
return chat_model.invoke(message_list), True
|
||||
|
||||
def execute_block(self, message_list: List[BaseMessage],
|
||||
|
|
@ -386,25 +280,18 @@ class BaseChatStep(IChatStep):
|
|||
paragraph_list=None,
|
||||
manage: PipelineManage = None,
|
||||
padding_problem_text: str = None,
|
||||
chat_user_id=None, chat_user_type=None, no_references_setting=None,
|
||||
model_setting=None,
|
||||
mcp_enable=False,
|
||||
mcp_tool_ids=None,
|
||||
mcp_servers='',
|
||||
mcp_source="referencing",
|
||||
tool_enable=False,
|
||||
tool_ids=None,
|
||||
mcp_output_enable=True):
|
||||
client_id=None, client_type=None, no_references_setting=None,
|
||||
model_setting=None):
|
||||
reasoning_content_enable = model_setting.get('reasoning_content_enable', False)
|
||||
reasoning_content_start = model_setting.get('reasoning_content_start', '<think>')
|
||||
reasoning_content_end = model_setting.get('reasoning_content_end', '</think>')
|
||||
reasoning = Reasoning(reasoning_content_start,
|
||||
reasoning_content_end)
|
||||
chat_record_id = uuid.uuid7()
|
||||
chat_record_id = uuid.uuid1()
|
||||
# 调用模型
|
||||
try:
|
||||
chat_result, is_ai_chat = self.get_block_result(message_list, chat_model, paragraph_list,
|
||||
no_references_setting, problem_text, mcp_enable, mcp_tool_ids, mcp_servers, mcp_source, tool_enable, tool_ids, mcp_output_enable)
|
||||
no_references_setting, problem_text)
|
||||
if is_ai_chat:
|
||||
request_token = chat_model.get_num_tokens_from_messages(message_list)
|
||||
response_token = chat_model.get_num_tokens(chat_result.content)
|
||||
|
|
@ -420,11 +307,12 @@ class BaseChatStep(IChatStep):
|
|||
else:
|
||||
reasoning_content = reasoning_result.get('reasoning_content') + reasoning_result_end.get(
|
||||
'reasoning_content')
|
||||
asker = manage.context.get('form_data', {}).get('asker', None)
|
||||
post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text,
|
||||
content, manage, self, padding_problem_text,
|
||||
reasoning_content=reasoning_content)
|
||||
if not manage.debug:
|
||||
add_access_num(chat_user_id, chat_user_type, manage.context.get('application_id'))
|
||||
content, manage, self, padding_problem_text, client_id,
|
||||
reasoning_content=reasoning_content if reasoning_content_enable else '',
|
||||
asker=asker)
|
||||
add_access_num(client_id, client_type, manage.context.get('application_id'))
|
||||
return manage.get_base_to_response().to_block_response(str(chat_id), str(chat_record_id),
|
||||
content, True,
|
||||
request_token, response_token,
|
||||
|
|
@ -437,9 +325,10 @@ class BaseChatStep(IChatStep):
|
|||
except Exception as e:
|
||||
all_text = 'Exception:' + str(e)
|
||||
write_context(self, manage, 0, 0, all_text)
|
||||
asker = manage.context.get('form_data', {}).get('asker', None)
|
||||
post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text,
|
||||
all_text, manage, self, padding_problem_text, reasoning_content='')
|
||||
if not manage.debug:
|
||||
add_access_num(chat_user_id, chat_user_type, manage.context.get('application_id'))
|
||||
all_text, manage, self, padding_problem_text, client_id, reasoning_content='',
|
||||
asker=asker)
|
||||
add_access_num(client_id, client_type, manage.context.get('application_id'))
|
||||
return manage.get_base_to_response().to_block_response(str(chat_id), str(chat_record_id), all_text, True, 0,
|
||||
0, _status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||||
|
|
|
|||
|
|
@ -16,35 +16,34 @@ from rest_framework import serializers
|
|||
from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep, ParagraphPipelineModel
|
||||
from application.chat_pipeline.pipeline_manage import PipelineManage
|
||||
from application.models import ChatRecord
|
||||
from application.serializers.application import NoReferencesSetting
|
||||
from application.serializers.application_serializers import NoReferencesSetting
|
||||
from common.field.common import InstanceField
|
||||
from common.util.field_message import ErrMessage
|
||||
|
||||
|
||||
class IGenerateHumanMessageStep(IBaseChatPipelineStep):
|
||||
class InstanceSerializer(serializers.Serializer):
|
||||
# 问题
|
||||
problem_text = serializers.CharField(required=True, label=_("question"))
|
||||
problem_text = serializers.CharField(required=True, error_messages=ErrMessage.char(_("question")))
|
||||
# 段落列表
|
||||
paragraph_list = serializers.ListField(child=InstanceField(model_type=ParagraphPipelineModel, required=True),
|
||||
label=_("Paragraph List"))
|
||||
error_messages=ErrMessage.list(_("Paragraph List")))
|
||||
# 历史对答
|
||||
history_chat_record = serializers.ListField(child=InstanceField(model_type=ChatRecord, required=True),
|
||||
label=_("History Questions"))
|
||||
error_messages=ErrMessage.list(_("History Questions")))
|
||||
# 多轮对话数量
|
||||
dialogue_number = serializers.IntegerField(required=True, label=_("Number of multi-round conversations"))
|
||||
dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer(_("Number of multi-round conversations")))
|
||||
# 最大携带知识库段落长度
|
||||
max_paragraph_char_number = serializers.IntegerField(required=True,
|
||||
label=_("Maximum length of the knowledge base paragraph"))
|
||||
max_paragraph_char_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer(
|
||||
_("Maximum length of the knowledge base paragraph")))
|
||||
# 模板
|
||||
prompt = serializers.CharField(required=True, label=_("Prompt word"))
|
||||
prompt = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Prompt word")))
|
||||
system = serializers.CharField(required=False, allow_null=True, allow_blank=True,
|
||||
label=_("System prompt words (role)"))
|
||||
error_messages=ErrMessage.char(_("System prompt words (role)")))
|
||||
# 补齐问题
|
||||
padding_problem_text = serializers.CharField(required=False,
|
||||
label=_("Completion problem"))
|
||||
padding_problem_text = serializers.CharField(required=False, error_messages=ErrMessage.char(_("Completion problem")))
|
||||
# 未查询到引用分段
|
||||
no_references_setting = NoReferencesSetting(required=True,
|
||||
label=_("No reference segment settings"))
|
||||
no_references_setting = NoReferencesSetting(required=True, error_messages=ErrMessage.base(_("No reference segment settings")))
|
||||
|
||||
def get_step_serializer(self, manage: PipelineManage) -> Type[serializers.Serializer]:
|
||||
return self.InstanceSerializer
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineMode
|
|||
from application.chat_pipeline.step.generate_human_message_step.i_generate_human_message_step import \
|
||||
IGenerateHumanMessageStep
|
||||
from application.models import ChatRecord
|
||||
from common.utils.common import flat_map
|
||||
from common.util.split_model import flat_map
|
||||
|
||||
|
||||
class BaseGenerateHumanMessageStep(IGenerateHumanMessageStep):
|
||||
|
|
|
|||
|
|
@ -16,20 +16,22 @@ from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep
|
|||
from application.chat_pipeline.pipeline_manage import PipelineManage
|
||||
from application.models import ChatRecord
|
||||
from common.field.common import InstanceField
|
||||
from common.util.field_message import ErrMessage
|
||||
|
||||
|
||||
class IResetProblemStep(IBaseChatPipelineStep):
|
||||
class InstanceSerializer(serializers.Serializer):
|
||||
# 问题文本
|
||||
problem_text = serializers.CharField(required=True, label=_("question"))
|
||||
problem_text = serializers.CharField(required=True, error_messages=ErrMessage.float(_("question")))
|
||||
# 历史对答
|
||||
history_chat_record = serializers.ListField(child=InstanceField(model_type=ChatRecord, required=True),
|
||||
label=_("History Questions"))
|
||||
error_messages=ErrMessage.list(_("History Questions")))
|
||||
# 大语言模型
|
||||
model_id = serializers.UUIDField(required=False, allow_null=True, label=_("Model id"))
|
||||
workspace_id = serializers.CharField(required=True, label=_("User ID"))
|
||||
model_id = serializers.UUIDField(required=False, allow_null=True, error_messages=ErrMessage.uuid(_("Model id")))
|
||||
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid(_("User ID")))
|
||||
problem_optimization_prompt = serializers.CharField(required=False, max_length=102400,
|
||||
label=_("Question completion prompt"))
|
||||
error_messages=ErrMessage.char(
|
||||
_("Question completion prompt")))
|
||||
|
||||
def get_step_serializer(self, manage: PipelineManage) -> Type[serializers.Serializer]:
|
||||
return self.InstanceSerializer
|
||||
|
|
@ -50,6 +52,6 @@ class IResetProblemStep(IBaseChatPipelineStep):
|
|||
@abstractmethod
|
||||
def execute(self, problem_text: str, history_chat_record: List[ChatRecord] = None, model_id: str = None,
|
||||
problem_optimization_prompt=None,
|
||||
workspace_id=None,
|
||||
user_id=None,
|
||||
**kwargs):
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -13,8 +13,8 @@ from langchain.schema import HumanMessage
|
|||
|
||||
from application.chat_pipeline.step.reset_problem_step.i_reset_problem_step import IResetProblemStep
|
||||
from application.models import ChatRecord
|
||||
from common.utils.split_model import flat_map
|
||||
from models_provider.tools import get_model_instance_by_model_workspace_id
|
||||
from common.util.split_model import flat_map
|
||||
from setting.models_provider.tools import get_model_instance_by_model_user_id
|
||||
|
||||
prompt = _(
|
||||
"() contains the user's question. Answer the guessed user's question based on the context ({question}) Requirement: Output a complete question and put it in the <data></data> tag")
|
||||
|
|
@ -23,9 +23,9 @@ prompt = _(
|
|||
class BaseResetProblemStep(IResetProblemStep):
|
||||
def execute(self, problem_text: str, history_chat_record: List[ChatRecord] = None, model_id: str = None,
|
||||
problem_optimization_prompt=None,
|
||||
workspace_id=None,
|
||||
user_id=None,
|
||||
**kwargs) -> str:
|
||||
chat_model = get_model_instance_by_model_workspace_id(model_id, workspace_id) if model_id is not None else None
|
||||
chat_model = get_model_instance_by_model_user_id(model_id, user_id) if model_id is not None else None
|
||||
if chat_model is None:
|
||||
return problem_text
|
||||
start_index = len(history_chat_record) - 3
|
||||
|
|
|
|||
|
|
@ -16,62 +16,62 @@ from rest_framework import serializers
|
|||
|
||||
from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep, ParagraphPipelineModel
|
||||
from application.chat_pipeline.pipeline_manage import PipelineManage
|
||||
from common.util.field_message import ErrMessage
|
||||
|
||||
|
||||
class ISearchDatasetStep(IBaseChatPipelineStep):
|
||||
class InstanceSerializer(serializers.Serializer):
|
||||
# 原始问题文本
|
||||
problem_text = serializers.CharField(required=True, label=_("question"))
|
||||
problem_text = serializers.CharField(required=True, error_messages=ErrMessage.char(_("question")))
|
||||
# 系统补全问题文本
|
||||
padding_problem_text = serializers.CharField(required=False,
|
||||
label=_("System completes question text"))
|
||||
error_messages=ErrMessage.char(_("System completes question text")))
|
||||
# 需要查询的数据集id列表
|
||||
knowledge_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True),
|
||||
label=_("Dataset id list"))
|
||||
dataset_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True),
|
||||
error_messages=ErrMessage.list(_("Dataset id list")))
|
||||
# 需要排除的文档id
|
||||
exclude_document_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True),
|
||||
label=_("List of document ids to exclude"))
|
||||
error_messages=ErrMessage.list(_("List of document ids to exclude")))
|
||||
# 需要排除向量id
|
||||
exclude_paragraph_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True),
|
||||
label=_("List of exclusion vector ids"))
|
||||
error_messages=ErrMessage.list(_("List of exclusion vector ids")))
|
||||
# 需要查询的条数
|
||||
top_n = serializers.IntegerField(required=True,
|
||||
label=_("Reference segment number"))
|
||||
error_messages=ErrMessage.integer(_("Reference segment number")))
|
||||
# 相似度 0-1之间
|
||||
similarity = serializers.FloatField(required=True, max_value=1, min_value=0,
|
||||
label=_("Similarity"))
|
||||
error_messages=ErrMessage.float(_("Similarity")))
|
||||
search_mode = serializers.CharField(required=True, validators=[
|
||||
validators.RegexValidator(regex=re.compile("^embedding|keywords|blend$"),
|
||||
message=_("The type only supports embedding|keywords|blend"), code=500)
|
||||
], label=_("Retrieval Mode"))
|
||||
workspace_id = serializers.CharField(required=True, label=_("Workspace ID"))
|
||||
], error_messages=ErrMessage.char(_("Retrieval Mode")))
|
||||
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid(_("User ID")))
|
||||
|
||||
def get_step_serializer(self, manage: PipelineManage) -> Type[InstanceSerializer]:
|
||||
return self.InstanceSerializer
|
||||
|
||||
def _run(self, manage: PipelineManage):
|
||||
paragraph_list = self.execute(**self.context['step_args'], manage=manage)
|
||||
paragraph_list = self.execute(**self.context['step_args'])
|
||||
manage.context['paragraph_list'] = paragraph_list
|
||||
self.context['paragraph_list'] = paragraph_list
|
||||
|
||||
@abstractmethod
|
||||
def execute(self, problem_text: str, knowledge_id_list: list[str], exclude_document_id_list: list[str],
|
||||
def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str],
|
||||
exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None,
|
||||
search_mode: str = None,
|
||||
workspace_id=None,
|
||||
manage: PipelineManage = None,
|
||||
user_id=None,
|
||||
**kwargs) -> List[ParagraphPipelineModel]:
|
||||
"""
|
||||
关于 用户和补全问题 说明: 补全问题如果有就使用补全问题去查询 反之就用用户原始问题查询
|
||||
:param similarity: 相关性
|
||||
:param top_n: 查询多少条
|
||||
:param problem_text: 用户问题
|
||||
:param knowledge_id_list: 需要查询的数据集id列表
|
||||
:param dataset_id_list: 需要查询的数据集id列表
|
||||
:param exclude_document_id_list: 需要排除的文档id
|
||||
:param exclude_paragraph_id_list: 需要排除段落id
|
||||
:param padding_problem_text 补全问题
|
||||
:param search_mode 检索模式
|
||||
:param workspace_id 工作空间id
|
||||
:param user_id 用户id
|
||||
:return: 段落列表
|
||||
"""
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -16,59 +16,51 @@ from rest_framework.utils.formatting import lazy_format
|
|||
from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel
|
||||
from application.chat_pipeline.step.search_dataset_step.i_search_dataset_step import ISearchDatasetStep
|
||||
from common.config.embedding_config import VectorStore, ModelManage
|
||||
from common.constants.permission_constants import RoleConstants
|
||||
from common.database_model_manage.database_model_manage import DatabaseModelManage
|
||||
from common.db.search import native_search
|
||||
from common.utils.common import get_file_content
|
||||
from knowledge.models import Paragraph, Knowledge
|
||||
from knowledge.models import SearchMode
|
||||
from maxkb.conf import PROJECT_DIR
|
||||
from models_provider.models import Model
|
||||
from models_provider.tools import get_model, get_model_by_id, get_model_default_params
|
||||
from common.util.file_util import get_file_content
|
||||
from dataset.models import Paragraph, DataSet
|
||||
from embedding.models import SearchMode
|
||||
from setting.models import Model
|
||||
from setting.models_provider import get_model
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
|
||||
def reset_meta(meta):
|
||||
if not meta.get('allow_download', False):
|
||||
return {'allow_download': False}
|
||||
return meta
|
||||
def get_model_by_id(_id, user_id):
|
||||
model = QuerySet(Model).filter(id=_id).first()
|
||||
if model is None:
|
||||
raise Exception(_("Model does not exist"))
|
||||
if model.permission_type == 'PRIVATE' and str(model.user_id) != str(user_id):
|
||||
message = lazy_format(_('No permission to use this model {model_name}'), model_name=model.name)
|
||||
raise Exception(message)
|
||||
return model
|
||||
|
||||
|
||||
def get_embedding_id(knowledge_id_list):
|
||||
knowledge_list = QuerySet(Knowledge).filter(id__in=knowledge_id_list)
|
||||
if len(set([knowledge.embedding_model_id for knowledge in knowledge_list])) > 1:
|
||||
raise Exception(
|
||||
_("The vector model of the associated knowledge base is inconsistent and the segmentation cannot be recalled."))
|
||||
if len(knowledge_list) == 0:
|
||||
def get_embedding_id(dataset_id_list):
|
||||
dataset_list = QuerySet(DataSet).filter(id__in=dataset_id_list)
|
||||
if len(set([dataset.embedding_mode_id for dataset in dataset_list])) > 1:
|
||||
raise Exception(_("The vector model of the associated knowledge base is inconsistent and the segmentation cannot be recalled."))
|
||||
if len(dataset_list) == 0:
|
||||
raise Exception(_("The knowledge base setting is wrong, please reset the knowledge base"))
|
||||
return knowledge_list[0].embedding_model_id
|
||||
return dataset_list[0].embedding_mode_id
|
||||
|
||||
|
||||
class BaseSearchDatasetStep(ISearchDatasetStep):
|
||||
|
||||
def execute(self, problem_text: str, knowledge_id_list: list[str], exclude_document_id_list: list[str],
|
||||
def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str],
|
||||
exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None,
|
||||
search_mode: str = None,
|
||||
workspace_id=None,
|
||||
manage=None,
|
||||
user_id=None,
|
||||
**kwargs) -> List[ParagraphPipelineModel]:
|
||||
get_knowledge_list_of_authorized = DatabaseModelManage.get_model('get_knowledge_list_of_authorized')
|
||||
chat_user_type = manage.context.get('chat_user_type')
|
||||
if get_knowledge_list_of_authorized is not None and RoleConstants.CHAT_USER.value.name == chat_user_type:
|
||||
knowledge_id_list = get_knowledge_list_of_authorized(manage.context.get('chat_user_id'),
|
||||
knowledge_id_list)
|
||||
if len(knowledge_id_list) == 0:
|
||||
if len(dataset_id_list) == 0:
|
||||
return []
|
||||
exec_problem_text = padding_problem_text if padding_problem_text is not None else problem_text
|
||||
model_id = get_embedding_id(knowledge_id_list)
|
||||
model = get_model_by_id(model_id, workspace_id)
|
||||
if model.model_type != "EMBEDDING":
|
||||
raise Exception(_("Model does not exist"))
|
||||
model_id = get_embedding_id(dataset_id_list)
|
||||
model = get_model_by_id(model_id, user_id)
|
||||
self.context['model_name'] = model.name
|
||||
default_params = get_model_default_params(model)
|
||||
embedding_model = ModelManage.get_model(model_id, lambda _id: get_model(model, **{**default_params}))
|
||||
embedding_model = ModelManage.get_model(model_id, lambda _id: get_model(model))
|
||||
embedding_value = embedding_model.embed_query(exec_problem_text)
|
||||
vector = VectorStore.get_embedding_vector()
|
||||
embedding_list = vector.query(exec_problem_text, embedding_value, knowledge_id_list, None, exclude_document_id_list,
|
||||
embedding_list = vector.query(exec_problem_text, embedding_value, dataset_id_list, exclude_document_id_list,
|
||||
exclude_paragraph_id_list, True, top_n, similarity, SearchMode(search_mode))
|
||||
if embedding_list is None:
|
||||
return []
|
||||
|
|
@ -86,12 +78,11 @@ class BaseSearchDatasetStep(ISearchDatasetStep):
|
|||
.add_paragraph(paragraph)
|
||||
.add_similarity(find_embedding.get('similarity'))
|
||||
.add_comprehensive_score(find_embedding.get('comprehensive_score'))
|
||||
.add_knowledge_name(paragraph.get('knowledge_name'))
|
||||
.add_knowledge_type(paragraph.get('knowledge_type'))
|
||||
.add_dataset_name(paragraph.get('dataset_name'))
|
||||
.add_document_name(paragraph.get('document_name'))
|
||||
.add_hit_handling_method(paragraph.get('hit_handling_method'))
|
||||
.add_directly_return_similarity(paragraph.get('directly_return_similarity'))
|
||||
.add_meta(reset_meta(paragraph.get('meta')))
|
||||
.add_meta(paragraph.get('meta'))
|
||||
.build())
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -111,7 +102,7 @@ class BaseSearchDatasetStep(ISearchDatasetStep):
|
|||
paragraph_list = native_search(QuerySet(Paragraph).filter(id__in=paragraph_id_list),
|
||||
get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", "application", 'sql',
|
||||
'list_knowledge_paragraph_by_paragraph_id.sql')),
|
||||
'list_dataset_paragraph_by_paragraph_id.sql')),
|
||||
with_table_name=True)
|
||||
# 如果向量库中存在脏数据 直接删除
|
||||
if len(paragraph_list) != len(paragraph_id_list):
|
||||
|
|
|
|||
|
|
@ -7,22 +7,6 @@
|
|||
@desc:
|
||||
"""
|
||||
|
||||
from typing import List, Dict
|
||||
|
||||
from django.db.models import QuerySet
|
||||
from django.utils.translation import gettext as _
|
||||
from rest_framework.exceptions import ErrorDetail, ValidationError
|
||||
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.utils.common import group_by
|
||||
from models_provider.models import Model
|
||||
from models_provider.tools import get_model_credential
|
||||
from tools.models.tool import Tool
|
||||
|
||||
end_nodes = ['ai-chat-node', 'reply-node', 'function-node', 'function-lib-node', 'application-node',
|
||||
'image-understand-node', 'speech-to-text-node', 'text-to-speech-node', 'image-generate-node',
|
||||
'variable-assign-node']
|
||||
|
||||
|
||||
class Answer:
|
||||
def __init__(self, content, view_type, runtime_node_id, chat_record_id, child_node, real_node_id,
|
||||
|
|
@ -58,207 +42,3 @@ class NodeChunk:
|
|||
|
||||
def is_end(self):
|
||||
return self.status == 200
|
||||
|
||||
|
||||
class Edge:
|
||||
def __init__(self, _id: str, _type: str, sourceNodeId: str, targetNodeId: str, **keywords):
|
||||
self.id = _id
|
||||
self.type = _type
|
||||
self.sourceNodeId = sourceNodeId
|
||||
self.targetNodeId = targetNodeId
|
||||
for keyword in keywords:
|
||||
self.__setattr__(keyword, keywords.get(keyword))
|
||||
|
||||
|
||||
class Node:
|
||||
def __init__(self, _id: str, _type: str, x: int, y: int, properties: dict, **kwargs):
|
||||
self.id = _id
|
||||
self.type = _type
|
||||
self.x = x
|
||||
self.y = y
|
||||
self.properties = properties
|
||||
for keyword in kwargs:
|
||||
self.__setattr__(keyword, kwargs.get(keyword))
|
||||
|
||||
|
||||
class EdgeNode:
|
||||
edge: Edge
|
||||
node: Node
|
||||
|
||||
def __init__(self, edge, node):
|
||||
self.edge = edge
|
||||
self.node = node
|
||||
|
||||
|
||||
class Workflow:
|
||||
"""
|
||||
节点列表
|
||||
"""
|
||||
nodes: List[Node]
|
||||
"""
|
||||
线列表
|
||||
"""
|
||||
edges: List[Edge]
|
||||
"""
|
||||
节点id:node
|
||||
"""
|
||||
node_map: Dict[str, Node]
|
||||
"""
|
||||
节点id:当前节点id上面的所有节点
|
||||
"""
|
||||
up_node_map: Dict[str, List[EdgeNode]]
|
||||
"""
|
||||
节点id:当前节点id下面的所有节点
|
||||
"""
|
||||
next_node_map: Dict[str, List[EdgeNode]]
|
||||
|
||||
def __init__(self, nodes: List[Node], edges: List[Edge]):
|
||||
self.nodes = nodes
|
||||
self.edges = edges
|
||||
self.node_map = {node.id: node for node in nodes}
|
||||
|
||||
self.up_node_map = {key: [EdgeNode(edge, self.node_map.get(edge.sourceNodeId)) for
|
||||
edge in edges] for
|
||||
key, edges in
|
||||
group_by(edges, key=lambda edge: edge.targetNodeId).items()}
|
||||
|
||||
self.next_node_map = {key: [EdgeNode(edge, self.node_map.get(edge.targetNodeId)) for edge in edges] for
|
||||
key, edges in
|
||||
group_by(edges, key=lambda edge: edge.sourceNodeId).items()}
|
||||
|
||||
def get_node(self, node_id):
|
||||
"""
|
||||
根据node_id 获取节点信息
|
||||
@param node_id: node_id
|
||||
@return: 节点信息
|
||||
"""
|
||||
return self.node_map.get(node_id)
|
||||
|
||||
def get_up_edge_nodes(self, node_id) -> List[EdgeNode]:
|
||||
"""
|
||||
根据节点id 获取当前连接前置节点和连线
|
||||
@param node_id: 节点id
|
||||
@return: 节点连线列表
|
||||
"""
|
||||
return self.up_node_map.get(node_id)
|
||||
|
||||
def get_next_edge_nodes(self, node_id) -> List[EdgeNode]:
|
||||
"""
|
||||
根据节点id 获取当前连接目标节点和连线
|
||||
@param node_id: 节点id
|
||||
@return: 节点连线列表
|
||||
"""
|
||||
return self.next_node_map.get(node_id)
|
||||
|
||||
def get_up_nodes(self, node_id) -> List[Node]:
|
||||
"""
|
||||
根据节点id 获取当前连接前置节点
|
||||
@param node_id: 节点id
|
||||
@return: 节点列表
|
||||
"""
|
||||
return [en.node for en in self.up_node_map.get(node_id)]
|
||||
|
||||
def get_next_nodes(self, node_id) -> List[Node]:
|
||||
"""
|
||||
根据节点id 获取当前连接目标节点
|
||||
@param node_id: 节点id
|
||||
@return: 节点列表
|
||||
"""
|
||||
return [en.node for en in self.next_node_map.get(node_id, [])]
|
||||
|
||||
@staticmethod
|
||||
def new_instance(flow_obj: Dict):
|
||||
nodes = flow_obj.get('nodes')
|
||||
edges = flow_obj.get('edges')
|
||||
nodes = [Node(node.get('id'), node.get('type'), **node)
|
||||
for node in nodes]
|
||||
edges = [Edge(edge.get('id'), edge.get('type'), **edge) for edge in edges]
|
||||
return Workflow(nodes, edges)
|
||||
|
||||
def get_start_node(self):
|
||||
return self.get_node('start-node')
|
||||
|
||||
def get_search_node(self):
|
||||
return [node for node in self.nodes if node.type == 'search-dataset-node']
|
||||
|
||||
def is_valid(self):
|
||||
"""
|
||||
校验工作流数据
|
||||
"""
|
||||
self.is_valid_model_params()
|
||||
self.is_valid_start_node()
|
||||
self.is_valid_base_node()
|
||||
self.is_valid_work_flow()
|
||||
|
||||
@staticmethod
|
||||
def is_valid_node_params(node: Node):
|
||||
from application.flow.step_node import get_node
|
||||
get_node(node.type)(node, None, None)
|
||||
|
||||
def is_valid_node(self, node: Node):
|
||||
self.is_valid_node_params(node)
|
||||
if node.type == 'condition-node':
|
||||
branch_list = node.properties.get('node_data').get('branch')
|
||||
for branch in branch_list:
|
||||
source_anchor_id = f"{node.id}_{branch.get('id')}_right"
|
||||
edge_list = [edge for edge in self.edges if edge.sourceAnchorId == source_anchor_id]
|
||||
if len(edge_list) == 0:
|
||||
raise AppApiException(500,
|
||||
_('The branch {branch} of the {node} node needs to be connected').format(
|
||||
node=node.properties.get("stepName"), branch=branch.get("type")))
|
||||
|
||||
else:
|
||||
edge_list = [edge for edge in self.edges if edge.sourceNodeId == node.id]
|
||||
if len(edge_list) == 0 and not end_nodes.__contains__(node.type):
|
||||
raise AppApiException(500, _("{node} Nodes cannot be considered as end nodes").format(
|
||||
node=node.properties.get("stepName")))
|
||||
|
||||
def is_valid_work_flow(self, up_node=None):
|
||||
if up_node is None:
|
||||
up_node = self.get_start_node()
|
||||
self.is_valid_node(up_node)
|
||||
next_nodes = self.get_next_nodes(up_node)
|
||||
for next_node in next_nodes:
|
||||
self.is_valid_work_flow(next_node)
|
||||
|
||||
def is_valid_start_node(self):
|
||||
start_node_list = [node for node in self.nodes if node.id == 'start-node']
|
||||
if len(start_node_list) == 0:
|
||||
raise AppApiException(500, _('The starting node is required'))
|
||||
if len(start_node_list) > 1:
|
||||
raise AppApiException(500, _('There can only be one starting node'))
|
||||
|
||||
def is_valid_model_params(self):
|
||||
node_list = [node for node in self.nodes if (node.type == 'ai-chat-node' or node.type == 'question-node')]
|
||||
for node in node_list:
|
||||
model = QuerySet(Model).filter(id=node.properties.get('node_data', {}).get('model_id')).first()
|
||||
if model is None:
|
||||
raise ValidationError(ErrorDetail(
|
||||
_('The node {node} model does not exist').format(node=node.properties.get("stepName"))))
|
||||
credential = get_model_credential(model.provider, model.model_type, model.model_name)
|
||||
model_params_setting = node.properties.get('node_data', {}).get('model_params_setting')
|
||||
model_params_setting_form = credential.get_model_params_setting_form(
|
||||
model.model_name)
|
||||
if model_params_setting is None:
|
||||
model_params_setting = model_params_setting_form.get_default_form_data()
|
||||
node.properties.get('node_data', {})['model_params_setting'] = model_params_setting
|
||||
if node.properties.get('status', 200) != 200:
|
||||
raise ValidationError(
|
||||
ErrorDetail(_("Node {node} is unavailable").format(node.properties.get("stepName"))))
|
||||
node_list = [node for node in self.nodes if (node.type == 'function-lib-node')]
|
||||
for node in node_list:
|
||||
function_lib_id = node.properties.get('node_data', {}).get('function_lib_id')
|
||||
if function_lib_id is None:
|
||||
raise ValidationError(ErrorDetail(
|
||||
_('The library ID of node {node} cannot be empty').format(node=node.properties.get("stepName"))))
|
||||
f_lib = QuerySet(Tool).filter(id=function_lib_id).first()
|
||||
if f_lib is None:
|
||||
raise ValidationError(ErrorDetail(_("The function library for node {node} is not available").format(
|
||||
node=node.properties.get("stepName"))))
|
||||
|
||||
def is_valid_base_node(self):
|
||||
base_node_list = [node for node in self.nodes if node.id == 'base-node']
|
||||
if len(base_node_list) == 0:
|
||||
raise AppApiException(500, _('Basic information node is required'))
|
||||
if len(base_node_list) > 1:
|
||||
raise AppApiException(500, _('There can only be one basic information node'))
|
||||
|
|
|
|||
|
|
@ -1,22 +0,0 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎虎
|
||||
@file: start_with.py
|
||||
@date:2025/10/20 10:37
|
||||
@desc:
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
from application.flow.compare import Compare
|
||||
|
||||
|
||||
class EndWithCompare(Compare):
|
||||
|
||||
def support(self, node_id, fields: List[str], source_value, compare, target_value):
|
||||
if compare == 'end_with':
|
||||
return True
|
||||
|
||||
def compare(self, source_value, compare, target_value):
|
||||
source_value = str(source_value)
|
||||
return source_value.endswith(str(target_value))
|
||||
|
|
@ -1,22 +0,0 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎虎
|
||||
@file: start_with.py
|
||||
@date:2025/10/20 10:37
|
||||
@desc:
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
from application.flow.compare import Compare
|
||||
|
||||
|
||||
class StartWithCompare(Compare):
|
||||
|
||||
def support(self, node_id, fields: List[str], source_value, compare, target_value):
|
||||
if compare == 'start_with':
|
||||
return True
|
||||
|
||||
def compare(self, source_value, compare, target_value):
|
||||
source_value = str(source_value)
|
||||
return source_value.startswith(str(target_value))
|
||||
|
|
@ -18,11 +18,13 @@ from rest_framework import serializers
|
|||
from rest_framework.exceptions import ValidationError, ErrorDetail
|
||||
|
||||
from application.flow.common import Answer, NodeChunk
|
||||
from application.models import ApplicationChatUserStats
|
||||
from application.models import ChatRecord, ChatUserType
|
||||
from application.models import ChatRecord
|
||||
from application.models.api_key_model import ApplicationPublicAccessClient
|
||||
from common.constants.authentication_type import AuthenticationType
|
||||
from common.field.common import InstanceField
|
||||
from common.util.field_message import ErrMessage
|
||||
|
||||
chat_cache = cache
|
||||
chat_cache = cache.caches['chat_cache']
|
||||
|
||||
|
||||
def write_context(step_variable: Dict, global_variable: Dict, node, workflow):
|
||||
|
|
@ -44,14 +46,16 @@ def is_interrupt(node, step_variable: Dict, global_variable: Dict):
|
|||
|
||||
|
||||
class WorkFlowPostHandler:
|
||||
def __init__(self, chat_info):
|
||||
def __init__(self, chat_info, client_id, client_type):
|
||||
self.chat_info = chat_info
|
||||
self.client_id = client_id
|
||||
self.client_type = client_type
|
||||
|
||||
def handler(self, workflow):
|
||||
workflow_body = workflow.get_body()
|
||||
question = workflow_body.get('question')
|
||||
chat_record_id = workflow_body.get('chat_record_id')
|
||||
chat_id = workflow_body.get('chat_id')
|
||||
def handler(self, chat_id,
|
||||
chat_record_id,
|
||||
answer,
|
||||
workflow):
|
||||
question = workflow.params['question']
|
||||
details = workflow.get_runtime_details()
|
||||
message_tokens = sum([row.get('message_tokens') for row in details.values() if
|
||||
'message_tokens' in row and row.get('message_tokens') is not None])
|
||||
|
|
@ -80,16 +84,15 @@ class WorkFlowPostHandler:
|
|||
answer_text_list=answer_text_list,
|
||||
run_time=time.time() - workflow.context['start_time'],
|
||||
index=0)
|
||||
|
||||
self.chat_info.append_chat_record(chat_record)
|
||||
self.chat_info.set_cache()
|
||||
|
||||
if not self.chat_info.debug and [ChatUserType.ANONYMOUS_USER.value, ChatUserType.CHAT_USER.value].__contains__(
|
||||
workflow_body.get('chat_user_type')):
|
||||
application_public_access_client = (QuerySet(ApplicationChatUserStats)
|
||||
.filter(chat_user_id=workflow_body.get('chat_user_id'),
|
||||
chat_user_type=workflow_body.get('chat_user_type'),
|
||||
application_id=self.chat_info.application_id).first())
|
||||
asker = workflow.context.get('asker', None)
|
||||
self.chat_info.append_chat_record(chat_record, self.client_id, asker)
|
||||
# 重新设置缓存
|
||||
chat_cache.set(chat_id,
|
||||
self.chat_info, timeout=60 * 30)
|
||||
if self.client_type == AuthenticationType.APPLICATION_ACCESS_TOKEN.value:
|
||||
application_public_access_client = (QuerySet(ApplicationPublicAccessClient)
|
||||
.filter(client_id=self.client_id,
|
||||
application_id=self.chat_info.application.id).first())
|
||||
if application_public_access_client is not None:
|
||||
application_public_access_client.access_num = application_public_access_client.access_num + 1
|
||||
application_public_access_client.intraday_access_num = application_public_access_client.intraday_access_num + 1
|
||||
|
|
@ -120,36 +123,31 @@ class NodeResult:
|
|||
|
||||
|
||||
class ReferenceAddressSerializer(serializers.Serializer):
|
||||
node_id = serializers.CharField(required=True, label="节点id")
|
||||
node_id = serializers.CharField(required=True, error_messages=ErrMessage.char("节点id"))
|
||||
fields = serializers.ListField(
|
||||
child=serializers.CharField(required=True, label="节点字段"), required=True,
|
||||
label="节点字段数组")
|
||||
child=serializers.CharField(required=True, error_messages=ErrMessage.char("节点字段")), required=True,
|
||||
error_messages=ErrMessage.list("节点字段数组"))
|
||||
|
||||
|
||||
class FlowParamsSerializer(serializers.Serializer):
|
||||
# 历史对答
|
||||
history_chat_record = serializers.ListField(child=InstanceField(model_type=ChatRecord, required=True),
|
||||
label="历史对答")
|
||||
error_messages=ErrMessage.list("历史对答"))
|
||||
|
||||
question = serializers.CharField(required=True, label="用户问题")
|
||||
question = serializers.CharField(required=True, error_messages=ErrMessage.list("用户问题"))
|
||||
|
||||
chat_id = serializers.CharField(required=True, label="对话id")
|
||||
chat_id = serializers.CharField(required=True, error_messages=ErrMessage.list("对话id"))
|
||||
|
||||
chat_record_id = serializers.CharField(required=True, label="对话记录id")
|
||||
chat_record_id = serializers.CharField(required=True, error_messages=ErrMessage.char("对话记录id"))
|
||||
|
||||
stream = serializers.BooleanField(required=True, label="流式输出")
|
||||
stream = serializers.BooleanField(required=True, error_messages=ErrMessage.boolean("流式输出"))
|
||||
|
||||
chat_user_id = serializers.CharField(required=False, label="对话用户id")
|
||||
client_id = serializers.CharField(required=False, error_messages=ErrMessage.char("客户端id"))
|
||||
|
||||
chat_user_type = serializers.CharField(required=False, label="对话用户类型")
|
||||
client_type = serializers.CharField(required=False, error_messages=ErrMessage.char("客户端类型"))
|
||||
|
||||
workspace_id = serializers.CharField(required=True, label="工作空间id")
|
||||
|
||||
application_id = serializers.CharField(required=True, label="应用id")
|
||||
|
||||
re_chat = serializers.BooleanField(required=True, label="换个答案")
|
||||
|
||||
debug = serializers.BooleanField(required=True, label="是否debug")
|
||||
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))
|
||||
re_chat = serializers.BooleanField(required=True, error_messages=ErrMessage.boolean("换个答案"))
|
||||
|
||||
|
||||
class INode:
|
||||
|
|
@ -168,7 +166,7 @@ class INode:
|
|||
self.runtime_node_id, self.context.get('reasoning_content', '') if reasoning_content_enable else '')]
|
||||
|
||||
def __init__(self, node, workflow_params, workflow_manage, up_node_id_list=None,
|
||||
get_node_params=lambda node: node.properties.get('node_data'), salt=None):
|
||||
get_node_params=lambda node: node.properties.get('node_data')):
|
||||
# 当前步骤上下文,用于存储当前步骤信息
|
||||
self.status = 200
|
||||
self.err_message = ''
|
||||
|
|
@ -188,8 +186,7 @@ class INode:
|
|||
self.runtime_node_id = sha1(uuid.NAMESPACE_DNS.bytes + bytes(str(uuid.uuid5(uuid.NAMESPACE_DNS,
|
||||
"".join([*sorted(up_node_id_list),
|
||||
node.id]))),
|
||||
"utf-8")).hexdigest() + (
|
||||
"__" + str(salt) if salt is not None else '')
|
||||
"utf-8")).hexdigest()
|
||||
|
||||
def valid_args(self, node_params, flow_params):
|
||||
flow_params_serializer_class = self.get_flow_params_serializer_class()
|
||||
|
|
|
|||
|
|
@ -1,193 +0,0 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: workflow_manage.py
|
||||
@date:2024/1/9 17:40
|
||||
@desc:
|
||||
"""
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import List
|
||||
|
||||
from django.db import close_old_connections
|
||||
from django.utils.translation import get_language
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
|
||||
from application.flow.common import Workflow
|
||||
from application.flow.i_step_node import WorkFlowPostHandler, INode
|
||||
from application.flow.step_node import get_node
|
||||
from application.flow.workflow_manage import WorkflowManage
|
||||
from common.handle.base_to_response import BaseToResponse
|
||||
from common.handle.impl.response.system_to_response import SystemToResponse
|
||||
|
||||
executor = ThreadPoolExecutor(max_workers=200)
|
||||
|
||||
|
||||
class NodeResultFuture:
|
||||
def __init__(self, r, e, status=200):
|
||||
self.r = r
|
||||
self.e = e
|
||||
self.status = status
|
||||
|
||||
def result(self):
|
||||
if self.status == 200:
|
||||
return self.r
|
||||
else:
|
||||
raise self.e
|
||||
|
||||
|
||||
def await_result(result, timeout=1):
|
||||
try:
|
||||
result.result(timeout)
|
||||
return False
|
||||
except Exception as e:
|
||||
return True
|
||||
|
||||
|
||||
class NodeChunkManage:
|
||||
|
||||
def __init__(self, work_flow):
|
||||
self.node_chunk_list = []
|
||||
self.current_node_chunk = None
|
||||
self.work_flow = work_flow
|
||||
|
||||
def add_node_chunk(self, node_chunk):
|
||||
self.node_chunk_list.append(node_chunk)
|
||||
|
||||
def contains(self, node_chunk):
|
||||
return self.node_chunk_list.__contains__(node_chunk)
|
||||
|
||||
def pop(self):
|
||||
if self.current_node_chunk is None:
|
||||
try:
|
||||
current_node_chunk = self.node_chunk_list.pop(0)
|
||||
self.current_node_chunk = current_node_chunk
|
||||
except IndexError as e:
|
||||
pass
|
||||
if self.current_node_chunk is not None:
|
||||
try:
|
||||
chunk = self.current_node_chunk.chunk_list.pop(0)
|
||||
return chunk
|
||||
except IndexError as e:
|
||||
if self.current_node_chunk.is_end():
|
||||
self.current_node_chunk = None
|
||||
if self.work_flow.answer_is_not_empty():
|
||||
chunk = self.work_flow.base_to_response.to_stream_chunk_response(
|
||||
self.work_flow.params['chat_id'],
|
||||
self.work_flow.params['chat_record_id'],
|
||||
'\n\n', False, 0, 0)
|
||||
self.work_flow.append_answer('\n\n')
|
||||
return chunk
|
||||
return self.pop()
|
||||
return None
|
||||
|
||||
|
||||
class LoopWorkflowManage(WorkflowManage):
|
||||
|
||||
def __init__(self, flow: Workflow,
|
||||
params,
|
||||
work_flow_post_handler: WorkFlowPostHandler,
|
||||
parentWorkflowManage,
|
||||
loop_params,
|
||||
get_loop_context,
|
||||
base_to_response: BaseToResponse = SystemToResponse(),
|
||||
start_node_id=None,
|
||||
start_node_data=None, chat_record=None, child_node=None):
|
||||
self.parentWorkflowManage = parentWorkflowManage
|
||||
self.loop_params = loop_params
|
||||
self.get_loop_context = get_loop_context
|
||||
self.loop_field_list = []
|
||||
super().__init__(flow, params, work_flow_post_handler, base_to_response, None, None, None,
|
||||
None,
|
||||
None, None, start_node_id, start_node_data, chat_record, child_node)
|
||||
|
||||
def get_node_cls_by_id(self, node_id, up_node_id_list=None,
|
||||
get_node_params=lambda node: node.properties.get('node_data')):
|
||||
for node in self.flow.nodes:
|
||||
if node.id == node_id:
|
||||
node_instance = get_node(node.type)(node,
|
||||
self.params, self, up_node_id_list,
|
||||
get_node_params,
|
||||
salt=self.get_index())
|
||||
return node_instance
|
||||
return None
|
||||
|
||||
def stream(self):
|
||||
close_old_connections()
|
||||
language = get_language()
|
||||
self.run_chain_async(self.start_node, None, language)
|
||||
return self.await_result()
|
||||
|
||||
def get_index(self):
|
||||
return self.loop_params.get('index')
|
||||
|
||||
def get_start_node(self):
|
||||
start_node_list = [node for node in self.flow.nodes if
|
||||
['loop-start-node'].__contains__(node.type)]
|
||||
return start_node_list[0]
|
||||
|
||||
def get_reference_field(self, node_id: str, fields: List[str]):
|
||||
"""
|
||||
@param node_id: 节点id
|
||||
@param fields: 字段
|
||||
@return:
|
||||
"""
|
||||
if node_id == 'global':
|
||||
return self.parentWorkflowManage.get_reference_field(node_id, fields)
|
||||
elif node_id == 'chat':
|
||||
return self.parentWorkflowManage.get_reference_field(node_id, fields)
|
||||
elif node_id == 'loop':
|
||||
loop_context = self.get_loop_context()
|
||||
return INode.get_field(loop_context, fields)
|
||||
else:
|
||||
node = self.get_node_by_id(node_id)
|
||||
if node:
|
||||
return node.get_reference_field(fields)
|
||||
return self.parentWorkflowManage.get_reference_field(node_id, fields)
|
||||
|
||||
def get_workflow_content(self):
|
||||
context = {
|
||||
'global': self.context,
|
||||
'chat': self.chat_context,
|
||||
'loop': self.get_loop_context(),
|
||||
}
|
||||
|
||||
for node in self.node_context:
|
||||
context[node.id] = node.context
|
||||
return context
|
||||
|
||||
def init_fields(self):
|
||||
super().init_fields()
|
||||
loop_field_list = []
|
||||
loop_start_node = self.flow.get_node('loop-start-node')
|
||||
loop_input_field_list = loop_start_node.properties.get('loop_input_field_list')
|
||||
node_name = loop_start_node.properties.get('stepName')
|
||||
node_id = loop_start_node.id
|
||||
if loop_input_field_list is not None:
|
||||
for f in loop_input_field_list:
|
||||
loop_field_list.append(
|
||||
{'label': f.get('label'), 'value': f.get('field'), 'node_id': node_id, 'node_name': node_name})
|
||||
self.loop_field_list = loop_field_list
|
||||
|
||||
def reset_prompt(self, prompt: str):
|
||||
prompt = super().reset_prompt(prompt)
|
||||
for field in self.loop_field_list:
|
||||
chatLabel = f"loop.{field.get('value')}"
|
||||
chatValue = f"context.get('loop').get('{field.get('value', '')}','')"
|
||||
prompt = prompt.replace(chatLabel, chatValue)
|
||||
|
||||
prompt = self.parentWorkflowManage.reset_prompt(prompt)
|
||||
return prompt
|
||||
|
||||
def generate_prompt(self, prompt: str):
|
||||
"""
|
||||
格式化生成提示词
|
||||
@param prompt: 提示词信息
|
||||
@return: 格式化后的提示词
|
||||
"""
|
||||
|
||||
context = {**self.get_workflow_content(), **self.parentWorkflowManage.get_workflow_content()}
|
||||
prompt = self.reset_prompt(prompt)
|
||||
prompt_template = PromptTemplate.from_template(prompt, template_format='jinja2')
|
||||
value = prompt_template.format(context=context)
|
||||
return value
|
||||
|
|
@ -10,43 +10,29 @@ from .ai_chat_step_node import *
|
|||
from .application_node import BaseApplicationNode
|
||||
from .condition_node import *
|
||||
from .direct_reply_node import *
|
||||
from .document_extract_node import *
|
||||
from .form_node import *
|
||||
from .image_generate_step_node import *
|
||||
from .image_to_video_step_node import BaseImageToVideoNode
|
||||
from .image_understand_step_node import *
|
||||
from .intent_node import *
|
||||
from .loop_break_node import BaseLoopBreakNode
|
||||
from .loop_continue_node import BaseLoopContinueNode
|
||||
from .loop_node import *
|
||||
from .loop_start_node import *
|
||||
from .mcp_node import BaseMcpNode
|
||||
from .parameter_extraction_node import BaseParameterExtractionNode
|
||||
from .function_lib_node import *
|
||||
from .function_node import *
|
||||
from .question_node import *
|
||||
from .reranker_node import *
|
||||
from .search_document_node import BaseSearchDocumentNode
|
||||
from .search_knowledge_node import *
|
||||
|
||||
from .document_extract_node import *
|
||||
from .image_understand_step_node import *
|
||||
from .image_generate_step_node import *
|
||||
|
||||
from .search_dataset_node import *
|
||||
from .speech_to_text_step_node import BaseSpeechToTextNode
|
||||
from .start_node import *
|
||||
from .text_to_speech_step_node.impl.base_text_to_speech_node import BaseTextToSpeechNode
|
||||
from .text_to_video_step_node.impl.base_text_to_video_node import BaseTextToVideoNode
|
||||
from .tool_lib_node import *
|
||||
from .tool_node import *
|
||||
from .variable_aggregation_node.impl.base_variable_aggregation_node import BaseVariableAggregationNode
|
||||
from .variable_assign_node import BaseVariableAssignNode
|
||||
from .variable_splitting_node import BaseVariableSplittingNode
|
||||
from .video_understand_step_node import BaseVideoUnderstandNode
|
||||
from .mcp_node import BaseMcpNode
|
||||
|
||||
node_list = [BaseStartStepNode, BaseChatNode, BaseSearchKnowledgeNode, BaseSearchDocumentNode, BaseQuestionNode,
|
||||
node_list = [BaseStartStepNode, BaseChatNode, BaseSearchDatasetNode, BaseQuestionNode,
|
||||
BaseConditionNode, BaseReplyNode,
|
||||
BaseToolNodeNode, BaseToolLibNodeNode, BaseRerankerNode, BaseApplicationNode,
|
||||
BaseFunctionNodeNode, BaseFunctionLibNodeNode, BaseRerankerNode, BaseApplicationNode,
|
||||
BaseDocumentExtractNode,
|
||||
BaseImageUnderstandNode, BaseFormNode, BaseSpeechToTextNode, BaseTextToSpeechNode,
|
||||
BaseImageGenerateNode, BaseVariableAssignNode, BaseMcpNode, BaseTextToVideoNode, BaseImageToVideoNode,
|
||||
BaseVideoUnderstandNode,
|
||||
BaseIntentNode, BaseLoopNode, BaseLoopStartStepNode,
|
||||
BaseLoopContinueNode,
|
||||
BaseLoopBreakNode, BaseVariableSplittingNode, BaseParameterExtractionNode, BaseVariableAggregationNode]
|
||||
BaseImageGenerateNode, BaseVariableAssignNode, BaseMcpNode]
|
||||
|
||||
|
||||
def get_node(node_type):
|
||||
|
|
|
|||
|
|
@ -12,35 +12,31 @@ from django.utils.translation import gettext_lazy as _
|
|||
from rest_framework import serializers
|
||||
|
||||
from application.flow.i_step_node import INode, NodeResult
|
||||
from common.util.field_message import ErrMessage
|
||||
|
||||
|
||||
class ChatNodeSerializer(serializers.Serializer):
|
||||
model_id = serializers.CharField(required=True, label=_("Model id"))
|
||||
model_id = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Model id")))
|
||||
system = serializers.CharField(required=False, allow_blank=True, allow_null=True,
|
||||
label=_("Role Setting"))
|
||||
prompt = serializers.CharField(required=True, label=_("Prompt word"))
|
||||
error_messages=ErrMessage.char(_("Role Setting")))
|
||||
prompt = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Prompt word")))
|
||||
# 多轮对话数量
|
||||
dialogue_number = serializers.IntegerField(required=True, label=_("Number of multi-round conversations"))
|
||||
dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer(
|
||||
_("Number of multi-round conversations")))
|
||||
|
||||
is_result = serializers.BooleanField(required=False,
|
||||
label=_('Whether to return content'))
|
||||
error_messages=ErrMessage.boolean(_('Whether to return content')))
|
||||
|
||||
model_params_setting = serializers.DictField(required=False,
|
||||
label=_("Model parameter settings"))
|
||||
error_messages=ErrMessage.dict(_("Model parameter settings")))
|
||||
model_setting = serializers.DictField(required=False,
|
||||
label='Model settings')
|
||||
error_messages=ErrMessage.dict('Model settings'))
|
||||
dialogue_type = serializers.CharField(required=False, allow_blank=True, allow_null=True,
|
||||
label=_("Context Type"))
|
||||
mcp_enable = serializers.BooleanField(required=False, label=_("Whether to enable MCP"))
|
||||
mcp_servers = serializers.JSONField(required=False, label=_("MCP Server"))
|
||||
mcp_tool_id = serializers.CharField(required=False, allow_blank=True, allow_null=True, label=_("MCP Tool ID"))
|
||||
mcp_tool_ids = serializers.ListField(child=serializers.UUIDField(), required=False, allow_empty=True, label=_("MCP Tool IDs"), )
|
||||
mcp_source = serializers.CharField(required=False, allow_blank=True, allow_null=True, label=_("MCP Source"))
|
||||
error_messages=ErrMessage.char(_("Context Type")))
|
||||
mcp_enable = serializers.BooleanField(required=False,
|
||||
error_messages=ErrMessage.boolean(_("Whether to enable MCP")))
|
||||
mcp_servers = serializers.JSONField(required=False, error_messages=ErrMessage.list(_("MCP Server")))
|
||||
|
||||
tool_enable = serializers.BooleanField(required=False, default=False, label=_("Whether to enable tools"))
|
||||
tool_ids = serializers.ListField(child=serializers.UUIDField(), required=False, allow_empty=True,
|
||||
label=_("Tool IDs"), )
|
||||
mcp_output_enable = serializers.BooleanField(required=False, default=True, label=_("Whether to enable MCP output"))
|
||||
|
||||
class IChatNode(INode):
|
||||
type = 'ai-chat-node'
|
||||
|
|
@ -58,11 +54,5 @@ class IChatNode(INode):
|
|||
model_setting=None,
|
||||
mcp_enable=False,
|
||||
mcp_servers=None,
|
||||
mcp_tool_id=None,
|
||||
mcp_tool_ids=None,
|
||||
mcp_source=None,
|
||||
tool_enable=False,
|
||||
tool_ids=None,
|
||||
mcp_output_enable=True,
|
||||
**kwargs) -> NodeResult:
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -6,26 +6,39 @@
|
|||
@date:2024/6/4 14:30
|
||||
@desc:
|
||||
"""
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from functools import reduce
|
||||
from types import AsyncGeneratorType
|
||||
from typing import List, Dict
|
||||
|
||||
from django.db.models import QuerySet
|
||||
from langchain.schema import HumanMessage, SystemMessage
|
||||
from langchain_core.messages import BaseMessage, AIMessage
|
||||
from langchain_core.messages import BaseMessage, AIMessage, AIMessageChunk, ToolMessage
|
||||
from langchain_mcp_adapters.client import MultiServerMCPClient
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
|
||||
from application.flow.i_step_node import NodeResult, INode
|
||||
from application.flow.step_node.ai_chat_step_node.i_chat_node import IChatNode
|
||||
from application.flow.tools import Reasoning, mcp_response_generator
|
||||
from common.utils.rsa_util import rsa_long_decrypt
|
||||
from common.utils.tool_code import ToolExecutor
|
||||
from maxkb.const import CONFIG
|
||||
from models_provider.models import Model
|
||||
from models_provider.tools import get_model_credential, get_model_instance_by_model_workspace_id
|
||||
from tools.models import Tool
|
||||
from application.flow.tools import Reasoning
|
||||
from setting.models import Model
|
||||
from setting.models_provider import get_model_credential
|
||||
from setting.models_provider.tools import get_model_instance_by_model_user_id
|
||||
|
||||
tool_message_template = """
|
||||
<details>
|
||||
<summary>
|
||||
<strong>Called MCP Tool: <em>%s</em></strong>
|
||||
</summary>
|
||||
|
||||
```json
|
||||
%s
|
||||
```
|
||||
</details>
|
||||
|
||||
"""
|
||||
|
||||
|
||||
def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str,
|
||||
|
|
@ -90,6 +103,39 @@ def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INo
|
|||
_write_context(node_variable, workflow_variable, node, workflow, answer, reasoning_content)
|
||||
|
||||
|
||||
async def _yield_mcp_response(chat_model, message_list, mcp_servers):
|
||||
async with MultiServerMCPClient(json.loads(mcp_servers)) as client:
|
||||
agent = create_react_agent(chat_model, client.get_tools())
|
||||
response = agent.astream({"messages": message_list}, stream_mode='messages')
|
||||
async for chunk in response:
|
||||
if isinstance(chunk[0], ToolMessage):
|
||||
content = tool_message_template % (chunk[0].name, chunk[0].content)
|
||||
chunk[0].content = content
|
||||
yield chunk[0]
|
||||
if isinstance(chunk[0], AIMessageChunk):
|
||||
yield chunk[0]
|
||||
|
||||
|
||||
def mcp_response_generator(chat_model, message_list, mcp_servers):
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
async_gen = _yield_mcp_response(chat_model, message_list, mcp_servers)
|
||||
while True:
|
||||
try:
|
||||
chunk = loop.run_until_complete(anext_async(async_gen))
|
||||
yield chunk
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
except Exception as e:
|
||||
print(f'exception: {e}')
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
async def anext_async(agen):
|
||||
return await agen.__anext__()
|
||||
|
||||
|
||||
def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
|
||||
"""
|
||||
写入上下文数据
|
||||
|
|
@ -106,9 +152,8 @@ def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wor
|
|||
reasoning_result = reasoning.get_reasoning_content(response)
|
||||
reasoning_result_end = reasoning.get_end_reasoning_content()
|
||||
content = reasoning_result.get('content') + reasoning_result_end.get('content')
|
||||
meta = {**response.response_metadata, **response.additional_kwargs}
|
||||
if 'reasoning_content' in meta:
|
||||
reasoning_content = meta.get('reasoning_content', '')
|
||||
if 'reasoning_content' in response.response_metadata:
|
||||
reasoning_content = response.response_metadata.get('reasoning_content', '')
|
||||
else:
|
||||
reasoning_content = reasoning_result.get('reasoning_content') + reasoning_result_end.get('reasoning_content')
|
||||
_write_context(node_variable, workflow_variable, node, workflow, content, reasoning_content)
|
||||
|
|
@ -152,12 +197,6 @@ class BaseChatNode(IChatNode):
|
|||
model_setting=None,
|
||||
mcp_enable=False,
|
||||
mcp_servers=None,
|
||||
mcp_tool_id=None,
|
||||
mcp_tool_ids=None,
|
||||
mcp_source=None,
|
||||
tool_enable=False,
|
||||
tool_ids=None,
|
||||
mcp_output_enable=True,
|
||||
**kwargs) -> NodeResult:
|
||||
if dialogue_type is None:
|
||||
dialogue_type = 'WORKFLOW'
|
||||
|
|
@ -168,9 +207,8 @@ class BaseChatNode(IChatNode):
|
|||
model_setting = {'reasoning_content_enable': False, 'reasoning_content_end': '</think>',
|
||||
'reasoning_content_start': '<think>'}
|
||||
self.context['model_setting'] = model_setting
|
||||
workspace_id = self.workflow_manage.get_body().get('workspace_id')
|
||||
chat_model = get_model_instance_by_model_workspace_id(model_id, workspace_id,
|
||||
**model_params_setting)
|
||||
chat_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'),
|
||||
**model_params_setting)
|
||||
history_message = self.get_history_message(history_chat_record, dialogue_number, dialogue_type,
|
||||
self.runtime_node_id)
|
||||
self.context['history_message'] = history_message
|
||||
|
|
@ -181,100 +219,24 @@ class BaseChatNode(IChatNode):
|
|||
message_list = self.generate_message_list(system, prompt, history_message)
|
||||
self.context['message_list'] = message_list
|
||||
|
||||
# 处理 MCP 请求
|
||||
mcp_result = self._handle_mcp_request(
|
||||
mcp_enable, tool_enable, mcp_source, mcp_servers, mcp_tool_id, mcp_tool_ids, tool_ids, mcp_output_enable,
|
||||
chat_model, message_list, history_message, question
|
||||
)
|
||||
if mcp_result:
|
||||
return mcp_result
|
||||
if mcp_enable and mcp_servers is not None and '"stdio"' not in mcp_servers:
|
||||
r = mcp_response_generator(chat_model, message_list, mcp_servers)
|
||||
return NodeResult(
|
||||
{'result': r, 'chat_model': chat_model, 'message_list': message_list,
|
||||
'history_message': history_message, 'question': question.content}, {},
|
||||
_write_context=write_context_stream)
|
||||
|
||||
if stream:
|
||||
r = chat_model.stream(message_list)
|
||||
return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list,
|
||||
'history_message': [{'content': message.content, 'role': message.type} for message in
|
||||
(history_message if history_message is not None else [])],
|
||||
'question': question.content}, {},
|
||||
'history_message': history_message, 'question': question.content}, {},
|
||||
_write_context=write_context_stream)
|
||||
else:
|
||||
r = chat_model.invoke(message_list)
|
||||
return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list,
|
||||
'history_message': [{'content': message.content, 'role': message.type} for message in
|
||||
(history_message if history_message is not None else [])],
|
||||
'question': question.content}, {},
|
||||
'history_message': history_message, 'question': question.content}, {},
|
||||
_write_context=write_context)
|
||||
|
||||
def _handle_mcp_request(self, mcp_enable, tool_enable, mcp_source, mcp_servers, mcp_tool_id, mcp_tool_ids, tool_ids,
|
||||
mcp_output_enable, chat_model, message_list, history_message, question):
|
||||
if not mcp_enable and not tool_enable:
|
||||
return None
|
||||
|
||||
mcp_servers_config = {}
|
||||
|
||||
# 迁移过来mcp_source是None
|
||||
if mcp_source is None:
|
||||
mcp_source = 'custom'
|
||||
if mcp_enable:
|
||||
# 兼容老数据
|
||||
if not mcp_tool_ids:
|
||||
mcp_tool_ids = []
|
||||
if mcp_tool_id:
|
||||
mcp_tool_ids = list(set(mcp_tool_ids + [mcp_tool_id]))
|
||||
if mcp_source == 'custom' and mcp_servers is not None and '"stdio"' not in mcp_servers:
|
||||
mcp_servers_config = json.loads(mcp_servers)
|
||||
mcp_servers_config = self.handle_variables(mcp_servers_config)
|
||||
elif mcp_tool_ids:
|
||||
mcp_tools = QuerySet(Tool).filter(id__in=mcp_tool_ids).values()
|
||||
for mcp_tool in mcp_tools:
|
||||
if mcp_tool and mcp_tool['is_active']:
|
||||
mcp_servers_config = {**mcp_servers_config, **json.loads(mcp_tool['code'])}
|
||||
mcp_servers_config = self.handle_variables(mcp_servers_config)
|
||||
|
||||
if tool_enable:
|
||||
if tool_ids and len(tool_ids) > 0: # 如果有工具ID,则将其转换为MCP
|
||||
self.context['tool_ids'] = tool_ids
|
||||
self.context['execute_ids'] = []
|
||||
for tool_id in tool_ids:
|
||||
tool = QuerySet(Tool).filter(id=tool_id).first()
|
||||
if not tool.is_active:
|
||||
continue
|
||||
executor = ToolExecutor(CONFIG.get('SANDBOX'))
|
||||
if tool.init_params is not None:
|
||||
params = json.loads(rsa_long_decrypt(tool.init_params))
|
||||
else:
|
||||
params = {}
|
||||
_id, tool_config = executor.get_tool_mcp_config(tool.code, params)
|
||||
|
||||
self.context['execute_ids'].append(_id)
|
||||
mcp_servers_config[str(tool.id)] = tool_config
|
||||
|
||||
if len(mcp_servers_config) > 0:
|
||||
r = mcp_response_generator(chat_model, message_list, json.dumps(mcp_servers_config), mcp_output_enable)
|
||||
return NodeResult(
|
||||
{'result': r, 'chat_model': chat_model, 'message_list': message_list,
|
||||
'history_message': [{'content': message.content, 'role': message.type} for message in
|
||||
(history_message if history_message is not None else [])],
|
||||
'question': question.content}, {},
|
||||
_write_context=write_context_stream)
|
||||
|
||||
return None
|
||||
|
||||
def handle_variables(self, tool_params):
|
||||
# 处理参数中的变量
|
||||
for k, v in tool_params.items():
|
||||
if type(v) == str:
|
||||
tool_params[k] = self.workflow_manage.generate_prompt(tool_params[k])
|
||||
if type(v) == dict:
|
||||
self.handle_variables(v)
|
||||
if (type(v) == list) and (type(v[0]) == str):
|
||||
tool_params[k] = self.get_reference_content(v)
|
||||
return tool_params
|
||||
|
||||
def get_reference_content(self, fields: List[str]):
|
||||
return str(self.workflow_manage.get_reference_field(
|
||||
fields[0],
|
||||
fields[1:]))
|
||||
|
||||
@staticmethod
|
||||
def get_history_message(history_chat_record, dialogue_number, dialogue_type, runtime_node_id):
|
||||
start_index = len(history_chat_record) - dialogue_number
|
||||
|
|
@ -307,20 +269,14 @@ class BaseChatNode(IChatNode):
|
|||
return result
|
||||
|
||||
def get_details(self, index: int, **kwargs):
|
||||
# 删除临时生成的MCP代码文件
|
||||
if self.context.get('execute_ids'):
|
||||
executor = ToolExecutor(CONFIG.get('SANDBOX'))
|
||||
# 清理工具代码文件,延时删除,避免文件被占用
|
||||
for tool_id in self.context.get('execute_ids'):
|
||||
code_path = f'{executor.sandbox_path}/execute/{tool_id}.py'
|
||||
if os.path.exists(code_path):
|
||||
os.remove(code_path)
|
||||
return {
|
||||
'name': self.node.properties.get('stepName'),
|
||||
"index": index,
|
||||
'run_time': self.context.get('run_time'),
|
||||
'system': self.context.get('system'),
|
||||
'history_message': self.context.get('history_message'),
|
||||
'history_message': [{'content': message.content, 'role': message.type} for message in
|
||||
(self.context.get('history_message') if self.context.get(
|
||||
'history_message') is not None else [])],
|
||||
'question': self.context.get('question'),
|
||||
'answer': self.context.get('answer'),
|
||||
'reasoning_content': self.context.get('reasoning_content'),
|
||||
|
|
|
|||
|
|
@ -4,23 +4,24 @@ from typing import Type
|
|||
from rest_framework import serializers
|
||||
|
||||
from application.flow.i_step_node import INode, NodeResult
|
||||
from common.util.field_message import ErrMessage
|
||||
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
|
||||
class ApplicationNodeSerializer(serializers.Serializer):
|
||||
application_id = serializers.CharField(required=True, label=_("Application ID"))
|
||||
application_id = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Application ID")))
|
||||
question_reference_address = serializers.ListField(required=True,
|
||||
label=_("User Questions"))
|
||||
api_input_field_list = serializers.ListField(required=False, label=_("API Input Fields"))
|
||||
error_messages=ErrMessage.list(_("User Questions")))
|
||||
api_input_field_list = serializers.ListField(required=False, error_messages=ErrMessage.list(_("API Input Fields")))
|
||||
user_input_field_list = serializers.ListField(required=False,
|
||||
label=_("User Input Fields"))
|
||||
image_list = serializers.ListField(required=False, label=_("picture"))
|
||||
document_list = serializers.ListField(required=False, label=_("document"))
|
||||
audio_list = serializers.ListField(required=False, label=_("Audio"))
|
||||
error_messages=ErrMessage.uuid(_("User Input Fields")))
|
||||
image_list = serializers.ListField(required=False, error_messages=ErrMessage.list(_("picture")))
|
||||
document_list = serializers.ListField(required=False, error_messages=ErrMessage.list(_("document")))
|
||||
audio_list = serializers.ListField(required=False, error_messages=ErrMessage.list(_("Audio")))
|
||||
child_node = serializers.DictField(required=False, allow_null=True,
|
||||
label=_("Child Nodes"))
|
||||
node_data = serializers.DictField(required=False, allow_null=True, label=_("Form Data"))
|
||||
error_messages=ErrMessage.dict(_("Child Nodes")))
|
||||
node_data = serializers.DictField(required=False, allow_null=True, error_messages=ErrMessage.dict(_("Form Data")))
|
||||
|
||||
|
||||
class IApplicationNode(INode):
|
||||
|
|
@ -74,7 +75,7 @@ class IApplicationNode(INode):
|
|||
if 'file_id' not in audio:
|
||||
raise ValueError(
|
||||
_("Parameter value error: The uploaded audio lacks file_id, and the audio upload fails."))
|
||||
return self.execute(**{**self.flow_params_serializer.data, **self.node_params_serializer.data},
|
||||
return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data,
|
||||
app_document_list=app_document_list, app_image_list=app_image_list,
|
||||
app_audio_list=app_audio_list,
|
||||
message=str(question), **kwargs)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import re
|
|||
import time
|
||||
import uuid
|
||||
from typing import Dict, List
|
||||
from django.utils.translation import gettext as _
|
||||
|
||||
from application.flow.common import Answer
|
||||
from application.flow.i_step_node import NodeResult, INode
|
||||
from application.flow.step_node.application_node.i_application_node import IApplicationNode
|
||||
|
|
@ -55,7 +55,7 @@ def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INo
|
|||
# 先把流转成字符串
|
||||
response_content = chunk.decode('utf-8')[6:]
|
||||
response_content = json.loads(response_content)
|
||||
content = (response_content.get('content', '') or '')
|
||||
content = response_content.get('content', '')
|
||||
runtime_node_id = response_content.get('runtime_node_id', '')
|
||||
chat_record_id = response_content.get('chat_record_id', '')
|
||||
child_node = response_content.get('child_node')
|
||||
|
|
@ -63,7 +63,7 @@ def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INo
|
|||
node_type = response_content.get('node_type')
|
||||
real_node_id = response_content.get('real_node_id')
|
||||
node_is_end = response_content.get('node_is_end', False)
|
||||
_reasoning_content = (response_content.get('reasoning_content', '') or '')
|
||||
_reasoning_content = response_content.get('reasoning_content', '')
|
||||
if node_type == 'form-node':
|
||||
is_interrupt_exec = True
|
||||
answer += content
|
||||
|
|
@ -171,21 +171,16 @@ class BaseApplicationNode(IApplicationNode):
|
|||
if self.node_params.get('is_result', False):
|
||||
self.answer_text = details.get('answer')
|
||||
|
||||
def execute(self, application_id, message, chat_id, chat_record_id, stream, re_chat,
|
||||
chat_user_id,
|
||||
chat_user_type,
|
||||
def execute(self, application_id, message, chat_id, chat_record_id, stream, re_chat, client_id, client_type,
|
||||
app_document_list=None, app_image_list=None, app_audio_list=None, child_node=None, node_data=None,
|
||||
**kwargs) -> NodeResult:
|
||||
from chat.serializers.chat import ChatSerializers
|
||||
if application_id == self.workflow_manage.get_body().get('application_id'):
|
||||
raise Exception(_("The sub application cannot use the current node"))
|
||||
from application.serializers.chat_message_serializers import ChatMessageSerializer
|
||||
# 生成嵌入应用的chat_id
|
||||
current_chat_id = string_to_uuid(chat_id + application_id)
|
||||
Chat.objects.get_or_create(id=current_chat_id, defaults={
|
||||
'application_id': application_id,
|
||||
'abstract': message[0:1024],
|
||||
'chat_user_id': chat_user_id,
|
||||
'chat_user_type': chat_user_type
|
||||
'client_id': client_id,
|
||||
})
|
||||
if app_document_list is None:
|
||||
app_document_list = []
|
||||
|
|
@ -202,26 +197,22 @@ class BaseApplicationNode(IApplicationNode):
|
|||
child_node_value = child_node.get('child_node')
|
||||
application_node_dict = self.context.get('application_node_dict')
|
||||
reset_application_node_dict(application_node_dict, runtime_node_id, node_data)
|
||||
response = ChatSerializers(data={
|
||||
"chat_id": current_chat_id,
|
||||
"chat_user_id": chat_user_id,
|
||||
'chat_user_type': chat_user_type,
|
||||
'application_id': application_id,
|
||||
'debug': False
|
||||
}).chat(instance=
|
||||
{'message': message,
|
||||
're_chat': re_chat,
|
||||
'stream': stream,
|
||||
'document_list': app_document_list,
|
||||
'image_list': app_image_list,
|
||||
'audio_list': app_audio_list,
|
||||
'runtime_node_id': runtime_node_id,
|
||||
'chat_record_id': record_id,
|
||||
'child_node': child_node_value,
|
||||
'node_data': node_data,
|
||||
'form_data': kwargs}
|
||||
)
|
||||
|
||||
response = ChatMessageSerializer(
|
||||
data={'chat_id': current_chat_id, 'message': message,
|
||||
're_chat': re_chat,
|
||||
'stream': stream,
|
||||
'application_id': application_id,
|
||||
'client_id': client_id,
|
||||
'client_type': client_type,
|
||||
'document_list': app_document_list,
|
||||
'image_list': app_image_list,
|
||||
'audio_list': app_audio_list,
|
||||
'runtime_node_id': runtime_node_id,
|
||||
'chat_record_id': record_id,
|
||||
'child_node': child_node_value,
|
||||
'node_data': node_data,
|
||||
'form_data': kwargs}).chat()
|
||||
if response.status_code == 200:
|
||||
if stream:
|
||||
content_generator = response.streaming_content
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@
|
|||
"""
|
||||
|
||||
from .contain_compare import *
|
||||
from .end_with import EndWithCompare
|
||||
from .equal_compare import *
|
||||
from .ge_compare import *
|
||||
from .gt_compare import *
|
||||
|
|
@ -24,10 +23,8 @@ from .len_le_compare import *
|
|||
from .len_lt_compare import *
|
||||
from .lt_compare import *
|
||||
from .not_contain_compare import *
|
||||
from .start_with import StartWithCompare
|
||||
|
||||
compare_handle_list = [GECompare(), GTCompare(), ContainCompare(), EqualCompare(), LTCompare(), LECompare(),
|
||||
LenLECompare(), LenGECompare(), LenEqualCompare(), LenGTCompare(), LenLTCompare(),
|
||||
IsNullCompare(),
|
||||
IsNotNullCompare(), NotContainCompare(), IsTrueCompare(), IsNotTrueCompare(), StartWithCompare(),
|
||||
EndWithCompare()]
|
||||
IsNotNullCompare(), NotContainCompare(), IsTrueCompare(), IsNotTrueCompare()]
|
||||
|
|
@ -8,7 +8,7 @@
|
|||
"""
|
||||
from typing import List
|
||||
|
||||
from application.flow.compare.compare import Compare
|
||||
from application.flow.step_node.condition_node.compare.compare import Compare
|
||||
|
||||
|
||||
class ContainCompare(Compare):
|
||||
|
|
@ -20,7 +20,4 @@ class ContainCompare(Compare):
|
|||
def compare(self, source_value, compare, target_value):
|
||||
if isinstance(source_value, str):
|
||||
return str(target_value) in source_value
|
||||
elif isinstance(source_value, list):
|
||||
return any([str(item) == str(target_value) for item in source_value])
|
||||
else:
|
||||
return str(target_value) in str(source_value)
|
||||
return any([str(item) == str(target_value) for item in source_value])
|
||||
|
|
@ -8,7 +8,7 @@
|
|||
"""
|
||||
from typing import List
|
||||
|
||||
from application.flow.compare import Compare
|
||||
from application.flow.step_node.condition_node.compare.compare import Compare
|
||||
|
||||
|
||||
class EqualCompare(Compare):
|
||||
|
|
@ -8,7 +8,7 @@
|
|||
"""
|
||||
from typing import List
|
||||
|
||||
from application.flow.compare import Compare
|
||||
from application.flow.step_node.condition_node.compare.compare import Compare
|
||||
|
||||
|
||||
class GECompare(Compare):
|
||||
|
|
@ -21,8 +21,4 @@ class GECompare(Compare):
|
|||
try:
|
||||
return float(source_value) >= float(target_value)
|
||||
except Exception as e:
|
||||
try:
|
||||
return str(source_value) >= str(target_value)
|
||||
except Exception as _:
|
||||
pass
|
||||
return False
|
||||
|
|
@ -8,7 +8,7 @@
|
|||
"""
|
||||
from typing import List
|
||||
|
||||
from application.flow.compare import Compare
|
||||
from application.flow.step_node.condition_node.compare.compare import Compare
|
||||
|
||||
|
||||
class GTCompare(Compare):
|
||||
|
|
@ -21,8 +21,4 @@ class GTCompare(Compare):
|
|||
try:
|
||||
return float(source_value) > float(target_value)
|
||||
except Exception as e:
|
||||
try:
|
||||
return str(source_value) > str(target_value)
|
||||
except Exception as _:
|
||||
pass
|
||||
return False
|
||||
|
|
@ -8,7 +8,7 @@
|
|||
"""
|
||||
from typing import List
|
||||
|
||||
from application.flow.compare import Compare
|
||||
from application.flow.step_node.condition_node.compare import Compare
|
||||
|
||||
|
||||
class IsNotNullCompare(Compare):
|
||||
|
|
@ -8,7 +8,7 @@
|
|||
"""
|
||||
from typing import List
|
||||
|
||||
from application.flow.compare import Compare
|
||||
from application.flow.step_node.condition_node.compare import Compare
|
||||
|
||||
|
||||
class IsNotTrueCompare(Compare):
|
||||
|
|
@ -8,7 +8,7 @@
|
|||
"""
|
||||
from typing import List
|
||||
|
||||
from application.flow.compare import Compare
|
||||
from application.flow.step_node.condition_node.compare import Compare
|
||||
|
||||
|
||||
class IsNullCompare(Compare):
|
||||
|
|
@ -18,7 +18,4 @@ class IsNullCompare(Compare):
|
|||
return True
|
||||
|
||||
def compare(self, source_value, compare, target_value):
|
||||
try:
|
||||
return source_value is None or len(source_value) == 0
|
||||
except Exception as e:
|
||||
return False
|
||||
return source_value is None or len(source_value) == 0
|
||||
|
|
@ -8,7 +8,7 @@
|
|||
"""
|
||||
from typing import List
|
||||
|
||||
from application.flow.compare import Compare
|
||||
from application.flow.step_node.condition_node.compare import Compare
|
||||
|
||||
|
||||
class IsTrueCompare(Compare):
|
||||
|
|
@ -8,7 +8,7 @@
|
|||
"""
|
||||
from typing import List
|
||||
|
||||
from application.flow.compare import Compare
|
||||
from application.flow.step_node.condition_node.compare.compare import Compare
|
||||
|
||||
|
||||
class LECompare(Compare):
|
||||
|
|
@ -21,8 +21,4 @@ class LECompare(Compare):
|
|||
try:
|
||||
return float(source_value) <= float(target_value)
|
||||
except Exception as e:
|
||||
try:
|
||||
return str(source_value) <= str(target_value)
|
||||
except Exception as _:
|
||||
pass
|
||||
return False
|
||||
|
|
@ -8,7 +8,7 @@
|
|||
"""
|
||||
from typing import List
|
||||
|
||||
from application.flow.compare import Compare
|
||||
from application.flow.step_node.condition_node.compare.compare import Compare
|
||||
|
||||
|
||||
class LenEqualCompare(Compare):
|
||||
|
|
@ -8,7 +8,7 @@
|
|||
"""
|
||||
from typing import List
|
||||
|
||||
from application.flow.compare import Compare
|
||||
from application.flow.step_node.condition_node.compare.compare import Compare
|
||||
|
||||
|
||||
class LenGECompare(Compare):
|
||||
|
|
@ -8,7 +8,7 @@
|
|||
"""
|
||||
from typing import List
|
||||
|
||||
from application.flow.compare import Compare
|
||||
from application.flow.step_node.condition_node.compare.compare import Compare
|
||||
|
||||
|
||||
class LenGTCompare(Compare):
|
||||
|
|
@ -8,7 +8,7 @@
|
|||
"""
|
||||
from typing import List
|
||||
|
||||
from application.flow.compare import Compare
|
||||
from application.flow.step_node.condition_node.compare.compare import Compare
|
||||
|
||||
|
||||
class LenLECompare(Compare):
|
||||
|
|
@ -8,7 +8,7 @@
|
|||
"""
|
||||
from typing import List
|
||||
|
||||
from application.flow.compare import Compare
|
||||
from application.flow.step_node.condition_node.compare.compare import Compare
|
||||
|
||||
|
||||
class LenLTCompare(Compare):
|
||||
|
|
@ -8,7 +8,7 @@
|
|||
"""
|
||||
from typing import List
|
||||
|
||||
from application.flow.compare import Compare
|
||||
from application.flow.step_node.condition_node.compare.compare import Compare
|
||||
|
||||
|
||||
class LTCompare(Compare):
|
||||
|
|
@ -21,8 +21,4 @@ class LTCompare(Compare):
|
|||
try:
|
||||
return float(source_value) < float(target_value)
|
||||
except Exception as e:
|
||||
try:
|
||||
return str(source_value) < str(target_value)
|
||||
except Exception as _:
|
||||
pass
|
||||
return False
|
||||
|
|
@ -8,7 +8,7 @@
|
|||
"""
|
||||
from typing import List
|
||||
|
||||
from application.flow.compare import Compare
|
||||
from application.flow.step_node.condition_node.compare.compare import Compare
|
||||
|
||||
|
||||
class NotContainCompare(Compare):
|
||||
|
|
@ -20,7 +20,4 @@ class NotContainCompare(Compare):
|
|||
def compare(self, source_value, compare, target_value):
|
||||
if isinstance(source_value, str):
|
||||
return str(target_value) not in source_value
|
||||
elif isinstance(self, list):
|
||||
return not any([str(item) == str(target_value) for item in source_value])
|
||||
else:
|
||||
return str(target_value) not in str(source_value)
|
||||
return not any([str(item) == str(target_value) for item in source_value])
|
||||
|
|
@ -12,18 +12,19 @@ from django.utils.translation import gettext_lazy as _
|
|||
from rest_framework import serializers
|
||||
|
||||
from application.flow.i_step_node import INode
|
||||
from common.util.field_message import ErrMessage
|
||||
|
||||
|
||||
class ConditionSerializer(serializers.Serializer):
|
||||
compare = serializers.CharField(required=True, label=_("Comparator"))
|
||||
value = serializers.CharField(required=True, label=_("value"))
|
||||
field = serializers.ListField(required=True, label=_("Fields"))
|
||||
compare = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Comparator")))
|
||||
value = serializers.CharField(required=True, error_messages=ErrMessage.char(_("value")))
|
||||
field = serializers.ListField(required=True, error_messages=ErrMessage.char(_("Fields")))
|
||||
|
||||
|
||||
class ConditionBranchSerializer(serializers.Serializer):
|
||||
id = serializers.CharField(required=True, label=_("Branch id"))
|
||||
type = serializers.CharField(required=True, label=_("Branch Type"))
|
||||
condition = serializers.CharField(required=True, label=_("Condition or|and"))
|
||||
id = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Branch id")))
|
||||
type = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Branch Type")))
|
||||
condition = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Condition or|and")))
|
||||
conditions = ConditionSerializer(many=True)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@
|
|||
from typing import List
|
||||
|
||||
from application.flow.i_step_node import NodeResult
|
||||
from application.flow.compare import compare_handle_list
|
||||
from application.flow.step_node.condition_node.compare import compare_handle_list
|
||||
from application.flow.step_node.condition_node.i_condition_node import IConditionNode
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -12,17 +12,16 @@ from rest_framework import serializers
|
|||
|
||||
from application.flow.i_step_node import INode, NodeResult
|
||||
from common.exception.app_exception import AppApiException
|
||||
|
||||
from common.util.field_message import ErrMessage
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
|
||||
class ReplyNodeParamsSerializer(serializers.Serializer):
|
||||
reply_type = serializers.CharField(required=True, label=_("Response Type"))
|
||||
fields = serializers.ListField(required=False, label=_("Reference Field"))
|
||||
reply_type = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Response Type")))
|
||||
fields = serializers.ListField(required=False, error_messages=ErrMessage.list(_("Reference Field")))
|
||||
content = serializers.CharField(required=False, allow_blank=True, allow_null=True,
|
||||
label=_("Direct answer content"))
|
||||
is_result = serializers.BooleanField(required=False,
|
||||
label=_('Whether to return content'))
|
||||
error_messages=ErrMessage.char(_("Direct answer content")))
|
||||
is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean(_('Whether to return content')))
|
||||
|
||||
def is_valid(self, *, raise_exception=False):
|
||||
super().is_valid(raise_exception=True)
|
||||
|
|
|
|||
|
|
@ -6,10 +6,11 @@ from django.utils.translation import gettext_lazy as _
|
|||
from rest_framework import serializers
|
||||
|
||||
from application.flow.i_step_node import INode, NodeResult
|
||||
from common.util.field_message import ErrMessage
|
||||
|
||||
|
||||
class DocumentExtractNodeSerializer(serializers.Serializer):
|
||||
document_list = serializers.ListField(required=False, label=_("document"))
|
||||
document_list = serializers.ListField(required=False, error_messages=ErrMessage.list(_("document")))
|
||||
|
||||
|
||||
class IDocumentExtractNode(INode):
|
||||
|
|
|
|||
|
|
@ -7,9 +7,9 @@ from django.db.models import QuerySet
|
|||
|
||||
from application.flow.i_step_node import NodeResult
|
||||
from application.flow.step_node.document_extract_node.i_document_extract_node import IDocumentExtractNode
|
||||
from knowledge.models import File, FileSourceType
|
||||
from knowledge.serializers.document import split_handles, parse_table_handle_list, FileBufferHandle
|
||||
from oss.serializers.file import FileSerializer
|
||||
from dataset.models import File
|
||||
from dataset.serializers.document_serializers import split_handles, parse_table_handle_list, FileBufferHandle
|
||||
from dataset.serializers.file_serializers import FileSerializer
|
||||
|
||||
|
||||
def bytes_to_uploaded_file(file_bytes, file_name="file.txt"):
|
||||
|
|
@ -37,11 +37,11 @@ def bytes_to_uploaded_file(file_bytes, file_name="file.txt"):
|
|||
|
||||
splitter = '\n`-----------------------------------`\n'
|
||||
|
||||
|
||||
class BaseDocumentExtractNode(IDocumentExtractNode):
|
||||
def save_context(self, details, workflow_manage):
|
||||
self.context['content'] = details.get('content')
|
||||
|
||||
|
||||
def execute(self, document, chat_id, **kwargs):
|
||||
get_buffer = FileBufferHandle().get_buffer
|
||||
|
||||
|
|
@ -61,18 +61,12 @@ class BaseDocumentExtractNode(IDocumentExtractNode):
|
|||
'application_id': str(application.id) if application.id else None,
|
||||
'file_id': str(image.id)
|
||||
}
|
||||
file_bytes = image.meta.pop('content')
|
||||
f = bytes_to_uploaded_file(file_bytes, image.file_name)
|
||||
FileSerializer(data={
|
||||
'file': f,
|
||||
'meta': meta,
|
||||
'source_id': meta['application_id'],
|
||||
'source_type': FileSourceType.APPLICATION.value
|
||||
}).upload()
|
||||
file = bytes_to_uploaded_file(image.image, image.image_name)
|
||||
FileSerializer(data={'file': file, 'meta': meta}).upload()
|
||||
|
||||
for doc in document:
|
||||
file = QuerySet(File).filter(id=doc['file_id']).first()
|
||||
buffer = io.BytesIO(file.get_bytes())
|
||||
buffer = io.BytesIO(file.get_byte().tobytes())
|
||||
buffer.name = doc['name'] # this is the important line
|
||||
|
||||
for split_handle in (parse_table_handle_list + split_handles):
|
||||
|
|
|
|||
|
|
@ -11,14 +11,14 @@ from typing import Type
|
|||
from rest_framework import serializers
|
||||
|
||||
from application.flow.i_step_node import INode, NodeResult
|
||||
|
||||
from common.util.field_message import ErrMessage
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
|
||||
class FormNodeParamsSerializer(serializers.Serializer):
|
||||
form_field_list = serializers.ListField(required=True, label=_("Form Configuration"))
|
||||
form_content_format = serializers.CharField(required=True, label=_('Form output content'))
|
||||
form_data = serializers.DictField(required=False, allow_null=True, label=_("Form Data"))
|
||||
form_field_list = serializers.ListField(required=True, error_messages=ErrMessage.list(_("Form Configuration")))
|
||||
form_content_format = serializers.CharField(required=True, error_messages=ErrMessage.char(_('Form output content')))
|
||||
form_data = serializers.DictField(required=False, allow_null=True, error_messages=ErrMessage.dict(_("Form Data")))
|
||||
|
||||
|
||||
class IFormNode(INode):
|
||||
|
|
|
|||
|
|
@ -16,29 +16,6 @@ from application.flow.common import Answer
|
|||
from application.flow.i_step_node import NodeResult
|
||||
from application.flow.step_node.form_node.i_form_node import IFormNode
|
||||
|
||||
multi_select_list = [
|
||||
'MultiSelect',
|
||||
'MultiRow'
|
||||
]
|
||||
|
||||
|
||||
def get_default_option(option_list, _type, value_field):
|
||||
try:
|
||||
if option_list is not None and isinstance(option_list, list) and len(option_list) > 0:
|
||||
default_value_list = [o.get(value_field) for o in option_list if o.get('default')]
|
||||
if len(default_value_list) == 0:
|
||||
return [option_list[0].get(
|
||||
value_field)] if multi_select_list.__contains__(_type) else option_list[0].get(
|
||||
value_field)
|
||||
else:
|
||||
if multi_select_list.__contains__(_type):
|
||||
return default_value_list
|
||||
else:
|
||||
return default_value_list[0]
|
||||
except Exception as _:
|
||||
pass
|
||||
return []
|
||||
|
||||
|
||||
def write_context(step_variable: Dict, global_variable: Dict, node, workflow):
|
||||
if step_variable is not None:
|
||||
|
|
@ -51,13 +28,6 @@ def write_context(step_variable: Dict, global_variable: Dict, node, workflow):
|
|||
node.context['run_time'] = time.time() - node.context['start_time']
|
||||
|
||||
|
||||
def generate_prompt(workflow_manage, _value):
|
||||
try:
|
||||
return workflow_manage.generate_prompt(_value)
|
||||
except Exception as e:
|
||||
return _value
|
||||
|
||||
|
||||
class BaseFormNode(IFormNode):
|
||||
def save_context(self, details, workflow_manage):
|
||||
form_data = details.get('form_data', None)
|
||||
|
|
@ -74,37 +44,6 @@ class BaseFormNode(IFormNode):
|
|||
for key in form_data:
|
||||
self.context[key] = form_data[key]
|
||||
|
||||
def reset_field(self, field):
|
||||
reset_field = ['field', 'label', 'default_value']
|
||||
for f in reset_field:
|
||||
_value = field[f]
|
||||
if _value is None:
|
||||
continue
|
||||
if isinstance(_value, str):
|
||||
field[f] = generate_prompt(self.workflow_manage, _value)
|
||||
elif f == 'label':
|
||||
_label_value = _value.get('label')
|
||||
_value['label'] = generate_prompt(self.workflow_manage, _label_value)
|
||||
tooltip = _value.get('attrs').get('tooltip')
|
||||
if tooltip is not None:
|
||||
_value.get('attrs')['tooltip'] = generate_prompt(self.workflow_manage, tooltip)
|
||||
|
||||
if ['SingleSelect', 'MultiSelect', 'RadioCard', 'RadioRow', 'MultiRow'].__contains__(field.get('input_type')):
|
||||
if field.get('assignment_method') == 'ref_variables':
|
||||
option_list = self.workflow_manage.get_reference_field(field.get('option_list')[0],
|
||||
field.get('option_list')[1:])
|
||||
option_list = option_list if isinstance(option_list, list) else []
|
||||
field['option_list'] = option_list
|
||||
field['default_value'] = get_default_option(option_list, field.get('input_type'),
|
||||
field.get('value_field'))
|
||||
|
||||
if ['JsonInput'].__contains__(field.get('input_type')):
|
||||
if field.get('default_value_assignment_method') == 'ref_variables':
|
||||
field['default_value'] = self.workflow_manage.get_reference_field(field.get('default_value')[0],
|
||||
field.get('default_value')[1:])
|
||||
|
||||
return field
|
||||
|
||||
def execute(self, form_field_list, form_content_format, form_data, **kwargs) -> NodeResult:
|
||||
if form_data is not None:
|
||||
self.context['is_submit'] = True
|
||||
|
|
@ -113,7 +52,6 @@ class BaseFormNode(IFormNode):
|
|||
self.context[key] = form_data.get(key)
|
||||
else:
|
||||
self.context['is_submit'] = False
|
||||
form_field_list = [self.reset_field(field) for field in form_field_list]
|
||||
form_setting = {"form_field_list": form_field_list, "runtime_node_id": self.runtime_node_id,
|
||||
"chat_record_id": self.flow_params_serializer.data.get("chat_record_id"),
|
||||
"is_submit": self.context.get("is_submit", False)}
|
||||
|
|
@ -121,10 +59,7 @@ class BaseFormNode(IFormNode):
|
|||
context = self.workflow_manage.get_workflow_content()
|
||||
form_content_format = self.workflow_manage.reset_prompt(form_content_format)
|
||||
prompt_template = PromptTemplate.from_template(form_content_format, template_format='jinja2')
|
||||
value = prompt_template.format(form=form, context=context, runtime_node_id=self.runtime_node_id,
|
||||
chat_record_id=self.flow_params_serializer.data.get("chat_record_id"),
|
||||
form_field_list=form_field_list)
|
||||
|
||||
value = prompt_template.format(form=form, context=context)
|
||||
return NodeResult(
|
||||
{'result': value, 'form_field_list': form_field_list, 'form_content_format': form_content_format}, {},
|
||||
_write_context=write_context)
|
||||
|
|
@ -140,9 +75,7 @@ class BaseFormNode(IFormNode):
|
|||
context = self.workflow_manage.get_workflow_content()
|
||||
form_content_format = self.workflow_manage.reset_prompt(form_content_format)
|
||||
prompt_template = PromptTemplate.from_template(form_content_format, template_format='jinja2')
|
||||
value = prompt_template.format(form=form, context=context, runtime_node_id=self.runtime_node_id,
|
||||
chat_record_id=self.flow_params_serializer.data.get("chat_record_id"),
|
||||
form_field_list=form_field_list)
|
||||
value = prompt_template.format(form=form, context=context)
|
||||
return [Answer(value, self.view_type, self.runtime_node_id, self.workflow_params['chat_record_id'], None,
|
||||
self.runtime_node_id, '')]
|
||||
|
||||
|
|
@ -157,9 +90,7 @@ class BaseFormNode(IFormNode):
|
|||
context = self.workflow_manage.get_workflow_content()
|
||||
form_content_format = self.workflow_manage.reset_prompt(form_content_format)
|
||||
prompt_template = PromptTemplate.from_template(form_content_format, template_format='jinja2')
|
||||
value = prompt_template.format(form=form, context=context, runtime_node_id=self.runtime_node_id,
|
||||
chat_record_id=self.flow_params_serializer.data.get("chat_record_id"),
|
||||
form_field_list=form_field_list)
|
||||
value = prompt_template.format(form=form, context=context)
|
||||
return {
|
||||
'name': self.node.properties.get('stepName'),
|
||||
"index": index,
|
||||
|
|
|
|||
|
|
@ -8,38 +8,35 @@
|
|||
"""
|
||||
from typing import Type
|
||||
|
||||
from django.db import connection
|
||||
from django.db.models import QuerySet
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from rest_framework import serializers
|
||||
|
||||
from application.flow.i_step_node import INode, NodeResult
|
||||
from common.field.common import ObjectField
|
||||
from tools.models.tool import Tool
|
||||
from common.util.field_message import ErrMessage
|
||||
from function_lib.models.function import FunctionLib
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
|
||||
class InputField(serializers.Serializer):
|
||||
name = serializers.CharField(required=True, label=_('Variable Name'))
|
||||
value = ObjectField(required=True, label=_("Variable Value"), model_type_list=[str, list])
|
||||
name = serializers.CharField(required=True, error_messages=ErrMessage.char(_('Variable Name')))
|
||||
value = ObjectField(required=True, error_messages=ErrMessage.char(_("Variable Value")), model_type_list=[str, list])
|
||||
|
||||
|
||||
class FunctionLibNodeParamsSerializer(serializers.Serializer):
|
||||
tool_lib_id = serializers.UUIDField(required=True, label=_('Library ID'))
|
||||
function_lib_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid(_('Library ID')))
|
||||
input_field_list = InputField(required=True, many=True)
|
||||
is_result = serializers.BooleanField(required=False,
|
||||
label=_('Whether to return content'))
|
||||
is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean(_('Whether to return content')))
|
||||
|
||||
def is_valid(self, *, raise_exception=False):
|
||||
super().is_valid(raise_exception=True)
|
||||
f_lib = QuerySet(Tool).filter(id=self.data.get('tool_lib_id')).first()
|
||||
# 归还链接到连接池
|
||||
connection.close()
|
||||
f_lib = QuerySet(FunctionLib).filter(id=self.data.get('function_lib_id')).first()
|
||||
if f_lib is None:
|
||||
raise Exception(_('The function has been deleted'))
|
||||
|
||||
|
||||
class IToolLibNode(INode):
|
||||
type = 'tool-lib-node'
|
||||
class IFunctionLibNode(INode):
|
||||
type = 'function-lib-node'
|
||||
|
||||
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
|
||||
return FunctionLibNodeParamsSerializer
|
||||
|
|
@ -47,5 +44,5 @@ class IToolLibNode(INode):
|
|||
def _run(self):
|
||||
return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data)
|
||||
|
||||
def execute(self, tool_lib_id, input_field_list, **kwargs) -> NodeResult:
|
||||
def execute(self, function_lib_id, input_field_list, **kwargs) -> NodeResult:
|
||||
pass
|
||||
|
|
@ -6,4 +6,4 @@
|
|||
@date:2024/8/8 17:48
|
||||
@desc:
|
||||
"""
|
||||
from .base_tool_lib_node import BaseToolLibNodeNode
|
||||
from .base_function_lib_node import BaseFunctionLibNodeNode
|
||||
|
|
@ -0,0 +1,150 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: base_function_lib_node.py
|
||||
@date:2024/8/8 17:49
|
||||
@desc:
|
||||
"""
|
||||
import json
|
||||
import time
|
||||
from typing import Dict
|
||||
|
||||
from django.db.models import QuerySet
|
||||
from django.utils.translation import gettext as _
|
||||
|
||||
from application.flow.i_step_node import NodeResult
|
||||
from application.flow.step_node.function_lib_node.i_function_lib_node import IFunctionLibNode
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.util.function_code import FunctionExecutor
|
||||
from common.util.rsa_util import rsa_long_decrypt
|
||||
from function_lib.models.function import FunctionLib
|
||||
from smartdoc.const import CONFIG
|
||||
|
||||
function_executor = FunctionExecutor(CONFIG.get('SANDBOX'))
|
||||
|
||||
|
||||
def write_context(step_variable: Dict, global_variable: Dict, node, workflow):
|
||||
if step_variable is not None:
|
||||
for key in step_variable:
|
||||
node.context[key] = step_variable[key]
|
||||
if workflow.is_result(node, NodeResult(step_variable, global_variable)) and 'result' in step_variable:
|
||||
result = str(step_variable['result']) + '\n'
|
||||
yield result
|
||||
node.answer_text = result
|
||||
node.context['run_time'] = time.time() - node.context['start_time']
|
||||
|
||||
|
||||
def get_field_value(debug_field_list, name, is_required):
|
||||
result = [field for field in debug_field_list if field.get('name') == name]
|
||||
if len(result) > 0:
|
||||
return result[-1]['value']
|
||||
if is_required:
|
||||
raise AppApiException(500, _('Field: {name} No value set').format(name=name))
|
||||
return None
|
||||
|
||||
|
||||
def valid_reference_value(_type, value, name):
|
||||
if _type == 'int':
|
||||
instance_type = int | float
|
||||
elif _type == 'float':
|
||||
instance_type = float | int
|
||||
elif _type == 'dict':
|
||||
instance_type = dict
|
||||
elif _type == 'array':
|
||||
instance_type = list
|
||||
elif _type == 'string':
|
||||
instance_type = str
|
||||
else:
|
||||
raise Exception(_('Field: {name} Type: {_type} Value: {value} Unsupported types').format(name=name,
|
||||
_type=_type))
|
||||
if not isinstance(value, instance_type):
|
||||
raise Exception(
|
||||
_('Field: {name} Type: {_type} Value: {value} Type error').format(name=name, _type=_type,
|
||||
value=value))
|
||||
|
||||
|
||||
def convert_value(name: str, value, _type, is_required, source, node):
|
||||
if not is_required and (value is None or (isinstance(value, str) and len(value) == 0)):
|
||||
return None
|
||||
if not is_required and source == 'reference' and (value is None or len(value) == 0):
|
||||
return None
|
||||
if source == 'reference':
|
||||
value = node.workflow_manage.get_reference_field(
|
||||
value[0],
|
||||
value[1:])
|
||||
valid_reference_value(_type, value, name)
|
||||
if _type == 'int':
|
||||
return int(value)
|
||||
if _type == 'float':
|
||||
return float(value)
|
||||
return value
|
||||
try:
|
||||
if _type == 'int':
|
||||
return int(value)
|
||||
if _type == 'float':
|
||||
return float(value)
|
||||
if _type == 'dict':
|
||||
v = json.loads(value)
|
||||
if isinstance(v, dict):
|
||||
return v
|
||||
raise Exception(_('type error'))
|
||||
if _type == 'array':
|
||||
v = json.loads(value)
|
||||
if isinstance(v, list):
|
||||
return v
|
||||
raise Exception(_('type error'))
|
||||
return value
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
_('Field: {name} Type: {_type} Value: {value} Type error').format(name=name, _type=_type,
|
||||
value=value))
|
||||
|
||||
|
||||
def valid_function(function_lib, user_id):
|
||||
if function_lib is None:
|
||||
raise Exception(_('Function does not exist'))
|
||||
if function_lib.permission_type == 'PRIVATE' and str(function_lib.user_id) != str(user_id):
|
||||
raise Exception(_('No permission to use this function {name}').format(name=function_lib.name))
|
||||
if not function_lib.is_active:
|
||||
raise Exception(_('Function {name} is unavailable').format(name=function_lib.name))
|
||||
|
||||
|
||||
class BaseFunctionLibNodeNode(IFunctionLibNode):
|
||||
def save_context(self, details, workflow_manage):
|
||||
self.context['result'] = details.get('result')
|
||||
if self.node_params.get('is_result'):
|
||||
self.answer_text = str(details.get('result'))
|
||||
|
||||
def execute(self, function_lib_id, input_field_list, **kwargs) -> NodeResult:
|
||||
function_lib = QuerySet(FunctionLib).filter(id=function_lib_id).first()
|
||||
valid_function(function_lib, self.flow_params_serializer.data.get('user_id'))
|
||||
params = {field.get('name'): convert_value(field.get('name'), field.get('value'), field.get('type'),
|
||||
field.get('is_required'),
|
||||
field.get('source'), self)
|
||||
for field in
|
||||
[{'value': get_field_value(input_field_list, field.get('name'), field.get('is_required'),
|
||||
), **field}
|
||||
for field in
|
||||
function_lib.input_field_list]}
|
||||
|
||||
self.context['params'] = params
|
||||
# 合并初始化参数
|
||||
if function_lib.init_params is not None:
|
||||
all_params = json.loads(rsa_long_decrypt(function_lib.init_params)) | params
|
||||
else:
|
||||
all_params = params
|
||||
result = function_executor.exec_code(function_lib.code, all_params)
|
||||
return NodeResult({'result': result}, {}, _write_context=write_context)
|
||||
|
||||
def get_details(self, index: int, **kwargs):
|
||||
return {
|
||||
'name': self.node.properties.get('stepName'),
|
||||
"index": index,
|
||||
"result": self.context.get('result'),
|
||||
"params": self.context.get('params'),
|
||||
'run_time': self.context.get('run_time'),
|
||||
'type': self.node.type,
|
||||
'status': self.status,
|
||||
'err_message': self.err_message
|
||||
}
|
||||
|
|
@ -15,23 +15,23 @@ from rest_framework import serializers
|
|||
from application.flow.i_step_node import INode, NodeResult
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.field.common import ObjectField
|
||||
|
||||
from common.util.field_message import ErrMessage
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from rest_framework.utils.formatting import lazy_format
|
||||
|
||||
|
||||
class InputField(serializers.Serializer):
|
||||
name = serializers.CharField(required=True, label=_('Variable Name'))
|
||||
is_required = serializers.BooleanField(required=True, label=_("Is this field required"))
|
||||
type = serializers.CharField(required=True, label=_("type"), validators=[
|
||||
name = serializers.CharField(required=True, error_messages=ErrMessage.char(_('Variable Name')))
|
||||
is_required = serializers.BooleanField(required=True, error_messages=ErrMessage.boolean(_("Is this field required")))
|
||||
type = serializers.CharField(required=True, error_messages=ErrMessage.char(_("type")), validators=[
|
||||
validators.RegexValidator(regex=re.compile("^string|int|dict|array|float$"),
|
||||
message=_("The field only supports string|int|dict|array|float"), code=500)
|
||||
])
|
||||
source = serializers.CharField(required=True, label=_("source"), validators=[
|
||||
source = serializers.CharField(required=True, error_messages=ErrMessage.char(_("source")), validators=[
|
||||
validators.RegexValidator(regex=re.compile("^custom|reference$"),
|
||||
message=_("The field only supports custom|reference"), code=500)
|
||||
])
|
||||
value = ObjectField(required=True, label=_("Variable Value"), model_type_list=[str, list])
|
||||
value = ObjectField(required=True, error_messages=ErrMessage.char(_("Variable Value")), model_type_list=[str, list])
|
||||
|
||||
def is_valid(self, *, raise_exception=False):
|
||||
super().is_valid(raise_exception=True)
|
||||
|
|
@ -43,16 +43,15 @@ class InputField(serializers.Serializer):
|
|||
|
||||
class FunctionNodeParamsSerializer(serializers.Serializer):
|
||||
input_field_list = InputField(required=True, many=True)
|
||||
code = serializers.CharField(required=True, label=_("function"))
|
||||
is_result = serializers.BooleanField(required=False,
|
||||
label=_('Whether to return content'))
|
||||
code = serializers.CharField(required=True, error_messages=ErrMessage.char(_("function")))
|
||||
is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean(_('Whether to return content')))
|
||||
|
||||
def is_valid(self, *, raise_exception=False):
|
||||
super().is_valid(raise_exception=True)
|
||||
|
||||
|
||||
class IToolNode(INode):
|
||||
type = 'tool-node'
|
||||
class IFunctionNode(INode):
|
||||
type = 'function-node'
|
||||
|
||||
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
|
||||
return FunctionNodeParamsSerializer
|
||||
|
|
@ -6,4 +6,4 @@
|
|||
@date:2024/8/13 11:19
|
||||
@desc:
|
||||
"""
|
||||
from .base_tool_node import BaseToolNodeNode
|
||||
from .base_function_node import BaseFunctionNodeNode
|
||||
|
|
@ -8,16 +8,16 @@
|
|||
"""
|
||||
import json
|
||||
import time
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from django.utils.translation import gettext as _
|
||||
|
||||
from application.flow.i_step_node import NodeResult
|
||||
from application.flow.step_node.tool_node.i_tool_node import IToolNode
|
||||
from common.utils.tool_code import ToolExecutor
|
||||
from maxkb.const import CONFIG
|
||||
from application.flow.step_node.function_node.i_function_node import IFunctionNode
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.util.function_code import FunctionExecutor
|
||||
from smartdoc.const import CONFIG
|
||||
|
||||
function_executor = ToolExecutor(CONFIG.get('SANDBOX'))
|
||||
function_executor = FunctionExecutor(CONFIG.get('SANDBOX'))
|
||||
|
||||
|
||||
def write_context(step_variable: Dict, global_variable: Dict, node, workflow):
|
||||
|
|
@ -32,54 +32,36 @@ def write_context(step_variable: Dict, global_variable: Dict, node, workflow):
|
|||
|
||||
|
||||
def valid_reference_value(_type, value, name):
|
||||
try:
|
||||
if _type == 'int':
|
||||
instance_type = int | float
|
||||
elif _type == 'float':
|
||||
instance_type = float | int
|
||||
elif _type == 'dict':
|
||||
value = json.loads(value) if isinstance(value, str) else value
|
||||
instance_type = dict
|
||||
elif _type == 'array':
|
||||
value = json.loads(value) if isinstance(value, str) else value
|
||||
instance_type = list
|
||||
elif _type == 'string':
|
||||
instance_type = str
|
||||
else:
|
||||
raise Exception(_(
|
||||
'Field: {name} Type: {_type} Value: {value} Unsupported types'
|
||||
).format(name=name, _type=_type))
|
||||
except:
|
||||
return value
|
||||
if _type == 'int':
|
||||
instance_type = int | float
|
||||
elif _type == 'float':
|
||||
instance_type = float | int
|
||||
elif _type == 'dict':
|
||||
instance_type = dict
|
||||
elif _type == 'array':
|
||||
instance_type = list
|
||||
elif _type == 'string':
|
||||
instance_type = str
|
||||
else:
|
||||
raise Exception(500, f'字段:{name}类型:{_type} 不支持的类型')
|
||||
if not isinstance(value, instance_type):
|
||||
raise Exception(_(
|
||||
'Field: {name} Type: {_type} Value: {value} Type error'
|
||||
).format(name=name, _type=_type, value=value))
|
||||
return value
|
||||
raise Exception(f'字段:{name}类型:{_type}值:{value}类型错误')
|
||||
|
||||
|
||||
def convert_value(name: str, value, _type, is_required, source, node):
|
||||
if not is_required and (value is None or ((isinstance(value, str) or isinstance(value, list)) and len(value) == 0)):
|
||||
if not is_required and (value is None or (isinstance(value, str) and len(value) == 0)):
|
||||
return None
|
||||
if source == 'reference':
|
||||
value = node.workflow_manage.get_reference_field(
|
||||
value[0],
|
||||
value[1:])
|
||||
if value is None:
|
||||
if not is_required:
|
||||
return None
|
||||
else:
|
||||
raise Exception(_(
|
||||
'Field: {name} Type: {_type} is required'
|
||||
).format(name=name, _type=_type))
|
||||
value = valid_reference_value(_type, value, name)
|
||||
valid_reference_value(_type, value, name)
|
||||
if _type == 'int':
|
||||
return int(value)
|
||||
if _type == 'float':
|
||||
return float(value)
|
||||
return value
|
||||
try:
|
||||
value = node.workflow_manage.generate_prompt(value)
|
||||
if _type == 'int':
|
||||
return int(value)
|
||||
if _type == 'float':
|
||||
|
|
@ -88,20 +70,18 @@ def convert_value(name: str, value, _type, is_required, source, node):
|
|||
v = json.loads(value)
|
||||
if isinstance(v, dict):
|
||||
return v
|
||||
raise Exception(_('type error'))
|
||||
raise Exception("类型错误")
|
||||
if _type == 'array':
|
||||
v = json.loads(value)
|
||||
if isinstance(v, list):
|
||||
return v
|
||||
raise Exception(_('type error'))
|
||||
raise Exception("类型错误")
|
||||
return value
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
_('Field: {name} Type: {_type} Value: {value} Type error').format(name=name, _type=_type,
|
||||
value=value))
|
||||
raise Exception(f'字段:{name}类型:{_type}值:{value}类型错误')
|
||||
|
||||
|
||||
class BaseToolNodeNode(IToolNode):
|
||||
class BaseFunctionNodeNode(IFunctionNode):
|
||||
def save_context(self, details, workflow_manage):
|
||||
self.context['result'] = details.get('result')
|
||||
if self.node_params.get('is_result', False):
|
||||
|
|
@ -2,31 +2,31 @@
|
|||
|
||||
from typing import Type
|
||||
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from rest_framework import serializers
|
||||
|
||||
from application.flow.i_step_node import INode, NodeResult
|
||||
from common.util.field_message import ErrMessage
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
|
||||
class ImageGenerateNodeSerializer(serializers.Serializer):
|
||||
model_id = serializers.CharField(required=True, label=_("Model id"))
|
||||
model_id = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Model id")))
|
||||
|
||||
prompt = serializers.CharField(required=True, label=_("Prompt word (positive)"))
|
||||
prompt = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Prompt word (positive)")))
|
||||
|
||||
negative_prompt = serializers.CharField(required=False, label=_("Prompt word (negative)"),
|
||||
negative_prompt = serializers.CharField(required=False, error_messages=ErrMessage.char(_("Prompt word (negative)")),
|
||||
allow_null=True, allow_blank=True, )
|
||||
# 多轮对话数量
|
||||
dialogue_number = serializers.IntegerField(required=False, default=0,
|
||||
label=_("Number of multi-round conversations"))
|
||||
error_messages=ErrMessage.integer(_("Number of multi-round conversations")))
|
||||
|
||||
dialogue_type = serializers.CharField(required=False, default='NODE',
|
||||
label=_("Conversation storage type"))
|
||||
error_messages=ErrMessage.char(_("Conversation storage type")))
|
||||
|
||||
is_result = serializers.BooleanField(required=False,
|
||||
label=_('Whether to return content'))
|
||||
is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean(_('Whether to return content')))
|
||||
|
||||
model_params_setting = serializers.JSONField(required=False, default=dict,
|
||||
label=_("Model parameter settings"))
|
||||
error_messages=ErrMessage.json(_("Model parameter settings")))
|
||||
|
||||
|
||||
class IImageGenerateNode(INode):
|
||||
|
|
|
|||
|
|
@ -7,10 +7,9 @@ from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
|
|||
|
||||
from application.flow.i_step_node import NodeResult
|
||||
from application.flow.step_node.image_generate_step_node.i_image_generate_node import IImageGenerateNode
|
||||
from common.utils.common import bytes_to_uploaded_file
|
||||
from knowledge.models import FileSourceType
|
||||
from oss.serializers.file import FileSerializer
|
||||
from models_provider.tools import get_model_instance_by_model_workspace_id
|
||||
from common.util.common import bytes_to_uploaded_file
|
||||
from dataset.serializers.file_serializers import FileSerializer
|
||||
from setting.models_provider.tools import get_model_instance_by_model_user_id
|
||||
|
||||
|
||||
class BaseImageGenerateNode(IImageGenerateNode):
|
||||
|
|
@ -24,10 +23,10 @@ class BaseImageGenerateNode(IImageGenerateNode):
|
|||
model_params_setting,
|
||||
chat_record_id,
|
||||
**kwargs) -> NodeResult:
|
||||
print(model_params_setting)
|
||||
application = self.workflow_manage.work_flow_post_handler.chat_info.application
|
||||
workspace_id = self.workflow_manage.get_body().get('workspace_id')
|
||||
tti_model = get_model_instance_by_model_workspace_id(model_id, workspace_id,
|
||||
**model_params_setting)
|
||||
tti_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'),
|
||||
**model_params_setting)
|
||||
history_message = self.get_history_message(history_chat_record, dialogue_number)
|
||||
self.context['history_message'] = history_message
|
||||
question = self.generate_prompt_question(prompt)
|
||||
|
|
@ -35,26 +34,19 @@ class BaseImageGenerateNode(IImageGenerateNode):
|
|||
message_list = self.generate_message_list(question, history_message)
|
||||
self.context['message_list'] = message_list
|
||||
self.context['dialogue_type'] = dialogue_type
|
||||
self.context['negative_prompt'] = self.generate_prompt_question(negative_prompt)
|
||||
print(message_list)
|
||||
image_urls = tti_model.generate_image(question, negative_prompt)
|
||||
# 保存图片
|
||||
file_urls = []
|
||||
for image_url in image_urls:
|
||||
file_name = 'generated_image.png'
|
||||
if isinstance(image_url, str) and image_url.startswith('http'):
|
||||
image_url = requests.get(image_url).content
|
||||
file = bytes_to_uploaded_file(image_url, file_name)
|
||||
file = bytes_to_uploaded_file(requests.get(image_url).content, file_name)
|
||||
meta = {
|
||||
'debug': False if application.id else True,
|
||||
'chat_id': chat_id,
|
||||
'application_id': str(application.id) if application.id else None,
|
||||
}
|
||||
file_url = FileSerializer(data={
|
||||
'file': file,
|
||||
'meta': meta,
|
||||
'source_id': meta['application_id'],
|
||||
'source_type': FileSourceType.APPLICATION.value
|
||||
}).upload()
|
||||
file_url = FileSerializer(data={'file': file, 'meta': meta}).upload()
|
||||
file_urls.append(file_url)
|
||||
self.context['image_list'] = [{'file_id': path.split('/')[-1], 'url': path} for path in file_urls]
|
||||
answer = ' '.join([f"" for path in file_urls])
|
||||
|
|
@ -126,6 +118,5 @@ class BaseImageGenerateNode(IImageGenerateNode):
|
|||
'status': self.status,
|
||||
'err_message': self.err_message,
|
||||
'image_list': self.context.get('image_list'),
|
||||
'dialogue_type': self.context.get('dialogue_type'),
|
||||
'negative_prompt': self.context.get('negative_prompt'),
|
||||
'dialogue_type': self.context.get('dialogue_type')
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,3 +0,0 @@
|
|||
# coding=utf-8
|
||||
|
||||
from .impl import *
|
||||
|
|
@ -1,64 +0,0 @@
|
|||
# coding=utf-8
|
||||
|
||||
from typing import Type
|
||||
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from rest_framework import serializers
|
||||
|
||||
from application.flow.i_step_node import INode, NodeResult
|
||||
|
||||
|
||||
class ImageToVideoNodeSerializer(serializers.Serializer):
|
||||
model_id = serializers.CharField(required=True, label=_("Model id"))
|
||||
|
||||
prompt = serializers.CharField(required=True, label=_("Prompt word (positive)"))
|
||||
|
||||
negative_prompt = serializers.CharField(required=False, label=_("Prompt word (negative)"),
|
||||
allow_null=True, allow_blank=True, )
|
||||
# 多轮对话数量
|
||||
dialogue_number = serializers.IntegerField(required=False, default=0,
|
||||
label=_("Number of multi-round conversations"))
|
||||
|
||||
dialogue_type = serializers.CharField(required=False, default='NODE',
|
||||
label=_("Conversation storage type"))
|
||||
|
||||
is_result = serializers.BooleanField(required=False,
|
||||
label=_('Whether to return content'))
|
||||
|
||||
model_params_setting = serializers.JSONField(required=False, default=dict,
|
||||
label=_("Model parameter settings"))
|
||||
|
||||
first_frame_url = serializers.ListField(required=True, label=_("First frame url"))
|
||||
last_frame_url = serializers.ListField(required=False, label=_("Last frame url"))
|
||||
|
||||
|
||||
class IImageToVideoNode(INode):
|
||||
type = 'image-to-video-node'
|
||||
|
||||
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
|
||||
return ImageToVideoNodeSerializer
|
||||
|
||||
def _run(self):
|
||||
first_frame_url = self.workflow_manage.get_reference_field(
|
||||
self.node_params_serializer.data.get('first_frame_url')[0],
|
||||
self.node_params_serializer.data.get('first_frame_url')[1:])
|
||||
if first_frame_url is []:
|
||||
raise ValueError(
|
||||
_("First frame url cannot be empty"))
|
||||
last_frame_url = None
|
||||
if self.node_params_serializer.data.get('last_frame_url') is not None and self.node_params_serializer.data.get(
|
||||
'last_frame_url') != []:
|
||||
last_frame_url = self.workflow_manage.get_reference_field(
|
||||
self.node_params_serializer.data.get('last_frame_url')[0],
|
||||
self.node_params_serializer.data.get('last_frame_url')[1:])
|
||||
node_params_data = {k: v for k, v in self.node_params_serializer.data.items()
|
||||
if k not in ['first_frame_url', 'last_frame_url']}
|
||||
return self.execute(first_frame_url=first_frame_url, last_frame_url=last_frame_url,
|
||||
**node_params_data, **self.flow_params_serializer.data)
|
||||
|
||||
def execute(self, model_id, prompt, negative_prompt, dialogue_number, dialogue_type, history_chat_record, chat_id,
|
||||
model_params_setting,
|
||||
chat_record_id,
|
||||
first_frame_url, last_frame_url,
|
||||
**kwargs) -> NodeResult:
|
||||
pass
|
||||
|
|
@ -1,3 +0,0 @@
|
|||
# coding=utf-8
|
||||
|
||||
from .base_image_to_video_node import BaseImageToVideoNode
|
||||
|
|
@ -1,158 +0,0 @@
|
|||
# coding=utf-8
|
||||
import base64
|
||||
from functools import reduce
|
||||
from typing import List
|
||||
|
||||
import requests
|
||||
from django.db.models import QuerySet
|
||||
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
|
||||
|
||||
from application.flow.i_step_node import NodeResult
|
||||
from application.flow.step_node.image_to_video_step_node.i_image_to_video_node import IImageToVideoNode
|
||||
from common.utils.common import bytes_to_uploaded_file
|
||||
from knowledge.models import FileSourceType, File
|
||||
from oss.serializers.file import FileSerializer, mime_types
|
||||
from models_provider.tools import get_model_instance_by_model_workspace_id
|
||||
from django.utils.translation import gettext
|
||||
|
||||
|
||||
class BaseImageToVideoNode(IImageToVideoNode):
|
||||
def save_context(self, details, workflow_manage):
|
||||
self.context['answer'] = details.get('answer')
|
||||
self.context['question'] = details.get('question')
|
||||
if self.node_params.get('is_result', False):
|
||||
self.answer_text = details.get('answer')
|
||||
|
||||
def execute(self, model_id, prompt, negative_prompt, dialogue_number, dialogue_type, history_chat_record, chat_id,
|
||||
model_params_setting,
|
||||
chat_record_id,
|
||||
first_frame_url, last_frame_url=None,
|
||||
**kwargs) -> NodeResult:
|
||||
application = self.workflow_manage.work_flow_post_handler.chat_info.application
|
||||
workspace_id = self.workflow_manage.get_body().get('workspace_id')
|
||||
ttv_model = get_model_instance_by_model_workspace_id(model_id, workspace_id,
|
||||
**model_params_setting)
|
||||
history_message = self.get_history_message(history_chat_record, dialogue_number)
|
||||
self.context['history_message'] = history_message
|
||||
question = self.generate_prompt_question(prompt)
|
||||
self.context['question'] = question
|
||||
message_list = self.generate_message_list(question, history_message)
|
||||
self.context['message_list'] = message_list
|
||||
self.context['dialogue_type'] = dialogue_type
|
||||
self.context['negative_prompt'] = self.generate_prompt_question(negative_prompt)
|
||||
self.context['first_frame_url'] = first_frame_url
|
||||
self.context['last_frame_url'] = last_frame_url
|
||||
# 处理首尾帧图片 这块可以是url 也可以是file_id 如果是url 可以直接传递给模型 如果是file_id 需要传base64
|
||||
# 判断是不是 url
|
||||
first_frame_url = self.get_file_base64(first_frame_url)
|
||||
last_frame_url = self.get_file_base64(last_frame_url)
|
||||
video_urls = ttv_model.generate_video(question, negative_prompt, first_frame_url, last_frame_url)
|
||||
# 保存图片
|
||||
if video_urls is None or video_urls == '':
|
||||
return NodeResult({'answer': gettext('Failed to generate video')}, {})
|
||||
file_name = 'generated_video.mp4'
|
||||
if isinstance(video_urls, str) and video_urls.startswith('http'):
|
||||
video_urls = requests.get(video_urls).content
|
||||
file = bytes_to_uploaded_file(video_urls, file_name)
|
||||
meta = {
|
||||
'debug': False if application.id else True,
|
||||
'chat_id': chat_id,
|
||||
'application_id': str(application.id) if application.id else None,
|
||||
}
|
||||
file_url = FileSerializer(data={
|
||||
'file': file,
|
||||
'meta': meta,
|
||||
'source_id': meta['application_id'],
|
||||
'source_type': FileSourceType.APPLICATION.value
|
||||
}).upload()
|
||||
video_label = f'<video src="{file_url}" controls style="max-width: 100%; width: 100%; height: auto; max-height: 60vh;"></video>'
|
||||
video_list = [{'file_id': file_url.split('/')[-1], 'file_name': file_name, 'url': file_url}]
|
||||
return NodeResult({'answer': video_label, 'chat_model': ttv_model, 'message_list': message_list,
|
||||
'video': video_list,
|
||||
'history_message': history_message, 'question': question}, {})
|
||||
|
||||
def get_file_base64(self, image_url):
|
||||
try:
|
||||
if isinstance(image_url, list):
|
||||
image_url = image_url[0].get('file_id')
|
||||
if isinstance(image_url, str) and not image_url.startswith('http'):
|
||||
file = QuerySet(File).filter(id=image_url).first()
|
||||
file_bytes = file.get_bytes()
|
||||
# 如果我不知道content_type 可以用 magic 库去检测
|
||||
file_type = file.file_name.split(".")[-1].lower()
|
||||
content_type = mime_types.get(file_type, 'application/octet-stream')
|
||||
encoded_bytes = base64.b64encode(file_bytes)
|
||||
return f'data:{content_type};base64,{encoded_bytes.decode()}'
|
||||
return image_url
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
gettext("Failed to obtain the image"))
|
||||
|
||||
def generate_history_ai_message(self, chat_record):
|
||||
for val in chat_record.details.values():
|
||||
if self.node.id == val['node_id'] and 'image_list' in val:
|
||||
if val['dialogue_type'] == 'WORKFLOW':
|
||||
return chat_record.get_ai_message()
|
||||
image_list = val['image_list']
|
||||
return AIMessage(content=[
|
||||
*[{'type': 'image_url', 'image_url': {'url': f'{file_url}'}} for file_url in image_list]
|
||||
])
|
||||
return chat_record.get_ai_message()
|
||||
|
||||
def get_history_message(self, history_chat_record, dialogue_number):
|
||||
start_index = len(history_chat_record) - dialogue_number
|
||||
history_message = reduce(lambda x, y: [*x, *y], [
|
||||
[self.generate_history_human_message(history_chat_record[index]),
|
||||
self.generate_history_ai_message(history_chat_record[index])]
|
||||
for index in
|
||||
range(start_index if start_index > 0 else 0, len(history_chat_record))], [])
|
||||
return history_message
|
||||
|
||||
def generate_history_human_message(self, chat_record):
|
||||
|
||||
for data in chat_record.details.values():
|
||||
if self.node.id == data['node_id'] and 'image_list' in data:
|
||||
image_list = data['image_list']
|
||||
if len(image_list) == 0 or data['dialogue_type'] == 'WORKFLOW':
|
||||
return HumanMessage(content=chat_record.problem_text)
|
||||
return HumanMessage(content=data['question'])
|
||||
return HumanMessage(content=chat_record.problem_text)
|
||||
|
||||
def generate_prompt_question(self, prompt):
|
||||
return self.workflow_manage.generate_prompt(prompt)
|
||||
|
||||
def generate_message_list(self, question: str, history_message):
|
||||
return [
|
||||
*history_message,
|
||||
question
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def reset_message_list(message_list: List[BaseMessage], answer_text):
|
||||
result = [{'role': 'user' if isinstance(message, HumanMessage) else 'ai', 'content': message.content} for
|
||||
message
|
||||
in
|
||||
message_list]
|
||||
result.append({'role': 'ai', 'content': answer_text})
|
||||
return result
|
||||
|
||||
def get_details(self, index: int, **kwargs):
|
||||
return {
|
||||
'name': self.node.properties.get('stepName'),
|
||||
"index": index,
|
||||
'run_time': self.context.get('run_time'),
|
||||
'history_message': [{'content': message.content, 'role': message.type} for message in
|
||||
(self.context.get('history_message') if self.context.get(
|
||||
'history_message') is not None else [])],
|
||||
'question': self.context.get('question'),
|
||||
'answer': self.context.get('answer'),
|
||||
'type': self.node.type,
|
||||
'message_tokens': self.context.get('message_tokens'),
|
||||
'answer_tokens': self.context.get('answer_tokens'),
|
||||
'status': self.status,
|
||||
'err_message': self.err_message,
|
||||
'first_frame_url': self.context.get('first_frame_url'),
|
||||
'last_frame_url': self.context.get('last_frame_url'),
|
||||
'dialogue_type': self.context.get('dialogue_type'),
|
||||
'negative_prompt': self.context.get('negative_prompt'),
|
||||
}
|
||||
|
|
@ -5,27 +5,26 @@ from typing import Type
|
|||
from rest_framework import serializers
|
||||
|
||||
from application.flow.i_step_node import INode, NodeResult
|
||||
|
||||
from common.util.field_message import ErrMessage
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
|
||||
class ImageUnderstandNodeSerializer(serializers.Serializer):
|
||||
model_id = serializers.CharField(required=True, label=_("Model id"))
|
||||
model_id = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Model id")))
|
||||
system = serializers.CharField(required=False, allow_blank=True, allow_null=True,
|
||||
label=_("Role Setting"))
|
||||
prompt = serializers.CharField(required=True, label=_("Prompt word"))
|
||||
error_messages=ErrMessage.char(_("Role Setting")))
|
||||
prompt = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Prompt word")))
|
||||
# 多轮对话数量
|
||||
dialogue_number = serializers.IntegerField(required=True, label=_("Number of multi-round conversations"))
|
||||
dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer(_("Number of multi-round conversations")))
|
||||
|
||||
dialogue_type = serializers.CharField(required=True, label=_("Conversation storage type"))
|
||||
dialogue_type = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Conversation storage type")))
|
||||
|
||||
is_result = serializers.BooleanField(required=False,
|
||||
label=_('Whether to return content'))
|
||||
is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean(_('Whether to return content')))
|
||||
|
||||
image_list = serializers.ListField(required=False, label=_("picture"))
|
||||
image_list = serializers.ListField(required=False, error_messages=ErrMessage.list(_("picture")))
|
||||
|
||||
model_params_setting = serializers.JSONField(required=False, default=dict,
|
||||
label=_("Model parameter settings"))
|
||||
error_messages=ErrMessage.json(_("Model parameter settings")))
|
||||
|
||||
|
||||
class IImageUnderstandNode(INode):
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
# coding=utf-8
|
||||
import base64
|
||||
import os
|
||||
import time
|
||||
from functools import reduce
|
||||
from imghdr import what
|
||||
from typing import List, Dict
|
||||
|
||||
from django.db.models import QuerySet
|
||||
|
|
@ -10,8 +10,9 @@ from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage, AI
|
|||
|
||||
from application.flow.i_step_node import NodeResult, INode
|
||||
from application.flow.step_node.image_understand_step_node.i_image_understand_node import IImageUnderstandNode
|
||||
from knowledge.models import File
|
||||
from models_provider.tools import get_model_instance_by_model_workspace_id
|
||||
from dataset.models import File
|
||||
from setting.models_provider.tools import get_model_instance_by_model_user_id
|
||||
from imghdr import what
|
||||
|
||||
|
||||
def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str):
|
||||
|
|
@ -59,9 +60,9 @@ def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wor
|
|||
|
||||
def file_id_to_base64(file_id: str):
|
||||
file = QuerySet(File).filter(id=file_id).first()
|
||||
file_bytes = file.get_bytes()
|
||||
file_bytes = file.get_byte()
|
||||
base64_image = base64.b64encode(file_bytes).decode("utf-8")
|
||||
return [base64_image, what(None, file_bytes)]
|
||||
return [base64_image, what(None, file_bytes.tobytes())]
|
||||
|
||||
|
||||
class BaseImageUnderstandNode(IImageUnderstandNode):
|
||||
|
|
@ -77,9 +78,10 @@ class BaseImageUnderstandNode(IImageUnderstandNode):
|
|||
image,
|
||||
**kwargs) -> NodeResult:
|
||||
# 处理不正确的参数
|
||||
workspace_id = self.workflow_manage.get_body().get('workspace_id')
|
||||
image_model = get_model_instance_by_model_workspace_id(model_id, workspace_id,
|
||||
**model_params_setting)
|
||||
if image is None or not isinstance(image, list):
|
||||
image = []
|
||||
print(model_params_setting)
|
||||
image_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'), **model_params_setting)
|
||||
# 执行详情中的历史消息不需要图片内容
|
||||
history_message = self.get_history_message_for_details(history_chat_record, dialogue_number)
|
||||
self.context['history_message'] = history_message
|
||||
|
|
@ -89,7 +91,7 @@ class BaseImageUnderstandNode(IImageUnderstandNode):
|
|||
message_list = self.generate_message_list(image_model, system, prompt,
|
||||
self.get_history_message(history_chat_record, dialogue_number), image)
|
||||
self.context['message_list'] = message_list
|
||||
self.generate_context_image(image)
|
||||
self.context['image_list'] = image
|
||||
self.context['dialogue_type'] = dialogue_type
|
||||
if stream:
|
||||
r = image_model.stream(message_list)
|
||||
|
|
@ -102,12 +104,6 @@ class BaseImageUnderstandNode(IImageUnderstandNode):
|
|||
'history_message': history_message, 'question': question.content}, {},
|
||||
_write_context=write_context)
|
||||
|
||||
def generate_context_image(self, image):
|
||||
if isinstance(image, str) and image.startswith('http'):
|
||||
self.context['image_list'] = [{'url': image}]
|
||||
elif image is not None and len(image) > 0:
|
||||
self.context['image_list'] = image
|
||||
|
||||
def get_history_message_for_details(self, history_chat_record, dialogue_number):
|
||||
start_index = len(history_chat_record) - dialogue_number
|
||||
history_message = reduce(lambda x, y: [*x, *y], [
|
||||
|
|
@ -134,7 +130,7 @@ class BaseImageUnderstandNode(IImageUnderstandNode):
|
|||
file_id_list = [image.get('file_id') for image in image_list]
|
||||
return HumanMessage(content=[
|
||||
{'type': 'text', 'text': data['question']},
|
||||
*[{'type': 'image_url', 'image_url': {'url': f'./oss/file/{file_id}'}} for file_id in file_id_list]
|
||||
*[{'type': 'image_url', 'image_url': {'url': f'/api/file/{file_id}'}} for file_id in file_id_list]
|
||||
|
||||
])
|
||||
return HumanMessage(content=chat_record.problem_text)
|
||||
|
|
@ -159,8 +155,7 @@ class BaseImageUnderstandNode(IImageUnderstandNode):
|
|||
return HumanMessage(
|
||||
content=[
|
||||
{'type': 'text', 'text': data['question']},
|
||||
*[{'type': 'image_url',
|
||||
'image_url': {'url': f'data:image/{base64_image[1]};base64,{base64_image[0]}'}} for
|
||||
*[{'type': 'image_url', 'image_url': {'url': f'data:image/{base64_image[1]};base64,{base64_image[0]}'}} for
|
||||
base64_image in image_base64_list]
|
||||
])
|
||||
return HumanMessage(content=chat_record.problem_text)
|
||||
|
|
@ -168,32 +163,24 @@ class BaseImageUnderstandNode(IImageUnderstandNode):
|
|||
def generate_prompt_question(self, prompt):
|
||||
return HumanMessage(self.workflow_manage.generate_prompt(prompt))
|
||||
|
||||
def _process_images(self, image):
|
||||
"""
|
||||
处理图像数据,转换为模型可识别的格式
|
||||
"""
|
||||
images = []
|
||||
if isinstance(image, str) and image.startswith('http'):
|
||||
images.append({'type': 'image_url', 'image_url': {'url': image}})
|
||||
elif image is not None and len(image) > 0:
|
||||
def generate_message_list(self, image_model, system: str, prompt: str, history_message, image):
|
||||
if image is not None and len(image) > 0:
|
||||
# 处理多张图片
|
||||
images = []
|
||||
for img in image:
|
||||
file_id = img['file_id']
|
||||
file = QuerySet(File).filter(id=file_id).first()
|
||||
image_bytes = file.get_bytes()
|
||||
image_bytes = file.get_byte()
|
||||
base64_image = base64.b64encode(image_bytes).decode("utf-8")
|
||||
image_format = what(None, image_bytes)
|
||||
images.append(
|
||||
{'type': 'image_url', 'image_url': {'url': f'data:image/{image_format};base64,{base64_image}'}})
|
||||
return images
|
||||
|
||||
def generate_message_list(self, image_model, system: str, prompt: str, history_message, image):
|
||||
prompt_text = self.workflow_manage.generate_prompt(prompt)
|
||||
images = self._process_images(image)
|
||||
|
||||
if images:
|
||||
messages = [HumanMessage(content=[{'type': 'text', 'text': prompt_text}, *images])]
|
||||
image_format = what(None, image_bytes.tobytes())
|
||||
images.append({'type': 'image_url', 'image_url': {'url': f'data:image/{image_format};base64,{base64_image}'}})
|
||||
messages = [HumanMessage(
|
||||
content=[
|
||||
{'type': 'text', 'text': self.workflow_manage.generate_prompt(prompt)},
|
||||
*images
|
||||
])]
|
||||
else:
|
||||
messages = [HumanMessage(prompt_text)]
|
||||
messages = [HumanMessage(self.workflow_manage.generate_prompt(prompt))]
|
||||
|
||||
if system is not None and len(system) > 0:
|
||||
return [
|
||||
|
|
|
|||
|
|
@ -1,6 +0,0 @@
|
|||
# coding=utf-8
|
||||
|
||||
|
||||
|
||||
|
||||
from .impl import *
|
||||
|
|
@ -1,46 +0,0 @@
|
|||
# coding=utf-8
|
||||
|
||||
from typing import Type
|
||||
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from rest_framework import serializers
|
||||
|
||||
from application.flow.i_step_node import INode, NodeResult
|
||||
|
||||
|
||||
class IntentBranchSerializer(serializers.Serializer):
|
||||
|
||||
id = serializers.CharField(required=True, label=_("Branch id"))
|
||||
content = serializers.CharField(required=True, label=_("content"))
|
||||
isOther = serializers.BooleanField(required=True, label=_("Branch Type"))
|
||||
|
||||
|
||||
class IntentNodeSerializer(serializers.Serializer):
|
||||
model_id = serializers.CharField(required=True, label=_("Model id"))
|
||||
content_list = serializers.ListField(required=True, label=_("Text content"))
|
||||
dialogue_number = serializers.IntegerField(required=True, label=
|
||||
_("Number of multi-round conversations"))
|
||||
model_params_setting = serializers.DictField(required=False,
|
||||
label=_("Model parameter settings"))
|
||||
branch = IntentBranchSerializer(many=True)
|
||||
|
||||
class IIntentNode(INode):
|
||||
type = 'intent-node'
|
||||
def save_context(self, details, workflow_manage):
|
||||
pass
|
||||
|
||||
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
|
||||
return IntentNodeSerializer
|
||||
|
||||
def _run(self):
|
||||
question = self.workflow_manage.get_reference_field(
|
||||
self.node_params_serializer.data.get('content_list')[0],
|
||||
self.node_params_serializer.data.get('content_list')[1:],
|
||||
)
|
||||
|
||||
return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data, user_input=str(question))
|
||||
|
||||
|
||||
def execute(self, model_id, dialogue_number, history_chat_record, user_input, branch,
|
||||
model_params_setting=None, **kwargs) -> NodeResult:
|
||||
pass
|
||||
|
|
@ -1,3 +0,0 @@
|
|||
|
||||
|
||||
from .base_intent_node import BaseIntentNode
|
||||
|
|
@ -1,261 +0,0 @@
|
|||
# coding=utf-8
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from typing import List, Dict, Any
|
||||
from functools import reduce
|
||||
|
||||
from django.db.models import QuerySet
|
||||
from langchain.schema import HumanMessage, SystemMessage
|
||||
|
||||
from application.flow.i_step_node import INode, NodeResult
|
||||
from application.flow.step_node.intent_node.i_intent_node import IIntentNode
|
||||
from models_provider.models import Model
|
||||
from models_provider.tools import get_model_instance_by_model_workspace_id, get_model_credential
|
||||
from .prompt_template import PROMPT_TEMPLATE
|
||||
|
||||
def get_default_model_params_setting(model_id):
|
||||
|
||||
model = QuerySet(Model).filter(id=model_id).first()
|
||||
credential = get_model_credential(model.provider, model.model_type, model.model_name)
|
||||
model_params_setting = credential.get_model_params_setting_form(
|
||||
model.model_name).get_default_form_data()
|
||||
return model_params_setting
|
||||
|
||||
|
||||
def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str):
|
||||
|
||||
chat_model = node_variable.get('chat_model')
|
||||
message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list'))
|
||||
answer_tokens = chat_model.get_num_tokens(answer)
|
||||
|
||||
node.context['message_tokens'] = message_tokens
|
||||
node.context['answer_tokens'] = answer_tokens
|
||||
node.context['answer'] = answer
|
||||
node.context['history_message'] = node_variable['history_message']
|
||||
node.context['user_input'] = node_variable['user_input']
|
||||
node.context['branch_id'] = node_variable.get('branch_id')
|
||||
node.context['reason'] = node_variable.get('reason')
|
||||
node.context['category'] = node_variable.get('category')
|
||||
node.context['run_time'] = time.time() - node.context['start_time']
|
||||
|
||||
|
||||
def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
|
||||
|
||||
response = node_variable.get('result')
|
||||
answer = response.content
|
||||
_write_context(node_variable, workflow_variable, node, workflow, answer)
|
||||
|
||||
|
||||
class BaseIntentNode(IIntentNode):
|
||||
|
||||
|
||||
def save_context(self, details, workflow_manage):
|
||||
|
||||
self.context['branch_id'] = details.get('branch_id')
|
||||
self.context['category'] = details.get('category')
|
||||
|
||||
|
||||
def execute(self, model_id, dialogue_number, history_chat_record, user_input, branch,
|
||||
model_params_setting=None, **kwargs) -> NodeResult:
|
||||
|
||||
# 设置默认模型参数
|
||||
if model_params_setting is None:
|
||||
model_params_setting = get_default_model_params_setting(model_id)
|
||||
|
||||
# 获取模型实例
|
||||
workspace_id = self.workflow_manage.get_body().get('workspace_id')
|
||||
chat_model = get_model_instance_by_model_workspace_id(
|
||||
model_id, workspace_id, **model_params_setting
|
||||
)
|
||||
|
||||
# 获取历史对话
|
||||
history_message = self.get_history_message(history_chat_record, dialogue_number)
|
||||
self.context['history_message'] = history_message
|
||||
|
||||
# 保存问题到上下文
|
||||
self.context['user_input'] = user_input
|
||||
|
||||
# 构建分类提示词
|
||||
prompt = self.build_classification_prompt(user_input, branch)
|
||||
|
||||
|
||||
# 生成消息列表
|
||||
system = self.build_system_prompt()
|
||||
message_list = self.generate_message_list(system, prompt, history_message)
|
||||
self.context['message_list'] = message_list
|
||||
|
||||
# 调用模型进行分类
|
||||
try:
|
||||
r = chat_model.invoke(message_list)
|
||||
classification_result = r.content.strip()
|
||||
# 解析分类结果获取分支信息
|
||||
matched_branch = self.parse_classification_result(classification_result, branch)
|
||||
|
||||
# 返回结果
|
||||
return NodeResult({
|
||||
'result': r,
|
||||
'chat_model': chat_model,
|
||||
'message_list': message_list,
|
||||
'history_message': history_message,
|
||||
'user_input': user_input,
|
||||
'branch_id': matched_branch['id'],
|
||||
'reason': self.parse_result_reason(r.content),
|
||||
'category': matched_branch.get('content', matched_branch['id'])
|
||||
}, {}, _write_context=write_context)
|
||||
|
||||
except Exception as e:
|
||||
# 错误处理:返回"其他"分支
|
||||
other_branch = self.find_other_branch(branch)
|
||||
if other_branch:
|
||||
return NodeResult({
|
||||
'branch_id': other_branch['id'],
|
||||
'category': other_branch.get('content', other_branch['id']),
|
||||
'error': str(e)
|
||||
}, {})
|
||||
else:
|
||||
raise Exception(f"error: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
def get_history_message(history_chat_record, dialogue_number):
|
||||
"""获取历史消息"""
|
||||
start_index = len(history_chat_record) - dialogue_number
|
||||
history_message = reduce(lambda x, y: [*x, *y], [
|
||||
[history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()]
|
||||
for index in
|
||||
range(start_index if start_index > 0 else 0, len(history_chat_record))], [])
|
||||
|
||||
for message in history_message:
|
||||
if isinstance(message.content, str):
|
||||
message.content = re.sub('<form_rander>[\d\D]*?<\/form_rander>', '', message.content)
|
||||
return history_message
|
||||
|
||||
|
||||
def build_system_prompt(self) -> str:
|
||||
"""构建系统提示词"""
|
||||
return "你是一个专业的意图识别助手,请根据用户输入和意图选项,准确识别用户的真实意图。"
|
||||
|
||||
def build_classification_prompt(self, user_input: str, branch: List[Dict]) -> str:
|
||||
"""构建分类提示词"""
|
||||
|
||||
classification_list = []
|
||||
|
||||
other_branch = self.find_other_branch(branch)
|
||||
# 添加其他分支
|
||||
if other_branch:
|
||||
classification_list.append({
|
||||
"classificationId": 0,
|
||||
"content": other_branch.get('content')
|
||||
})
|
||||
# 添加正常分支
|
||||
classification_id = 1
|
||||
for b in branch:
|
||||
if not b.get('isOther'):
|
||||
classification_list.append({
|
||||
"classificationId": classification_id,
|
||||
"content": b['content']
|
||||
})
|
||||
classification_id += 1
|
||||
|
||||
return PROMPT_TEMPLATE.format(
|
||||
classification_list=classification_list,
|
||||
user_input=user_input
|
||||
)
|
||||
|
||||
|
||||
def generate_message_list(self, system: str, prompt: str, history_message):
|
||||
"""生成消息列表"""
|
||||
if system is None or len(system) == 0:
|
||||
return [*history_message, HumanMessage(self.workflow_manage.generate_prompt(prompt))]
|
||||
else:
|
||||
return [SystemMessage(self.workflow_manage.generate_prompt(system)), *history_message,
|
||||
HumanMessage(self.workflow_manage.generate_prompt(prompt))]
|
||||
|
||||
def parse_classification_result(self, result: str, branch: List[Dict]) -> Dict[str, Any]:
|
||||
"""解析分类结果"""
|
||||
|
||||
other_branch = self.find_other_branch(branch)
|
||||
normal_intents = [
|
||||
b
|
||||
for b in branch
|
||||
if not b.get('isOther')
|
||||
]
|
||||
|
||||
def get_branch_by_id(category_id: int):
|
||||
if category_id == 0:
|
||||
return other_branch
|
||||
elif 1 <= category_id <= len(normal_intents):
|
||||
return normal_intents[category_id - 1]
|
||||
return None
|
||||
|
||||
try:
|
||||
result_json = json.loads(result)
|
||||
classification_id = result_json.get('classificationId')
|
||||
# 如果是 0 ,返回其他分支
|
||||
matched_branch = get_branch_by_id(classification_id)
|
||||
if matched_branch:
|
||||
return matched_branch
|
||||
|
||||
except Exception as e:
|
||||
# json 解析失败,re 提取
|
||||
numbers = re.findall(r'"classificationId":\s*(\d+)', result)
|
||||
if numbers:
|
||||
classification_id = int(numbers[0])
|
||||
|
||||
matched_branch = get_branch_by_id(classification_id)
|
||||
if matched_branch:
|
||||
return matched_branch
|
||||
|
||||
# 如果都解析失败,返回“other”
|
||||
return other_branch or (normal_intents[0] if normal_intents else {'id': 'unknown', 'content': 'unknown'})
|
||||
|
||||
def parse_result_reason(self, result: str):
|
||||
"""解析分类的原因"""
|
||||
try:
|
||||
result_json = json.loads(result)
|
||||
return result_json.get('reason', '')
|
||||
except Exception as e:
|
||||
reason_patterns = [
|
||||
r'"reason":\s*"([^"]*)"', # 标准格式
|
||||
r'"reason":\s*"([^"]*)', # 缺少结束引号
|
||||
r'"reason":\s*([^,}\n]*)', # 没有引号包围的内容
|
||||
]
|
||||
for pattern in reason_patterns:
|
||||
match = re.search(pattern, result, re.DOTALL)
|
||||
if match:
|
||||
reason = match.group(1).strip()
|
||||
# 清理可能的尾部字符
|
||||
reason = re.sub(r'["\s]*$', '', reason)
|
||||
return reason
|
||||
|
||||
return ''
|
||||
|
||||
def find_other_branch(self, branch: List[Dict]) -> Dict[str, Any] | None:
|
||||
"""查找其他分支"""
|
||||
for b in branch:
|
||||
if b.get('isOther'):
|
||||
return b
|
||||
return None
|
||||
|
||||
|
||||
def get_details(self, index: int, **kwargs):
|
||||
"""获取节点执行详情"""
|
||||
return {
|
||||
'name': self.node.properties.get('stepName'),
|
||||
'index': index,
|
||||
'run_time': self.context.get('run_time'),
|
||||
'system': self.context.get('system'),
|
||||
'history_message': [
|
||||
{'content': message.content, 'role': message.type}
|
||||
for message in (self.context.get('history_message') or [])
|
||||
],
|
||||
'user_input': self.context.get('user_input'),
|
||||
'answer': self.context.get('answer'),
|
||||
'branch_id': self.context.get('branch_id'),
|
||||
'category': self.context.get('category'),
|
||||
'type': self.node.type,
|
||||
'message_tokens': self.context.get('message_tokens'),
|
||||
'answer_tokens': self.context.get('answer_tokens'),
|
||||
'status': self.status,
|
||||
'err_message': self.err_message
|
||||
}
|
||||
|
|
@ -1,32 +0,0 @@
|
|||
|
||||
|
||||
|
||||
|
||||
PROMPT_TEMPLATE = """
|
||||
# Role
|
||||
You are an intention classification expert, good at being able to judge which classification the user's input belongs to.
|
||||
|
||||
## Skills
|
||||
Skill 1: Clearly determine which of the following intention classifications the user's input belongs to.
|
||||
Intention classification list:
|
||||
{classification_list}
|
||||
|
||||
Note:
|
||||
- Please determine the match only between the user's input content and the Intention classification list content, without judging or categorizing the match with the classification ID.
|
||||
- **When classifying, you must give higher weight to the context and intent continuity shown in the historical conversation. Do not rely solely on the literal meaning of the current input; instead, prioritize the most consistent classification with the previous dialogue flow.**
|
||||
|
||||
## User Input
|
||||
{user_input}
|
||||
|
||||
## Reply requirements
|
||||
- The answer must be returned in JSON format.
|
||||
- Strictly ensure that the output is in a valid JSON format.
|
||||
- Do not add prefix ```json or suffix ```
|
||||
- The answer needs to include the following fields such as:
|
||||
{{
|
||||
"classificationId": 0,
|
||||
"reason": ""
|
||||
}}
|
||||
|
||||
## Limit
|
||||
- Please do not reply in text."""
|
||||
|
|
@ -1,9 +0,0 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎虎
|
||||
@file: __init__.py.py
|
||||
@date:2025/9/15 12:08
|
||||
@desc:
|
||||
"""
|
||||
from .impl import *
|
||||
|
|
@ -1,39 +0,0 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎虎
|
||||
@file: i_loop_break_node.py
|
||||
@date:2025/9/15 12:14
|
||||
@desc:
|
||||
"""
|
||||
from typing import Type
|
||||
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from rest_framework import serializers
|
||||
|
||||
from application.flow.i_step_node import INode
|
||||
from application.flow.i_step_node import NodeResult
|
||||
|
||||
|
||||
class ConditionSerializer(serializers.Serializer):
|
||||
compare = serializers.CharField(required=True, label=_("Comparator"))
|
||||
value = serializers.CharField(required=True, label=_("value"))
|
||||
field = serializers.ListField(required=True, label=_("Fields"))
|
||||
|
||||
|
||||
class LoopBreakNodeSerializer(serializers.Serializer):
|
||||
condition = serializers.CharField(required=True, label=_("Condition or|and"))
|
||||
condition_list = ConditionSerializer(many=True)
|
||||
|
||||
|
||||
class ILoopBreakNode(INode):
|
||||
type = 'loop-break-node'
|
||||
|
||||
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
|
||||
return LoopBreakNodeSerializer
|
||||
|
||||
def _run(self):
|
||||
return self.execute(**self.node_params_serializer.data)
|
||||
|
||||
def execute(self, condition, condition_list, **kwargs) -> NodeResult:
|
||||
pass
|
||||
|
|
@ -1,9 +0,0 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎虎
|
||||
@file: __init__.py.py
|
||||
@date:2025/9/15 12:16
|
||||
@desc:
|
||||
"""
|
||||
from .base_loop_break_node import BaseLoopBreakNode
|
||||
|
|
@ -1,59 +0,0 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎虎
|
||||
@file: base_loop_break_node.py
|
||||
@date:2025/9/15 12:17
|
||||
@desc:
|
||||
"""
|
||||
import time
|
||||
from typing import List, Dict
|
||||
|
||||
from application.flow.compare import compare_handle_list
|
||||
from application.flow.i_step_node import NodeResult
|
||||
from application.flow.step_node.loop_break_node.i_loop_break_node import ILoopBreakNode
|
||||
|
||||
|
||||
def _write_context(step_variable: Dict, global_variable: Dict, node, workflow):
|
||||
if step_variable.get("is_break"):
|
||||
yield "BREAK"
|
||||
|
||||
node.context['run_time'] = time.time() - node.context['start_time']
|
||||
|
||||
|
||||
class BaseLoopBreakNode(ILoopBreakNode):
|
||||
def execute(self, condition, condition_list, **kwargs) -> NodeResult:
|
||||
r = [self.assertion(row.get('field'), row.get('compare'), row.get('value')) for row in
|
||||
condition_list]
|
||||
is_break = all(r) if condition == 'and' else any(r)
|
||||
if is_break:
|
||||
self.node_params['is_result'] = True
|
||||
self.context['is_break'] = is_break
|
||||
return NodeResult({'is_break': is_break}, {},
|
||||
_write_context=_write_context,
|
||||
_is_interrupt=lambda n, v, w: is_break)
|
||||
|
||||
def assertion(self, field_list: List[str], compare: str, value):
|
||||
try:
|
||||
value = self.workflow_manage.generate_prompt(value)
|
||||
except Exception as e:
|
||||
pass
|
||||
field_value = None
|
||||
try:
|
||||
field_value = self.workflow_manage.get_reference_field(field_list[0], field_list[1:])
|
||||
except Exception as e:
|
||||
pass
|
||||
for compare_handler in compare_handle_list:
|
||||
if compare_handler.support(field_list[0], field_list[1:], field_value, compare, value):
|
||||
return compare_handler.compare(field_value, compare, value)
|
||||
|
||||
def get_details(self, index: int, **kwargs):
|
||||
return {
|
||||
'name': self.node.properties.get('stepName'),
|
||||
"index": index,
|
||||
'is_break': self.context.get('is_break'),
|
||||
'run_time': self.context.get('run_time'),
|
||||
'type': self.node.type,
|
||||
'status': self.status,
|
||||
'err_message': self.err_message
|
||||
}
|
||||
|
|
@ -1,9 +0,0 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎虎
|
||||
@file: __init__.py.py
|
||||
@date:2025/9/15 12:08
|
||||
@desc:
|
||||
"""
|
||||
from .impl import *
|
||||
|
|
@ -1,38 +0,0 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎虎
|
||||
@file: i_loop_continue_node.py
|
||||
@date:2025/9/15 12:13
|
||||
@desc:
|
||||
"""
|
||||
from typing import Type
|
||||
|
||||
from rest_framework import serializers
|
||||
|
||||
from application.flow.i_step_node import INode, NodeResult
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
|
||||
class ConditionSerializer(serializers.Serializer):
|
||||
compare = serializers.CharField(required=True, label=_("Comparator"))
|
||||
value = serializers.CharField(required=True, label=_("value"))
|
||||
field = serializers.ListField(required=True, label=_("Fields"))
|
||||
|
||||
|
||||
class LoopContinueNodeSerializer(serializers.Serializer):
|
||||
condition = serializers.CharField(required=True, label=_("Condition or|and"))
|
||||
condition_list = ConditionSerializer(many=True)
|
||||
|
||||
|
||||
class ILoopContinueNode(INode):
|
||||
type = 'loop-continue-node'
|
||||
|
||||
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
|
||||
return LoopContinueNodeSerializer
|
||||
|
||||
def _run(self):
|
||||
return self.execute(**self.node_params_serializer.data)
|
||||
|
||||
def execute(self, condition, condition_list, **kwargs) -> NodeResult:
|
||||
pass
|
||||
|
|
@ -1,9 +0,0 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎虎
|
||||
@file: __init__.py.py
|
||||
@date:2025/9/15 12:13
|
||||
@desc:
|
||||
"""
|
||||
from .base_loop_continue_node import BaseLoopContinueNode
|
||||
|
|
@ -1,49 +0,0 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎虎
|
||||
@file: base_loop_continue_node.py
|
||||
@date:2025/9/15 12:13
|
||||
@desc:
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
from application.flow.compare import compare_handle_list
|
||||
from application.flow.i_step_node import NodeResult
|
||||
from application.flow.step_node.loop_continue_node.i_loop_continue_node import ILoopContinueNode
|
||||
|
||||
|
||||
class BaseLoopContinueNode(ILoopContinueNode):
|
||||
def execute(self, condition, condition_list, **kwargs) -> NodeResult:
|
||||
condition_list = [self.assertion(row.get('field'), row.get('compare'), row.get('value')) for row in
|
||||
condition_list]
|
||||
is_continue = all(condition_list) if condition == 'and' else any(condition_list)
|
||||
self.context['is_continue'] = is_continue
|
||||
if is_continue:
|
||||
return NodeResult({'is_continue': is_continue, 'branch_id': 'continue'}, {})
|
||||
return NodeResult({'is_continue': is_continue}, {})
|
||||
|
||||
def assertion(self, field_list: List[str], compare: str, value):
|
||||
try:
|
||||
value = self.workflow_manage.generate_prompt(value)
|
||||
except Exception as e:
|
||||
pass
|
||||
field_value = None
|
||||
try:
|
||||
field_value = self.workflow_manage.get_reference_field(field_list[0], field_list[1:])
|
||||
except Exception as e:
|
||||
pass
|
||||
for compare_handler in compare_handle_list:
|
||||
if compare_handler.support(field_list[0], field_list[1:], field_value, compare, value):
|
||||
return compare_handler.compare(field_value, compare, value)
|
||||
|
||||
def get_details(self, index: int, **kwargs):
|
||||
return {
|
||||
'name': self.node.properties.get('stepName'),
|
||||
"index": index,
|
||||
"is_continue": self.context.get('is_continue'),
|
||||
'run_time': self.context.get('run_time'),
|
||||
'type': self.node.type,
|
||||
'status': self.status,
|
||||
'err_message': self.err_message
|
||||
}
|
||||
|
|
@ -1,56 +0,0 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: i_loop_node.py
|
||||
@date:2025/3/11 18:19
|
||||
@desc:
|
||||
"""
|
||||
from typing import Type
|
||||
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from rest_framework import serializers
|
||||
|
||||
from application.flow.i_step_node import INode, NodeResult
|
||||
from common.exception.app_exception import AppApiException
|
||||
|
||||
|
||||
class ILoopNodeSerializer(serializers.Serializer):
|
||||
loop_type = serializers.CharField(required=True, label=_("loop_type"))
|
||||
array = serializers.ListField(required=False, allow_null=True,
|
||||
label=_("array"))
|
||||
number = serializers.IntegerField(required=False, allow_null=True,
|
||||
label=_("number"))
|
||||
loop_body = serializers.DictField(required=True, label="循环体")
|
||||
|
||||
def is_valid(self, *, raise_exception=False):
|
||||
super().is_valid(raise_exception=True)
|
||||
loop_type = self.data.get('loop_type')
|
||||
if loop_type == 'ARRAY':
|
||||
array = self.data.get('array')
|
||||
if array is None or len(array) == 0:
|
||||
message = _('{field}, this field is required.', field='array')
|
||||
raise AppApiException(500, message)
|
||||
elif loop_type == 'NUMBER':
|
||||
number = self.data.get('number')
|
||||
if number is None:
|
||||
message = _('{field}, this field is required.', field='number')
|
||||
raise AppApiException(500, message)
|
||||
|
||||
|
||||
class ILoopNode(INode):
|
||||
type = 'loop-node'
|
||||
|
||||
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
|
||||
return ILoopNodeSerializer
|
||||
|
||||
def _run(self):
|
||||
array = self.node_params_serializer.data.get('array')
|
||||
if self.node_params_serializer.data.get('loop_type') == 'ARRAY':
|
||||
array = self.workflow_manage.get_reference_field(
|
||||
array[0],
|
||||
array[1:])
|
||||
return self.execute(**{**self.node_params_serializer.data, "array": array}, **self.flow_params_serializer.data)
|
||||
|
||||
def execute(self, loop_type, array, number, loop_body, stream, **kwargs) -> NodeResult:
|
||||
pass
|
||||
|
|
@ -1,9 +0,0 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: __init__.py.py
|
||||
@date:2025/3/11 18:24
|
||||
@desc:
|
||||
"""
|
||||
from .base_loop_node import BaseLoopNode
|
||||
|
|
@ -1,298 +0,0 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: base_loop_node.py
|
||||
@date:2025/3/11 18:24
|
||||
@desc:
|
||||
"""
|
||||
import time
|
||||
from typing import Dict, List
|
||||
|
||||
from django.utils.translation import gettext as _
|
||||
|
||||
from application.flow.common import Answer
|
||||
from application.flow.i_step_node import NodeResult, WorkFlowPostHandler, INode
|
||||
from application.flow.step_node.loop_node.i_loop_node import ILoopNode
|
||||
from application.flow.tools import Reasoning
|
||||
from application.models import ChatRecord
|
||||
from common.handle.impl.response.loop_to_response import LoopToResponse
|
||||
from maxkb.const import CONFIG
|
||||
|
||||
max_loop_count = int(CONFIG.get("WORKFLOW_LOOP_NODE_MAX_LOOP_COUNT", 500))
|
||||
|
||||
|
||||
def _is_interrupt_exec(node, node_variable: Dict, workflow_variable: Dict):
|
||||
return node.context.get('is_interrupt_exec', False)
|
||||
|
||||
|
||||
def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str,
|
||||
reasoning_content: str):
|
||||
node.context['answer'] = answer
|
||||
node.context['run_time'] = time.time() - node.context['start_time']
|
||||
node.context['reasoning_content'] = reasoning_content
|
||||
if workflow.is_result(node, NodeResult(node_variable, workflow_variable)):
|
||||
node.answer_text = answer
|
||||
|
||||
|
||||
def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
|
||||
"""
|
||||
写入上下文数据 (流式)
|
||||
@param node_variable: 节点数据
|
||||
@param workflow_variable: 全局数据
|
||||
@param node: 节点
|
||||
@param workflow: 工作流管理器
|
||||
"""
|
||||
|
||||
response = node_variable.get('result')
|
||||
workflow_manage = node_variable.get('workflow_manage')
|
||||
answer = ''
|
||||
reasoning_content = ''
|
||||
for chunk in response:
|
||||
content_chunk = chunk.get('content', '')
|
||||
reasoning_content_chunk = chunk.get('reasoning_content', '')
|
||||
reasoning_content += reasoning_content_chunk
|
||||
answer += content_chunk
|
||||
yield {'content': content_chunk,
|
||||
'reasoning_content': reasoning_content_chunk}
|
||||
runtime_details = workflow_manage.get_runtime_details()
|
||||
_write_context(node_variable, workflow_variable, node, workflow, answer, reasoning_content)
|
||||
|
||||
|
||||
def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
|
||||
"""
|
||||
写入上下文数据
|
||||
@param node_variable: 节点数据
|
||||
@param workflow_variable: 全局数据
|
||||
@param node: 节点实例对象
|
||||
@param workflow: 工作流管理器
|
||||
"""
|
||||
response = node_variable.get('result')
|
||||
model_setting = node.context.get('model_setting',
|
||||
{'reasoning_content_enable': False, 'reasoning_content_end': '</think>',
|
||||
'reasoning_content_start': '<think>'})
|
||||
reasoning = Reasoning(model_setting.get('reasoning_content_start'), model_setting.get('reasoning_content_end'))
|
||||
reasoning_result = reasoning.get_reasoning_content(response)
|
||||
reasoning_result_end = reasoning.get_end_reasoning_content()
|
||||
content = reasoning_result.get('content') + reasoning_result_end.get('content')
|
||||
if 'reasoning_content' in response.response_metadata:
|
||||
reasoning_content = response.response_metadata.get('reasoning_content', '')
|
||||
else:
|
||||
reasoning_content = reasoning_result.get('reasoning_content') + reasoning_result_end.get('reasoning_content')
|
||||
_write_context(node_variable, workflow_variable, node, workflow, content, reasoning_content)
|
||||
|
||||
|
||||
def get_answer_list(instance, child_node_node_dict, runtime_node_id):
|
||||
answer_list = instance.get_record_answer_list()
|
||||
for a in answer_list:
|
||||
_v = child_node_node_dict.get(a.get('runtime_node_id'))
|
||||
if _v:
|
||||
a['runtime_node_id'] = runtime_node_id
|
||||
a['child_node'] = _v
|
||||
return answer_list
|
||||
|
||||
|
||||
def insert_or_replace(arr, index, value):
|
||||
if index < len(arr):
|
||||
arr[index] = value # 替换
|
||||
else:
|
||||
# 在末尾插入足够多的None,然后替换最后一个
|
||||
arr.extend([None] * (index - len(arr) + 1))
|
||||
arr[index] = value
|
||||
return arr
|
||||
|
||||
|
||||
def generate_loop_number(number: int):
|
||||
def i(current_index: int):
|
||||
return iter([(index, index) for index in range(current_index, number)])
|
||||
|
||||
return i
|
||||
|
||||
|
||||
def generate_loop_array(array):
|
||||
def i(current_index: int):
|
||||
return iter([(array[index], index) for index in range(current_index, len(array))])
|
||||
|
||||
return i
|
||||
|
||||
|
||||
def generate_while_loop(current_index: int):
|
||||
index = current_index
|
||||
while True:
|
||||
yield index, index
|
||||
index += 1
|
||||
|
||||
|
||||
def loop(workflow_manage_new_instance, node: INode, generate_loop):
|
||||
loop_global_data = {}
|
||||
break_outer = False
|
||||
is_interrupt_exec = False
|
||||
loop_node_data = node.context.get('loop_node_data') or []
|
||||
loop_answer_data = node.context.get("loop_answer_data") or []
|
||||
start_index = node.context.get("current_index") or 0
|
||||
current_index = start_index
|
||||
node_params = node.node_params
|
||||
start_node_id = node_params.get('child_node', {}).get('runtime_node_id')
|
||||
loop_type = node_params.get('loop_type')
|
||||
start_node_data = None
|
||||
chat_record = None
|
||||
child_node = None
|
||||
if start_node_id:
|
||||
chat_record_id = node_params.get('child_node', {}).get('chat_record_id')
|
||||
child_node = node_params.get('child_node', {}).get('child_node')
|
||||
start_node_data = node_params.get('node_data')
|
||||
chat_record = ChatRecord(id=chat_record_id, answer_text_list=[], answer_text='',
|
||||
details=loop_node_data[current_index])
|
||||
|
||||
for item, index in generate_loop(current_index):
|
||||
if 0 < max_loop_count <= index - start_index and loop_type == 'LOOP':
|
||||
raise Exception(_('Exceeding the maximum number of cycles'))
|
||||
"""
|
||||
指定次数循环
|
||||
@return:
|
||||
"""
|
||||
instance = workflow_manage_new_instance({'index': index, 'item': item}, loop_global_data, start_node_id,
|
||||
start_node_data, chat_record, child_node)
|
||||
response = instance.stream()
|
||||
answer = ''
|
||||
current_index = index
|
||||
reasoning_content = ''
|
||||
child_node_node_dict = {}
|
||||
for chunk in response:
|
||||
if chunk.get('node_type') == 'loop-break-node' and chunk.get('content', '') == 'BREAK':
|
||||
break_outer = True
|
||||
continue
|
||||
child_node = chunk.get('child_node')
|
||||
runtime_node_id = chunk.get('runtime_node_id', '')
|
||||
chat_record_id = chunk.get('chat_record_id', '')
|
||||
child_node_node_dict[runtime_node_id] = {
|
||||
'runtime_node_id': runtime_node_id,
|
||||
'chat_record_id': chat_record_id,
|
||||
'child_node': child_node}
|
||||
content_chunk = (chunk.get('content', '') or '')
|
||||
reasoning_content_chunk = (chunk.get('reasoning_content', '') or '')
|
||||
reasoning_content += reasoning_content_chunk
|
||||
answer += content_chunk
|
||||
yield chunk
|
||||
if chunk.get('node_status', "SUCCESS") == 'ERROR':
|
||||
insert_or_replace(loop_node_data, index, instance.get_runtime_details())
|
||||
insert_or_replace(loop_answer_data, index,
|
||||
get_answer_list(instance, child_node_node_dict, node.runtime_node_id))
|
||||
node.context['is_interrupt_exec'] = is_interrupt_exec
|
||||
node.context['loop_node_data'] = loop_node_data
|
||||
node.context['loop_answer_data'] = loop_answer_data
|
||||
node.context["index"] = current_index
|
||||
node.context["item"] = current_index
|
||||
node.status = 500
|
||||
node.err_message = chunk.get('content')
|
||||
return
|
||||
node_type = chunk.get('node_type')
|
||||
if node_type == 'form-node':
|
||||
break_outer = True
|
||||
is_interrupt_exec = True
|
||||
start_node_id = None
|
||||
start_node_data = None
|
||||
chat_record = None
|
||||
child_node = None
|
||||
insert_or_replace(loop_node_data, index, instance.get_runtime_details())
|
||||
insert_or_replace(loop_answer_data, index,
|
||||
get_answer_list(instance, child_node_node_dict, node.runtime_node_id))
|
||||
if break_outer:
|
||||
break
|
||||
node.context['is_interrupt_exec'] = is_interrupt_exec
|
||||
node.context['loop_node_data'] = loop_node_data
|
||||
node.context['loop_answer_data'] = loop_answer_data
|
||||
node.context["index"] = current_index
|
||||
node.context["item"] = current_index
|
||||
|
||||
|
||||
def get_write_context(loop_type, array, number, loop_body, stream):
|
||||
def inner_write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
|
||||
if loop_type == 'ARRAY':
|
||||
return loop(node_variable['workflow_manage_new_instance'], node, generate_loop_array(array))
|
||||
if loop_type == 'LOOP':
|
||||
return loop(node_variable['workflow_manage_new_instance'], node, generate_while_loop)
|
||||
return loop(node_variable['workflow_manage_new_instance'], node, generate_loop_number(number))
|
||||
|
||||
return inner_write_context
|
||||
|
||||
|
||||
class LoopWorkFlowPostHandler(WorkFlowPostHandler):
|
||||
def handler(self, workflow):
|
||||
pass
|
||||
|
||||
|
||||
class BaseLoopNode(ILoopNode):
|
||||
def save_context(self, details, workflow_manage):
|
||||
self.context['loop_context_data'] = details.get('loop_context_data')
|
||||
self.context['loop_answer_data'] = details.get('loop_answer_data')
|
||||
self.context['loop_node_data'] = details.get('loop_node_data')
|
||||
self.context['result'] = details.get('result')
|
||||
self.context['params'] = details.get('params')
|
||||
self.context['run_time'] = details.get('run_time')
|
||||
self.context['index'] = details.get('current_index')
|
||||
self.context['item'] = details.get('current_item')
|
||||
for key, value in (details.get('loop_context_data') or {}).items():
|
||||
self.context[key] = value
|
||||
self.answer_text = ""
|
||||
|
||||
def get_answer_list(self) -> List[Answer] | None:
|
||||
result = []
|
||||
for answer_list in (self.context.get("loop_answer_data") or []):
|
||||
for a in answer_list:
|
||||
if isinstance(a, dict):
|
||||
result.append(Answer(**a))
|
||||
|
||||
return result
|
||||
|
||||
def get_loop_context(self):
|
||||
return self.context
|
||||
|
||||
def execute(self, loop_type, array, number, loop_body, stream, **kwargs) -> NodeResult:
|
||||
from application.flow.loop_workflow_manage import LoopWorkflowManage, Workflow
|
||||
def workflow_manage_new_instance(loop_data, global_data, start_node_id=None,
|
||||
start_node_data=None, chat_record=None, child_node=None):
|
||||
workflow_manage = LoopWorkflowManage(Workflow.new_instance(loop_body), self.workflow_manage.params,
|
||||
LoopWorkFlowPostHandler(
|
||||
self.workflow_manage.work_flow_post_handler.chat_info),
|
||||
self.workflow_manage,
|
||||
loop_data,
|
||||
self.get_loop_context,
|
||||
base_to_response=LoopToResponse(),
|
||||
start_node_id=start_node_id,
|
||||
start_node_data=start_node_data,
|
||||
chat_record=chat_record,
|
||||
child_node=child_node
|
||||
)
|
||||
|
||||
return workflow_manage
|
||||
|
||||
return NodeResult({'workflow_manage_new_instance': workflow_manage_new_instance}, {},
|
||||
_write_context=get_write_context(loop_type, array, number, loop_body, stream),
|
||||
_is_interrupt=_is_interrupt_exec)
|
||||
|
||||
def get_loop_context_data(self):
|
||||
fields = self.node.properties.get('config', []).get('fields', []) or []
|
||||
return {f.get('value'): self.context.get(f.get('value')) for f in fields if
|
||||
self.context.get(f.get('value')) is not None}
|
||||
|
||||
def get_details(self, index: int, **kwargs):
|
||||
|
||||
return {
|
||||
'name': self.node.properties.get('stepName'),
|
||||
"index": index,
|
||||
"result": self.context.get('result'),
|
||||
'array': self.node_params_serializer.data.get('array'),
|
||||
'number': self.node_params_serializer.data.get('number'),
|
||||
"params": self.context.get('params'),
|
||||
'run_time': self.context.get('run_time'),
|
||||
'type': self.node.type,
|
||||
'current_index': self.context.get("index"),
|
||||
"current_item": self.context.get("item"),
|
||||
'loop_type': self.node_params_serializer.data.get('loop_type'),
|
||||
'status': self.status,
|
||||
'loop_context_data': self.get_loop_context_data(),
|
||||
'loop_node_data': self.context.get("loop_node_data"),
|
||||
'loop_answer_data': self.context.get("loop_answer_data"),
|
||||
'err_message': self.err_message
|
||||
}
|
||||
|
|
@ -1,20 +0,0 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: i_start_node.py
|
||||
@date:2024/6/3 16:54
|
||||
@desc:
|
||||
"""
|
||||
|
||||
from application.flow.i_step_node import INode, NodeResult
|
||||
|
||||
|
||||
class ILoopStarNode(INode):
|
||||
type = 'loop-start-node'
|
||||
|
||||
def _run(self):
|
||||
return self.execute(**self.flow_params_serializer.data)
|
||||
|
||||
def execute(self, **kwargs) -> NodeResult:
|
||||
pass
|
||||
|
|
@ -1,9 +0,0 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: __init__.py
|
||||
@date:2024/6/11 15:36
|
||||
@desc:
|
||||
"""
|
||||
from .base_start_node import BaseLoopStartStepNode
|
||||
|
|
@ -1,55 +0,0 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: base_start_node.py
|
||||
@date:2024/6/3 17:17
|
||||
@desc:
|
||||
"""
|
||||
from typing import Type
|
||||
|
||||
from rest_framework import serializers
|
||||
|
||||
from application.flow.i_step_node import NodeResult
|
||||
from application.flow.step_node.loop_start_node.i_loop_start_node import ILoopStarNode
|
||||
|
||||
|
||||
class BaseLoopStartStepNode(ILoopStarNode):
|
||||
def save_context(self, details, workflow_manage):
|
||||
self.context['index'] = details.get('current_index')
|
||||
self.context['item'] = details.get('current_item')
|
||||
|
||||
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
|
||||
pass
|
||||
|
||||
def execute(self, **kwargs) -> NodeResult:
|
||||
"""
|
||||
开始节点 初始化全局变量
|
||||
"""
|
||||
loop_params = self.workflow_manage.loop_params
|
||||
node_variable = {
|
||||
'index': loop_params.get("index"),
|
||||
'item': loop_params.get("item")
|
||||
}
|
||||
self.workflow_manage.chat_context = self.workflow_manage.get_chat_info().get_chat_variable()
|
||||
return NodeResult(node_variable, {})
|
||||
|
||||
def get_details(self, index: int, **kwargs):
|
||||
global_fields = []
|
||||
for field in self.node.properties.get('config')['globalFields']:
|
||||
key = field['value']
|
||||
global_fields.append({
|
||||
'label': field['label'],
|
||||
'key': key,
|
||||
'value': self.workflow_manage.context[key] if key in self.workflow_manage.context else ''
|
||||
})
|
||||
return {
|
||||
'name': self.node.properties.get('stepName'),
|
||||
"index": index,
|
||||
"current_index": self.context.get('index'),
|
||||
"current_item": self.context.get('item'),
|
||||
'run_time': self.context.get('run_time'),
|
||||
'type': self.node.type,
|
||||
'status': self.status,
|
||||
'err_message': self.err_message,
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue