mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 10:02:46 +00:00
Compare commits
156 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
847755b1c2 | ||
|
|
b57455d0ee | ||
|
|
2a257edff9 | ||
|
|
d47699331c | ||
|
|
90c64d77dd | ||
|
|
e1ada3ffe2 | ||
|
|
b62c79fda6 | ||
|
|
3d9e7dd4b1 | ||
|
|
8ff15865a7 | ||
|
|
48899d55d1 | ||
|
|
1cc4107bfe | ||
|
|
b13cd03706 | ||
|
|
69f024492b | ||
|
|
a9c46cd7e0 | ||
|
|
a9e9f5b085 | ||
|
|
e12b1fe14e | ||
|
|
7ce66a7bf3 | ||
|
|
decd3395db | ||
|
|
9d7a383348 | ||
|
|
8f7d91798b | ||
|
|
81a3af2c8b | ||
|
|
2ec0d22b14 | ||
|
|
27a77dc657 | ||
|
|
187e9c1e4e | ||
|
|
e5bab10824 | ||
|
|
347f4a0b03 | ||
|
|
289ebf42a6 | ||
|
|
71fdce08d7 | ||
|
|
adc5af9cef | ||
|
|
ce2ab322f6 | ||
|
|
a7e31b94c7 | ||
|
|
8498687794 | ||
|
|
190ca3e198 | ||
|
|
c1ddec1a61 | ||
|
|
1ba8077e95 | ||
|
|
9a42bd2302 | ||
|
|
35b662a52d | ||
|
|
a4faf52261 | ||
|
|
a30316d87a | ||
|
|
a4d10cbe3b | ||
|
|
ceccf9f1fa | ||
|
|
3964db20dc | ||
|
|
57ada0708f | ||
|
|
a1a92a833a | ||
|
|
5e0d8048f9 | ||
|
|
8903b35aec | ||
|
|
fa4f7e99fd | ||
|
|
b0630b3ddd | ||
|
|
b2bf69740c | ||
|
|
8cf66b9eca | ||
|
|
1db8577ca6 | ||
|
|
949e4dea9e | ||
|
|
c12988bc8a | ||
|
|
2728453f6c | ||
|
|
ccf43bbcd9 | ||
|
|
8d8de53e38 | ||
|
|
7faf556771 | ||
|
|
76ec8ad6f6 | ||
|
|
0cf05a76a0 | ||
|
|
357edbfbe0 | ||
|
|
0609a9afc8 | ||
|
|
96e59a018f | ||
|
|
79b2de8893 | ||
|
|
0c7cca035e | ||
|
|
00a3e5ddc3 | ||
|
|
a6533c0db7 | ||
|
|
704077d066 | ||
|
|
59ee0c1270 | ||
|
|
b37cc3ba1c | ||
|
|
dfe6d0a91b | ||
|
|
5813eedd4f | ||
|
|
363150380d | ||
|
|
47f9c04664 | ||
|
|
17cd88edda | ||
|
|
d33b620dc8 | ||
|
|
b95cb20704 | ||
|
|
4ff1944b60 | ||
|
|
1c6b0f8a86 | ||
|
|
d85801fe58 | ||
|
|
e79e7d505d | ||
|
|
33b1cd65b0 | ||
|
|
8d503c8bf8 | ||
|
|
df7f922013 | ||
|
|
a0541203e4 | ||
|
|
fb64731cd8 | ||
|
|
d960a18711 | ||
|
|
ee35cc21e9 | ||
|
|
e4a60daa17 | ||
|
|
7bcb770ee5 | ||
|
|
b5b09dc8b4 | ||
|
|
b1f6092620 | ||
|
|
c8441cfd73 | ||
|
|
7e4b147576 | ||
|
|
9cd082089a | ||
|
|
d4541e23f9 | ||
|
|
5e7e91cced | ||
|
|
d9787bb548 | ||
|
|
131b5b3bbe | ||
|
|
e58f95832b | ||
|
|
f24337d5f3 | ||
|
|
d6f1d25b59 | ||
|
|
0ec198fa43 | ||
|
|
77e96624ee | ||
|
|
5e02809db2 | ||
|
|
6fe001fcf8 | ||
|
|
f646102262 | ||
|
|
b97f4e16ba | ||
|
|
0c14306889 | ||
|
|
f1d043f67b | ||
|
|
c39a6e81d7 | ||
|
|
9c56c7e198 | ||
|
|
6484fef8ea | ||
|
|
8a194481ac | ||
|
|
072b817792 | ||
|
|
d32f7d36a6 | ||
|
|
54c9d4e725 | ||
|
|
d2637c3de2 | ||
|
|
2550324003 | ||
|
|
bf52dd8174 | ||
|
|
2ecec57d2f | ||
|
|
b5fda0e020 | ||
|
|
45a60cd9a7 | ||
|
|
2b82675853 | ||
|
|
1d3bf1ca73 | ||
|
|
04cb9c96fe | ||
|
|
45d8ac2eee | ||
|
|
eb60331c88 | ||
|
|
8fc9b0a22d | ||
|
|
ec5fd9e343 | ||
|
|
1e49939c38 | ||
|
|
5a7b23aa00 | ||
|
|
791505b7b8 | ||
|
|
da10649adb | ||
|
|
a025e3960d | ||
|
|
98ed348de9 | ||
|
|
b9dcc36b31 | ||
|
|
7a3d3844ae | ||
|
|
6ea2cf149a | ||
|
|
c30677d8b0 | ||
|
|
a1a2fb5628 | ||
|
|
f9cb0e24d6 | ||
|
|
2dc42183cb | ||
|
|
c781c11d26 | ||
|
|
e178cfe5c0 | ||
|
|
3b24373cd0 | ||
|
|
125ed8aa7a | ||
|
|
0b60a03e5d | ||
|
|
bb3f17ebfe | ||
|
|
1695710cbe | ||
|
|
ebe8506c67 | ||
|
|
3c3bd9884f | ||
|
|
f188383fea | ||
|
|
bbab359813 | ||
|
|
8381ca5287 | ||
|
|
f5282bf1e7 | ||
|
|
c0ffc0aaf5 |
|
|
@ -1,4 +1,2 @@
|
||||||
.git*
|
.git*
|
||||||
.idea*
|
.idea*
|
||||||
*.md
|
|
||||||
.venv/
|
|
||||||
|
|
@ -6,4 +6,12 @@ updates:
|
||||||
interval: "weekly"
|
interval: "weekly"
|
||||||
timezone: "Asia/Shanghai"
|
timezone: "Asia/Shanghai"
|
||||||
day: "friday"
|
day: "friday"
|
||||||
target-branch: "v3"
|
target-branch: "v2"
|
||||||
|
groups:
|
||||||
|
python-dependencies:
|
||||||
|
patterns:
|
||||||
|
- "*"
|
||||||
|
# ignore:
|
||||||
|
# - dependency-name: "pymupdf"
|
||||||
|
# versions: ["*"]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -14,19 +14,33 @@ on:
|
||||||
- linux/amd64,linux/arm64
|
- linux/amd64,linux/arm64
|
||||||
jobs:
|
jobs:
|
||||||
build-and-push-python-pg-to-ghcr:
|
build-and-push-python-pg-to-ghcr:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
|
- name: Check Disk Space
|
||||||
|
run: df -h
|
||||||
|
- name: Free Disk Space (Ubuntu)
|
||||||
|
uses: jlumbroso/free-disk-space@main
|
||||||
|
with:
|
||||||
|
tool-cache: true
|
||||||
|
android: true
|
||||||
|
dotnet: true
|
||||||
|
haskell: true
|
||||||
|
large-packages: true
|
||||||
|
docker-images: true
|
||||||
|
swap-storage: true
|
||||||
|
- name: Check Disk Space
|
||||||
|
run: df -h
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
ref: ${{ github.ref_name }}
|
ref: main
|
||||||
- name: Prepare
|
- name: Prepare
|
||||||
id: prepare
|
id: prepare
|
||||||
run: |
|
run: |
|
||||||
DOCKER_IMAGE=ghcr.io/1panel-dev/maxkb-base
|
DOCKER_IMAGE=ghcr.io/1panel-dev/maxkb-python-pg
|
||||||
DOCKER_PLATFORMS=${{ github.event.inputs.architecture }}
|
DOCKER_PLATFORMS=${{ github.event.inputs.architecture }}
|
||||||
TAG_NAME=python3.11-pg17.6
|
TAG_NAME=python3.11-pg15.8
|
||||||
DOCKER_IMAGE_TAGS="--tag ${DOCKER_IMAGE}:${TAG_NAME}"
|
DOCKER_IMAGE_TAGS="--tag ${DOCKER_IMAGE}:${TAG_NAME} --tag ${DOCKER_IMAGE}:latest"
|
||||||
echo ::set-output name=docker_image::${DOCKER_IMAGE}
|
echo ::set-output name=docker_image::${DOCKER_IMAGE}
|
||||||
echo ::set-output name=version::${TAG_NAME}
|
echo ::set-output name=version::${TAG_NAME}
|
||||||
echo ::set-output name=buildx_args::--platform ${DOCKER_PLATFORMS} --no-cache \
|
echo ::set-output name=buildx_args::--platform ${DOCKER_PLATFORMS} --no-cache \
|
||||||
|
|
@ -37,7 +51,8 @@ jobs:
|
||||||
- name: Set up QEMU
|
- name: Set up QEMU
|
||||||
uses: docker/setup-qemu-action@v3
|
uses: docker/setup-qemu-action@v3
|
||||||
with:
|
with:
|
||||||
cache-image: false
|
# Until https://github.com/tonistiigi/binfmt/issues/215
|
||||||
|
image: tonistiigi/binfmt:qemu-v7.0.0-28
|
||||||
- name: Set up Docker Buildx
|
- name: Set up Docker Buildx
|
||||||
uses: docker/setup-buildx-action@v3
|
uses: docker/setup-buildx-action@v3
|
||||||
- name: Login to GitHub Container Registry
|
- name: Login to GitHub Container Registry
|
||||||
|
|
@ -48,4 +63,4 @@ jobs:
|
||||||
password: ${{ secrets.GH_TOKEN }}
|
password: ${{ secrets.GH_TOKEN }}
|
||||||
- name: Docker Buildx (build-and-push)
|
- name: Docker Buildx (build-and-push)
|
||||||
run: |
|
run: |
|
||||||
docker buildx build --output "type=image,push=true" ${{ steps.prepare.outputs.buildx_args }} -f installer/Dockerfile-base
|
docker buildx build --output "type=image,push=true" ${{ steps.prepare.outputs.buildx_args }} -f installer/Dockerfile-python-pg
|
||||||
|
|
@ -5,7 +5,7 @@ on:
|
||||||
inputs:
|
inputs:
|
||||||
dockerImageTag:
|
dockerImageTag:
|
||||||
description: 'Docker Image Tag'
|
description: 'Docker Image Tag'
|
||||||
default: 'v2.0.3'
|
default: 'v1.0.1'
|
||||||
required: true
|
required: true
|
||||||
architecture:
|
architecture:
|
||||||
description: 'Architecture'
|
description: 'Architecture'
|
||||||
|
|
@ -19,12 +19,26 @@ on:
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
build-and-push-vector-model-to-ghcr:
|
build-and-push-vector-model-to-ghcr:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
|
- name: Check Disk Space
|
||||||
|
run: df -h
|
||||||
|
- name: Free Disk Space (Ubuntu)
|
||||||
|
uses: jlumbroso/free-disk-space@main
|
||||||
|
with:
|
||||||
|
tool-cache: true
|
||||||
|
android: true
|
||||||
|
dotnet: true
|
||||||
|
haskell: true
|
||||||
|
large-packages: true
|
||||||
|
docker-images: true
|
||||||
|
swap-storage: true
|
||||||
|
- name: Check Disk Space
|
||||||
|
run: df -h
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
ref: ${{ github.ref_name }}
|
ref: main
|
||||||
- name: Prepare
|
- name: Prepare
|
||||||
id: prepare
|
id: prepare
|
||||||
run: |
|
run: |
|
||||||
|
|
@ -42,7 +56,8 @@ jobs:
|
||||||
- name: Set up QEMU
|
- name: Set up QEMU
|
||||||
uses: docker/setup-qemu-action@v3
|
uses: docker/setup-qemu-action@v3
|
||||||
with:
|
with:
|
||||||
cache-image: false
|
# Until https://github.com/tonistiigi/binfmt/issues/215
|
||||||
|
image: tonistiigi/binfmt:qemu-v7.0.0-28
|
||||||
- name: Set up Docker Buildx
|
- name: Set up Docker Buildx
|
||||||
uses: docker/setup-buildx-action@v3
|
uses: docker/setup-buildx-action@v3
|
||||||
- name: Login to GitHub Container Registry
|
- name: Login to GitHub Container Registry
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,13 @@
|
||||||
name: build-and-push
|
name: build-and-push
|
||||||
|
|
||||||
run-name: 构建镜像并推送仓库 ${{ github.event.inputs.dockerImageTag }} (${{ github.event.inputs.registry }}) (${{ github.event.inputs.architecture }})
|
run-name: 构建镜像并推送仓库 ${{ github.event.inputs.dockerImageTag }} (${{ github.event.inputs.registry }})
|
||||||
|
|
||||||
on:
|
on:
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
inputs:
|
inputs:
|
||||||
dockerImageTag:
|
dockerImageTag:
|
||||||
description: 'Image Tag'
|
description: 'Image Tag'
|
||||||
default: 'v2.4.0-dev'
|
default: 'v1.10.7-dev'
|
||||||
required: true
|
required: true
|
||||||
dockerImageTagWithLatest:
|
dockerImageTagWithLatest:
|
||||||
description: '是否发布latest tag(正式发版时选择,测试版本切勿选择)'
|
description: '是否发布latest tag(正式发版时选择,测试版本切勿选择)'
|
||||||
|
|
@ -38,10 +38,20 @@ jobs:
|
||||||
if: ${{ contains(github.event.inputs.registry, 'fit2cloud') }}
|
if: ${{ contains(github.event.inputs.registry, 'fit2cloud') }}
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Clear Work Dir
|
- name: Check Disk Space
|
||||||
run: |
|
run: df -h
|
||||||
ls -la
|
- name: Free Disk Space (Ubuntu)
|
||||||
rm -rf -- ./* ./.??*
|
uses: jlumbroso/free-disk-space@main
|
||||||
|
with:
|
||||||
|
tool-cache: true
|
||||||
|
android: true
|
||||||
|
dotnet: true
|
||||||
|
haskell: true
|
||||||
|
large-packages: true
|
||||||
|
docker-images: true
|
||||||
|
swap-storage: true
|
||||||
|
- name: Check Disk Space
|
||||||
|
run: df -h
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
|
|
@ -54,17 +64,15 @@ jobs:
|
||||||
TAG_NAME=${{ github.event.inputs.dockerImageTag }}
|
TAG_NAME=${{ github.event.inputs.dockerImageTag }}
|
||||||
TAG_NAME_WITH_LATEST=${{ github.event.inputs.dockerImageTagWithLatest }}
|
TAG_NAME_WITH_LATEST=${{ github.event.inputs.dockerImageTagWithLatest }}
|
||||||
if [[ ${TAG_NAME_WITH_LATEST} == 'true' ]]; then
|
if [[ ${TAG_NAME_WITH_LATEST} == 'true' ]]; then
|
||||||
DOCKER_IMAGE_TAGS="--tag ${DOCKER_IMAGE}:${TAG_NAME} --tag ${DOCKER_IMAGE}:${TAG_NAME%%.*} --tag ${DOCKER_IMAGE}:latest"
|
DOCKER_IMAGE_TAGS="--tag ${DOCKER_IMAGE}:${TAG_NAME} --tag ${DOCKER_IMAGE}:${TAG_NAME%%.*}"
|
||||||
else
|
else
|
||||||
DOCKER_IMAGE_TAGS="--tag ${DOCKER_IMAGE}:${TAG_NAME}"
|
DOCKER_IMAGE_TAGS="--tag ${DOCKER_IMAGE}:${TAG_NAME}"
|
||||||
fi
|
fi
|
||||||
echo "buildx_args=--platform ${DOCKER_PLATFORMS} --memory-swap -1 \
|
echo ::set-output name=buildx_args::--platform ${DOCKER_PLATFORMS} --memory-swap -1 \
|
||||||
--build-arg DOCKER_IMAGE_TAG=${{ github.event.inputs.dockerImageTag }} --build-arg BUILD_AT=$(TZ=Asia/Shanghai date +'%Y-%m-%dT%H:%M') --build-arg GITHUB_COMMIT=`git rev-parse --short HEAD` --no-cache \
|
--build-arg DOCKER_IMAGE_TAG=${{ github.event.inputs.dockerImageTag }} --build-arg BUILD_AT=$(TZ=Asia/Shanghai date +'%Y-%m-%dT%H:%M') --build-arg GITHUB_COMMIT=`git rev-parse --short HEAD` --no-cache \
|
||||||
${DOCKER_IMAGE_TAGS} ." >> $GITHUB_OUTPUT
|
${DOCKER_IMAGE_TAGS} .
|
||||||
- name: Set up QEMU
|
- name: Set up QEMU
|
||||||
uses: docker/setup-qemu-action@v3
|
uses: docker/setup-qemu-action@v3
|
||||||
with:
|
|
||||||
cache-image: false
|
|
||||||
- name: Set up Docker Buildx
|
- name: Set up Docker Buildx
|
||||||
uses: docker/setup-buildx-action@v3
|
uses: docker/setup-buildx-action@v3
|
||||||
- name: Login to GitHub Container Registry
|
- name: Login to GitHub Container Registry
|
||||||
|
|
@ -79,12 +87,6 @@ jobs:
|
||||||
registry: ${{ secrets.FIT2CLOUD_REGISTRY_HOST }}
|
registry: ${{ secrets.FIT2CLOUD_REGISTRY_HOST }}
|
||||||
username: ${{ secrets.FIT2CLOUD_REGISTRY_USERNAME }}
|
username: ${{ secrets.FIT2CLOUD_REGISTRY_USERNAME }}
|
||||||
password: ${{ secrets.FIT2CLOUD_REGISTRY_PASSWORD }}
|
password: ${{ secrets.FIT2CLOUD_REGISTRY_PASSWORD }}
|
||||||
- name: Build Web
|
|
||||||
run: |
|
|
||||||
docker buildx build --no-cache --target web-build --output type=local,dest=./web-build-output . -f installer/Dockerfile
|
|
||||||
rm -rf ./ui
|
|
||||||
cp -r ./web-build-output/ui ./
|
|
||||||
rm -rf ./web-build-output
|
|
||||||
- name: Docker Buildx (build-and-push)
|
- name: Docker Buildx (build-and-push)
|
||||||
run: |
|
run: |
|
||||||
sudo sync && echo 3 | sudo tee /proc/sys/vm/drop_caches && free -m
|
sudo sync && echo 3 | sudo tee /proc/sys/vm/drop_caches && free -m
|
||||||
|
|
@ -94,10 +96,20 @@ jobs:
|
||||||
if: ${{ contains(github.event.inputs.registry, 'dockerhub') }}
|
if: ${{ contains(github.event.inputs.registry, 'dockerhub') }}
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Clear Work Dir
|
- name: Check Disk Space
|
||||||
run: |
|
run: df -h
|
||||||
ls -la
|
- name: Free Disk Space (Ubuntu)
|
||||||
rm -rf -- ./* ./.??*
|
uses: jlumbroso/free-disk-space@main
|
||||||
|
with:
|
||||||
|
tool-cache: true
|
||||||
|
android: true
|
||||||
|
dotnet: true
|
||||||
|
haskell: true
|
||||||
|
large-packages: true
|
||||||
|
docker-images: true
|
||||||
|
swap-storage: true
|
||||||
|
- name: Check Disk Space
|
||||||
|
run: df -h
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
|
|
@ -110,17 +122,15 @@ jobs:
|
||||||
TAG_NAME=${{ github.event.inputs.dockerImageTag }}
|
TAG_NAME=${{ github.event.inputs.dockerImageTag }}
|
||||||
TAG_NAME_WITH_LATEST=${{ github.event.inputs.dockerImageTagWithLatest }}
|
TAG_NAME_WITH_LATEST=${{ github.event.inputs.dockerImageTagWithLatest }}
|
||||||
if [[ ${TAG_NAME_WITH_LATEST} == 'true' ]]; then
|
if [[ ${TAG_NAME_WITH_LATEST} == 'true' ]]; then
|
||||||
DOCKER_IMAGE_TAGS="--tag ${DOCKER_IMAGE}:${TAG_NAME} --tag ${DOCKER_IMAGE}:${TAG_NAME%%.*} --tag ${DOCKER_IMAGE}:latest"
|
DOCKER_IMAGE_TAGS="--tag ${DOCKER_IMAGE}:${TAG_NAME} --tag ${DOCKER_IMAGE}:${TAG_NAME%%.*}"
|
||||||
else
|
else
|
||||||
DOCKER_IMAGE_TAGS="--tag ${DOCKER_IMAGE}:${TAG_NAME}"
|
DOCKER_IMAGE_TAGS="--tag ${DOCKER_IMAGE}:${TAG_NAME}"
|
||||||
fi
|
fi
|
||||||
echo "buildx_args=--platform ${DOCKER_PLATFORMS} --memory-swap -1 \
|
echo ::set-output name=buildx_args::--platform ${DOCKER_PLATFORMS} --memory-swap -1 \
|
||||||
--build-arg DOCKER_IMAGE_TAG=${{ github.event.inputs.dockerImageTag }} --build-arg BUILD_AT=$(TZ=Asia/Shanghai date +'%Y-%m-%dT%H:%M') --build-arg GITHUB_COMMIT=`git rev-parse --short HEAD` --no-cache \
|
--build-arg DOCKER_IMAGE_TAG=${{ github.event.inputs.dockerImageTag }} --build-arg BUILD_AT=$(TZ=Asia/Shanghai date +'%Y-%m-%dT%H:%M') --build-arg GITHUB_COMMIT=`git rev-parse --short HEAD` --no-cache \
|
||||||
${DOCKER_IMAGE_TAGS} ." >> $GITHUB_OUTPUT
|
${DOCKER_IMAGE_TAGS} .
|
||||||
- name: Set up QEMU
|
- name: Set up QEMU
|
||||||
uses: docker/setup-qemu-action@v3
|
uses: docker/setup-qemu-action@v3
|
||||||
with:
|
|
||||||
cache-image: false
|
|
||||||
- name: Set up Docker Buildx
|
- name: Set up Docker Buildx
|
||||||
uses: docker/setup-buildx-action@v3
|
uses: docker/setup-buildx-action@v3
|
||||||
- name: Login to GitHub Container Registry
|
- name: Login to GitHub Container Registry
|
||||||
|
|
@ -134,12 +144,6 @@ jobs:
|
||||||
with:
|
with:
|
||||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||||
- name: Build Web
|
|
||||||
run: |
|
|
||||||
docker buildx build --no-cache --target web-build --output type=local,dest=./web-build-output . -f installer/Dockerfile
|
|
||||||
rm -rf ./ui
|
|
||||||
cp -r ./web-build-output/ui ./
|
|
||||||
rm -rf ./web-build-output
|
|
||||||
- name: Docker Buildx (build-and-push)
|
- name: Docker Buildx (build-and-push)
|
||||||
run: |
|
run: |
|
||||||
sudo sync && echo 3 | sudo tee /proc/sys/vm/drop_caches && free -m
|
sudo sync && echo 3 | sudo tee /proc/sys/vm/drop_caches && free -m
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,8 @@
|
||||||
name: Typos Check
|
name: Typos Check
|
||||||
on:
|
on:
|
||||||
workflow_dispatch:
|
|
||||||
push:
|
push:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
pull_request:
|
pull_request:
|
||||||
types: [opened, synchronize, reopened]
|
types: [opened, synchronize, reopened]
|
||||||
|
|
||||||
|
|
@ -11,21 +12,7 @@ jobs:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout Actions Repository
|
- name: Checkout Actions Repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v2
|
||||||
with:
|
|
||||||
ref: ${{ github.ref_name }}
|
|
||||||
- name: Create config file
|
|
||||||
run: |
|
|
||||||
cat <<EOF > typo-check-config.toml
|
|
||||||
[files]
|
|
||||||
extend-exclude = [
|
|
||||||
"**/*_svg",
|
|
||||||
"**/migrations/**",
|
|
||||||
"**/loopEdge.ts",
|
|
||||||
"**/edge.ts",
|
|
||||||
]
|
|
||||||
EOF
|
|
||||||
- name: Check spelling
|
- name: Check spelling
|
||||||
uses: crate-ci/typos@master
|
uses: crate-ci/typos@master
|
||||||
with:
|
|
||||||
config: ./typo-check-config.toml
|
|
||||||
|
|
|
||||||
|
|
@ -137,9 +137,9 @@ celerybeat.pid
|
||||||
# Environments
|
# Environments
|
||||||
.env
|
.env
|
||||||
.venv
|
.venv
|
||||||
# env/
|
env/
|
||||||
venv/
|
venv/
|
||||||
# ENV/
|
ENV/
|
||||||
env.bak/
|
env.bak/
|
||||||
venv.bak/
|
venv.bak/
|
||||||
|
|
||||||
|
|
@ -183,10 +183,5 @@ apps/xpack
|
||||||
data
|
data
|
||||||
.dev
|
.dev
|
||||||
poetry.lock
|
poetry.lock
|
||||||
uv.lock
|
apps/setting/models_provider/impl/*/icon/
|
||||||
apps/models_provider/impl/*/icon/
|
tmp/
|
||||||
apps/models_provider/impl/tencent_model_provider/credential/stt.py
|
|
||||||
apps/models_provider/impl/tencent_model_provider/model/stt.py
|
|
||||||
tmp/
|
|
||||||
config.yml
|
|
||||||
.SANDBOX_BANNED_HOSTS
|
|
||||||
|
|
@ -0,0 +1,4 @@
|
||||||
|
[files]
|
||||||
|
extend-exclude = [
|
||||||
|
'apps/setting/models_provider/impl/*/icon/*'
|
||||||
|
]
|
||||||
|
|
@ -27,4 +27,4 @@ When reporting issues, always include:
|
||||||
* Snapshots or log files if needed
|
* Snapshots or log files if needed
|
||||||
|
|
||||||
Because the issues are open to the public, when submitting files, be sure to remove any sensitive information, e.g. user name, password, IP address, and company name. You can
|
Because the issues are open to the public, when submitting files, be sure to remove any sensitive information, e.g. user name, password, IP address, and company name. You can
|
||||||
replace those parts with "REDACTED" or other strings like "****".
|
replace those parts with "REDACTED" or other strings like "****".
|
||||||
|
|
|
||||||
2
LICENSE
2
LICENSE
|
|
@ -671,4 +671,4 @@ into proprietary programs. If your program is a subroutine library, you
|
||||||
may consider it more useful to permit linking proprietary applications with
|
may consider it more useful to permit linking proprietary applications with
|
||||||
the library. If this is what you want to do, use the GNU Lesser General
|
the library. If this is what you want to do, use the GNU Lesser General
|
||||||
Public License instead of this License. But first, please read
|
Public License instead of this License. But first, please read
|
||||||
<https://www.gnu.org/licenses/why-not-lgpl.html>.
|
<https://www.gnu.org/licenses/why-not-lgpl.html>.
|
||||||
|
|
|
||||||
75
README.md
75
README.md
|
|
@ -14,7 +14,7 @@
|
||||||
MaxKB = Max Knowledge Brain, it is an open-source platform for building enterprise-grade agents. MaxKB integrates Retrieval-Augmented Generation (RAG) pipelines, supports robust workflows, and provides advanced MCP tool-use capabilities. MaxKB is widely applied in scenarios such as intelligent customer service, corporate internal knowledge bases, academic research, and education.
|
MaxKB = Max Knowledge Brain, it is an open-source platform for building enterprise-grade agents. MaxKB integrates Retrieval-Augmented Generation (RAG) pipelines, supports robust workflows, and provides advanced MCP tool-use capabilities. MaxKB is widely applied in scenarios such as intelligent customer service, corporate internal knowledge bases, academic research, and education.
|
||||||
|
|
||||||
- **RAG Pipeline**: Supports direct uploading of documents / automatic crawling of online documents, with features for automatic text splitting, vectorization. This effectively reduces hallucinations in large models, providing a superior smart Q&A interaction experience.
|
- **RAG Pipeline**: Supports direct uploading of documents / automatic crawling of online documents, with features for automatic text splitting, vectorization. This effectively reduces hallucinations in large models, providing a superior smart Q&A interaction experience.
|
||||||
- **Agentic Workflow**: Equipped with a powerful workflow engine, function library and MCP tool-use, enabling the orchestration of AI processes to meet the needs of complex business scenarios.
|
- **Agentic Workflow**: Equipped with a powerful workflow engine, function library and MCP tool-use, enabling the orchestration of AI processes to meet the needs of complex business scenarios.
|
||||||
- **Seamless Integration**: Facilitates zero-coding rapid integration into third-party business systems, quickly equipping existing systems with intelligent Q&A capabilities to enhance user satisfaction.
|
- **Seamless Integration**: Facilitates zero-coding rapid integration into third-party business systems, quickly equipping existing systems with intelligent Q&A capabilities to enhance user satisfaction.
|
||||||
- **Model-Agnostic**: Supports various large models, including private models (such as DeepSeek, Llama, Qwen, etc.) and public models (like OpenAI, Claude, Gemini, etc.).
|
- **Model-Agnostic**: Supports various large models, including private models (such as DeepSeek, Llama, Qwen, etc.) and public models (like OpenAI, Claude, Gemini, etc.).
|
||||||
- **Multi Modal**: Native support for input and output text, image, audio and video.
|
- **Multi Modal**: Native support for input and output text, image, audio and video.
|
||||||
|
|
@ -24,7 +24,7 @@ MaxKB = Max Knowledge Brain, it is an open-source platform for building enterpri
|
||||||
Execute the script below to start a MaxKB container using Docker:
|
Execute the script below to start a MaxKB container using Docker:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
docker run -d --name=maxkb --restart=always -p 8080:8080 -v ~/.maxkb:/opt/maxkb 1panel/maxkb
|
docker run -d --name=maxkb --restart=always -p 8080:8080 -v ~/.maxkb:/var/lib/postgresql/data -v ~/.python-packages:/opt/maxkb/app/sandbox/python-packages 1panel/maxkb
|
||||||
```
|
```
|
||||||
|
|
||||||
Access MaxKB web interface at `http://your_server_ip:8080` with default admin credentials:
|
Access MaxKB web interface at `http://your_server_ip:8080` with default admin credentials:
|
||||||
|
|
@ -32,18 +32,18 @@ Access MaxKB web interface at `http://your_server_ip:8080` with default admin cr
|
||||||
- username: admin
|
- username: admin
|
||||||
- password: MaxKB@123..
|
- password: MaxKB@123..
|
||||||
|
|
||||||
中国用户如遇到 Docker 镜像 Pull 失败问题,请参照该 [离线安装文档](https://maxkb.cn/docs/v2/installation/offline_installtion/) 进行安装。
|
中国用户如遇到 Docker 镜像 Pull 失败问题,请参照该 [离线安装文档](https://maxkb.cn/docs/installation/offline_installtion/) 进行安装。
|
||||||
|
|
||||||
## Screenshots
|
## Screenshots
|
||||||
|
|
||||||
<table style="border-collapse: collapse; border: 1px solid black;">
|
<table style="border-collapse: collapse; border: 1px solid black;">
|
||||||
<tr>
|
<tr>
|
||||||
<td style="padding: 5px;background-color:#fff;"><img src= "https://github.com/user-attachments/assets/eb285512-a66a-4752-8941-c65ed1592238" alt="MaxKB Demo1" /></td>
|
<td style="padding: 5px;background-color:#fff;"><img src= "https://maxkb.hk/images/overview.png" alt="MaxKB Demo1" /></td>
|
||||||
<td style="padding: 5px;background-color:#fff;"><img src= "https://github.com/user-attachments/assets/f732f1f5-472c-4fd2-93c1-a277eda83d04" alt="MaxKB Demo2" /></td>
|
<td style="padding: 5px;background-color:#fff;"><img src= "https://maxkb.hk/images/screenshot-models.png" alt="MaxKB Demo2" /></td>
|
||||||
</tr>
|
</tr>
|
||||||
<tr>
|
<tr>
|
||||||
<td style="padding: 5px;background-color:#fff;"><img src= "https://github.com/user-attachments/assets/c927474a-9a23-4830-822f-5db26025c9b2" alt="MaxKB Demo3" /></td>
|
<td style="padding: 5px;background-color:#fff;"><img src= "https://maxkb.hk/images/screenshot-knowledge.png" alt="MaxKB Demo3" /></td>
|
||||||
<td style="padding: 5px;background-color:#fff;"><img src= "https://github.com/user-attachments/assets/e6268996-a46d-4e58-9f30-31139df78ad2" alt="MaxKB Demo4" /></td>
|
<td style="padding: 5px;background-color:#fff;"><img src= "https://maxkb.hk/images/screenshot-function.png" alt="MaxKB Demo4" /></td>
|
||||||
</tr>
|
</tr>
|
||||||
</table>
|
</table>
|
||||||
|
|
||||||
|
|
@ -54,6 +54,67 @@ Access MaxKB web interface at `http://your_server_ip:8080` with default admin cr
|
||||||
- LLM Framework:[LangChain](https://www.langchain.com/)
|
- LLM Framework:[LangChain](https://www.langchain.com/)
|
||||||
- Database:[PostgreSQL + pgvector](https://www.postgresql.org/)
|
- Database:[PostgreSQL + pgvector](https://www.postgresql.org/)
|
||||||
|
|
||||||
|
## Feature Comparison
|
||||||
|
|
||||||
|
<table style="width: 100%;">
|
||||||
|
<tr>
|
||||||
|
<th align="center">Feature</th>
|
||||||
|
<th align="center">LangChain</th>
|
||||||
|
<th align="center">Dify.AI</th>
|
||||||
|
<th align="center">Flowise</th>
|
||||||
|
<th align="center">MaxKB <br>(Built upon LangChain)</th>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td align="center">Supported LLMs</td>
|
||||||
|
<td align="center">Rich Variety</td>
|
||||||
|
<td align="center">Rich Variety</td>
|
||||||
|
<td align="center">Rich Variety</td>
|
||||||
|
<td align="center">Rich Variety</td>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td align="center">RAG Engine</td>
|
||||||
|
<td align="center">✅</td>
|
||||||
|
<td align="center">✅</td>
|
||||||
|
<td align="center">✅</td>
|
||||||
|
<td align="center">✅</td>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td align="center">Agent</td>
|
||||||
|
<td align="center">✅</td>
|
||||||
|
<td align="center">✅</td>
|
||||||
|
<td align="center">❌</td>
|
||||||
|
<td align="center">✅</td>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td align="center">Workflow</td>
|
||||||
|
<td align="center">❌</td>
|
||||||
|
<td align="center">✅</td>
|
||||||
|
<td align="center">✅</td>
|
||||||
|
<td align="center">✅</td>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td align="center">Observability</td>
|
||||||
|
<td align="center">✅</td>
|
||||||
|
<td align="center">✅</td>
|
||||||
|
<td align="center">❌</td>
|
||||||
|
<td align="center">✅</td>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td align="center">SSO/Access control</td>
|
||||||
|
<td align="center">❌</td>
|
||||||
|
<td align="center">✅</td>
|
||||||
|
<td align="center">❌</td>
|
||||||
|
<td align="center">✅ (Pro)</td>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td align="center">On-premise Deployment</td>
|
||||||
|
<td align="center">✅</td>
|
||||||
|
<td align="center">✅</td>
|
||||||
|
<td align="center">✅</td>
|
||||||
|
<td align="center">✅</td>
|
||||||
|
</tr>
|
||||||
|
</table>
|
||||||
|
|
||||||
## Star History
|
## Star History
|
||||||
|
|
||||||
[](https://star-history.com/#1Panel-dev/MaxKB&Date)
|
[](https://star-history.com/#1Panel-dev/MaxKB&Date)
|
||||||
|
|
|
||||||
20
README_CN.md
20
README_CN.md
|
|
@ -14,12 +14,12 @@
|
||||||
</p>
|
</p>
|
||||||
<hr/>
|
<hr/>
|
||||||
|
|
||||||
MaxKB = Max Knowledge Brain,是一个强大易用的企业级智能体平台,致力于解决企业 AI 落地面临的技术门槛高、部署成本高、迭代周期长等问题,助力企业在人工智能时代赢得先机。秉承“开箱即用,伴随成长”的设计理念,MaxKB 支持企业快速接入主流大模型,高效构建专属知识库,并提供从基础问答(RAG)、复杂流程自动化(工作流)到智能体(Agent)的渐进式升级路径,全面赋能智能客服、智能办公助手等多种应用场景。
|
MaxKB = Max Knowledge Brain,是一款强大易用的企业级智能体平台,支持 RAG 检索增强生成、工作流编排、MCP 工具调用能力。MaxKB 支持对接各种主流大语言模型,广泛应用于智能客服、企业内部知识库问答、员工助手、学术研究与教育等场景。
|
||||||
|
|
||||||
- **RAG 检索增强生成**:高效搭建本地 AI 知识库,支持直接上传文档 / 自动爬取在线文档,支持文本自动拆分、向量化,有效减少大模型幻觉,提升问答效果;
|
- **RAG 检索增强生成**:高效搭建本地 AI 知识库,支持直接上传文档 / 自动爬取在线文档,支持文本自动拆分、向量化,有效减少大模型幻觉,提升问答效果;
|
||||||
- **灵活编排**:内置强大的工作流引擎、函数库和 MCP 工具调用能力,支持编排 AI 工作过程,满足复杂业务场景下的需求;
|
- **灵活编排**:内置强大的工作流引擎、函数库和 MCP 工具调用能力,支持编排 AI 工作过程,满足复杂业务场景下的需求;
|
||||||
- **无缝嵌入**:支持零编码快速嵌入到第三方业务系统,让已有系统快速拥有智能问答能力,提高用户满意度;
|
- **无缝嵌入**:支持零编码快速嵌入到第三方业务系统,让已有系统快速拥有智能问答能力,提高用户满意度;
|
||||||
- **模型中立**:支持对接各种大模型,包括本地私有大模型(DeepSeek R1 / Qwen 3 等)、国内公共大模型(通义千问 / 腾讯混元 / 字节豆包 / 百度千帆 / 智谱 AI / Kimi 等)和国外公共大模型(OpenAI / Claude / Gemini 等)。
|
- **模型中立**:支持对接各种大模型,包括本地私有大模型(DeepSeek R1 / Llama 3 / Qwen 2 等)、国内公共大模型(通义千问 / 腾讯混元 / 字节豆包 / 百度千帆 / 智谱 AI / Kimi 等)和国外公共大模型(OpenAI / Claude / Gemini 等)。
|
||||||
|
|
||||||
MaxKB 三分钟视频介绍:https://www.bilibili.com/video/BV18JypYeEkj/
|
MaxKB 三分钟视频介绍:https://www.bilibili.com/video/BV18JypYeEkj/
|
||||||
|
|
||||||
|
|
@ -27,10 +27,10 @@ MaxKB 三分钟视频介绍:https://www.bilibili.com/video/BV18JypYeEkj/
|
||||||
|
|
||||||
```
|
```
|
||||||
# Linux 机器
|
# Linux 机器
|
||||||
docker run -d --name=maxkb --restart=always -p 8080:8080 -v ~/.maxkb:/opt/maxkb registry.fit2cloud.com/maxkb/maxkb
|
docker run -d --name=maxkb --restart=always -p 8080:8080 -v ~/.maxkb:/var/lib/postgresql/data -v ~/.python-packages:/opt/maxkb/app/sandbox/python-packages registry.fit2cloud.com/maxkb/maxkb
|
||||||
|
|
||||||
# Windows 机器
|
# Windows 机器
|
||||||
docker run -d --name=maxkb --restart=always -p 8080:8080 -v C:/maxkb:/opt/maxkb registry.fit2cloud.com/maxkb/maxkb
|
docker run -d --name=maxkb --restart=always -p 8080:8080 -v C:/maxkb:/var/lib/postgresql/data -v C:/python-packages:/opt/maxkb/app/sandbox/python-packages registry.fit2cloud.com/maxkb/maxkb
|
||||||
|
|
||||||
# 用户名: admin
|
# 用户名: admin
|
||||||
# 密码: MaxKB@123..
|
# 密码: MaxKB@123..
|
||||||
|
|
@ -38,8 +38,8 @@ docker run -d --name=maxkb --restart=always -p 8080:8080 -v C:/maxkb:/opt/maxkb
|
||||||
|
|
||||||
- 你也可以通过 [1Panel 应用商店](https://apps.fit2cloud.com/1panel) 快速部署 MaxKB;
|
- 你也可以通过 [1Panel 应用商店](https://apps.fit2cloud.com/1panel) 快速部署 MaxKB;
|
||||||
- 如果是内网环境,推荐使用 [离线安装包](https://community.fit2cloud.com/#/products/maxkb/downloads) 进行安装部署;
|
- 如果是内网环境,推荐使用 [离线安装包](https://community.fit2cloud.com/#/products/maxkb/downloads) 进行安装部署;
|
||||||
- MaxKB 不同产品产品版本的对比请参见:[MaxKB 产品版本对比](https://maxkb.cn/price);
|
- MaxKB 产品版本分为社区版和专业版,详情请参见:[MaxKB 产品版本对比](https://maxkb.cn/pricing.html);
|
||||||
- 如果您需要向团队介绍 MaxKB,可以使用这个 [官方 PPT 材料](https://fit2cloud.com/maxkb/download/introduce-maxkb_202507.pdf)。
|
- 如果您需要向团队介绍 MaxKB,可以使用这个 [官方 PPT 材料](https://maxkb.cn/download/introduce-maxkb_202503.pdf)。
|
||||||
|
|
||||||
如你有更多问题,可以查看使用手册,或者通过论坛与我们交流。
|
如你有更多问题,可以查看使用手册,或者通过论坛与我们交流。
|
||||||
|
|
||||||
|
|
@ -54,12 +54,12 @@ docker run -d --name=maxkb --restart=always -p 8080:8080 -v C:/maxkb:/opt/maxkb
|
||||||
|
|
||||||
<table style="border-collapse: collapse; border: 1px solid black;">
|
<table style="border-collapse: collapse; border: 1px solid black;">
|
||||||
<tr>
|
<tr>
|
||||||
<td style="padding: 5px;background-color:#fff;"><img src= "https://github.com/user-attachments/assets/eb285512-a66a-4752-8941-c65ed1592238" alt="MaxKB Demo1" /></td>
|
<td style="padding: 5px;background-color:#fff;"><img src= "https://github.com/1Panel-dev/MaxKB/assets/52996290/d87395fa-a8d7-401c-82bf-c6e475d10ae9" alt="MaxKB Demo1" /></td>
|
||||||
<td style="padding: 5px;background-color:#fff;"><img src= "https://github.com/user-attachments/assets/f732f1f5-472c-4fd2-93c1-a277eda83d04" alt="MaxKB Demo2" /></td>
|
<td style="padding: 5px;background-color:#fff;"><img src= "https://github.com/1Panel-dev/MaxKB/assets/52996290/47c35ee4-3a3b-4bd4-9f4f-ee20788b2b9a" alt="MaxKB Demo2" /></td>
|
||||||
</tr>
|
</tr>
|
||||||
<tr>
|
<tr>
|
||||||
<td style="padding: 5px;background-color:#fff;"><img src= "https://github.com/user-attachments/assets/c927474a-9a23-4830-822f-5db26025c9b2" alt="MaxKB Demo3" /></td>
|
<td style="padding: 5px;background-color:#fff;"><img src= "https://github.com/user-attachments/assets/9a1043cb-fa62-4f71-b9a3-0b46fa59a70e" alt="MaxKB Demo3" /></td>
|
||||||
<td style="padding: 5px;background-color:#fff;"><img src= "https://github.com/user-attachments/assets/e6268996-a46d-4e58-9f30-31139df78ad2" alt="MaxKB Demo4" /></td>
|
<td style="padding: 5px;background-color:#fff;"><img src= "https://github.com/user-attachments/assets/3407ce9a-779c-4eb4-858e-9441a2ddc664" alt="MaxKB Demo4" /></td>
|
||||||
</tr>
|
</tr>
|
||||||
</table>
|
</table>
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -36,4 +36,4 @@
|
||||||
- [MaxKB 应用案例:重磅!陕西广电网络“秦岭云”平台实现DeepSeek本地化部署](https://mp.weixin.qq.com/s/ZKmEU_wWShK1YDomKJHQeA)
|
- [MaxKB 应用案例:重磅!陕西广电网络“秦岭云”平台实现DeepSeek本地化部署](https://mp.weixin.qq.com/s/ZKmEU_wWShK1YDomKJHQeA)
|
||||||
- [MaxKB 应用案例:粤海集团完成DeepSeek私有化部署,助力集团智能化管理](https://mp.weixin.qq.com/s/2JbVp0-kr9Hfp-0whH4cvg)
|
- [MaxKB 应用案例:粤海集团完成DeepSeek私有化部署,助力集团智能化管理](https://mp.weixin.qq.com/s/2JbVp0-kr9Hfp-0whH4cvg)
|
||||||
- [MaxKB 应用案例:建筑材料工业信息中心完成DeepSeek本地化部署,推动行业数智化转型新发展](https://mp.weixin.qq.com/s/HThGSnND3qDF8ySEqiM4jw)
|
- [MaxKB 应用案例:建筑材料工业信息中心完成DeepSeek本地化部署,推动行业数智化转型新发展](https://mp.weixin.qq.com/s/HThGSnND3qDF8ySEqiM4jw)
|
||||||
- [MaxKB 应用案例:一起DeepSeek!福建设计以AI大模型开启新篇章](https://mp.weixin.qq.com/s/m67e-H7iQBg3d24NM82UjA)
|
- [MaxKB 应用案例:一起DeepSeek!福建设计以AI大模型开启新篇章](https://mp.weixin.qq.com/s/m67e-H7iQBg3d24NM82UjA)
|
||||||
|
|
|
||||||
|
|
@ -1,35 +0,0 @@
|
||||||
# coding=utf-8
|
|
||||||
"""
|
|
||||||
@project: MaxKB
|
|
||||||
@Author:虎虎
|
|
||||||
@file: application_access_token.py
|
|
||||||
@date:2025/6/9 17:46
|
|
||||||
@desc:
|
|
||||||
"""
|
|
||||||
from drf_spectacular.types import OpenApiTypes
|
|
||||||
from drf_spectacular.utils import OpenApiParameter
|
|
||||||
|
|
||||||
from application.serializers.application_access_token import AccessTokenEditSerializer
|
|
||||||
from common.mixins.api_mixin import APIMixin
|
|
||||||
|
|
||||||
|
|
||||||
class ApplicationAccessTokenAPI(APIMixin):
|
|
||||||
@staticmethod
|
|
||||||
def get_parameters():
|
|
||||||
return [OpenApiParameter(
|
|
||||||
name="workspace_id",
|
|
||||||
description="工作空间id",
|
|
||||||
type=OpenApiTypes.STR,
|
|
||||||
location='path',
|
|
||||||
required=True,
|
|
||||||
), OpenApiParameter(
|
|
||||||
name="application_id",
|
|
||||||
description="应用id",
|
|
||||||
type=OpenApiTypes.STR,
|
|
||||||
location='path',
|
|
||||||
required=True,
|
|
||||||
)]
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_request():
|
|
||||||
return AccessTokenEditSerializer
|
|
||||||
|
|
@ -1,218 +0,0 @@
|
||||||
# coding=utf-8
|
|
||||||
"""
|
|
||||||
@project: MaxKB
|
|
||||||
@Author:虎虎
|
|
||||||
@file: application.py
|
|
||||||
@date:2025/5/26 16:59
|
|
||||||
@desc:
|
|
||||||
"""
|
|
||||||
from django.utils.translation import gettext_lazy as _
|
|
||||||
from drf_spectacular.types import OpenApiTypes
|
|
||||||
from drf_spectacular.utils import OpenApiParameter
|
|
||||||
from rest_framework import serializers
|
|
||||||
|
|
||||||
from application.serializers.application import ApplicationCreateSerializer, ApplicationListResponse, \
|
|
||||||
ApplicationImportRequest, ApplicationEditSerializer, TextToSpeechRequest, SpeechToTextRequest, PlayDemoTextRequest
|
|
||||||
from common.mixins.api_mixin import APIMixin
|
|
||||||
from common.result import ResultSerializer, ResultPageSerializer, DefaultResultSerializer
|
|
||||||
|
|
||||||
|
|
||||||
class ApplicationCreateRequest(ApplicationCreateSerializer.SimplateRequest):
|
|
||||||
work_flow = serializers.DictField(required=True, label=_("Workflow Objects"))
|
|
||||||
|
|
||||||
|
|
||||||
class ApplicationCreateResponse(ResultSerializer):
|
|
||||||
def get_data(self):
|
|
||||||
return ApplicationCreateSerializer.ApplicationResponse()
|
|
||||||
|
|
||||||
|
|
||||||
class ApplicationListResult(ResultSerializer):
|
|
||||||
def get_data(self):
|
|
||||||
return ApplicationListResponse(many=True)
|
|
||||||
|
|
||||||
|
|
||||||
class ApplicationPageResult(ResultPageSerializer):
|
|
||||||
def get_data(self):
|
|
||||||
return ApplicationListResponse(many=True)
|
|
||||||
|
|
||||||
|
|
||||||
class ApplicationQueryAPI(APIMixin):
|
|
||||||
@staticmethod
|
|
||||||
def get_parameters():
|
|
||||||
return [
|
|
||||||
OpenApiParameter(
|
|
||||||
name="workspace_id",
|
|
||||||
description="工作空间id",
|
|
||||||
type=OpenApiTypes.STR,
|
|
||||||
location='path',
|
|
||||||
required=True,
|
|
||||||
),
|
|
||||||
OpenApiParameter(
|
|
||||||
name="current_page",
|
|
||||||
description=_("Current page"),
|
|
||||||
type=OpenApiTypes.INT,
|
|
||||||
location='path',
|
|
||||||
required=True,
|
|
||||||
),
|
|
||||||
OpenApiParameter(
|
|
||||||
name="page_size",
|
|
||||||
description=_("Page size"),
|
|
||||||
type=OpenApiTypes.INT,
|
|
||||||
location='path',
|
|
||||||
required=True,
|
|
||||||
),
|
|
||||||
OpenApiParameter(
|
|
||||||
name="folder_id",
|
|
||||||
description=_("folder id"),
|
|
||||||
type=OpenApiTypes.STR,
|
|
||||||
location='query',
|
|
||||||
required=False,
|
|
||||||
),
|
|
||||||
OpenApiParameter(
|
|
||||||
name="name",
|
|
||||||
description=_("Application Name"),
|
|
||||||
type=OpenApiTypes.STR,
|
|
||||||
location='query',
|
|
||||||
required=False,
|
|
||||||
),
|
|
||||||
OpenApiParameter(
|
|
||||||
name="desc",
|
|
||||||
description=_("Application Description"),
|
|
||||||
type=OpenApiTypes.STR,
|
|
||||||
location='query',
|
|
||||||
required=False,
|
|
||||||
),
|
|
||||||
OpenApiParameter(
|
|
||||||
name="user_id",
|
|
||||||
description=_("User ID"),
|
|
||||||
type=OpenApiTypes.STR,
|
|
||||||
location='query',
|
|
||||||
required=False,
|
|
||||||
),
|
|
||||||
OpenApiParameter(
|
|
||||||
name="publish_status",
|
|
||||||
description=_("Publish status") + '(published|unpublished)',
|
|
||||||
type=OpenApiTypes.STR,
|
|
||||||
location='query',
|
|
||||||
required=False,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_response():
|
|
||||||
return ApplicationListResult
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_page_response():
|
|
||||||
return ApplicationPageResult
|
|
||||||
|
|
||||||
|
|
||||||
class ApplicationCreateAPI(APIMixin):
|
|
||||||
@staticmethod
|
|
||||||
def get_parameters():
|
|
||||||
return [
|
|
||||||
OpenApiParameter(
|
|
||||||
name="workspace_id",
|
|
||||||
description="工作空间id",
|
|
||||||
type=OpenApiTypes.STR,
|
|
||||||
location='path',
|
|
||||||
required=True,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_request():
|
|
||||||
return ApplicationCreateRequest
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_response():
|
|
||||||
return ApplicationCreateResponse
|
|
||||||
|
|
||||||
|
|
||||||
class ApplicationImportAPI(APIMixin):
|
|
||||||
@staticmethod
|
|
||||||
def get_parameters():
|
|
||||||
ApplicationCreateAPI.get_parameters()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_request():
|
|
||||||
return ApplicationImportRequest
|
|
||||||
|
|
||||||
|
|
||||||
class ApplicationOperateAPI(APIMixin):
|
|
||||||
@staticmethod
|
|
||||||
def get_parameters():
|
|
||||||
return [
|
|
||||||
OpenApiParameter(
|
|
||||||
name="workspace_id",
|
|
||||||
description="工作空间id",
|
|
||||||
type=OpenApiTypes.STR,
|
|
||||||
location='path',
|
|
||||||
required=True,
|
|
||||||
),
|
|
||||||
OpenApiParameter(
|
|
||||||
name="application_id",
|
|
||||||
description="应用id",
|
|
||||||
type=OpenApiTypes.STR,
|
|
||||||
location='path',
|
|
||||||
required=True,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class ApplicationExportAPI(APIMixin):
|
|
||||||
@staticmethod
|
|
||||||
def get_parameters():
|
|
||||||
return ApplicationOperateAPI.get_parameters()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_response():
|
|
||||||
return DefaultResultSerializer
|
|
||||||
|
|
||||||
|
|
||||||
class ApplicationEditAPI(APIMixin):
|
|
||||||
@staticmethod
|
|
||||||
def get_request():
|
|
||||||
return ApplicationEditSerializer
|
|
||||||
|
|
||||||
|
|
||||||
class TextToSpeechAPI(APIMixin):
|
|
||||||
@staticmethod
|
|
||||||
def get_parameters():
|
|
||||||
return ApplicationOperateAPI.get_parameters()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_request():
|
|
||||||
return TextToSpeechRequest
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_response():
|
|
||||||
return DefaultResultSerializer
|
|
||||||
|
|
||||||
|
|
||||||
class SpeechToTextAPI(APIMixin):
|
|
||||||
@staticmethod
|
|
||||||
def get_parameters():
|
|
||||||
return ApplicationOperateAPI.get_parameters()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_request():
|
|
||||||
return SpeechToTextRequest
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_response():
|
|
||||||
return DefaultResultSerializer
|
|
||||||
|
|
||||||
|
|
||||||
class PlayDemoTextAPI(APIMixin):
|
|
||||||
@staticmethod
|
|
||||||
def get_parameters():
|
|
||||||
return ApplicationOperateAPI.get_parameters()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_request():
|
|
||||||
return PlayDemoTextRequest
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_response():
|
|
||||||
return DefaultResultSerializer
|
|
||||||
|
|
@ -1,61 +0,0 @@
|
||||||
from drf_spectacular.types import OpenApiTypes
|
|
||||||
from drf_spectacular.utils import OpenApiParameter
|
|
||||||
|
|
||||||
from application.serializers.application_api_key import EditApplicationKeySerializer, ApplicationKeySerializerModel
|
|
||||||
from common.mixins.api_mixin import APIMixin
|
|
||||||
from common.result import ResultSerializer
|
|
||||||
|
|
||||||
|
|
||||||
class ApplicationKeyListResult(ResultSerializer):
|
|
||||||
def get_data(self):
|
|
||||||
return ApplicationKeySerializerModel(many=True)
|
|
||||||
|
|
||||||
|
|
||||||
class ApplicationKeyResult(ResultSerializer):
|
|
||||||
def get_data(self):
|
|
||||||
return ApplicationKeySerializerModel()
|
|
||||||
|
|
||||||
|
|
||||||
class ApplicationKeyAPI(APIMixin):
|
|
||||||
@staticmethod
|
|
||||||
def get_parameters():
|
|
||||||
return [
|
|
||||||
OpenApiParameter(
|
|
||||||
name="workspace_id",
|
|
||||||
description="工作空间id",
|
|
||||||
type=OpenApiTypes.STR,
|
|
||||||
location='path',
|
|
||||||
required=True,
|
|
||||||
),
|
|
||||||
OpenApiParameter(
|
|
||||||
name="application_id",
|
|
||||||
description="application ID",
|
|
||||||
type=OpenApiTypes.STR,
|
|
||||||
location='path',
|
|
||||||
required=True,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_response():
|
|
||||||
return ApplicationKeyResult
|
|
||||||
|
|
||||||
class List(APIMixin):
|
|
||||||
@staticmethod
|
|
||||||
def get_response():
|
|
||||||
return ApplicationKeyListResult
|
|
||||||
|
|
||||||
class Operate(APIMixin):
|
|
||||||
@staticmethod
|
|
||||||
def get_parameters():
|
|
||||||
return [*ApplicationKeyAPI.get_parameters(), OpenApiParameter(
|
|
||||||
name="api_key_id",
|
|
||||||
description="ApiKeyId",
|
|
||||||
type=OpenApiTypes.STR,
|
|
||||||
location='path',
|
|
||||||
required=True,
|
|
||||||
)]
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_request():
|
|
||||||
return EditApplicationKeySerializer
|
|
||||||
|
|
@ -1,141 +0,0 @@
|
||||||
# coding=utf-8
|
|
||||||
"""
|
|
||||||
@project: MaxKB
|
|
||||||
@Author:虎虎
|
|
||||||
@file: application_chat.py
|
|
||||||
@date:2025/6/10 13:54
|
|
||||||
@desc:
|
|
||||||
"""
|
|
||||||
from django.utils.translation import gettext_lazy as _
|
|
||||||
from drf_spectacular.types import OpenApiTypes
|
|
||||||
from drf_spectacular.utils import OpenApiParameter
|
|
||||||
|
|
||||||
from application.serializers.application_chat import ApplicationChatQuerySerializers, \
|
|
||||||
ApplicationChatResponseSerializers, ApplicationChatRecordExportRequest
|
|
||||||
from common.mixins.api_mixin import APIMixin
|
|
||||||
from common.result import ResultSerializer, ResultPageSerializer
|
|
||||||
|
|
||||||
|
|
||||||
class ApplicationChatListResponseSerializers(ResultSerializer):
|
|
||||||
def get_data(self):
|
|
||||||
return ApplicationChatResponseSerializers(many=True)
|
|
||||||
|
|
||||||
|
|
||||||
class ApplicationChatPageResponseSerializers(ResultPageSerializer):
|
|
||||||
def get_data(self):
|
|
||||||
return ApplicationChatResponseSerializers(many=True)
|
|
||||||
|
|
||||||
|
|
||||||
class ApplicationChatQueryAPI(APIMixin):
|
|
||||||
@staticmethod
|
|
||||||
def get_request():
|
|
||||||
return ApplicationChatQuerySerializers
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_parameters():
|
|
||||||
return [
|
|
||||||
OpenApiParameter(
|
|
||||||
name="workspace_id",
|
|
||||||
description="工作空间id",
|
|
||||||
type=OpenApiTypes.STR,
|
|
||||||
location='path',
|
|
||||||
required=True,
|
|
||||||
),
|
|
||||||
OpenApiParameter(
|
|
||||||
name="application_id",
|
|
||||||
description="application ID",
|
|
||||||
type=OpenApiTypes.STR,
|
|
||||||
location='path',
|
|
||||||
required=True,
|
|
||||||
), OpenApiParameter(
|
|
||||||
name="start_time",
|
|
||||||
description="start Time",
|
|
||||||
type=OpenApiTypes.STR,
|
|
||||||
required=True,
|
|
||||||
),
|
|
||||||
OpenApiParameter(
|
|
||||||
name="end_time",
|
|
||||||
description="end Time",
|
|
||||||
type=OpenApiTypes.STR,
|
|
||||||
required=True,
|
|
||||||
),
|
|
||||||
OpenApiParameter(
|
|
||||||
name="abstract",
|
|
||||||
description="summary",
|
|
||||||
type=OpenApiTypes.STR,
|
|
||||||
required=False,
|
|
||||||
),
|
|
||||||
OpenApiParameter(
|
|
||||||
name="username",
|
|
||||||
description="username",
|
|
||||||
type=OpenApiTypes.STR,
|
|
||||||
required=False,
|
|
||||||
),
|
|
||||||
OpenApiParameter(
|
|
||||||
name="min_star",
|
|
||||||
description=_("Minimum number of likes"),
|
|
||||||
type=OpenApiTypes.INT,
|
|
||||||
required=False,
|
|
||||||
),
|
|
||||||
OpenApiParameter(
|
|
||||||
name="min_trample",
|
|
||||||
description=_("Minimum number of clicks"),
|
|
||||||
type=OpenApiTypes.INT,
|
|
||||||
required=False,
|
|
||||||
),
|
|
||||||
OpenApiParameter(
|
|
||||||
name="comparer",
|
|
||||||
description=_("Comparator"),
|
|
||||||
type=OpenApiTypes.STR,
|
|
||||||
required=False,
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_response():
|
|
||||||
return ApplicationChatListResponseSerializers
|
|
||||||
|
|
||||||
|
|
||||||
class ApplicationChatQueryPageAPI(APIMixin):
|
|
||||||
@staticmethod
|
|
||||||
def get_request():
|
|
||||||
return ApplicationChatQueryAPI.get_request()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_parameters():
|
|
||||||
return [
|
|
||||||
*ApplicationChatQueryAPI.get_parameters(),
|
|
||||||
OpenApiParameter(
|
|
||||||
name="current_page",
|
|
||||||
description=_("Current page"),
|
|
||||||
type=OpenApiTypes.INT,
|
|
||||||
location='path',
|
|
||||||
required=True,
|
|
||||||
),
|
|
||||||
OpenApiParameter(
|
|
||||||
name="page_size",
|
|
||||||
description=_("Page size"),
|
|
||||||
type=OpenApiTypes.INT,
|
|
||||||
location='path',
|
|
||||||
required=True,
|
|
||||||
),
|
|
||||||
|
|
||||||
]
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_response():
|
|
||||||
return ApplicationChatPageResponseSerializers
|
|
||||||
|
|
||||||
|
|
||||||
class ApplicationChatExportAPI(APIMixin):
|
|
||||||
@staticmethod
|
|
||||||
def get_request():
|
|
||||||
return ApplicationChatRecordExportRequest
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_parameters():
|
|
||||||
return ApplicationChatQueryAPI.get_parameters()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_response():
|
|
||||||
return None
|
|
||||||
|
|
@ -1,180 +0,0 @@
|
||||||
# coding=utf-8
|
|
||||||
"""
|
|
||||||
@project: MaxKB
|
|
||||||
@Author:虎虎
|
|
||||||
@file: application_chat_record.py
|
|
||||||
@date:2025/6/10 15:19
|
|
||||||
@desc:
|
|
||||||
"""
|
|
||||||
from django.utils.translation import gettext_lazy as _
|
|
||||||
from drf_spectacular.types import OpenApiTypes
|
|
||||||
from drf_spectacular.utils import OpenApiParameter
|
|
||||||
|
|
||||||
from application.serializers.application_chat_record import ApplicationChatRecordAddKnowledgeSerializer, \
|
|
||||||
ApplicationChatRecordImproveInstanceSerializer
|
|
||||||
from common.mixins.api_mixin import APIMixin
|
|
||||||
|
|
||||||
|
|
||||||
class ApplicationChatRecordQueryAPI(APIMixin):
|
|
||||||
@staticmethod
|
|
||||||
def get_response():
|
|
||||||
pass
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_request():
|
|
||||||
pass
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_parameters():
|
|
||||||
return [
|
|
||||||
OpenApiParameter(
|
|
||||||
name="workspace_id",
|
|
||||||
description="工作空间id",
|
|
||||||
type=OpenApiTypes.STR,
|
|
||||||
location='path',
|
|
||||||
required=True,
|
|
||||||
),
|
|
||||||
OpenApiParameter(
|
|
||||||
name="application_id",
|
|
||||||
description="Application ID",
|
|
||||||
type=OpenApiTypes.STR,
|
|
||||||
location='path',
|
|
||||||
required=True,
|
|
||||||
),
|
|
||||||
OpenApiParameter(
|
|
||||||
name="chat_id",
|
|
||||||
description=_("Chat ID"),
|
|
||||||
type=OpenApiTypes.STR,
|
|
||||||
location='path',
|
|
||||||
required=True,
|
|
||||||
),
|
|
||||||
OpenApiParameter(
|
|
||||||
name="order_asc",
|
|
||||||
description=_("Is it in order"),
|
|
||||||
type=OpenApiTypes.BOOL,
|
|
||||||
required=True,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class ApplicationChatRecordPageQueryAPI(APIMixin):
|
|
||||||
@staticmethod
|
|
||||||
def get_response():
|
|
||||||
pass
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_request():
|
|
||||||
pass
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_parameters():
|
|
||||||
return [*ApplicationChatRecordQueryAPI.get_parameters(),
|
|
||||||
OpenApiParameter(
|
|
||||||
name="current_page",
|
|
||||||
description=_("Current page"),
|
|
||||||
type=OpenApiTypes.INT,
|
|
||||||
location='path',
|
|
||||||
required=True,
|
|
||||||
),
|
|
||||||
OpenApiParameter(
|
|
||||||
name="page_size",
|
|
||||||
description=_("Page size"),
|
|
||||||
type=OpenApiTypes.INT,
|
|
||||||
location='path',
|
|
||||||
required=True,
|
|
||||||
)]
|
|
||||||
|
|
||||||
|
|
||||||
class ApplicationChatRecordImproveParagraphAPI(APIMixin):
|
|
||||||
@staticmethod
|
|
||||||
def get_response():
|
|
||||||
pass
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_request():
|
|
||||||
return ApplicationChatRecordImproveInstanceSerializer
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_parameters():
|
|
||||||
return [OpenApiParameter(
|
|
||||||
name="workspace_id",
|
|
||||||
description="工作空间id",
|
|
||||||
type=OpenApiTypes.STR,
|
|
||||||
location='path',
|
|
||||||
required=True,
|
|
||||||
),
|
|
||||||
OpenApiParameter(
|
|
||||||
name="application_id",
|
|
||||||
description="Application ID",
|
|
||||||
type=OpenApiTypes.STR,
|
|
||||||
location='path',
|
|
||||||
required=True,
|
|
||||||
),
|
|
||||||
OpenApiParameter(
|
|
||||||
name="chat_id",
|
|
||||||
description=_("Chat ID"),
|
|
||||||
type=OpenApiTypes.STR,
|
|
||||||
location='path',
|
|
||||||
required=True,
|
|
||||||
),
|
|
||||||
OpenApiParameter(
|
|
||||||
name="chat_record_id",
|
|
||||||
description=_("Chat Record ID"),
|
|
||||||
type=OpenApiTypes.STR,
|
|
||||||
location='path',
|
|
||||||
required=True,
|
|
||||||
),
|
|
||||||
OpenApiParameter(
|
|
||||||
name="knowledge_id",
|
|
||||||
description=_("Knowledge ID"),
|
|
||||||
type=OpenApiTypes.STR,
|
|
||||||
location='path',
|
|
||||||
required=True,
|
|
||||||
),
|
|
||||||
OpenApiParameter(
|
|
||||||
name="document_id",
|
|
||||||
description=_("Document ID"),
|
|
||||||
type=OpenApiTypes.STR,
|
|
||||||
location='path',
|
|
||||||
required=True,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
class Operate(APIMixin):
|
|
||||||
@staticmethod
|
|
||||||
def get_parameters():
|
|
||||||
return [*ApplicationChatRecordImproveParagraphAPI.get_parameters(), OpenApiParameter(
|
|
||||||
name="paragraph_id",
|
|
||||||
description=_("Paragraph ID"),
|
|
||||||
type=OpenApiTypes.STR,
|
|
||||||
location='path',
|
|
||||||
required=True,
|
|
||||||
)]
|
|
||||||
|
|
||||||
|
|
||||||
class ApplicationChatRecordAddKnowledgeAPI(APIMixin):
|
|
||||||
@staticmethod
|
|
||||||
def get_request():
|
|
||||||
return ApplicationChatRecordAddKnowledgeSerializer
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_response():
|
|
||||||
return None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_parameters():
|
|
||||||
return [
|
|
||||||
OpenApiParameter(
|
|
||||||
name="workspace_id",
|
|
||||||
description="工作空间id",
|
|
||||||
type=OpenApiTypes.STR,
|
|
||||||
location='path',
|
|
||||||
required=True,
|
|
||||||
),
|
|
||||||
OpenApiParameter(
|
|
||||||
name="application_id",
|
|
||||||
description="Application ID",
|
|
||||||
type=OpenApiTypes.STR,
|
|
||||||
location='path',
|
|
||||||
required=True,
|
|
||||||
)]
|
|
||||||
|
|
@ -1,55 +0,0 @@
|
||||||
# coding=utf-8
|
|
||||||
"""
|
|
||||||
@project: MaxKB
|
|
||||||
@Author:虎虎
|
|
||||||
@file: application_stats.py
|
|
||||||
@date:2025/6/9 20:45
|
|
||||||
@desc:
|
|
||||||
"""
|
|
||||||
from drf_spectacular.types import OpenApiTypes
|
|
||||||
from drf_spectacular.utils import OpenApiParameter
|
|
||||||
|
|
||||||
from application.serializers.application_stats import ApplicationStatsSerializer
|
|
||||||
from common.mixins.api_mixin import APIMixin
|
|
||||||
from common.result import ResultSerializer
|
|
||||||
|
|
||||||
|
|
||||||
class ApplicationStatsResult(ResultSerializer):
|
|
||||||
def get_data(self):
|
|
||||||
return ApplicationStatsSerializer(many=True)
|
|
||||||
|
|
||||||
|
|
||||||
class ApplicationStatsAPI(APIMixin):
|
|
||||||
@staticmethod
|
|
||||||
def get_parameters():
|
|
||||||
return [OpenApiParameter(
|
|
||||||
name="workspace_id",
|
|
||||||
description="工作空间id",
|
|
||||||
type=OpenApiTypes.STR,
|
|
||||||
location='path',
|
|
||||||
required=True,
|
|
||||||
),
|
|
||||||
OpenApiParameter(
|
|
||||||
name="application_id",
|
|
||||||
description="application ID",
|
|
||||||
type=OpenApiTypes.STR,
|
|
||||||
location='path',
|
|
||||||
required=True,
|
|
||||||
),
|
|
||||||
OpenApiParameter(
|
|
||||||
name="start_time",
|
|
||||||
description="start Time",
|
|
||||||
type=OpenApiTypes.STR,
|
|
||||||
required=True,
|
|
||||||
),
|
|
||||||
OpenApiParameter(
|
|
||||||
name="end_time",
|
|
||||||
description="end Time",
|
|
||||||
type=OpenApiTypes.STR,
|
|
||||||
required=True,
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_response():
|
|
||||||
return ApplicationStatsResult
|
|
||||||
|
|
@ -1,96 +0,0 @@
|
||||||
# coding=utf-8
|
|
||||||
"""
|
|
||||||
@project: MaxKB
|
|
||||||
@Author:虎虎
|
|
||||||
@file: application_version.py
|
|
||||||
@date:2025/6/4 17:33
|
|
||||||
@desc:
|
|
||||||
"""
|
|
||||||
from drf_spectacular.types import OpenApiTypes
|
|
||||||
from drf_spectacular.utils import OpenApiParameter
|
|
||||||
|
|
||||||
from application.serializers.application_version import ApplicationVersionModelSerializer
|
|
||||||
from common.mixins.api_mixin import APIMixin
|
|
||||||
from common.result import ResultSerializer, PageDataResponse, ResultPageSerializer
|
|
||||||
|
|
||||||
|
|
||||||
class ApplicationListVersionResult(ResultSerializer):
|
|
||||||
def get_data(self):
|
|
||||||
return ApplicationVersionModelSerializer(many=True)
|
|
||||||
|
|
||||||
|
|
||||||
class ApplicationPageVersionResult(ResultPageSerializer):
|
|
||||||
def get_data(self):
|
|
||||||
return ApplicationVersionModelSerializer(many=True)
|
|
||||||
|
|
||||||
|
|
||||||
class ApplicationWorkflowVersionResult(ResultSerializer):
|
|
||||||
def get_data(self):
|
|
||||||
return ApplicationVersionModelSerializer()
|
|
||||||
|
|
||||||
|
|
||||||
class ApplicationVersionAPI(APIMixin):
|
|
||||||
@staticmethod
|
|
||||||
def get_parameters():
|
|
||||||
return [
|
|
||||||
OpenApiParameter(
|
|
||||||
name="workspace_id",
|
|
||||||
description="工作空间id",
|
|
||||||
type=OpenApiTypes.STR,
|
|
||||||
location='path',
|
|
||||||
required=True,
|
|
||||||
),
|
|
||||||
OpenApiParameter(
|
|
||||||
name="application_id",
|
|
||||||
description="application ID",
|
|
||||||
type=OpenApiTypes.STR,
|
|
||||||
location='path',
|
|
||||||
required=True,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class ApplicationVersionOperateAPI(APIMixin):
|
|
||||||
@staticmethod
|
|
||||||
def get_parameters():
|
|
||||||
return [
|
|
||||||
OpenApiParameter(
|
|
||||||
name="application_version_id",
|
|
||||||
description="工作流版本id",
|
|
||||||
type=OpenApiTypes.STR,
|
|
||||||
location='path',
|
|
||||||
required=True,
|
|
||||||
)
|
|
||||||
, *ApplicationVersionAPI.get_parameters()
|
|
||||||
]
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_response():
|
|
||||||
return ApplicationWorkflowVersionResult
|
|
||||||
|
|
||||||
|
|
||||||
class ApplicationVersionListAPI(APIMixin):
|
|
||||||
@staticmethod
|
|
||||||
def get_parameters():
|
|
||||||
return [
|
|
||||||
OpenApiParameter(
|
|
||||||
name="name",
|
|
||||||
description="Version Name",
|
|
||||||
type=OpenApiTypes.STR,
|
|
||||||
required=False,
|
|
||||||
)
|
|
||||||
, *ApplicationVersionAPI.get_parameters()]
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_response():
|
|
||||||
return ApplicationListVersionResult
|
|
||||||
|
|
||||||
|
|
||||||
class ApplicationVersionPageAPI(APIMixin):
|
|
||||||
@staticmethod
|
|
||||||
def get_parameters():
|
|
||||||
return ApplicationVersionListAPI.get_parameters()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_response():
|
|
||||||
return ApplicationPageVersionResult
|
|
||||||
|
|
@ -12,45 +12,42 @@ from typing import Type
|
||||||
|
|
||||||
from rest_framework import serializers
|
from rest_framework import serializers
|
||||||
|
|
||||||
from knowledge.models import Paragraph
|
from dataset.models import Paragraph
|
||||||
|
|
||||||
|
|
||||||
class ParagraphPipelineModel:
|
class ParagraphPipelineModel:
|
||||||
|
|
||||||
def __init__(self, _id: str, document_id: str, knowledge_id: str, content: str, title: str, status: str,
|
def __init__(self, _id: str, document_id: str, dataset_id: str, content: str, title: str, status: str,
|
||||||
is_active: bool, comprehensive_score: float, similarity: float, knowledge_name: str,
|
is_active: bool, comprehensive_score: float, similarity: float, dataset_name: str, document_name: str,
|
||||||
document_name: str,
|
hit_handling_method: str, directly_return_similarity: float, meta: dict = None):
|
||||||
hit_handling_method: str, directly_return_similarity: float, knowledge_type, meta: dict = None):
|
|
||||||
self.id = _id
|
self.id = _id
|
||||||
self.document_id = document_id
|
self.document_id = document_id
|
||||||
self.knowledge_id = knowledge_id
|
self.dataset_id = dataset_id
|
||||||
self.content = content
|
self.content = content
|
||||||
self.title = title
|
self.title = title
|
||||||
self.status = status,
|
self.status = status,
|
||||||
self.is_active = is_active
|
self.is_active = is_active
|
||||||
self.comprehensive_score = comprehensive_score
|
self.comprehensive_score = comprehensive_score
|
||||||
self.similarity = similarity
|
self.similarity = similarity
|
||||||
self.knowledge_name = knowledge_name
|
self.dataset_name = dataset_name
|
||||||
self.document_name = document_name
|
self.document_name = document_name
|
||||||
self.hit_handling_method = hit_handling_method
|
self.hit_handling_method = hit_handling_method
|
||||||
self.directly_return_similarity = directly_return_similarity
|
self.directly_return_similarity = directly_return_similarity
|
||||||
self.meta = meta
|
self.meta = meta
|
||||||
self.knowledge_type = knowledge_type
|
|
||||||
|
|
||||||
def to_dict(self):
|
def to_dict(self):
|
||||||
return {
|
return {
|
||||||
'id': self.id,
|
'id': self.id,
|
||||||
'document_id': self.document_id,
|
'document_id': self.document_id,
|
||||||
'knowledge_id': self.knowledge_id,
|
'dataset_id': self.dataset_id,
|
||||||
'content': self.content,
|
'content': self.content,
|
||||||
'title': self.title,
|
'title': self.title,
|
||||||
'status': self.status,
|
'status': self.status,
|
||||||
'is_active': self.is_active,
|
'is_active': self.is_active,
|
||||||
'comprehensive_score': self.comprehensive_score,
|
'comprehensive_score': self.comprehensive_score,
|
||||||
'similarity': self.similarity,
|
'similarity': self.similarity,
|
||||||
'knowledge_name': self.knowledge_name,
|
'dataset_name': self.dataset_name,
|
||||||
'document_name': self.document_name,
|
'document_name': self.document_name,
|
||||||
'knowledge_type': self.knowledge_type,
|
|
||||||
'meta': self.meta,
|
'meta': self.meta,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -60,8 +57,7 @@ class ParagraphPipelineModel:
|
||||||
self.paragraph = {}
|
self.paragraph = {}
|
||||||
self.comprehensive_score = None
|
self.comprehensive_score = None
|
||||||
self.document_name = None
|
self.document_name = None
|
||||||
self.knowledge_name = None
|
self.dataset_name = None
|
||||||
self.knowledge_type = None
|
|
||||||
self.hit_handling_method = None
|
self.hit_handling_method = None
|
||||||
self.directly_return_similarity = 0.9
|
self.directly_return_similarity = 0.9
|
||||||
self.meta = {}
|
self.meta = {}
|
||||||
|
|
@ -70,7 +66,7 @@ class ParagraphPipelineModel:
|
||||||
if isinstance(paragraph, Paragraph):
|
if isinstance(paragraph, Paragraph):
|
||||||
self.paragraph = {'id': paragraph.id,
|
self.paragraph = {'id': paragraph.id,
|
||||||
'document_id': paragraph.document_id,
|
'document_id': paragraph.document_id,
|
||||||
'knowledge_id': paragraph.knowledge_id,
|
'dataset_id': paragraph.dataset_id,
|
||||||
'content': paragraph.content,
|
'content': paragraph.content,
|
||||||
'title': paragraph.title,
|
'title': paragraph.title,
|
||||||
'status': paragraph.status,
|
'status': paragraph.status,
|
||||||
|
|
@ -80,12 +76,8 @@ class ParagraphPipelineModel:
|
||||||
self.paragraph = paragraph
|
self.paragraph = paragraph
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def add_knowledge_name(self, knowledge_name):
|
def add_dataset_name(self, dataset_name):
|
||||||
self.knowledge_name = knowledge_name
|
self.dataset_name = dataset_name
|
||||||
return self
|
|
||||||
|
|
||||||
def add_knowledge_type(self, knowledge_type):
|
|
||||||
self.knowledge_type = knowledge_type
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def add_document_name(self, document_name):
|
def add_document_name(self, document_name):
|
||||||
|
|
@ -114,13 +106,12 @@ class ParagraphPipelineModel:
|
||||||
|
|
||||||
def build(self):
|
def build(self):
|
||||||
return ParagraphPipelineModel(str(self.paragraph.get('id')), str(self.paragraph.get('document_id')),
|
return ParagraphPipelineModel(str(self.paragraph.get('id')), str(self.paragraph.get('document_id')),
|
||||||
str(self.paragraph.get('knowledge_id')),
|
str(self.paragraph.get('dataset_id')),
|
||||||
self.paragraph.get('content'), self.paragraph.get('title'),
|
self.paragraph.get('content'), self.paragraph.get('title'),
|
||||||
self.paragraph.get('status'),
|
self.paragraph.get('status'),
|
||||||
self.paragraph.get('is_active'),
|
self.paragraph.get('is_active'),
|
||||||
self.comprehensive_score, self.similarity, self.knowledge_name,
|
self.comprehensive_score, self.similarity, self.dataset_name,
|
||||||
self.document_name, self.hit_handling_method, self.directly_return_similarity,
|
self.document_name, self.hit_handling_method, self.directly_return_similarity,
|
||||||
self.knowledge_type,
|
|
||||||
self.meta)
|
self.meta)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -17,14 +17,12 @@ from common.handle.impl.response.system_to_response import SystemToResponse
|
||||||
|
|
||||||
class PipelineManage:
|
class PipelineManage:
|
||||||
def __init__(self, step_list: List[Type[IBaseChatPipelineStep]],
|
def __init__(self, step_list: List[Type[IBaseChatPipelineStep]],
|
||||||
base_to_response: BaseToResponse = SystemToResponse(),
|
base_to_response: BaseToResponse = SystemToResponse()):
|
||||||
debug=False):
|
|
||||||
# 步骤执行器
|
# 步骤执行器
|
||||||
self.step_list = [step() for step in step_list]
|
self.step_list = [step() for step in step_list]
|
||||||
# 上下文
|
# 上下文
|
||||||
self.context = {'message_tokens': 0, 'answer_tokens': 0}
|
self.context = {'message_tokens': 0, 'answer_tokens': 0}
|
||||||
self.base_to_response = base_to_response
|
self.base_to_response = base_to_response
|
||||||
self.debug = debug
|
|
||||||
|
|
||||||
def run(self, context: Dict = None):
|
def run(self, context: Dict = None):
|
||||||
self.context['start_time'] = time.time()
|
self.context['start_time'] = time.time()
|
||||||
|
|
@ -46,7 +44,6 @@ class PipelineManage:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.step_list: List[Type[IBaseChatPipelineStep]] = []
|
self.step_list: List[Type[IBaseChatPipelineStep]] = []
|
||||||
self.base_to_response = SystemToResponse()
|
self.base_to_response = SystemToResponse()
|
||||||
self.debug = False
|
|
||||||
|
|
||||||
def append_step(self, step: Type[IBaseChatPipelineStep]):
|
def append_step(self, step: Type[IBaseChatPipelineStep]):
|
||||||
self.step_list.append(step)
|
self.step_list.append(step)
|
||||||
|
|
@ -56,9 +53,5 @@ class PipelineManage:
|
||||||
self.base_to_response = base_to_response
|
self.base_to_response = base_to_response
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def add_debug(self, debug):
|
|
||||||
self.debug = debug
|
|
||||||
return self
|
|
||||||
|
|
||||||
def build(self):
|
def build(self):
|
||||||
return PipelineManage(step_list=self.step_list, base_to_response=self.base_to_response, debug=self.debug)
|
return PipelineManage(step_list=self.step_list, base_to_response=self.base_to_response)
|
||||||
|
|
|
||||||
|
|
@ -16,8 +16,9 @@ from rest_framework import serializers
|
||||||
|
|
||||||
from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep, ParagraphPipelineModel
|
from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep, ParagraphPipelineModel
|
||||||
from application.chat_pipeline.pipeline_manage import PipelineManage
|
from application.chat_pipeline.pipeline_manage import PipelineManage
|
||||||
from application.serializers.application import NoReferencesSetting
|
from application.serializers.application_serializers import NoReferencesSetting
|
||||||
from common.field.common import InstanceField
|
from common.field.common import InstanceField
|
||||||
|
from common.util.field_message import ErrMessage
|
||||||
|
|
||||||
|
|
||||||
class ModelField(serializers.Field):
|
class ModelField(serializers.Field):
|
||||||
|
|
@ -44,7 +45,7 @@ class PostResponseHandler:
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def handler(self, chat_id, chat_record_id, paragraph_list: List[ParagraphPipelineModel], problem_text: str,
|
def handler(self, chat_id, chat_record_id, paragraph_list: List[ParagraphPipelineModel], problem_text: str,
|
||||||
answer_text,
|
answer_text,
|
||||||
manage, step, padding_problem_text: str = None, **kwargs):
|
manage, step, padding_problem_text: str = None, client_id=None, **kwargs):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -52,43 +53,35 @@ class IChatStep(IBaseChatPipelineStep):
|
||||||
class InstanceSerializer(serializers.Serializer):
|
class InstanceSerializer(serializers.Serializer):
|
||||||
# 对话列表
|
# 对话列表
|
||||||
message_list = serializers.ListField(required=True, child=MessageField(required=True),
|
message_list = serializers.ListField(required=True, child=MessageField(required=True),
|
||||||
label=_("Conversation list"))
|
error_messages=ErrMessage.list(_("Conversation list")))
|
||||||
model_id = serializers.UUIDField(required=False, allow_null=True, label=_("Model id"))
|
model_id = serializers.UUIDField(required=False, allow_null=True, error_messages=ErrMessage.uuid(_("Model id")))
|
||||||
# 段落列表
|
# 段落列表
|
||||||
paragraph_list = serializers.ListField(label=_("Paragraph List"))
|
paragraph_list = serializers.ListField(error_messages=ErrMessage.list(_("Paragraph List")))
|
||||||
# 对话id
|
# 对话id
|
||||||
chat_id = serializers.UUIDField(required=True, label=_("Conversation ID"))
|
chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid(_("Conversation ID")))
|
||||||
# 用户问题
|
# 用户问题
|
||||||
problem_text = serializers.CharField(required=True, label=_("User Questions"))
|
problem_text = serializers.CharField(required=True, error_messages=ErrMessage.uuid(_("User Questions")))
|
||||||
# 后置处理器
|
# 后置处理器
|
||||||
post_response_handler = InstanceField(model_type=PostResponseHandler,
|
post_response_handler = InstanceField(model_type=PostResponseHandler,
|
||||||
label=_("Post-processor"))
|
error_messages=ErrMessage.base(_("Post-processor")))
|
||||||
# 补全问题
|
# 补全问题
|
||||||
padding_problem_text = serializers.CharField(required=False,
|
padding_problem_text = serializers.CharField(required=False,
|
||||||
label=_("Completion Question"))
|
error_messages=ErrMessage.base(_("Completion Question")))
|
||||||
# 是否使用流的形式输出
|
# 是否使用流的形式输出
|
||||||
stream = serializers.BooleanField(required=False, label=_("Streaming Output"))
|
stream = serializers.BooleanField(required=False, error_messages=ErrMessage.base(_("Streaming Output")))
|
||||||
chat_user_id = serializers.CharField(required=True, label=_("Chat user id"))
|
client_id = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Client id")))
|
||||||
|
client_type = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Client Type")))
|
||||||
chat_user_type = serializers.CharField(required=True, label=_("Chat user Type"))
|
|
||||||
# 未查询到引用分段
|
# 未查询到引用分段
|
||||||
no_references_setting = NoReferencesSetting(required=True,
|
no_references_setting = NoReferencesSetting(required=True,
|
||||||
label=_("No reference segment settings"))
|
error_messages=ErrMessage.base(_("No reference segment settings")))
|
||||||
|
|
||||||
workspace_id = serializers.CharField(required=True, label=_("Workspace ID"))
|
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid(_("User ID")))
|
||||||
|
|
||||||
model_setting = serializers.DictField(required=True, allow_null=True,
|
model_setting = serializers.DictField(required=True, allow_null=True,
|
||||||
label=_("Model settings"))
|
error_messages=ErrMessage.dict(_("Model settings")))
|
||||||
|
|
||||||
model_params_setting = serializers.DictField(required=False, allow_null=True,
|
model_params_setting = serializers.DictField(required=False, allow_null=True,
|
||||||
label=_("Model parameter settings"))
|
error_messages=ErrMessage.dict(_("Model parameter settings")))
|
||||||
mcp_enable = serializers.BooleanField(label="MCP否启用", required=False, default=False)
|
|
||||||
mcp_tool_ids = serializers.JSONField(label="MCP工具ID列表", required=False, default=list)
|
|
||||||
mcp_servers = serializers.JSONField(label="MCP服务列表", required=False, default=dict)
|
|
||||||
mcp_source = serializers.CharField(label="MCP Source", required=False, default="referencing")
|
|
||||||
tool_enable = serializers.BooleanField(label="工具是否启用", required=False, default=False)
|
|
||||||
tool_ids = serializers.JSONField(label="工具ID列表", required=False, default=list)
|
|
||||||
mcp_output_enable = serializers.BooleanField(label="MCP输出是否启用", required=False, default=True)
|
|
||||||
|
|
||||||
def is_valid(self, *, raise_exception=False):
|
def is_valid(self, *, raise_exception=False):
|
||||||
super().is_valid(raise_exception=True)
|
super().is_valid(raise_exception=True)
|
||||||
|
|
@ -109,12 +102,9 @@ class IChatStep(IBaseChatPipelineStep):
|
||||||
chat_id, problem_text,
|
chat_id, problem_text,
|
||||||
post_response_handler: PostResponseHandler,
|
post_response_handler: PostResponseHandler,
|
||||||
model_id: str = None,
|
model_id: str = None,
|
||||||
workspace_id: str = None,
|
user_id: str = None,
|
||||||
paragraph_list=None,
|
paragraph_list=None,
|
||||||
manage: PipelineManage = None,
|
manage: PipelineManage = None,
|
||||||
padding_problem_text: str = None, stream: bool = True, chat_user_id=None, chat_user_type=None,
|
padding_problem_text: str = None, stream: bool = True, client_id=None, client_type=None,
|
||||||
no_references_setting=None, model_params_setting=None, model_setting=None,
|
no_references_setting=None, model_params_setting=None, model_setting=None, **kwargs):
|
||||||
mcp_enable=False, mcp_tool_ids=None, mcp_servers='', mcp_source="referencing",
|
|
||||||
tool_enable=False, tool_ids=None, mcp_output_enable=True,
|
|
||||||
**kwargs):
|
|
||||||
pass
|
pass
|
||||||
|
|
|
||||||
|
|
@ -6,11 +6,10 @@
|
||||||
@date:2024/1/9 18:25
|
@date:2024/1/9 18:25
|
||||||
@desc: 对话step Base实现
|
@desc: 对话step Base实现
|
||||||
"""
|
"""
|
||||||
import json
|
import logging
|
||||||
import os
|
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
import uuid_utils.compat as uuid
|
import uuid
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from django.db.models import QuerySet
|
from django.db.models import QuerySet
|
||||||
|
|
@ -19,28 +18,22 @@ from django.utils.translation import gettext as _
|
||||||
from langchain.chat_models.base import BaseChatModel
|
from langchain.chat_models.base import BaseChatModel
|
||||||
from langchain.schema import BaseMessage
|
from langchain.schema import BaseMessage
|
||||||
from langchain.schema.messages import HumanMessage, AIMessage
|
from langchain.schema.messages import HumanMessage, AIMessage
|
||||||
from langchain_core.messages import AIMessageChunk, SystemMessage
|
from langchain_core.messages import AIMessageChunk
|
||||||
from rest_framework import status
|
from rest_framework import status
|
||||||
|
|
||||||
from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel
|
from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel
|
||||||
from application.chat_pipeline.pipeline_manage import PipelineManage
|
from application.chat_pipeline.pipeline_manage import PipelineManage
|
||||||
from application.chat_pipeline.step.chat_step.i_chat_step import IChatStep, PostResponseHandler
|
from application.chat_pipeline.step.chat_step.i_chat_step import IChatStep, PostResponseHandler
|
||||||
from application.flow.tools import Reasoning, mcp_response_generator
|
from application.flow.tools import Reasoning
|
||||||
from application.models import ApplicationChatUserStats, ChatUserType
|
from application.models.api_key_model import ApplicationPublicAccessClient
|
||||||
from common.utils.logger import maxkb_logger
|
from common.constants.authentication_type import AuthenticationType
|
||||||
from common.utils.rsa_util import rsa_long_decrypt
|
from setting.models_provider.tools import get_model_instance_by_model_user_id
|
||||||
from common.utils.tool_code import ToolExecutor
|
|
||||||
from maxkb.const import CONFIG
|
|
||||||
from models_provider.tools import get_model_instance_by_model_workspace_id
|
|
||||||
from tools.models import Tool
|
|
||||||
|
|
||||||
|
|
||||||
def add_access_num(chat_user_id=None, chat_user_type=None, application_id=None):
|
def add_access_num(client_id=None, client_type=None, application_id=None):
|
||||||
if [ChatUserType.ANONYMOUS_USER.value, ChatUserType.CHAT_USER.value].__contains__(
|
if client_type == AuthenticationType.APPLICATION_ACCESS_TOKEN.value and application_id is not None:
|
||||||
chat_user_type) and application_id is not None:
|
application_public_access_client = (QuerySet(ApplicationPublicAccessClient).filter(client_id=client_id,
|
||||||
application_public_access_client = (QuerySet(ApplicationChatUserStats).filter(chat_user_id=chat_user_id,
|
application_id=application_id)
|
||||||
chat_user_type=chat_user_type,
|
|
||||||
application_id=application_id)
|
|
||||||
.first())
|
.first())
|
||||||
if application_public_access_client is not None:
|
if application_public_access_client is not None:
|
||||||
application_public_access_client.access_num = application_public_access_client.access_num + 1
|
application_public_access_client.access_num = application_public_access_client.access_num + 1
|
||||||
|
|
@ -59,7 +52,6 @@ def write_context(step, manage, request_token, response_token, all_text):
|
||||||
manage.context['answer_tokens'] = manage.context['answer_tokens'] + response_token
|
manage.context['answer_tokens'] = manage.context['answer_tokens'] + response_token
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def event_content(response,
|
def event_content(response,
|
||||||
chat_id,
|
chat_id,
|
||||||
chat_record_id,
|
chat_record_id,
|
||||||
|
|
@ -71,7 +63,7 @@ def event_content(response,
|
||||||
message_list: List[BaseMessage],
|
message_list: List[BaseMessage],
|
||||||
problem_text: str,
|
problem_text: str,
|
||||||
padding_problem_text: str = None,
|
padding_problem_text: str = None,
|
||||||
chat_user_id=None, chat_user_type=None,
|
client_id=None, client_type=None,
|
||||||
is_ai_chat: bool = None,
|
is_ai_chat: bool = None,
|
||||||
model_setting=None):
|
model_setting=None):
|
||||||
if model_setting is None:
|
if model_setting is None:
|
||||||
|
|
@ -93,7 +85,6 @@ def event_content(response,
|
||||||
reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '')
|
reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '')
|
||||||
else:
|
else:
|
||||||
reasoning_content_chunk = reasoning_chunk.get('reasoning_content')
|
reasoning_content_chunk = reasoning_chunk.get('reasoning_content')
|
||||||
content_chunk = reasoning._normalize_content(content_chunk)
|
|
||||||
all_text += content_chunk
|
all_text += content_chunk
|
||||||
if reasoning_content_chunk is None:
|
if reasoning_content_chunk is None:
|
||||||
reasoning_content_chunk = ''
|
reasoning_content_chunk = ''
|
||||||
|
|
@ -133,24 +124,26 @@ def event_content(response,
|
||||||
request_token = 0
|
request_token = 0
|
||||||
response_token = 0
|
response_token = 0
|
||||||
write_context(step, manage, request_token, response_token, all_text)
|
write_context(step, manage, request_token, response_token, all_text)
|
||||||
|
asker = manage.context.get('form_data', {}).get('asker', None)
|
||||||
post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text,
|
post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text,
|
||||||
all_text, manage, step, padding_problem_text,
|
all_text, manage, step, padding_problem_text, client_id,
|
||||||
reasoning_content=reasoning_content if reasoning_content_enable else '')
|
reasoning_content=reasoning_content if reasoning_content_enable else ''
|
||||||
|
, asker=asker)
|
||||||
yield manage.get_base_to_response().to_stream_chunk_response(chat_id, str(chat_record_id), 'ai-chat-node',
|
yield manage.get_base_to_response().to_stream_chunk_response(chat_id, str(chat_record_id), 'ai-chat-node',
|
||||||
[], '', True,
|
[], '', True,
|
||||||
request_token, response_token,
|
request_token, response_token,
|
||||||
{'node_is_end': True, 'view_type': 'many_view',
|
{'node_is_end': True, 'view_type': 'many_view',
|
||||||
'node_type': 'ai-chat-node'})
|
'node_type': 'ai-chat-node'})
|
||||||
if not manage.debug:
|
add_access_num(client_id, client_type, manage.context.get('application_id'))
|
||||||
add_access_num(chat_user_id, chat_user_type, manage.context.get('application_id'))
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
maxkb_logger.error(f'{str(e)}:{traceback.format_exc()}')
|
logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}')
|
||||||
all_text = 'Exception:' + str(e)
|
all_text = 'Exception:' + str(e)
|
||||||
write_context(step, manage, 0, 0, all_text)
|
write_context(step, manage, 0, 0, all_text)
|
||||||
|
asker = manage.context.get('form_data', {}).get('asker', None)
|
||||||
post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text,
|
post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text,
|
||||||
all_text, manage, step, padding_problem_text, reasoning_content='')
|
all_text, manage, step, padding_problem_text, client_id, reasoning_content='',
|
||||||
if not manage.debug:
|
asker=asker)
|
||||||
add_access_num(chat_user_id, chat_user_type, manage.context.get('application_id'))
|
add_access_num(client_id, client_type, manage.context.get('application_id'))
|
||||||
yield manage.get_base_to_response().to_stream_chunk_response(chat_id, str(chat_record_id), 'ai-chat-node',
|
yield manage.get_base_to_response().to_stream_chunk_response(chat_id, str(chat_record_id), 'ai-chat-node',
|
||||||
[], all_text,
|
[], all_text,
|
||||||
False,
|
False,
|
||||||
|
|
@ -167,40 +160,28 @@ class BaseChatStep(IChatStep):
|
||||||
problem_text,
|
problem_text,
|
||||||
post_response_handler: PostResponseHandler,
|
post_response_handler: PostResponseHandler,
|
||||||
model_id: str = None,
|
model_id: str = None,
|
||||||
workspace_id: str = None,
|
user_id: str = None,
|
||||||
paragraph_list=None,
|
paragraph_list=None,
|
||||||
manage: PipelineManage = None,
|
manage: PipelineManage = None,
|
||||||
padding_problem_text: str = None,
|
padding_problem_text: str = None,
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
chat_user_id=None, chat_user_type=None,
|
client_id=None, client_type=None,
|
||||||
no_references_setting=None,
|
no_references_setting=None,
|
||||||
model_params_setting=None,
|
model_params_setting=None,
|
||||||
model_setting=None,
|
model_setting=None,
|
||||||
mcp_enable=False,
|
|
||||||
mcp_tool_ids=None,
|
|
||||||
mcp_servers='',
|
|
||||||
mcp_source="referencing",
|
|
||||||
tool_enable=False,
|
|
||||||
tool_ids=None,
|
|
||||||
mcp_output_enable=True,
|
|
||||||
**kwargs):
|
**kwargs):
|
||||||
chat_model = get_model_instance_by_model_workspace_id(model_id, workspace_id,
|
chat_model = get_model_instance_by_model_user_id(model_id, user_id,
|
||||||
**model_params_setting) if model_id is not None else None
|
**model_params_setting) if model_id is not None else None
|
||||||
if stream:
|
if stream:
|
||||||
return self.execute_stream(message_list, chat_id, problem_text, post_response_handler, chat_model,
|
return self.execute_stream(message_list, chat_id, problem_text, post_response_handler, chat_model,
|
||||||
paragraph_list,
|
paragraph_list,
|
||||||
manage, padding_problem_text, chat_user_id, chat_user_type,
|
manage, padding_problem_text, client_id, client_type, no_references_setting,
|
||||||
no_references_setting,
|
model_setting)
|
||||||
model_setting,
|
|
||||||
mcp_enable, mcp_tool_ids, mcp_servers, mcp_source, tool_enable, tool_ids,
|
|
||||||
mcp_output_enable)
|
|
||||||
else:
|
else:
|
||||||
return self.execute_block(message_list, chat_id, problem_text, post_response_handler, chat_model,
|
return self.execute_block(message_list, chat_id, problem_text, post_response_handler, chat_model,
|
||||||
paragraph_list,
|
paragraph_list,
|
||||||
manage, padding_problem_text, chat_user_id, chat_user_type, no_references_setting,
|
manage, padding_problem_text, client_id, client_type, no_references_setting,
|
||||||
model_setting,
|
model_setting)
|
||||||
mcp_enable, mcp_tool_ids, mcp_servers, mcp_source, tool_enable, tool_ids,
|
|
||||||
mcp_output_enable)
|
|
||||||
|
|
||||||
def get_details(self, manage, **kwargs):
|
def get_details(self, manage, **kwargs):
|
||||||
return {
|
return {
|
||||||
|
|
@ -216,69 +197,19 @@ class BaseChatStep(IChatStep):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def reset_message_list(message_list: List[BaseMessage], answer_text):
|
def reset_message_list(message_list: List[BaseMessage], answer_text):
|
||||||
result = [{'role': 'user' if isinstance(message, HumanMessage) else (
|
result = [{'role': 'user' if isinstance(message, HumanMessage) else 'ai', 'content': message.content} for
|
||||||
'system' if isinstance(message, SystemMessage) else 'ai'), 'content': message.content} for
|
|
||||||
message
|
message
|
||||||
in
|
in
|
||||||
message_list]
|
message_list]
|
||||||
result.append({'role': 'ai', 'content': answer_text})
|
result.append({'role': 'ai', 'content': answer_text})
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _handle_mcp_request(self, mcp_enable, tool_enable, mcp_source, mcp_servers, mcp_tool_ids, tool_ids,
|
@staticmethod
|
||||||
mcp_output_enable, chat_model, message_list):
|
def get_stream_result(message_list: List[BaseMessage],
|
||||||
if not mcp_enable and not tool_enable:
|
|
||||||
return None
|
|
||||||
|
|
||||||
mcp_servers_config = {}
|
|
||||||
|
|
||||||
# 迁移过来mcp_source是None
|
|
||||||
if mcp_source is None:
|
|
||||||
mcp_source = 'custom'
|
|
||||||
if mcp_enable:
|
|
||||||
# 兼容老数据
|
|
||||||
if not mcp_tool_ids:
|
|
||||||
mcp_tool_ids = []
|
|
||||||
if mcp_source == 'custom' and mcp_servers is not None and '"stdio"' not in mcp_servers:
|
|
||||||
mcp_servers_config = json.loads(mcp_servers)
|
|
||||||
elif mcp_tool_ids:
|
|
||||||
mcp_tools = QuerySet(Tool).filter(id__in=mcp_tool_ids).values()
|
|
||||||
for mcp_tool in mcp_tools:
|
|
||||||
if mcp_tool and mcp_tool['is_active']:
|
|
||||||
mcp_servers_config = {**mcp_servers_config, **json.loads(mcp_tool['code'])}
|
|
||||||
|
|
||||||
if tool_enable:
|
|
||||||
if tool_ids and len(tool_ids) > 0: # 如果有工具ID,则将其转换为MCP
|
|
||||||
self.context['tool_ids'] = tool_ids
|
|
||||||
for tool_id in tool_ids:
|
|
||||||
tool = QuerySet(Tool).filter(id=tool_id).first()
|
|
||||||
if tool is None or tool.is_active is False:
|
|
||||||
continue
|
|
||||||
executor = ToolExecutor()
|
|
||||||
if tool.init_params is not None:
|
|
||||||
params = json.loads(rsa_long_decrypt(tool.init_params))
|
|
||||||
else:
|
|
||||||
params = {}
|
|
||||||
tool_config = executor.get_tool_mcp_config(tool.code, params)
|
|
||||||
|
|
||||||
mcp_servers_config[str(tool.id)] = tool_config
|
|
||||||
|
|
||||||
if len(mcp_servers_config) > 0:
|
|
||||||
return mcp_response_generator(chat_model, message_list, json.dumps(mcp_servers_config), mcp_output_enable)
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_stream_result(self, message_list: List[BaseMessage],
|
|
||||||
chat_model: BaseChatModel = None,
|
chat_model: BaseChatModel = None,
|
||||||
paragraph_list=None,
|
paragraph_list=None,
|
||||||
no_references_setting=None,
|
no_references_setting=None,
|
||||||
problem_text=None,
|
problem_text=None):
|
||||||
mcp_enable=False,
|
|
||||||
mcp_tool_ids=None,
|
|
||||||
mcp_servers='',
|
|
||||||
mcp_source="referencing",
|
|
||||||
tool_enable=False,
|
|
||||||
tool_ids=None,
|
|
||||||
mcp_output_enable=True):
|
|
||||||
if paragraph_list is None:
|
if paragraph_list is None:
|
||||||
paragraph_list = []
|
paragraph_list = []
|
||||||
directly_return_chunk_list = [AIMessageChunk(content=paragraph.content)
|
directly_return_chunk_list = [AIMessageChunk(content=paragraph.content)
|
||||||
|
|
@ -294,13 +225,6 @@ class BaseChatStep(IChatStep):
|
||||||
return iter([AIMessageChunk(
|
return iter([AIMessageChunk(
|
||||||
_('Sorry, the AI model is not configured. Please go to the application to set up the AI model first.'))]), False
|
_('Sorry, the AI model is not configured. Please go to the application to set up the AI model first.'))]), False
|
||||||
else:
|
else:
|
||||||
# 处理 MCP 请求
|
|
||||||
mcp_result = self._handle_mcp_request(
|
|
||||||
mcp_enable, tool_enable, mcp_source, mcp_servers, mcp_tool_ids, tool_ids, mcp_output_enable, chat_model,
|
|
||||||
message_list,
|
|
||||||
)
|
|
||||||
if mcp_result:
|
|
||||||
return mcp_result, True
|
|
||||||
return chat_model.stream(message_list), True
|
return chat_model.stream(message_list), True
|
||||||
|
|
||||||
def execute_stream(self, message_list: List[BaseMessage],
|
def execute_stream(self, message_list: List[BaseMessage],
|
||||||
|
|
@ -311,44 +235,27 @@ class BaseChatStep(IChatStep):
|
||||||
paragraph_list=None,
|
paragraph_list=None,
|
||||||
manage: PipelineManage = None,
|
manage: PipelineManage = None,
|
||||||
padding_problem_text: str = None,
|
padding_problem_text: str = None,
|
||||||
chat_user_id=None, chat_user_type=None,
|
client_id=None, client_type=None,
|
||||||
no_references_setting=None,
|
no_references_setting=None,
|
||||||
model_setting=None,
|
model_setting=None):
|
||||||
mcp_enable=False,
|
|
||||||
mcp_tool_ids=None,
|
|
||||||
mcp_servers='',
|
|
||||||
mcp_source="referencing",
|
|
||||||
tool_enable=False,
|
|
||||||
tool_ids=None,
|
|
||||||
mcp_output_enable=True):
|
|
||||||
chat_result, is_ai_chat = self.get_stream_result(message_list, chat_model, paragraph_list,
|
chat_result, is_ai_chat = self.get_stream_result(message_list, chat_model, paragraph_list,
|
||||||
no_references_setting, problem_text, mcp_enable, mcp_tool_ids,
|
no_references_setting, problem_text)
|
||||||
mcp_servers, mcp_source, tool_enable, tool_ids,
|
chat_record_id = uuid.uuid1()
|
||||||
mcp_output_enable)
|
|
||||||
chat_record_id = uuid.uuid7()
|
|
||||||
r = StreamingHttpResponse(
|
r = StreamingHttpResponse(
|
||||||
streaming_content=event_content(chat_result, chat_id, chat_record_id, paragraph_list,
|
streaming_content=event_content(chat_result, chat_id, chat_record_id, paragraph_list,
|
||||||
post_response_handler, manage, self, chat_model, message_list, problem_text,
|
post_response_handler, manage, self, chat_model, message_list, problem_text,
|
||||||
padding_problem_text, chat_user_id, chat_user_type, is_ai_chat,
|
padding_problem_text, client_id, client_type, is_ai_chat, model_setting),
|
||||||
model_setting),
|
|
||||||
content_type='text/event-stream;charset=utf-8')
|
content_type='text/event-stream;charset=utf-8')
|
||||||
|
|
||||||
r['Cache-Control'] = 'no-cache'
|
r['Cache-Control'] = 'no-cache'
|
||||||
return r
|
return r
|
||||||
|
|
||||||
def get_block_result(self, message_list: List[BaseMessage],
|
@staticmethod
|
||||||
|
def get_block_result(message_list: List[BaseMessage],
|
||||||
chat_model: BaseChatModel = None,
|
chat_model: BaseChatModel = None,
|
||||||
paragraph_list=None,
|
paragraph_list=None,
|
||||||
no_references_setting=None,
|
no_references_setting=None,
|
||||||
problem_text=None,
|
problem_text=None):
|
||||||
mcp_enable=False,
|
|
||||||
mcp_tool_ids=None,
|
|
||||||
mcp_servers='',
|
|
||||||
mcp_source="referencing",
|
|
||||||
tool_enable=False,
|
|
||||||
tool_ids=None,
|
|
||||||
mcp_output_enable=True
|
|
||||||
):
|
|
||||||
if paragraph_list is None:
|
if paragraph_list is None:
|
||||||
paragraph_list = []
|
paragraph_list = []
|
||||||
directly_return_chunk_list = [AIMessageChunk(content=paragraph.content)
|
directly_return_chunk_list = [AIMessageChunk(content=paragraph.content)
|
||||||
|
|
@ -363,13 +270,6 @@ class BaseChatStep(IChatStep):
|
||||||
return AIMessage(
|
return AIMessage(
|
||||||
_('Sorry, the AI model is not configured. Please go to the application to set up the AI model first.')), False
|
_('Sorry, the AI model is not configured. Please go to the application to set up the AI model first.')), False
|
||||||
else:
|
else:
|
||||||
# 处理 MCP 请求
|
|
||||||
mcp_result = self._handle_mcp_request(
|
|
||||||
mcp_enable, tool_enable, mcp_source, mcp_servers, mcp_tool_ids, tool_ids, mcp_output_enable,
|
|
||||||
chat_model, message_list,
|
|
||||||
)
|
|
||||||
if mcp_result:
|
|
||||||
return mcp_result, True
|
|
||||||
return chat_model.invoke(message_list), True
|
return chat_model.invoke(message_list), True
|
||||||
|
|
||||||
def execute_block(self, message_list: List[BaseMessage],
|
def execute_block(self, message_list: List[BaseMessage],
|
||||||
|
|
@ -380,27 +280,18 @@ class BaseChatStep(IChatStep):
|
||||||
paragraph_list=None,
|
paragraph_list=None,
|
||||||
manage: PipelineManage = None,
|
manage: PipelineManage = None,
|
||||||
padding_problem_text: str = None,
|
padding_problem_text: str = None,
|
||||||
chat_user_id=None, chat_user_type=None, no_references_setting=None,
|
client_id=None, client_type=None, no_references_setting=None,
|
||||||
model_setting=None,
|
model_setting=None):
|
||||||
mcp_enable=False,
|
|
||||||
mcp_tool_ids=None,
|
|
||||||
mcp_servers='',
|
|
||||||
mcp_source="referencing",
|
|
||||||
tool_enable=False,
|
|
||||||
tool_ids=None,
|
|
||||||
mcp_output_enable=True):
|
|
||||||
reasoning_content_enable = model_setting.get('reasoning_content_enable', False)
|
reasoning_content_enable = model_setting.get('reasoning_content_enable', False)
|
||||||
reasoning_content_start = model_setting.get('reasoning_content_start', '<think>')
|
reasoning_content_start = model_setting.get('reasoning_content_start', '<think>')
|
||||||
reasoning_content_end = model_setting.get('reasoning_content_end', '</think>')
|
reasoning_content_end = model_setting.get('reasoning_content_end', '</think>')
|
||||||
reasoning = Reasoning(reasoning_content_start,
|
reasoning = Reasoning(reasoning_content_start,
|
||||||
reasoning_content_end)
|
reasoning_content_end)
|
||||||
chat_record_id = uuid.uuid7()
|
chat_record_id = uuid.uuid1()
|
||||||
# 调用模型
|
# 调用模型
|
||||||
try:
|
try:
|
||||||
chat_result, is_ai_chat = self.get_block_result(message_list, chat_model, paragraph_list,
|
chat_result, is_ai_chat = self.get_block_result(message_list, chat_model, paragraph_list,
|
||||||
no_references_setting, problem_text, mcp_enable,
|
no_references_setting, problem_text)
|
||||||
mcp_tool_ids, mcp_servers, mcp_source, tool_enable,
|
|
||||||
tool_ids, mcp_output_enable)
|
|
||||||
if is_ai_chat:
|
if is_ai_chat:
|
||||||
request_token = chat_model.get_num_tokens_from_messages(message_list)
|
request_token = chat_model.get_num_tokens_from_messages(message_list)
|
||||||
response_token = chat_model.get_num_tokens(chat_result.content)
|
response_token = chat_model.get_num_tokens(chat_result.content)
|
||||||
|
|
@ -412,15 +303,16 @@ class BaseChatStep(IChatStep):
|
||||||
reasoning_result_end = reasoning.get_end_reasoning_content()
|
reasoning_result_end = reasoning.get_end_reasoning_content()
|
||||||
content = reasoning_result.get('content') + reasoning_result_end.get('content')
|
content = reasoning_result.get('content') + reasoning_result_end.get('content')
|
||||||
if 'reasoning_content' in chat_result.response_metadata:
|
if 'reasoning_content' in chat_result.response_metadata:
|
||||||
reasoning_content = (chat_result.response_metadata.get('reasoning_content', '') or '')
|
reasoning_content = chat_result.response_metadata.get('reasoning_content', '')
|
||||||
else:
|
else:
|
||||||
reasoning_content = (reasoning_result.get('reasoning_content') or "") + (reasoning_result_end.get(
|
reasoning_content = reasoning_result.get('reasoning_content') + reasoning_result_end.get(
|
||||||
'reasoning_content') or "")
|
'reasoning_content')
|
||||||
|
asker = manage.context.get('form_data', {}).get('asker', None)
|
||||||
post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text,
|
post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text,
|
||||||
content, manage, self, padding_problem_text,
|
content, manage, self, padding_problem_text, client_id,
|
||||||
reasoning_content=reasoning_content)
|
reasoning_content=reasoning_content if reasoning_content_enable else '',
|
||||||
if not manage.debug:
|
asker=asker)
|
||||||
add_access_num(chat_user_id, chat_user_type, manage.context.get('application_id'))
|
add_access_num(client_id, client_type, manage.context.get('application_id'))
|
||||||
return manage.get_base_to_response().to_block_response(str(chat_id), str(chat_record_id),
|
return manage.get_base_to_response().to_block_response(str(chat_id), str(chat_record_id),
|
||||||
content, True,
|
content, True,
|
||||||
request_token, response_token,
|
request_token, response_token,
|
||||||
|
|
@ -433,9 +325,10 @@ class BaseChatStep(IChatStep):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
all_text = 'Exception:' + str(e)
|
all_text = 'Exception:' + str(e)
|
||||||
write_context(self, manage, 0, 0, all_text)
|
write_context(self, manage, 0, 0, all_text)
|
||||||
|
asker = manage.context.get('form_data', {}).get('asker', None)
|
||||||
post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text,
|
post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text,
|
||||||
all_text, manage, self, padding_problem_text, reasoning_content='')
|
all_text, manage, self, padding_problem_text, client_id, reasoning_content='',
|
||||||
if not manage.debug:
|
asker=asker)
|
||||||
add_access_num(chat_user_id, chat_user_type, manage.context.get('application_id'))
|
add_access_num(client_id, client_type, manage.context.get('application_id'))
|
||||||
return manage.get_base_to_response().to_block_response(str(chat_id), str(chat_record_id), all_text, True, 0,
|
return manage.get_base_to_response().to_block_response(str(chat_id), str(chat_record_id), all_text, True, 0,
|
||||||
0, _status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
0, _status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||||||
|
|
|
||||||
|
|
@ -16,35 +16,34 @@ from rest_framework import serializers
|
||||||
from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep, ParagraphPipelineModel
|
from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep, ParagraphPipelineModel
|
||||||
from application.chat_pipeline.pipeline_manage import PipelineManage
|
from application.chat_pipeline.pipeline_manage import PipelineManage
|
||||||
from application.models import ChatRecord
|
from application.models import ChatRecord
|
||||||
from application.serializers.application import NoReferencesSetting
|
from application.serializers.application_serializers import NoReferencesSetting
|
||||||
from common.field.common import InstanceField
|
from common.field.common import InstanceField
|
||||||
|
from common.util.field_message import ErrMessage
|
||||||
|
|
||||||
|
|
||||||
class IGenerateHumanMessageStep(IBaseChatPipelineStep):
|
class IGenerateHumanMessageStep(IBaseChatPipelineStep):
|
||||||
class InstanceSerializer(serializers.Serializer):
|
class InstanceSerializer(serializers.Serializer):
|
||||||
# 问题
|
# 问题
|
||||||
problem_text = serializers.CharField(required=True, label=_("question"))
|
problem_text = serializers.CharField(required=True, error_messages=ErrMessage.char(_("question")))
|
||||||
# 段落列表
|
# 段落列表
|
||||||
paragraph_list = serializers.ListField(child=InstanceField(model_type=ParagraphPipelineModel, required=True),
|
paragraph_list = serializers.ListField(child=InstanceField(model_type=ParagraphPipelineModel, required=True),
|
||||||
label=_("Paragraph List"))
|
error_messages=ErrMessage.list(_("Paragraph List")))
|
||||||
# 历史对答
|
# 历史对答
|
||||||
history_chat_record = serializers.ListField(child=InstanceField(model_type=ChatRecord, required=True),
|
history_chat_record = serializers.ListField(child=InstanceField(model_type=ChatRecord, required=True),
|
||||||
label=_("History Questions"))
|
error_messages=ErrMessage.list(_("History Questions")))
|
||||||
# 多轮对话数量
|
# 多轮对话数量
|
||||||
dialogue_number = serializers.IntegerField(required=True, label=_("Number of multi-round conversations"))
|
dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer(_("Number of multi-round conversations")))
|
||||||
# 最大携带知识库段落长度
|
# 最大携带知识库段落长度
|
||||||
max_paragraph_char_number = serializers.IntegerField(required=True,
|
max_paragraph_char_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer(
|
||||||
label=_("Maximum length of the knowledge base paragraph"))
|
_("Maximum length of the knowledge base paragraph")))
|
||||||
# 模板
|
# 模板
|
||||||
prompt = serializers.CharField(required=True, label=_("Prompt word"))
|
prompt = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Prompt word")))
|
||||||
system = serializers.CharField(required=False, allow_null=True, allow_blank=True,
|
system = serializers.CharField(required=False, allow_null=True, allow_blank=True,
|
||||||
label=_("System prompt words (role)"))
|
error_messages=ErrMessage.char(_("System prompt words (role)")))
|
||||||
# 补齐问题
|
# 补齐问题
|
||||||
padding_problem_text = serializers.CharField(required=False,
|
padding_problem_text = serializers.CharField(required=False, error_messages=ErrMessage.char(_("Completion problem")))
|
||||||
label=_("Completion problem"))
|
|
||||||
# 未查询到引用分段
|
# 未查询到引用分段
|
||||||
no_references_setting = NoReferencesSetting(required=True,
|
no_references_setting = NoReferencesSetting(required=True, error_messages=ErrMessage.base(_("No reference segment settings")))
|
||||||
label=_("No reference segment settings"))
|
|
||||||
|
|
||||||
def get_step_serializer(self, manage: PipelineManage) -> Type[serializers.Serializer]:
|
def get_step_serializer(self, manage: PipelineManage) -> Type[serializers.Serializer]:
|
||||||
return self.InstanceSerializer
|
return self.InstanceSerializer
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineMode
|
||||||
from application.chat_pipeline.step.generate_human_message_step.i_generate_human_message_step import \
|
from application.chat_pipeline.step.generate_human_message_step.i_generate_human_message_step import \
|
||||||
IGenerateHumanMessageStep
|
IGenerateHumanMessageStep
|
||||||
from application.models import ChatRecord
|
from application.models import ChatRecord
|
||||||
from common.utils.common import flat_map
|
from common.util.split_model import flat_map
|
||||||
|
|
||||||
|
|
||||||
class BaseGenerateHumanMessageStep(IGenerateHumanMessageStep):
|
class BaseGenerateHumanMessageStep(IGenerateHumanMessageStep):
|
||||||
|
|
|
||||||
|
|
@ -16,20 +16,22 @@ from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep
|
||||||
from application.chat_pipeline.pipeline_manage import PipelineManage
|
from application.chat_pipeline.pipeline_manage import PipelineManage
|
||||||
from application.models import ChatRecord
|
from application.models import ChatRecord
|
||||||
from common.field.common import InstanceField
|
from common.field.common import InstanceField
|
||||||
|
from common.util.field_message import ErrMessage
|
||||||
|
|
||||||
|
|
||||||
class IResetProblemStep(IBaseChatPipelineStep):
|
class IResetProblemStep(IBaseChatPipelineStep):
|
||||||
class InstanceSerializer(serializers.Serializer):
|
class InstanceSerializer(serializers.Serializer):
|
||||||
# 问题文本
|
# 问题文本
|
||||||
problem_text = serializers.CharField(required=True, label=_("question"))
|
problem_text = serializers.CharField(required=True, error_messages=ErrMessage.float(_("question")))
|
||||||
# 历史对答
|
# 历史对答
|
||||||
history_chat_record = serializers.ListField(child=InstanceField(model_type=ChatRecord, required=True),
|
history_chat_record = serializers.ListField(child=InstanceField(model_type=ChatRecord, required=True),
|
||||||
label=_("History Questions"))
|
error_messages=ErrMessage.list(_("History Questions")))
|
||||||
# 大语言模型
|
# 大语言模型
|
||||||
model_id = serializers.UUIDField(required=False, allow_null=True, label=_("Model id"))
|
model_id = serializers.UUIDField(required=False, allow_null=True, error_messages=ErrMessage.uuid(_("Model id")))
|
||||||
workspace_id = serializers.CharField(required=True, label=_("User ID"))
|
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid(_("User ID")))
|
||||||
problem_optimization_prompt = serializers.CharField(required=False, max_length=102400,
|
problem_optimization_prompt = serializers.CharField(required=False, max_length=102400,
|
||||||
label=_("Question completion prompt"))
|
error_messages=ErrMessage.char(
|
||||||
|
_("Question completion prompt")))
|
||||||
|
|
||||||
def get_step_serializer(self, manage: PipelineManage) -> Type[serializers.Serializer]:
|
def get_step_serializer(self, manage: PipelineManage) -> Type[serializers.Serializer]:
|
||||||
return self.InstanceSerializer
|
return self.InstanceSerializer
|
||||||
|
|
@ -50,6 +52,6 @@ class IResetProblemStep(IBaseChatPipelineStep):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def execute(self, problem_text: str, history_chat_record: List[ChatRecord] = None, model_id: str = None,
|
def execute(self, problem_text: str, history_chat_record: List[ChatRecord] = None, model_id: str = None,
|
||||||
problem_optimization_prompt=None,
|
problem_optimization_prompt=None,
|
||||||
workspace_id=None,
|
user_id=None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
pass
|
pass
|
||||||
|
|
|
||||||
|
|
@ -13,8 +13,8 @@ from langchain.schema import HumanMessage
|
||||||
|
|
||||||
from application.chat_pipeline.step.reset_problem_step.i_reset_problem_step import IResetProblemStep
|
from application.chat_pipeline.step.reset_problem_step.i_reset_problem_step import IResetProblemStep
|
||||||
from application.models import ChatRecord
|
from application.models import ChatRecord
|
||||||
from common.utils.split_model import flat_map
|
from common.util.split_model import flat_map
|
||||||
from models_provider.tools import get_model_instance_by_model_workspace_id
|
from setting.models_provider.tools import get_model_instance_by_model_user_id
|
||||||
|
|
||||||
prompt = _(
|
prompt = _(
|
||||||
"() contains the user's question. Answer the guessed user's question based on the context ({question}) Requirement: Output a complete question and put it in the <data></data> tag")
|
"() contains the user's question. Answer the guessed user's question based on the context ({question}) Requirement: Output a complete question and put it in the <data></data> tag")
|
||||||
|
|
@ -23,9 +23,9 @@ prompt = _(
|
||||||
class BaseResetProblemStep(IResetProblemStep):
|
class BaseResetProblemStep(IResetProblemStep):
|
||||||
def execute(self, problem_text: str, history_chat_record: List[ChatRecord] = None, model_id: str = None,
|
def execute(self, problem_text: str, history_chat_record: List[ChatRecord] = None, model_id: str = None,
|
||||||
problem_optimization_prompt=None,
|
problem_optimization_prompt=None,
|
||||||
workspace_id=None,
|
user_id=None,
|
||||||
**kwargs) -> str:
|
**kwargs) -> str:
|
||||||
chat_model = get_model_instance_by_model_workspace_id(model_id, workspace_id) if model_id is not None else None
|
chat_model = get_model_instance_by_model_user_id(model_id, user_id) if model_id is not None else None
|
||||||
if chat_model is None:
|
if chat_model is None:
|
||||||
return problem_text
|
return problem_text
|
||||||
start_index = len(history_chat_record) - 3
|
start_index = len(history_chat_record) - 3
|
||||||
|
|
|
||||||
|
|
@ -16,62 +16,62 @@ from rest_framework import serializers
|
||||||
|
|
||||||
from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep, ParagraphPipelineModel
|
from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep, ParagraphPipelineModel
|
||||||
from application.chat_pipeline.pipeline_manage import PipelineManage
|
from application.chat_pipeline.pipeline_manage import PipelineManage
|
||||||
|
from common.util.field_message import ErrMessage
|
||||||
|
|
||||||
|
|
||||||
class ISearchDatasetStep(IBaseChatPipelineStep):
|
class ISearchDatasetStep(IBaseChatPipelineStep):
|
||||||
class InstanceSerializer(serializers.Serializer):
|
class InstanceSerializer(serializers.Serializer):
|
||||||
# 原始问题文本
|
# 原始问题文本
|
||||||
problem_text = serializers.CharField(required=True, label=_("question"))
|
problem_text = serializers.CharField(required=True, error_messages=ErrMessage.char(_("question")))
|
||||||
# 系统补全问题文本
|
# 系统补全问题文本
|
||||||
padding_problem_text = serializers.CharField(required=False,
|
padding_problem_text = serializers.CharField(required=False,
|
||||||
label=_("System completes question text"))
|
error_messages=ErrMessage.char(_("System completes question text")))
|
||||||
# 需要查询的数据集id列表
|
# 需要查询的数据集id列表
|
||||||
knowledge_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True),
|
dataset_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True),
|
||||||
label=_("Dataset id list"))
|
error_messages=ErrMessage.list(_("Dataset id list")))
|
||||||
# 需要排除的文档id
|
# 需要排除的文档id
|
||||||
exclude_document_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True),
|
exclude_document_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True),
|
||||||
label=_("List of document ids to exclude"))
|
error_messages=ErrMessage.list(_("List of document ids to exclude")))
|
||||||
# 需要排除向量id
|
# 需要排除向量id
|
||||||
exclude_paragraph_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True),
|
exclude_paragraph_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True),
|
||||||
label=_("List of exclusion vector ids"))
|
error_messages=ErrMessage.list(_("List of exclusion vector ids")))
|
||||||
# 需要查询的条数
|
# 需要查询的条数
|
||||||
top_n = serializers.IntegerField(required=True,
|
top_n = serializers.IntegerField(required=True,
|
||||||
label=_("Reference segment number"))
|
error_messages=ErrMessage.integer(_("Reference segment number")))
|
||||||
# 相似度 0-1之间
|
# 相似度 0-1之间
|
||||||
similarity = serializers.FloatField(required=True, max_value=1, min_value=0,
|
similarity = serializers.FloatField(required=True, max_value=1, min_value=0,
|
||||||
label=_("Similarity"))
|
error_messages=ErrMessage.float(_("Similarity")))
|
||||||
search_mode = serializers.CharField(required=True, validators=[
|
search_mode = serializers.CharField(required=True, validators=[
|
||||||
validators.RegexValidator(regex=re.compile("^embedding|keywords|blend$"),
|
validators.RegexValidator(regex=re.compile("^embedding|keywords|blend$"),
|
||||||
message=_("The type only supports embedding|keywords|blend"), code=500)
|
message=_("The type only supports embedding|keywords|blend"), code=500)
|
||||||
], label=_("Retrieval Mode"))
|
], error_messages=ErrMessage.char(_("Retrieval Mode")))
|
||||||
workspace_id = serializers.CharField(required=True, label=_("Workspace ID"))
|
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid(_("User ID")))
|
||||||
|
|
||||||
def get_step_serializer(self, manage: PipelineManage) -> Type[InstanceSerializer]:
|
def get_step_serializer(self, manage: PipelineManage) -> Type[InstanceSerializer]:
|
||||||
return self.InstanceSerializer
|
return self.InstanceSerializer
|
||||||
|
|
||||||
def _run(self, manage: PipelineManage):
|
def _run(self, manage: PipelineManage):
|
||||||
paragraph_list = self.execute(**self.context['step_args'], manage=manage)
|
paragraph_list = self.execute(**self.context['step_args'])
|
||||||
manage.context['paragraph_list'] = paragraph_list
|
manage.context['paragraph_list'] = paragraph_list
|
||||||
self.context['paragraph_list'] = paragraph_list
|
self.context['paragraph_list'] = paragraph_list
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def execute(self, problem_text: str, knowledge_id_list: list[str], exclude_document_id_list: list[str],
|
def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str],
|
||||||
exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None,
|
exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None,
|
||||||
search_mode: str = None,
|
search_mode: str = None,
|
||||||
workspace_id=None,
|
user_id=None,
|
||||||
manage: PipelineManage = None,
|
|
||||||
**kwargs) -> List[ParagraphPipelineModel]:
|
**kwargs) -> List[ParagraphPipelineModel]:
|
||||||
"""
|
"""
|
||||||
关于 用户和补全问题 说明: 补全问题如果有就使用补全问题去查询 反之就用用户原始问题查询
|
关于 用户和补全问题 说明: 补全问题如果有就使用补全问题去查询 反之就用用户原始问题查询
|
||||||
:param similarity: 相关性
|
:param similarity: 相关性
|
||||||
:param top_n: 查询多少条
|
:param top_n: 查询多少条
|
||||||
:param problem_text: 用户问题
|
:param problem_text: 用户问题
|
||||||
:param knowledge_id_list: 需要查询的数据集id列表
|
:param dataset_id_list: 需要查询的数据集id列表
|
||||||
:param exclude_document_id_list: 需要排除的文档id
|
:param exclude_document_id_list: 需要排除的文档id
|
||||||
:param exclude_paragraph_id_list: 需要排除段落id
|
:param exclude_paragraph_id_list: 需要排除段落id
|
||||||
:param padding_problem_text 补全问题
|
:param padding_problem_text 补全问题
|
||||||
:param search_mode 检索模式
|
:param search_mode 检索模式
|
||||||
:param workspace_id 工作空间id
|
:param user_id 用户id
|
||||||
:return: 段落列表
|
:return: 段落列表
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
|
||||||
|
|
@ -16,59 +16,51 @@ from rest_framework.utils.formatting import lazy_format
|
||||||
from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel
|
from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel
|
||||||
from application.chat_pipeline.step.search_dataset_step.i_search_dataset_step import ISearchDatasetStep
|
from application.chat_pipeline.step.search_dataset_step.i_search_dataset_step import ISearchDatasetStep
|
||||||
from common.config.embedding_config import VectorStore, ModelManage
|
from common.config.embedding_config import VectorStore, ModelManage
|
||||||
from common.constants.permission_constants import RoleConstants
|
|
||||||
from common.database_model_manage.database_model_manage import DatabaseModelManage
|
|
||||||
from common.db.search import native_search
|
from common.db.search import native_search
|
||||||
from common.utils.common import get_file_content
|
from common.util.file_util import get_file_content
|
||||||
from knowledge.models import Paragraph, Knowledge
|
from dataset.models import Paragraph, DataSet
|
||||||
from knowledge.models import SearchMode
|
from embedding.models import SearchMode
|
||||||
from maxkb.conf import PROJECT_DIR
|
from setting.models import Model
|
||||||
from models_provider.models import Model
|
from setting.models_provider import get_model
|
||||||
from models_provider.tools import get_model, get_model_by_id, get_model_default_params
|
from smartdoc.conf import PROJECT_DIR
|
||||||
|
|
||||||
|
|
||||||
def reset_meta(meta):
|
def get_model_by_id(_id, user_id):
|
||||||
if not meta.get('allow_download', False):
|
model = QuerySet(Model).filter(id=_id).first()
|
||||||
return {'allow_download': False}
|
if model is None:
|
||||||
return meta
|
raise Exception(_("Model does not exist"))
|
||||||
|
if model.permission_type == 'PRIVATE' and str(model.user_id) != str(user_id):
|
||||||
|
message = lazy_format(_('No permission to use this model {model_name}'), model_name=model.name)
|
||||||
|
raise Exception(message)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
def get_embedding_id(knowledge_id_list):
|
def get_embedding_id(dataset_id_list):
|
||||||
knowledge_list = QuerySet(Knowledge).filter(id__in=knowledge_id_list)
|
dataset_list = QuerySet(DataSet).filter(id__in=dataset_id_list)
|
||||||
if len(set([knowledge.embedding_model_id for knowledge in knowledge_list])) > 1:
|
if len(set([dataset.embedding_mode_id for dataset in dataset_list])) > 1:
|
||||||
raise Exception(
|
raise Exception(_("The vector model of the associated knowledge base is inconsistent and the segmentation cannot be recalled."))
|
||||||
_("The vector model of the associated knowledge base is inconsistent and the segmentation cannot be recalled."))
|
if len(dataset_list) == 0:
|
||||||
if len(knowledge_list) == 0:
|
|
||||||
raise Exception(_("The knowledge base setting is wrong, please reset the knowledge base"))
|
raise Exception(_("The knowledge base setting is wrong, please reset the knowledge base"))
|
||||||
return knowledge_list[0].embedding_model_id
|
return dataset_list[0].embedding_mode_id
|
||||||
|
|
||||||
|
|
||||||
class BaseSearchDatasetStep(ISearchDatasetStep):
|
class BaseSearchDatasetStep(ISearchDatasetStep):
|
||||||
|
|
||||||
def execute(self, problem_text: str, knowledge_id_list: list[str], exclude_document_id_list: list[str],
|
def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str],
|
||||||
exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None,
|
exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None,
|
||||||
search_mode: str = None,
|
search_mode: str = None,
|
||||||
workspace_id=None,
|
user_id=None,
|
||||||
manage=None,
|
|
||||||
**kwargs) -> List[ParagraphPipelineModel]:
|
**kwargs) -> List[ParagraphPipelineModel]:
|
||||||
get_knowledge_list_of_authorized = DatabaseModelManage.get_model('get_knowledge_list_of_authorized')
|
if len(dataset_id_list) == 0:
|
||||||
chat_user_type = manage.context.get('chat_user_type')
|
|
||||||
if get_knowledge_list_of_authorized is not None and RoleConstants.CHAT_USER.value.name == chat_user_type:
|
|
||||||
knowledge_id_list = get_knowledge_list_of_authorized(manage.context.get('chat_user_id'),
|
|
||||||
knowledge_id_list)
|
|
||||||
if len(knowledge_id_list) == 0:
|
|
||||||
return []
|
return []
|
||||||
exec_problem_text = padding_problem_text if padding_problem_text is not None else problem_text
|
exec_problem_text = padding_problem_text if padding_problem_text is not None else problem_text
|
||||||
model_id = get_embedding_id(knowledge_id_list)
|
model_id = get_embedding_id(dataset_id_list)
|
||||||
model = get_model_by_id(model_id, workspace_id)
|
model = get_model_by_id(model_id, user_id)
|
||||||
if model.model_type != "EMBEDDING":
|
|
||||||
raise Exception(_("Model does not exist"))
|
|
||||||
self.context['model_name'] = model.name
|
self.context['model_name'] = model.name
|
||||||
default_params = get_model_default_params(model)
|
embedding_model = ModelManage.get_model(model_id, lambda _id: get_model(model))
|
||||||
embedding_model = ModelManage.get_model(model_id, lambda _id: get_model(model, **{**default_params}))
|
|
||||||
embedding_value = embedding_model.embed_query(exec_problem_text)
|
embedding_value = embedding_model.embed_query(exec_problem_text)
|
||||||
vector = VectorStore.get_embedding_vector()
|
vector = VectorStore.get_embedding_vector()
|
||||||
embedding_list = vector.query(exec_problem_text, embedding_value, knowledge_id_list, None, exclude_document_id_list,
|
embedding_list = vector.query(exec_problem_text, embedding_value, dataset_id_list, exclude_document_id_list,
|
||||||
exclude_paragraph_id_list, True, top_n, similarity, SearchMode(search_mode))
|
exclude_paragraph_id_list, True, top_n, similarity, SearchMode(search_mode))
|
||||||
if embedding_list is None:
|
if embedding_list is None:
|
||||||
return []
|
return []
|
||||||
|
|
@ -86,12 +78,11 @@ class BaseSearchDatasetStep(ISearchDatasetStep):
|
||||||
.add_paragraph(paragraph)
|
.add_paragraph(paragraph)
|
||||||
.add_similarity(find_embedding.get('similarity'))
|
.add_similarity(find_embedding.get('similarity'))
|
||||||
.add_comprehensive_score(find_embedding.get('comprehensive_score'))
|
.add_comprehensive_score(find_embedding.get('comprehensive_score'))
|
||||||
.add_knowledge_name(paragraph.get('knowledge_name'))
|
.add_dataset_name(paragraph.get('dataset_name'))
|
||||||
.add_knowledge_type(paragraph.get('knowledge_type'))
|
|
||||||
.add_document_name(paragraph.get('document_name'))
|
.add_document_name(paragraph.get('document_name'))
|
||||||
.add_hit_handling_method(paragraph.get('hit_handling_method'))
|
.add_hit_handling_method(paragraph.get('hit_handling_method'))
|
||||||
.add_directly_return_similarity(paragraph.get('directly_return_similarity'))
|
.add_directly_return_similarity(paragraph.get('directly_return_similarity'))
|
||||||
.add_meta(reset_meta(paragraph.get('meta')))
|
.add_meta(paragraph.get('meta'))
|
||||||
.build())
|
.build())
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
@ -111,7 +102,7 @@ class BaseSearchDatasetStep(ISearchDatasetStep):
|
||||||
paragraph_list = native_search(QuerySet(Paragraph).filter(id__in=paragraph_id_list),
|
paragraph_list = native_search(QuerySet(Paragraph).filter(id__in=paragraph_id_list),
|
||||||
get_file_content(
|
get_file_content(
|
||||||
os.path.join(PROJECT_DIR, "apps", "application", 'sql',
|
os.path.join(PROJECT_DIR, "apps", "application", 'sql',
|
||||||
'list_knowledge_paragraph_by_paragraph_id.sql')),
|
'list_dataset_paragraph_by_paragraph_id.sql')),
|
||||||
with_table_name=True)
|
with_table_name=True)
|
||||||
# 如果向量库中存在脏数据 直接删除
|
# 如果向量库中存在脏数据 直接删除
|
||||||
if len(paragraph_list) != len(paragraph_id_list):
|
if len(paragraph_list) != len(paragraph_id_list):
|
||||||
|
|
|
||||||
|
|
@ -6,22 +6,6 @@
|
||||||
@date:2024/12/11 17:57
|
@date:2024/12/11 17:57
|
||||||
@desc:
|
@desc:
|
||||||
"""
|
"""
|
||||||
from enum import Enum
|
|
||||||
from typing import List, Dict
|
|
||||||
|
|
||||||
from django.db.models import QuerySet
|
|
||||||
from django.utils.translation import gettext as _
|
|
||||||
from rest_framework.exceptions import ErrorDetail, ValidationError
|
|
||||||
|
|
||||||
from common.exception.app_exception import AppApiException
|
|
||||||
from common.utils.common import group_by
|
|
||||||
from models_provider.models import Model
|
|
||||||
from models_provider.tools import get_model_credential
|
|
||||||
from tools.models.tool import Tool
|
|
||||||
|
|
||||||
end_nodes = ['ai-chat-node', 'reply-node', 'function-node', 'function-lib-node', 'application-node',
|
|
||||||
'image-understand-node', 'speech-to-text-node', 'text-to-speech-node', 'image-generate-node',
|
|
||||||
'variable-assign-node']
|
|
||||||
|
|
||||||
|
|
||||||
class Answer:
|
class Answer:
|
||||||
|
|
@ -58,220 +42,3 @@ class NodeChunk:
|
||||||
|
|
||||||
def is_end(self):
|
def is_end(self):
|
||||||
return self.status == 200
|
return self.status == 200
|
||||||
|
|
||||||
|
|
||||||
class Edge:
|
|
||||||
def __init__(self, _id: str, _type: str, sourceNodeId: str, targetNodeId: str, **keywords):
|
|
||||||
self.id = _id
|
|
||||||
self.type = _type
|
|
||||||
self.sourceNodeId = sourceNodeId
|
|
||||||
self.targetNodeId = targetNodeId
|
|
||||||
for keyword in keywords:
|
|
||||||
self.__setattr__(keyword, keywords.get(keyword))
|
|
||||||
|
|
||||||
|
|
||||||
class Node:
|
|
||||||
def __init__(self, _id: str, _type: str, x: int, y: int, properties: dict, **kwargs):
|
|
||||||
self.id = _id
|
|
||||||
self.type = _type
|
|
||||||
self.x = x
|
|
||||||
self.y = y
|
|
||||||
self.properties = properties
|
|
||||||
for keyword in kwargs:
|
|
||||||
self.__setattr__(keyword, kwargs.get(keyword))
|
|
||||||
|
|
||||||
|
|
||||||
class EdgeNode:
|
|
||||||
edge: Edge
|
|
||||||
node: Node
|
|
||||||
|
|
||||||
def __init__(self, edge, node):
|
|
||||||
self.edge = edge
|
|
||||||
self.node = node
|
|
||||||
|
|
||||||
|
|
||||||
class WorkflowMode(Enum):
|
|
||||||
APPLICATION = "application"
|
|
||||||
|
|
||||||
APPLICATION_LOOP = "application-loop"
|
|
||||||
|
|
||||||
KNOWLEDGE = "knowledge"
|
|
||||||
|
|
||||||
KNOWLEDGE_LOOP = "knowledge-loop"
|
|
||||||
|
|
||||||
|
|
||||||
class Workflow:
|
|
||||||
"""
|
|
||||||
节点列表
|
|
||||||
"""
|
|
||||||
nodes: List[Node]
|
|
||||||
"""
|
|
||||||
线列表
|
|
||||||
"""
|
|
||||||
edges: List[Edge]
|
|
||||||
"""
|
|
||||||
节点id:node
|
|
||||||
"""
|
|
||||||
node_map: Dict[str, Node]
|
|
||||||
"""
|
|
||||||
节点id:当前节点id上面的所有节点
|
|
||||||
"""
|
|
||||||
up_node_map: Dict[str, List[EdgeNode]]
|
|
||||||
"""
|
|
||||||
节点id:当前节点id下面的所有节点
|
|
||||||
"""
|
|
||||||
next_node_map: Dict[str, List[EdgeNode]]
|
|
||||||
|
|
||||||
workflow_mode: WorkflowMode
|
|
||||||
|
|
||||||
def __init__(self, nodes: List[Node], edges: List[Edge],
|
|
||||||
workflow_mode: WorkflowMode = WorkflowMode.APPLICATION.value):
|
|
||||||
self.nodes = nodes
|
|
||||||
self.edges = edges
|
|
||||||
self.node_map = {node.id: node for node in nodes}
|
|
||||||
|
|
||||||
self.up_node_map = {key: [EdgeNode(edge, self.node_map.get(edge.sourceNodeId)) for
|
|
||||||
edge in edges] for
|
|
||||||
key, edges in
|
|
||||||
group_by(edges, key=lambda edge: edge.targetNodeId).items()}
|
|
||||||
|
|
||||||
self.next_node_map = {key: [EdgeNode(edge, self.node_map.get(edge.targetNodeId)) for edge in edges] for
|
|
||||||
key, edges in
|
|
||||||
group_by(edges, key=lambda edge: edge.sourceNodeId).items()}
|
|
||||||
self.workflow_mode = workflow_mode
|
|
||||||
|
|
||||||
def get_node(self, node_id):
|
|
||||||
"""
|
|
||||||
根据node_id 获取节点信息
|
|
||||||
@param node_id: node_id
|
|
||||||
@return: 节点信息
|
|
||||||
"""
|
|
||||||
return self.node_map.get(node_id)
|
|
||||||
|
|
||||||
def get_up_edge_nodes(self, node_id) -> List[EdgeNode]:
|
|
||||||
"""
|
|
||||||
根据节点id 获取当前连接前置节点和连线
|
|
||||||
@param node_id: 节点id
|
|
||||||
@return: 节点连线列表
|
|
||||||
"""
|
|
||||||
return self.up_node_map.get(node_id)
|
|
||||||
|
|
||||||
def get_next_edge_nodes(self, node_id) -> List[EdgeNode]:
|
|
||||||
"""
|
|
||||||
根据节点id 获取当前连接目标节点和连线
|
|
||||||
@param node_id: 节点id
|
|
||||||
@return: 节点连线列表
|
|
||||||
"""
|
|
||||||
return self.next_node_map.get(node_id)
|
|
||||||
|
|
||||||
def get_up_nodes(self, node_id) -> List[Node]:
|
|
||||||
"""
|
|
||||||
根据节点id 获取当前连接前置节点
|
|
||||||
@param node_id: 节点id
|
|
||||||
@return: 节点列表
|
|
||||||
"""
|
|
||||||
return [en.node for en in self.up_node_map.get(node_id)]
|
|
||||||
|
|
||||||
def get_next_nodes(self, node_id) -> List[Node]:
|
|
||||||
"""
|
|
||||||
根据节点id 获取当前连接目标节点
|
|
||||||
@param node_id: 节点id
|
|
||||||
@return: 节点列表
|
|
||||||
"""
|
|
||||||
return [en.node for en in self.next_node_map.get(node_id, [])]
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def new_instance(flow_obj: Dict, workflow_mode: WorkflowMode = WorkflowMode.APPLICATION):
|
|
||||||
nodes = flow_obj.get('nodes')
|
|
||||||
edges = flow_obj.get('edges')
|
|
||||||
nodes = [Node(node.get('id'), node.get('type'), **node)
|
|
||||||
for node in nodes]
|
|
||||||
edges = [Edge(edge.get('id'), edge.get('type'), **edge) for edge in edges]
|
|
||||||
return Workflow(nodes, edges, workflow_mode)
|
|
||||||
|
|
||||||
def get_start_node(self):
|
|
||||||
return self.get_node('start-node')
|
|
||||||
|
|
||||||
def get_search_node(self):
|
|
||||||
return [node for node in self.nodes if node.type == 'search-dataset-node']
|
|
||||||
|
|
||||||
def is_valid(self):
|
|
||||||
"""
|
|
||||||
校验工作流数据
|
|
||||||
"""
|
|
||||||
self.is_valid_model_params()
|
|
||||||
self.is_valid_start_node()
|
|
||||||
self.is_valid_base_node()
|
|
||||||
self.is_valid_work_flow()
|
|
||||||
|
|
||||||
def is_valid_node_params(self, node: Node):
|
|
||||||
from application.flow.step_node import get_node
|
|
||||||
get_node(node.type, self.workflow_mode)(node, None, None)
|
|
||||||
|
|
||||||
def is_valid_node(self, node: Node):
|
|
||||||
self.is_valid_node_params(node)
|
|
||||||
if node.type == 'condition-node':
|
|
||||||
branch_list = node.properties.get('node_data').get('branch')
|
|
||||||
for branch in branch_list:
|
|
||||||
source_anchor_id = f"{node.id}_{branch.get('id')}_right"
|
|
||||||
edge_list = [edge for edge in self.edges if edge.sourceAnchorId == source_anchor_id]
|
|
||||||
if len(edge_list) == 0:
|
|
||||||
raise AppApiException(500,
|
|
||||||
_('The branch {branch} of the {node} node needs to be connected').format(
|
|
||||||
node=node.properties.get("stepName"), branch=branch.get("type")))
|
|
||||||
|
|
||||||
else:
|
|
||||||
edge_list = [edge for edge in self.edges if edge.sourceNodeId == node.id]
|
|
||||||
if len(edge_list) == 0 and not end_nodes.__contains__(node.type):
|
|
||||||
raise AppApiException(500, _("{node} Nodes cannot be considered as end nodes").format(
|
|
||||||
node=node.properties.get("stepName")))
|
|
||||||
|
|
||||||
def is_valid_work_flow(self, up_node=None):
|
|
||||||
if up_node is None:
|
|
||||||
up_node = self.get_start_node()
|
|
||||||
self.is_valid_node(up_node)
|
|
||||||
next_nodes = self.get_next_nodes(up_node)
|
|
||||||
for next_node in next_nodes:
|
|
||||||
self.is_valid_work_flow(next_node)
|
|
||||||
|
|
||||||
def is_valid_start_node(self):
|
|
||||||
start_node_list = [node for node in self.nodes if node.id == 'start-node']
|
|
||||||
if len(start_node_list) == 0:
|
|
||||||
raise AppApiException(500, _('The starting node is required'))
|
|
||||||
if len(start_node_list) > 1:
|
|
||||||
raise AppApiException(500, _('There can only be one starting node'))
|
|
||||||
|
|
||||||
def is_valid_model_params(self):
|
|
||||||
node_list = [node for node in self.nodes if (node.type == 'ai-chat-node' or node.type == 'question-node')]
|
|
||||||
for node in node_list:
|
|
||||||
model = QuerySet(Model).filter(id=node.properties.get('node_data', {}).get('model_id')).first()
|
|
||||||
if model is None:
|
|
||||||
raise ValidationError(ErrorDetail(
|
|
||||||
_('The node {node} model does not exist').format(node=node.properties.get("stepName"))))
|
|
||||||
credential = get_model_credential(model.provider, model.model_type, model.model_name)
|
|
||||||
model_params_setting = node.properties.get('node_data', {}).get('model_params_setting')
|
|
||||||
model_params_setting_form = credential.get_model_params_setting_form(
|
|
||||||
model.model_name)
|
|
||||||
if model_params_setting is None:
|
|
||||||
model_params_setting = model_params_setting_form.get_default_form_data()
|
|
||||||
node.properties.get('node_data', {})['model_params_setting'] = model_params_setting
|
|
||||||
if node.properties.get('status', 200) != 200:
|
|
||||||
raise ValidationError(
|
|
||||||
ErrorDetail(_("Node {node} is unavailable").format(node.properties.get("stepName"))))
|
|
||||||
node_list = [node for node in self.nodes if (node.type == 'function-lib-node')]
|
|
||||||
for node in node_list:
|
|
||||||
function_lib_id = node.properties.get('node_data', {}).get('function_lib_id')
|
|
||||||
if function_lib_id is None:
|
|
||||||
raise ValidationError(ErrorDetail(
|
|
||||||
_('The library ID of node {node} cannot be empty').format(node=node.properties.get("stepName"))))
|
|
||||||
f_lib = QuerySet(Tool).filter(id=function_lib_id).first()
|
|
||||||
if f_lib is None:
|
|
||||||
raise ValidationError(ErrorDetail(_("The function library for node {node} is not available").format(
|
|
||||||
node=node.properties.get("stepName"))))
|
|
||||||
|
|
||||||
def is_valid_base_node(self):
|
|
||||||
base_node_list = [node for node in self.nodes if node.id == 'base-node']
|
|
||||||
if len(base_node_list) == 0:
|
|
||||||
raise AppApiException(500, _('Basic information node is required'))
|
|
||||||
if len(base_node_list) > 1:
|
|
||||||
raise AppApiException(500, _('There can only be one basic information node'))
|
|
||||||
|
|
|
||||||
|
|
@ -1,22 +0,0 @@
|
||||||
# coding=utf-8
|
|
||||||
"""
|
|
||||||
@project: MaxKB
|
|
||||||
@Author:虎虎
|
|
||||||
@file: start_with.py
|
|
||||||
@date:2025/10/20 10:37
|
|
||||||
@desc:
|
|
||||||
"""
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from application.flow.compare import Compare
|
|
||||||
|
|
||||||
|
|
||||||
class EndWithCompare(Compare):
|
|
||||||
|
|
||||||
def support(self, node_id, fields: List[str], source_value, compare, target_value):
|
|
||||||
if compare == 'end_with':
|
|
||||||
return True
|
|
||||||
|
|
||||||
def compare(self, source_value, compare, target_value):
|
|
||||||
source_value = str(source_value)
|
|
||||||
return source_value.endswith(str(target_value))
|
|
||||||
|
|
@ -1,22 +0,0 @@
|
||||||
# coding=utf-8
|
|
||||||
"""
|
|
||||||
@project: MaxKB
|
|
||||||
@Author:虎虎
|
|
||||||
@file: start_with.py
|
|
||||||
@date:2025/10/20 10:37
|
|
||||||
@desc:
|
|
||||||
"""
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from application.flow.compare import Compare
|
|
||||||
|
|
||||||
|
|
||||||
class StartWithCompare(Compare):
|
|
||||||
|
|
||||||
def support(self, node_id, fields: List[str], source_value, compare, target_value):
|
|
||||||
if compare == 'start_with':
|
|
||||||
return True
|
|
||||||
|
|
||||||
def compare(self, source_value, compare, target_value):
|
|
||||||
source_value = str(source_value)
|
|
||||||
return source_value.startswith(str(target_value))
|
|
||||||
|
|
@ -18,12 +18,13 @@ from rest_framework import serializers
|
||||||
from rest_framework.exceptions import ValidationError, ErrorDetail
|
from rest_framework.exceptions import ValidationError, ErrorDetail
|
||||||
|
|
||||||
from application.flow.common import Answer, NodeChunk
|
from application.flow.common import Answer, NodeChunk
|
||||||
from application.models import ApplicationChatUserStats
|
from application.models import ChatRecord
|
||||||
from application.models import ChatRecord, ChatUserType
|
from application.models.api_key_model import ApplicationPublicAccessClient
|
||||||
|
from common.constants.authentication_type import AuthenticationType
|
||||||
from common.field.common import InstanceField
|
from common.field.common import InstanceField
|
||||||
from knowledge.models.knowledge_action import KnowledgeAction, State
|
from common.util.field_message import ErrMessage
|
||||||
|
|
||||||
chat_cache = cache
|
chat_cache = cache.caches['chat_cache']
|
||||||
|
|
||||||
|
|
||||||
def write_context(step_variable: Dict, global_variable: Dict, node, workflow):
|
def write_context(step_variable: Dict, global_variable: Dict, node, workflow):
|
||||||
|
|
@ -45,14 +46,16 @@ def is_interrupt(node, step_variable: Dict, global_variable: Dict):
|
||||||
|
|
||||||
|
|
||||||
class WorkFlowPostHandler:
|
class WorkFlowPostHandler:
|
||||||
def __init__(self, chat_info):
|
def __init__(self, chat_info, client_id, client_type):
|
||||||
self.chat_info = chat_info
|
self.chat_info = chat_info
|
||||||
|
self.client_id = client_id
|
||||||
|
self.client_type = client_type
|
||||||
|
|
||||||
def handler(self, workflow):
|
def handler(self, chat_id,
|
||||||
workflow_body = workflow.get_body()
|
chat_record_id,
|
||||||
question = workflow_body.get('question')
|
answer,
|
||||||
chat_record_id = workflow_body.get('chat_record_id')
|
workflow):
|
||||||
chat_id = workflow_body.get('chat_id')
|
question = workflow.params['question']
|
||||||
details = workflow.get_runtime_details()
|
details = workflow.get_runtime_details()
|
||||||
message_tokens = sum([row.get('message_tokens') for row in details.values() if
|
message_tokens = sum([row.get('message_tokens') for row in details.values() if
|
||||||
'message_tokens' in row and row.get('message_tokens') is not None])
|
'message_tokens' in row and row.get('message_tokens') is not None])
|
||||||
|
|
@ -79,60 +82,21 @@ class WorkFlowPostHandler:
|
||||||
message_tokens=message_tokens,
|
message_tokens=message_tokens,
|
||||||
answer_tokens=answer_tokens,
|
answer_tokens=answer_tokens,
|
||||||
answer_text_list=answer_text_list,
|
answer_text_list=answer_text_list,
|
||||||
run_time=time.time() - workflow.context.get('start_time') if workflow.context.get(
|
run_time=time.time() - workflow.context['start_time'],
|
||||||
'start_time') is not None else 0,
|
|
||||||
index=0)
|
index=0)
|
||||||
|
asker = workflow.context.get('asker', None)
|
||||||
self.chat_info.append_chat_record(chat_record)
|
self.chat_info.append_chat_record(chat_record, self.client_id, asker)
|
||||||
self.chat_info.set_cache()
|
# 重新设置缓存
|
||||||
|
chat_cache.set(chat_id,
|
||||||
if not self.chat_info.debug and [ChatUserType.ANONYMOUS_USER.value, ChatUserType.CHAT_USER.value].__contains__(
|
self.chat_info, timeout=60 * 30)
|
||||||
workflow_body.get('chat_user_type')):
|
if self.client_type == AuthenticationType.APPLICATION_ACCESS_TOKEN.value:
|
||||||
application_public_access_client = (QuerySet(ApplicationChatUserStats)
|
application_public_access_client = (QuerySet(ApplicationPublicAccessClient)
|
||||||
.filter(chat_user_id=workflow_body.get('chat_user_id'),
|
.filter(client_id=self.client_id,
|
||||||
chat_user_type=workflow_body.get('chat_user_type'),
|
application_id=self.chat_info.application.id).first())
|
||||||
application_id=self.chat_info.application_id).first())
|
|
||||||
if application_public_access_client is not None:
|
if application_public_access_client is not None:
|
||||||
application_public_access_client.access_num = application_public_access_client.access_num + 1
|
application_public_access_client.access_num = application_public_access_client.access_num + 1
|
||||||
application_public_access_client.intraday_access_num = application_public_access_client.intraday_access_num + 1
|
application_public_access_client.intraday_access_num = application_public_access_client.intraday_access_num + 1
|
||||||
application_public_access_client.save()
|
application_public_access_client.save()
|
||||||
self.chat_info = None
|
|
||||||
|
|
||||||
|
|
||||||
class KnowledgeWorkflowPostHandler(WorkFlowPostHandler):
|
|
||||||
def __init__(self, chat_info, knowledge_action_id):
|
|
||||||
super().__init__(chat_info)
|
|
||||||
self.knowledge_action_id = knowledge_action_id
|
|
||||||
|
|
||||||
def handler(self, workflow):
|
|
||||||
state = get_workflow_state(workflow)
|
|
||||||
QuerySet(KnowledgeAction).filter(id=self.knowledge_action_id).update(
|
|
||||||
state=state,
|
|
||||||
run_time=time.time() - workflow.context.get('start_time') if workflow.context.get(
|
|
||||||
'start_time') is not None else 0)
|
|
||||||
|
|
||||||
|
|
||||||
def get_loop_workflow_node(node_list):
|
|
||||||
result = []
|
|
||||||
for item in node_list:
|
|
||||||
if item.get('type') == 'loop-node':
|
|
||||||
for loop_item in item.get('loop_node_data') or []:
|
|
||||||
for inner_item in loop_item.values():
|
|
||||||
result.append(inner_item)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def get_workflow_state(workflow):
|
|
||||||
details = workflow.get_runtime_details()
|
|
||||||
node_list = details.values()
|
|
||||||
all_node = [*node_list, *get_loop_workflow_node(node_list)]
|
|
||||||
err = any([True for value in all_node if value.get('status') == 500])
|
|
||||||
if err:
|
|
||||||
return State.FAILURE
|
|
||||||
write_is_exist = any([True for value in all_node if value.get('type') == 'knowledge-write-node'])
|
|
||||||
if not write_is_exist:
|
|
||||||
return State.FAILURE
|
|
||||||
return State.SUCCESS
|
|
||||||
|
|
||||||
|
|
||||||
class NodeResult:
|
class NodeResult:
|
||||||
|
|
@ -159,44 +123,31 @@ class NodeResult:
|
||||||
|
|
||||||
|
|
||||||
class ReferenceAddressSerializer(serializers.Serializer):
|
class ReferenceAddressSerializer(serializers.Serializer):
|
||||||
node_id = serializers.CharField(required=True, label="节点id")
|
node_id = serializers.CharField(required=True, error_messages=ErrMessage.char("节点id"))
|
||||||
fields = serializers.ListField(
|
fields = serializers.ListField(
|
||||||
child=serializers.CharField(required=True, label="节点字段"), required=True,
|
child=serializers.CharField(required=True, error_messages=ErrMessage.char("节点字段")), required=True,
|
||||||
label="节点字段数组")
|
error_messages=ErrMessage.list("节点字段数组"))
|
||||||
|
|
||||||
|
|
||||||
class FlowParamsSerializer(serializers.Serializer):
|
class FlowParamsSerializer(serializers.Serializer):
|
||||||
# 历史对答
|
# 历史对答
|
||||||
history_chat_record = serializers.ListField(child=InstanceField(model_type=ChatRecord, required=True),
|
history_chat_record = serializers.ListField(child=InstanceField(model_type=ChatRecord, required=True),
|
||||||
label="历史对答")
|
error_messages=ErrMessage.list("历史对答"))
|
||||||
|
|
||||||
question = serializers.CharField(required=True, label="用户问题")
|
question = serializers.CharField(required=True, error_messages=ErrMessage.list("用户问题"))
|
||||||
|
|
||||||
chat_id = serializers.CharField(required=True, label="对话id")
|
chat_id = serializers.CharField(required=True, error_messages=ErrMessage.list("对话id"))
|
||||||
|
|
||||||
chat_record_id = serializers.CharField(required=True, label="对话记录id")
|
chat_record_id = serializers.CharField(required=True, error_messages=ErrMessage.char("对话记录id"))
|
||||||
|
|
||||||
stream = serializers.BooleanField(required=True, label="流式输出")
|
stream = serializers.BooleanField(required=True, error_messages=ErrMessage.boolean("流式输出"))
|
||||||
|
|
||||||
chat_user_id = serializers.CharField(required=False, label="对话用户id")
|
client_id = serializers.CharField(required=False, error_messages=ErrMessage.char("客户端id"))
|
||||||
|
|
||||||
chat_user_type = serializers.CharField(required=False, label="对话用户类型")
|
client_type = serializers.CharField(required=False, error_messages=ErrMessage.char("客户端类型"))
|
||||||
|
|
||||||
workspace_id = serializers.CharField(required=True, label="工作空间id")
|
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))
|
||||||
|
re_chat = serializers.BooleanField(required=True, error_messages=ErrMessage.boolean("换个答案"))
|
||||||
application_id = serializers.CharField(required=True, label="应用id")
|
|
||||||
|
|
||||||
re_chat = serializers.BooleanField(required=True, label="换个答案")
|
|
||||||
|
|
||||||
debug = serializers.BooleanField(required=True, label="是否debug")
|
|
||||||
|
|
||||||
|
|
||||||
class KnowledgeFlowParamsSerializer(serializers.Serializer):
|
|
||||||
knowledge_id = serializers.UUIDField(required=True, label="知识库id")
|
|
||||||
workspace_id = serializers.CharField(required=True, label="工作空间id")
|
|
||||||
knowledge_action_id = serializers.UUIDField(required=True, label="知识库任务执行器id")
|
|
||||||
data_source = serializers.DictField(required=True, label="数据源")
|
|
||||||
knowledge_base = serializers.DictField(required=False, label="知识库设置")
|
|
||||||
|
|
||||||
|
|
||||||
class INode:
|
class INode:
|
||||||
|
|
@ -211,12 +162,11 @@ class INode:
|
||||||
return None
|
return None
|
||||||
reasoning_content_enable = self.context.get('model_setting', {}).get('reasoning_content_enable', False)
|
reasoning_content_enable = self.context.get('model_setting', {}).get('reasoning_content_enable', False)
|
||||||
return [
|
return [
|
||||||
Answer(self.answer_text, self.view_type, self.runtime_node_id, self.workflow_params.get('chat_record_id'),
|
Answer(self.answer_text, self.view_type, self.runtime_node_id, self.workflow_params['chat_record_id'], {},
|
||||||
{},
|
|
||||||
self.runtime_node_id, self.context.get('reasoning_content', '') if reasoning_content_enable else '')]
|
self.runtime_node_id, self.context.get('reasoning_content', '') if reasoning_content_enable else '')]
|
||||||
|
|
||||||
def __init__(self, node, workflow_params, workflow_manage, up_node_id_list=None,
|
def __init__(self, node, workflow_params, workflow_manage, up_node_id_list=None,
|
||||||
get_node_params=lambda node: node.properties.get('node_data'), salt=None):
|
get_node_params=lambda node: node.properties.get('node_data')):
|
||||||
# 当前步骤上下文,用于存储当前步骤信息
|
# 当前步骤上下文,用于存储当前步骤信息
|
||||||
self.status = 200
|
self.status = 200
|
||||||
self.err_message = ''
|
self.err_message = ''
|
||||||
|
|
@ -236,8 +186,7 @@ class INode:
|
||||||
self.runtime_node_id = sha1(uuid.NAMESPACE_DNS.bytes + bytes(str(uuid.uuid5(uuid.NAMESPACE_DNS,
|
self.runtime_node_id = sha1(uuid.NAMESPACE_DNS.bytes + bytes(str(uuid.uuid5(uuid.NAMESPACE_DNS,
|
||||||
"".join([*sorted(up_node_id_list),
|
"".join([*sorted(up_node_id_list),
|
||||||
node.id]))),
|
node.id]))),
|
||||||
"utf-8")).hexdigest() + (
|
"utf-8")).hexdigest()
|
||||||
"__" + str(salt) if salt is not None else '')
|
|
||||||
|
|
||||||
def valid_args(self, node_params, flow_params):
|
def valid_args(self, node_params, flow_params):
|
||||||
flow_params_serializer_class = self.get_flow_params_serializer_class()
|
flow_params_serializer_class = self.get_flow_params_serializer_class()
|
||||||
|
|
@ -269,14 +218,13 @@ class INode:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_flow_params_serializer_class(self) -> Type[serializers.Serializer]:
|
def get_flow_params_serializer_class(self) -> Type[serializers.Serializer]:
|
||||||
return self.workflow_manage.get_params_serializer_class()
|
return FlowParamsSerializer
|
||||||
|
|
||||||
def get_write_error_context(self, e):
|
def get_write_error_context(self, e):
|
||||||
self.status = 500
|
self.status = 500
|
||||||
self.answer_text = str(e)
|
self.answer_text = str(e)
|
||||||
self.err_message = str(e)
|
self.err_message = str(e)
|
||||||
current_time = time.time()
|
self.context['run_time'] = time.time() - self.context['start_time']
|
||||||
self.context['run_time'] = current_time - (self.context.get('start_time') or current_time)
|
|
||||||
|
|
||||||
def write_error_context(answer, status=200):
|
def write_error_context(answer, status=200):
|
||||||
pass
|
pass
|
||||||
|
|
|
||||||
|
|
@ -1,15 +0,0 @@
|
||||||
# coding=utf-8
|
|
||||||
"""
|
|
||||||
@project: maxkb
|
|
||||||
@Author:虎
|
|
||||||
@file: workflow_manage.py
|
|
||||||
@date:2024/1/9 17:40
|
|
||||||
@desc:
|
|
||||||
"""
|
|
||||||
from application.flow.i_step_node import KnowledgeFlowParamsSerializer
|
|
||||||
from application.flow.loop_workflow_manage import LoopWorkflowManage
|
|
||||||
|
|
||||||
|
|
||||||
class KnowledgeLoopWorkflowManage(LoopWorkflowManage):
|
|
||||||
def get_params_serializer_class(self):
|
|
||||||
return KnowledgeFlowParamsSerializer
|
|
||||||
|
|
@ -1,104 +0,0 @@
|
||||||
# coding=utf-8
|
|
||||||
"""
|
|
||||||
@project: MaxKB
|
|
||||||
@Author:虎虎
|
|
||||||
@file: Knowledge_workflow_manage.py
|
|
||||||
@date:2025/11/13 19:02
|
|
||||||
@desc:
|
|
||||||
"""
|
|
||||||
import time
|
|
||||||
import traceback
|
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
|
|
||||||
from django.db.models import QuerySet
|
|
||||||
from django.utils.translation import get_language
|
|
||||||
|
|
||||||
from application.flow.common import Workflow
|
|
||||||
from application.flow.i_step_node import WorkFlowPostHandler, KnowledgeFlowParamsSerializer
|
|
||||||
from application.flow.workflow_manage import WorkflowManage
|
|
||||||
from common.handle.base_to_response import BaseToResponse
|
|
||||||
from common.handle.impl.response.system_to_response import SystemToResponse
|
|
||||||
from knowledge.models.knowledge_action import KnowledgeAction, State
|
|
||||||
|
|
||||||
executor = ThreadPoolExecutor(max_workers=200)
|
|
||||||
|
|
||||||
|
|
||||||
class KnowledgeWorkflowManage(WorkflowManage):
|
|
||||||
|
|
||||||
def __init__(self, flow: Workflow,
|
|
||||||
params,
|
|
||||||
work_flow_post_handler: WorkFlowPostHandler,
|
|
||||||
base_to_response: BaseToResponse = SystemToResponse(),
|
|
||||||
start_node_id=None,
|
|
||||||
start_node_data=None, chat_record=None, child_node=None):
|
|
||||||
super().__init__(flow, params, work_flow_post_handler, base_to_response, None, None, None,
|
|
||||||
None,
|
|
||||||
None, None, start_node_id, start_node_data, chat_record, child_node)
|
|
||||||
|
|
||||||
def get_params_serializer_class(self):
|
|
||||||
return KnowledgeFlowParamsSerializer
|
|
||||||
|
|
||||||
def get_start_node(self):
|
|
||||||
start_node_list = [node for node in self.flow.nodes if
|
|
||||||
self.params.get('data_source', {}).get('node_id') == node.id]
|
|
||||||
return start_node_list[0]
|
|
||||||
|
|
||||||
def run(self):
|
|
||||||
self.context['start_time'] = time.time()
|
|
||||||
executor.submit(self._run)
|
|
||||||
|
|
||||||
def _run(self):
|
|
||||||
QuerySet(KnowledgeAction).filter(id=self.params.get('knowledge_action_id')).update(
|
|
||||||
state=State.STARTED)
|
|
||||||
language = get_language()
|
|
||||||
self.run_chain_async(self.start_node, None, language)
|
|
||||||
while self.is_run():
|
|
||||||
pass
|
|
||||||
self.work_flow_post_handler.handler(self)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_node_details(current_node, node, index):
|
|
||||||
if current_node == node:
|
|
||||||
return {
|
|
||||||
'name': node.node.properties.get('stepName'),
|
|
||||||
"index": index,
|
|
||||||
'run_time': 0,
|
|
||||||
'type': node.type,
|
|
||||||
'status': 202,
|
|
||||||
'err_message': ""
|
|
||||||
}
|
|
||||||
|
|
||||||
return node.get_details(index)
|
|
||||||
|
|
||||||
def run_chain(self, current_node, node_result_future=None):
|
|
||||||
QuerySet(KnowledgeAction).filter(id=self.params.get('knowledge_action_id')).update(
|
|
||||||
details=self.get_runtime_details(lambda node, index: self.get_node_details(current_node, node, index)))
|
|
||||||
if node_result_future is None:
|
|
||||||
node_result_future = self.run_node_future(current_node)
|
|
||||||
try:
|
|
||||||
result = self.hand_node_result(current_node, node_result_future)
|
|
||||||
return result
|
|
||||||
except Exception as e:
|
|
||||||
traceback.print_exc()
|
|
||||||
return None
|
|
||||||
|
|
||||||
def hand_node_result(self, current_node, node_result_future):
|
|
||||||
try:
|
|
||||||
current_result = node_result_future.result()
|
|
||||||
result = current_result.write_context(current_node, self)
|
|
||||||
if result is not None:
|
|
||||||
# 阻塞获取结果
|
|
||||||
list(result)
|
|
||||||
return current_result
|
|
||||||
except Exception as e:
|
|
||||||
traceback.print_exc()
|
|
||||||
self.status = 500
|
|
||||||
current_node.get_write_error_context(e)
|
|
||||||
self.answer += str(e)
|
|
||||||
QuerySet(KnowledgeAction).filter(id=self.params.get('knowledge_action_id')).update(
|
|
||||||
details=self.get_runtime_details(),
|
|
||||||
state=State.FAILURE)
|
|
||||||
finally:
|
|
||||||
current_node.node_chunk.end()
|
|
||||||
QuerySet(KnowledgeAction).filter(id=self.params.get('knowledge_action_id')).update(
|
|
||||||
details=self.get_runtime_details())
|
|
||||||
|
|
@ -1,193 +0,0 @@
|
||||||
# coding=utf-8
|
|
||||||
"""
|
|
||||||
@project: maxkb
|
|
||||||
@Author:虎
|
|
||||||
@file: workflow_manage.py
|
|
||||||
@date:2024/1/9 17:40
|
|
||||||
@desc:
|
|
||||||
"""
|
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from django.db import close_old_connections
|
|
||||||
from django.utils.translation import get_language
|
|
||||||
from langchain_core.prompts import PromptTemplate
|
|
||||||
|
|
||||||
from application.flow.common import Workflow
|
|
||||||
from application.flow.i_step_node import WorkFlowPostHandler, INode
|
|
||||||
from application.flow.step_node import get_node
|
|
||||||
from application.flow.workflow_manage import WorkflowManage
|
|
||||||
from common.handle.base_to_response import BaseToResponse
|
|
||||||
from common.handle.impl.response.system_to_response import SystemToResponse
|
|
||||||
|
|
||||||
executor = ThreadPoolExecutor(max_workers=200)
|
|
||||||
|
|
||||||
|
|
||||||
class NodeResultFuture:
|
|
||||||
def __init__(self, r, e, status=200):
|
|
||||||
self.r = r
|
|
||||||
self.e = e
|
|
||||||
self.status = status
|
|
||||||
|
|
||||||
def result(self):
|
|
||||||
if self.status == 200:
|
|
||||||
return self.r
|
|
||||||
else:
|
|
||||||
raise self.e
|
|
||||||
|
|
||||||
|
|
||||||
def await_result(result, timeout=1):
|
|
||||||
try:
|
|
||||||
result.result(timeout)
|
|
||||||
return False
|
|
||||||
except Exception as e:
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
class NodeChunkManage:
|
|
||||||
|
|
||||||
def __init__(self, work_flow):
|
|
||||||
self.node_chunk_list = []
|
|
||||||
self.current_node_chunk = None
|
|
||||||
self.work_flow = work_flow
|
|
||||||
|
|
||||||
def add_node_chunk(self, node_chunk):
|
|
||||||
self.node_chunk_list.append(node_chunk)
|
|
||||||
|
|
||||||
def contains(self, node_chunk):
|
|
||||||
return self.node_chunk_list.__contains__(node_chunk)
|
|
||||||
|
|
||||||
def pop(self):
|
|
||||||
if self.current_node_chunk is None:
|
|
||||||
try:
|
|
||||||
current_node_chunk = self.node_chunk_list.pop(0)
|
|
||||||
self.current_node_chunk = current_node_chunk
|
|
||||||
except IndexError as e:
|
|
||||||
pass
|
|
||||||
if self.current_node_chunk is not None:
|
|
||||||
try:
|
|
||||||
chunk = self.current_node_chunk.chunk_list.pop(0)
|
|
||||||
return chunk
|
|
||||||
except IndexError as e:
|
|
||||||
if self.current_node_chunk.is_end():
|
|
||||||
self.current_node_chunk = None
|
|
||||||
if self.work_flow.answer_is_not_empty():
|
|
||||||
chunk = self.work_flow.base_to_response.to_stream_chunk_response(
|
|
||||||
self.work_flow.params['chat_id'],
|
|
||||||
self.work_flow.params['chat_record_id'],
|
|
||||||
'\n\n', False, 0, 0)
|
|
||||||
self.work_flow.append_answer('\n\n')
|
|
||||||
return chunk
|
|
||||||
return self.pop()
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class LoopWorkflowManage(WorkflowManage):
|
|
||||||
|
|
||||||
def __init__(self, flow: Workflow,
|
|
||||||
params,
|
|
||||||
work_flow_post_handler: WorkFlowPostHandler,
|
|
||||||
parentWorkflowManage,
|
|
||||||
loop_params,
|
|
||||||
get_loop_context,
|
|
||||||
base_to_response: BaseToResponse = SystemToResponse(),
|
|
||||||
start_node_id=None,
|
|
||||||
start_node_data=None, chat_record=None, child_node=None):
|
|
||||||
self.parentWorkflowManage = parentWorkflowManage
|
|
||||||
self.loop_params = loop_params
|
|
||||||
self.get_loop_context = get_loop_context
|
|
||||||
self.loop_field_list = []
|
|
||||||
super().__init__(flow, params, work_flow_post_handler, base_to_response, None, None, None,
|
|
||||||
None,
|
|
||||||
None, None, start_node_id, start_node_data, chat_record, child_node)
|
|
||||||
|
|
||||||
def get_node_cls_by_id(self, node_id, up_node_id_list=None,
|
|
||||||
get_node_params=lambda node: node.properties.get('node_data')):
|
|
||||||
for node in self.flow.nodes:
|
|
||||||
if node.id == node_id:
|
|
||||||
node_instance = get_node(node.type, self.flow.workflow_mode)(node,
|
|
||||||
self.params, self, up_node_id_list,
|
|
||||||
get_node_params,
|
|
||||||
salt=self.get_index())
|
|
||||||
return node_instance
|
|
||||||
return None
|
|
||||||
|
|
||||||
def stream(self):
|
|
||||||
close_old_connections()
|
|
||||||
language = get_language()
|
|
||||||
self.run_chain_async(self.start_node, None, language)
|
|
||||||
return self.await_result(is_cleanup=False)
|
|
||||||
|
|
||||||
def get_index(self):
|
|
||||||
return self.loop_params.get('index')
|
|
||||||
|
|
||||||
def get_start_node(self):
|
|
||||||
start_node_list = [node for node in self.flow.nodes if
|
|
||||||
['loop-start-node'].__contains__(node.type)]
|
|
||||||
return start_node_list[0]
|
|
||||||
|
|
||||||
def get_reference_field(self, node_id: str, fields: List[str]):
|
|
||||||
"""
|
|
||||||
@param node_id: 节点id
|
|
||||||
@param fields: 字段
|
|
||||||
@return:
|
|
||||||
"""
|
|
||||||
if node_id == 'global':
|
|
||||||
return self.parentWorkflowManage.get_reference_field(node_id, fields)
|
|
||||||
elif node_id == 'chat':
|
|
||||||
return self.parentWorkflowManage.get_reference_field(node_id, fields)
|
|
||||||
elif node_id == 'loop':
|
|
||||||
loop_context = self.get_loop_context()
|
|
||||||
return INode.get_field(loop_context, fields)
|
|
||||||
else:
|
|
||||||
node = self.get_node_by_id(node_id)
|
|
||||||
if node:
|
|
||||||
return node.get_reference_field(fields)
|
|
||||||
return self.parentWorkflowManage.get_reference_field(node_id, fields)
|
|
||||||
|
|
||||||
def get_workflow_content(self):
|
|
||||||
context = {
|
|
||||||
'global': self.context,
|
|
||||||
'chat': self.chat_context,
|
|
||||||
'loop': self.get_loop_context(),
|
|
||||||
}
|
|
||||||
|
|
||||||
for node in self.node_context:
|
|
||||||
context[node.id] = node.context
|
|
||||||
return context
|
|
||||||
|
|
||||||
def init_fields(self):
|
|
||||||
super().init_fields()
|
|
||||||
loop_field_list = []
|
|
||||||
loop_start_node = self.flow.get_node('loop-start-node')
|
|
||||||
loop_input_field_list = loop_start_node.properties.get('loop_input_field_list')
|
|
||||||
node_name = loop_start_node.properties.get('stepName')
|
|
||||||
node_id = loop_start_node.id
|
|
||||||
if loop_input_field_list is not None:
|
|
||||||
for f in loop_input_field_list:
|
|
||||||
loop_field_list.append(
|
|
||||||
{'label': f.get('label'), 'value': f.get('field'), 'node_id': node_id, 'node_name': node_name})
|
|
||||||
self.loop_field_list = loop_field_list
|
|
||||||
|
|
||||||
def reset_prompt(self, prompt: str):
|
|
||||||
prompt = super().reset_prompt(prompt)
|
|
||||||
for field in self.loop_field_list:
|
|
||||||
chatLabel = f"loop.{field.get('value')}"
|
|
||||||
chatValue = f"context.get('loop').get('{field.get('value', '')}','')"
|
|
||||||
prompt = prompt.replace(chatLabel, chatValue)
|
|
||||||
|
|
||||||
prompt = self.parentWorkflowManage.reset_prompt(prompt)
|
|
||||||
return prompt
|
|
||||||
|
|
||||||
def generate_prompt(self, prompt: str):
|
|
||||||
"""
|
|
||||||
格式化生成提示词
|
|
||||||
@param prompt: 提示词信息
|
|
||||||
@return: 格式化后的提示词
|
|
||||||
"""
|
|
||||||
|
|
||||||
context = {**self.get_workflow_content(), **self.parentWorkflowManage.get_workflow_content()}
|
|
||||||
prompt = self.reset_prompt(prompt)
|
|
||||||
prompt_template = PromptTemplate.from_template(prompt, template_format='jinja2')
|
|
||||||
value = prompt_template.format(context=context)
|
|
||||||
return value
|
|
||||||
|
|
@ -9,52 +9,34 @@
|
||||||
from .ai_chat_step_node import *
|
from .ai_chat_step_node import *
|
||||||
from .application_node import BaseApplicationNode
|
from .application_node import BaseApplicationNode
|
||||||
from .condition_node import *
|
from .condition_node import *
|
||||||
from .data_source_local_node.impl.base_data_source_local_node import BaseDataSourceLocalNode
|
|
||||||
from .data_source_web_node.impl.base_data_source_web_node import BaseDataSourceWebNode
|
|
||||||
from .direct_reply_node import *
|
from .direct_reply_node import *
|
||||||
from .document_extract_node import *
|
|
||||||
from .form_node import *
|
from .form_node import *
|
||||||
from .image_generate_step_node import *
|
from .function_lib_node import *
|
||||||
from .image_to_video_step_node import BaseImageToVideoNode
|
from .function_node import *
|
||||||
from .image_understand_step_node import *
|
|
||||||
from .intent_node import *
|
|
||||||
from .knowledge_write_node.impl.base_knowledge_write_node import BaseKnowledgeWriteNode
|
|
||||||
from .loop_break_node import BaseLoopBreakNode
|
|
||||||
from .loop_continue_node import BaseLoopContinueNode
|
|
||||||
from .loop_node import *
|
|
||||||
from .loop_start_node import *
|
|
||||||
from .mcp_node import BaseMcpNode
|
|
||||||
from .parameter_extraction_node import BaseParameterExtractionNode
|
|
||||||
from .question_node import *
|
from .question_node import *
|
||||||
from .reranker_node import *
|
from .reranker_node import *
|
||||||
from .search_document_node import BaseSearchDocumentNode
|
|
||||||
from .search_knowledge_node import *
|
from .document_extract_node import *
|
||||||
|
from .image_understand_step_node import *
|
||||||
|
from .image_generate_step_node import *
|
||||||
|
|
||||||
|
from .search_dataset_node import *
|
||||||
from .speech_to_text_step_node import BaseSpeechToTextNode
|
from .speech_to_text_step_node import BaseSpeechToTextNode
|
||||||
from .start_node import *
|
from .start_node import *
|
||||||
from .text_to_speech_step_node.impl.base_text_to_speech_node import BaseTextToSpeechNode
|
from .text_to_speech_step_node.impl.base_text_to_speech_node import BaseTextToSpeechNode
|
||||||
from .text_to_video_step_node.impl.base_text_to_video_node import BaseTextToVideoNode
|
|
||||||
from .tool_lib_node import *
|
|
||||||
from .tool_node import *
|
|
||||||
from .variable_aggregation_node.impl.base_variable_aggregation_node import BaseVariableAggregationNode
|
|
||||||
from .variable_assign_node import BaseVariableAssignNode
|
from .variable_assign_node import BaseVariableAssignNode
|
||||||
from .variable_splitting_node import BaseVariableSplittingNode
|
from .mcp_node import BaseMcpNode
|
||||||
from .video_understand_step_node import BaseVideoUnderstandNode
|
|
||||||
from .document_split_node import BaseDocumentSplitNode
|
|
||||||
|
|
||||||
node_list = [BaseStartStepNode, BaseChatNode, BaseSearchKnowledgeNode, BaseSearchDocumentNode, BaseQuestionNode,
|
node_list = [BaseStartStepNode, BaseChatNode, BaseSearchDatasetNode, BaseQuestionNode,
|
||||||
BaseConditionNode, BaseReplyNode,
|
BaseConditionNode, BaseReplyNode,
|
||||||
BaseToolNodeNode, BaseToolLibNodeNode, BaseRerankerNode, BaseApplicationNode,
|
BaseFunctionNodeNode, BaseFunctionLibNodeNode, BaseRerankerNode, BaseApplicationNode,
|
||||||
BaseDocumentExtractNode,
|
BaseDocumentExtractNode,
|
||||||
BaseImageUnderstandNode, BaseFormNode, BaseSpeechToTextNode, BaseTextToSpeechNode,
|
BaseImageUnderstandNode, BaseFormNode, BaseSpeechToTextNode, BaseTextToSpeechNode,
|
||||||
BaseImageGenerateNode, BaseVariableAssignNode, BaseMcpNode, BaseTextToVideoNode, BaseImageToVideoNode,
|
BaseImageGenerateNode, BaseVariableAssignNode, BaseMcpNode]
|
||||||
BaseVideoUnderstandNode,
|
|
||||||
BaseIntentNode, BaseLoopNode, BaseLoopStartStepNode,
|
|
||||||
BaseLoopContinueNode,
|
|
||||||
BaseLoopBreakNode, BaseVariableSplittingNode, BaseParameterExtractionNode, BaseVariableAggregationNode,
|
|
||||||
BaseDataSourceLocalNode, BaseDataSourceWebNode, BaseKnowledgeWriteNode, BaseDocumentSplitNode]
|
|
||||||
|
|
||||||
node_map = {n.type: {w: n for w in n.support} for n in node_list}
|
|
||||||
|
|
||||||
|
|
||||||
def get_node(node_type, workflow_model):
|
def get_node(node_type):
|
||||||
return node_map.get(node_type).get(workflow_model)
|
find_list = [node for node in node_list if node.type == node_type]
|
||||||
|
if len(find_list) > 0:
|
||||||
|
return find_list[0]
|
||||||
|
return None
|
||||||
|
|
|
||||||
|
|
@ -11,55 +11,41 @@ from typing import Type
|
||||||
from django.utils.translation import gettext_lazy as _
|
from django.utils.translation import gettext_lazy as _
|
||||||
from rest_framework import serializers
|
from rest_framework import serializers
|
||||||
|
|
||||||
from application.flow.common import WorkflowMode
|
|
||||||
from application.flow.i_step_node import INode, NodeResult
|
from application.flow.i_step_node import INode, NodeResult
|
||||||
|
from common.util.field_message import ErrMessage
|
||||||
|
|
||||||
|
|
||||||
class ChatNodeSerializer(serializers.Serializer):
|
class ChatNodeSerializer(serializers.Serializer):
|
||||||
model_id = serializers.CharField(required=True, label=_("Model id"))
|
model_id = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Model id")))
|
||||||
system = serializers.CharField(required=False, allow_blank=True, allow_null=True,
|
system = serializers.CharField(required=False, allow_blank=True, allow_null=True,
|
||||||
label=_("Role Setting"))
|
error_messages=ErrMessage.char(_("Role Setting")))
|
||||||
prompt = serializers.CharField(required=True, label=_("Prompt word"))
|
prompt = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Prompt word")))
|
||||||
# 多轮对话数量
|
# 多轮对话数量
|
||||||
dialogue_number = serializers.IntegerField(required=True, label=_("Number of multi-round conversations"))
|
dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer(
|
||||||
|
_("Number of multi-round conversations")))
|
||||||
|
|
||||||
is_result = serializers.BooleanField(required=False,
|
is_result = serializers.BooleanField(required=False,
|
||||||
label=_('Whether to return content'))
|
error_messages=ErrMessage.boolean(_('Whether to return content')))
|
||||||
|
|
||||||
model_params_setting = serializers.DictField(required=False,
|
model_params_setting = serializers.DictField(required=False,
|
||||||
label=_("Model parameter settings"))
|
error_messages=ErrMessage.dict(_("Model parameter settings")))
|
||||||
model_setting = serializers.DictField(required=False,
|
model_setting = serializers.DictField(required=False,
|
||||||
label='Model settings')
|
error_messages=ErrMessage.dict('Model settings'))
|
||||||
dialogue_type = serializers.CharField(required=False, allow_blank=True, allow_null=True,
|
dialogue_type = serializers.CharField(required=False, allow_blank=True, allow_null=True,
|
||||||
label=_("Context Type"))
|
error_messages=ErrMessage.char(_("Context Type")))
|
||||||
mcp_enable = serializers.BooleanField(required=False, label=_("Whether to enable MCP"))
|
mcp_enable = serializers.BooleanField(required=False,
|
||||||
mcp_servers = serializers.JSONField(required=False, label=_("MCP Server"))
|
error_messages=ErrMessage.boolean(_("Whether to enable MCP")))
|
||||||
mcp_tool_id = serializers.CharField(required=False, allow_blank=True, allow_null=True, label=_("MCP Tool ID"))
|
mcp_servers = serializers.JSONField(required=False, error_messages=ErrMessage.list(_("MCP Server")))
|
||||||
mcp_tool_ids = serializers.ListField(child=serializers.UUIDField(), required=False, allow_empty=True,
|
|
||||||
label=_("MCP Tool IDs"), )
|
|
||||||
mcp_source = serializers.CharField(required=False, allow_blank=True, allow_null=True, label=_("MCP Source"))
|
|
||||||
|
|
||||||
tool_enable = serializers.BooleanField(required=False, default=False, label=_("Whether to enable tools"))
|
|
||||||
tool_ids = serializers.ListField(child=serializers.UUIDField(), required=False, allow_empty=True,
|
|
||||||
label=_("Tool IDs"), )
|
|
||||||
mcp_output_enable = serializers.BooleanField(required=False, default=True, label=_("Whether to enable MCP output"))
|
|
||||||
|
|
||||||
|
|
||||||
class IChatNode(INode):
|
class IChatNode(INode):
|
||||||
type = 'ai-chat-node'
|
type = 'ai-chat-node'
|
||||||
support = [WorkflowMode.APPLICATION, WorkflowMode.APPLICATION_LOOP, WorkflowMode.KNOWLEDGE_LOOP,
|
|
||||||
WorkflowMode.KNOWLEDGE]
|
|
||||||
|
|
||||||
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
|
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
|
||||||
return ChatNodeSerializer
|
return ChatNodeSerializer
|
||||||
|
|
||||||
def _run(self):
|
def _run(self):
|
||||||
if [WorkflowMode.KNOWLEDGE, WorkflowMode.KNOWLEDGE_LOOP].__contains__(
|
return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data)
|
||||||
self.workflow_manage.flow.workflow_mode):
|
|
||||||
return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data,
|
|
||||||
**{'history_chat_record': [], 'stream': True, 'chat_id': None, 'chat_record_id': None})
|
|
||||||
else:
|
|
||||||
return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data)
|
|
||||||
|
|
||||||
def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id,
|
def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id,
|
||||||
chat_record_id,
|
chat_record_id,
|
||||||
|
|
@ -68,11 +54,5 @@ class IChatNode(INode):
|
||||||
model_setting=None,
|
model_setting=None,
|
||||||
mcp_enable=False,
|
mcp_enable=False,
|
||||||
mcp_servers=None,
|
mcp_servers=None,
|
||||||
mcp_tool_id=None,
|
|
||||||
mcp_tool_ids=None,
|
|
||||||
mcp_source=None,
|
|
||||||
tool_enable=False,
|
|
||||||
tool_ids=None,
|
|
||||||
mcp_output_enable=True,
|
|
||||||
**kwargs) -> NodeResult:
|
**kwargs) -> NodeResult:
|
||||||
pass
|
pass
|
||||||
|
|
|
||||||
|
|
@ -6,26 +6,39 @@
|
||||||
@date:2024/6/4 14:30
|
@date:2024/6/4 14:30
|
||||||
@desc:
|
@desc:
|
||||||
"""
|
"""
|
||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import os
|
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
|
from types import AsyncGeneratorType
|
||||||
from typing import List, Dict
|
from typing import List, Dict
|
||||||
|
|
||||||
from django.db.models import QuerySet
|
from django.db.models import QuerySet
|
||||||
from langchain.schema import HumanMessage, SystemMessage
|
from langchain.schema import HumanMessage, SystemMessage
|
||||||
from langchain_core.messages import BaseMessage, AIMessage
|
from langchain_core.messages import BaseMessage, AIMessage, AIMessageChunk, ToolMessage
|
||||||
|
from langchain_mcp_adapters.client import MultiServerMCPClient
|
||||||
|
from langgraph.prebuilt import create_react_agent
|
||||||
|
|
||||||
from application.flow.i_step_node import NodeResult, INode
|
from application.flow.i_step_node import NodeResult, INode
|
||||||
from application.flow.step_node.ai_chat_step_node.i_chat_node import IChatNode
|
from application.flow.step_node.ai_chat_step_node.i_chat_node import IChatNode
|
||||||
from application.flow.tools import Reasoning, mcp_response_generator
|
from application.flow.tools import Reasoning
|
||||||
from common.utils.rsa_util import rsa_long_decrypt
|
from setting.models import Model
|
||||||
from common.utils.tool_code import ToolExecutor
|
from setting.models_provider import get_model_credential
|
||||||
from maxkb.const import CONFIG
|
from setting.models_provider.tools import get_model_instance_by_model_user_id
|
||||||
from models_provider.models import Model
|
|
||||||
from models_provider.tools import get_model_credential, get_model_instance_by_model_workspace_id
|
tool_message_template = """
|
||||||
from tools.models import Tool
|
<details>
|
||||||
|
<summary>
|
||||||
|
<strong>Called MCP Tool: <em>%s</em></strong>
|
||||||
|
</summary>
|
||||||
|
|
||||||
|
```json
|
||||||
|
%s
|
||||||
|
```
|
||||||
|
</details>
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str,
|
def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str,
|
||||||
|
|
@ -90,6 +103,39 @@ def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INo
|
||||||
_write_context(node_variable, workflow_variable, node, workflow, answer, reasoning_content)
|
_write_context(node_variable, workflow_variable, node, workflow, answer, reasoning_content)
|
||||||
|
|
||||||
|
|
||||||
|
async def _yield_mcp_response(chat_model, message_list, mcp_servers):
|
||||||
|
async with MultiServerMCPClient(json.loads(mcp_servers)) as client:
|
||||||
|
agent = create_react_agent(chat_model, client.get_tools())
|
||||||
|
response = agent.astream({"messages": message_list}, stream_mode='messages')
|
||||||
|
async for chunk in response:
|
||||||
|
if isinstance(chunk[0], ToolMessage):
|
||||||
|
content = tool_message_template % (chunk[0].name, chunk[0].content)
|
||||||
|
chunk[0].content = content
|
||||||
|
yield chunk[0]
|
||||||
|
if isinstance(chunk[0], AIMessageChunk):
|
||||||
|
yield chunk[0]
|
||||||
|
|
||||||
|
|
||||||
|
def mcp_response_generator(chat_model, message_list, mcp_servers):
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
try:
|
||||||
|
async_gen = _yield_mcp_response(chat_model, message_list, mcp_servers)
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
chunk = loop.run_until_complete(anext_async(async_gen))
|
||||||
|
yield chunk
|
||||||
|
except StopAsyncIteration:
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
print(f'exception: {e}')
|
||||||
|
finally:
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
|
||||||
|
async def anext_async(agen):
|
||||||
|
return await agen.__anext__()
|
||||||
|
|
||||||
|
|
||||||
def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
|
def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
|
||||||
"""
|
"""
|
||||||
写入上下文数据
|
写入上下文数据
|
||||||
|
|
@ -106,11 +152,10 @@ def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wor
|
||||||
reasoning_result = reasoning.get_reasoning_content(response)
|
reasoning_result = reasoning.get_reasoning_content(response)
|
||||||
reasoning_result_end = reasoning.get_end_reasoning_content()
|
reasoning_result_end = reasoning.get_end_reasoning_content()
|
||||||
content = reasoning_result.get('content') + reasoning_result_end.get('content')
|
content = reasoning_result.get('content') + reasoning_result_end.get('content')
|
||||||
meta = {**response.response_metadata, **response.additional_kwargs}
|
if 'reasoning_content' in response.response_metadata:
|
||||||
if 'reasoning_content' in meta:
|
reasoning_content = response.response_metadata.get('reasoning_content', '')
|
||||||
reasoning_content = (meta.get('reasoning_content', '') or '')
|
|
||||||
else:
|
else:
|
||||||
reasoning_content = (reasoning_result.get('reasoning_content') or '') + (reasoning_result_end.get('reasoning_content') or '')
|
reasoning_content = reasoning_result.get('reasoning_content') + reasoning_result_end.get('reasoning_content')
|
||||||
_write_context(node_variable, workflow_variable, node, workflow, content, reasoning_content)
|
_write_context(node_variable, workflow_variable, node, workflow, content, reasoning_content)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -152,12 +197,6 @@ class BaseChatNode(IChatNode):
|
||||||
model_setting=None,
|
model_setting=None,
|
||||||
mcp_enable=False,
|
mcp_enable=False,
|
||||||
mcp_servers=None,
|
mcp_servers=None,
|
||||||
mcp_tool_id=None,
|
|
||||||
mcp_tool_ids=None,
|
|
||||||
mcp_source=None,
|
|
||||||
tool_enable=False,
|
|
||||||
tool_ids=None,
|
|
||||||
mcp_output_enable=True,
|
|
||||||
**kwargs) -> NodeResult:
|
**kwargs) -> NodeResult:
|
||||||
if dialogue_type is None:
|
if dialogue_type is None:
|
||||||
dialogue_type = 'WORKFLOW'
|
dialogue_type = 'WORKFLOW'
|
||||||
|
|
@ -168,9 +207,8 @@ class BaseChatNode(IChatNode):
|
||||||
model_setting = {'reasoning_content_enable': False, 'reasoning_content_end': '</think>',
|
model_setting = {'reasoning_content_enable': False, 'reasoning_content_end': '</think>',
|
||||||
'reasoning_content_start': '<think>'}
|
'reasoning_content_start': '<think>'}
|
||||||
self.context['model_setting'] = model_setting
|
self.context['model_setting'] = model_setting
|
||||||
workspace_id = self.workflow_manage.get_body().get('workspace_id')
|
chat_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'),
|
||||||
chat_model = get_model_instance_by_model_workspace_id(model_id, workspace_id,
|
**model_params_setting)
|
||||||
**model_params_setting)
|
|
||||||
history_message = self.get_history_message(history_chat_record, dialogue_number, dialogue_type,
|
history_message = self.get_history_message(history_chat_record, dialogue_number, dialogue_type,
|
||||||
self.runtime_node_id)
|
self.runtime_node_id)
|
||||||
self.context['history_message'] = history_message
|
self.context['history_message'] = history_message
|
||||||
|
|
@ -181,98 +219,24 @@ class BaseChatNode(IChatNode):
|
||||||
message_list = self.generate_message_list(system, prompt, history_message)
|
message_list = self.generate_message_list(system, prompt, history_message)
|
||||||
self.context['message_list'] = message_list
|
self.context['message_list'] = message_list
|
||||||
|
|
||||||
# 处理 MCP 请求
|
if mcp_enable and mcp_servers is not None and '"stdio"' not in mcp_servers:
|
||||||
mcp_result = self._handle_mcp_request(
|
r = mcp_response_generator(chat_model, message_list, mcp_servers)
|
||||||
mcp_enable, tool_enable, mcp_source, mcp_servers, mcp_tool_id, mcp_tool_ids, tool_ids, mcp_output_enable,
|
return NodeResult(
|
||||||
chat_model, message_list, history_message, question
|
{'result': r, 'chat_model': chat_model, 'message_list': message_list,
|
||||||
)
|
'history_message': history_message, 'question': question.content}, {},
|
||||||
if mcp_result:
|
_write_context=write_context_stream)
|
||||||
return mcp_result
|
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
r = chat_model.stream(message_list)
|
r = chat_model.stream(message_list)
|
||||||
return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list,
|
return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list,
|
||||||
'history_message': [{'content': message.content, 'role': message.type} for message in
|
'history_message': history_message, 'question': question.content}, {},
|
||||||
(history_message if history_message is not None else [])],
|
|
||||||
'question': question.content}, {},
|
|
||||||
_write_context=write_context_stream)
|
_write_context=write_context_stream)
|
||||||
else:
|
else:
|
||||||
r = chat_model.invoke(message_list)
|
r = chat_model.invoke(message_list)
|
||||||
return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list,
|
return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list,
|
||||||
'history_message': [{'content': message.content, 'role': message.type} for message in
|
'history_message': history_message, 'question': question.content}, {},
|
||||||
(history_message if history_message is not None else [])],
|
|
||||||
'question': question.content}, {},
|
|
||||||
_write_context=write_context)
|
_write_context=write_context)
|
||||||
|
|
||||||
def _handle_mcp_request(self, mcp_enable, tool_enable, mcp_source, mcp_servers, mcp_tool_id, mcp_tool_ids, tool_ids,
|
|
||||||
mcp_output_enable, chat_model, message_list, history_message, question):
|
|
||||||
if not mcp_enable and not tool_enable:
|
|
||||||
return None
|
|
||||||
|
|
||||||
mcp_servers_config = {}
|
|
||||||
|
|
||||||
# 迁移过来mcp_source是None
|
|
||||||
if mcp_source is None:
|
|
||||||
mcp_source = 'custom'
|
|
||||||
if mcp_enable:
|
|
||||||
# 兼容老数据
|
|
||||||
if not mcp_tool_ids:
|
|
||||||
mcp_tool_ids = []
|
|
||||||
if mcp_tool_id:
|
|
||||||
mcp_tool_ids = list(set(mcp_tool_ids + [mcp_tool_id]))
|
|
||||||
if mcp_source == 'custom' and mcp_servers is not None and '"stdio"' not in mcp_servers:
|
|
||||||
mcp_servers_config = json.loads(mcp_servers)
|
|
||||||
mcp_servers_config = self.handle_variables(mcp_servers_config)
|
|
||||||
elif mcp_tool_ids:
|
|
||||||
mcp_tools = QuerySet(Tool).filter(id__in=mcp_tool_ids).values()
|
|
||||||
for mcp_tool in mcp_tools:
|
|
||||||
if mcp_tool and mcp_tool['is_active']:
|
|
||||||
mcp_servers_config = {**mcp_servers_config, **json.loads(mcp_tool['code'])}
|
|
||||||
mcp_servers_config = self.handle_variables(mcp_servers_config)
|
|
||||||
|
|
||||||
if tool_enable:
|
|
||||||
if tool_ids and len(tool_ids) > 0: # 如果有工具ID,则将其转换为MCP
|
|
||||||
self.context['tool_ids'] = tool_ids
|
|
||||||
for tool_id in tool_ids:
|
|
||||||
tool = QuerySet(Tool).filter(id=tool_id).first()
|
|
||||||
if not tool.is_active:
|
|
||||||
continue
|
|
||||||
executor = ToolExecutor()
|
|
||||||
if tool.init_params is not None:
|
|
||||||
params = json.loads(rsa_long_decrypt(tool.init_params))
|
|
||||||
else:
|
|
||||||
params = {}
|
|
||||||
tool_config = executor.get_tool_mcp_config(tool.code, params)
|
|
||||||
|
|
||||||
mcp_servers_config[str(tool.id)] = tool_config
|
|
||||||
|
|
||||||
if len(mcp_servers_config) > 0:
|
|
||||||
r = mcp_response_generator(chat_model, message_list, json.dumps(mcp_servers_config), mcp_output_enable)
|
|
||||||
return NodeResult(
|
|
||||||
{'result': r, 'chat_model': chat_model, 'message_list': message_list,
|
|
||||||
'history_message': [{'content': message.content, 'role': message.type} for message in
|
|
||||||
(history_message if history_message is not None else [])],
|
|
||||||
'question': question.content}, {},
|
|
||||||
_write_context=write_context_stream)
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
def handle_variables(self, tool_params):
|
|
||||||
# 处理参数中的变量
|
|
||||||
for k, v in tool_params.items():
|
|
||||||
if type(v) == str:
|
|
||||||
tool_params[k] = self.workflow_manage.generate_prompt(tool_params[k])
|
|
||||||
if type(v) == dict:
|
|
||||||
self.handle_variables(v)
|
|
||||||
if (type(v) == list) and (type(v[0]) == str):
|
|
||||||
tool_params[k] = self.get_reference_content(v)
|
|
||||||
return tool_params
|
|
||||||
|
|
||||||
def get_reference_content(self, fields: List[str]):
|
|
||||||
return str(self.workflow_manage.get_reference_field(
|
|
||||||
fields[0],
|
|
||||||
fields[1:]))
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_history_message(history_chat_record, dialogue_number, dialogue_type, runtime_node_id):
|
def get_history_message(history_chat_record, dialogue_number, dialogue_type, runtime_node_id):
|
||||||
start_index = len(history_chat_record) - dialogue_number
|
start_index = len(history_chat_record) - dialogue_number
|
||||||
|
|
@ -310,7 +274,9 @@ class BaseChatNode(IChatNode):
|
||||||
"index": index,
|
"index": index,
|
||||||
'run_time': self.context.get('run_time'),
|
'run_time': self.context.get('run_time'),
|
||||||
'system': self.context.get('system'),
|
'system': self.context.get('system'),
|
||||||
'history_message': self.context.get('history_message'),
|
'history_message': [{'content': message.content, 'role': message.type} for message in
|
||||||
|
(self.context.get('history_message') if self.context.get(
|
||||||
|
'history_message') is not None else [])],
|
||||||
'question': self.context.get('question'),
|
'question': self.context.get('question'),
|
||||||
'answer': self.context.get('answer'),
|
'answer': self.context.get('answer'),
|
||||||
'reasoning_content': self.context.get('reasoning_content'),
|
'reasoning_content': self.context.get('reasoning_content'),
|
||||||
|
|
|
||||||
|
|
@ -3,30 +3,29 @@ from typing import Type
|
||||||
|
|
||||||
from rest_framework import serializers
|
from rest_framework import serializers
|
||||||
|
|
||||||
from application.flow.common import WorkflowMode
|
|
||||||
from application.flow.i_step_node import INode, NodeResult
|
from application.flow.i_step_node import INode, NodeResult
|
||||||
|
from common.util.field_message import ErrMessage
|
||||||
|
|
||||||
from django.utils.translation import gettext_lazy as _
|
from django.utils.translation import gettext_lazy as _
|
||||||
|
|
||||||
|
|
||||||
class ApplicationNodeSerializer(serializers.Serializer):
|
class ApplicationNodeSerializer(serializers.Serializer):
|
||||||
application_id = serializers.CharField(required=True, label=_("Application ID"))
|
application_id = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Application ID")))
|
||||||
question_reference_address = serializers.ListField(required=True,
|
question_reference_address = serializers.ListField(required=True,
|
||||||
label=_("User Questions"))
|
error_messages=ErrMessage.list(_("User Questions")))
|
||||||
api_input_field_list = serializers.ListField(required=False, label=_("API Input Fields"))
|
api_input_field_list = serializers.ListField(required=False, error_messages=ErrMessage.list(_("API Input Fields")))
|
||||||
user_input_field_list = serializers.ListField(required=False,
|
user_input_field_list = serializers.ListField(required=False,
|
||||||
label=_("User Input Fields"))
|
error_messages=ErrMessage.uuid(_("User Input Fields")))
|
||||||
image_list = serializers.ListField(required=False, label=_("picture"))
|
image_list = serializers.ListField(required=False, error_messages=ErrMessage.list(_("picture")))
|
||||||
document_list = serializers.ListField(required=False, label=_("document"))
|
document_list = serializers.ListField(required=False, error_messages=ErrMessage.list(_("document")))
|
||||||
audio_list = serializers.ListField(required=False, label=_("Audio"))
|
audio_list = serializers.ListField(required=False, error_messages=ErrMessage.list(_("Audio")))
|
||||||
child_node = serializers.DictField(required=False, allow_null=True,
|
child_node = serializers.DictField(required=False, allow_null=True,
|
||||||
label=_("Child Nodes"))
|
error_messages=ErrMessage.dict(_("Child Nodes")))
|
||||||
node_data = serializers.DictField(required=False, allow_null=True, label=_("Form Data"))
|
node_data = serializers.DictField(required=False, allow_null=True, error_messages=ErrMessage.dict(_("Form Data")))
|
||||||
|
|
||||||
|
|
||||||
class IApplicationNode(INode):
|
class IApplicationNode(INode):
|
||||||
type = 'application-node'
|
type = 'application-node'
|
||||||
support = [WorkflowMode.APPLICATION, WorkflowMode.APPLICATION_LOOP]
|
|
||||||
|
|
||||||
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
|
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
|
||||||
return ApplicationNodeSerializer
|
return ApplicationNodeSerializer
|
||||||
|
|
@ -76,7 +75,7 @@ class IApplicationNode(INode):
|
||||||
if 'file_id' not in audio:
|
if 'file_id' not in audio:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
_("Parameter value error: The uploaded audio lacks file_id, and the audio upload fails."))
|
_("Parameter value error: The uploaded audio lacks file_id, and the audio upload fails."))
|
||||||
return self.execute(**{**self.flow_params_serializer.data, **self.node_params_serializer.data},
|
return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data,
|
||||||
app_document_list=app_document_list, app_image_list=app_image_list,
|
app_document_list=app_document_list, app_image_list=app_image_list,
|
||||||
app_audio_list=app_audio_list,
|
app_audio_list=app_audio_list,
|
||||||
message=str(question), **kwargs)
|
message=str(question), **kwargs)
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ import re
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
from django.utils.translation import gettext as _
|
|
||||||
from application.flow.common import Answer
|
from application.flow.common import Answer
|
||||||
from application.flow.i_step_node import NodeResult, INode
|
from application.flow.i_step_node import NodeResult, INode
|
||||||
from application.flow.step_node.application_node.i_application_node import IApplicationNode
|
from application.flow.step_node.application_node.i_application_node import IApplicationNode
|
||||||
|
|
@ -55,7 +55,7 @@ def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INo
|
||||||
# 先把流转成字符串
|
# 先把流转成字符串
|
||||||
response_content = chunk.decode('utf-8')[6:]
|
response_content = chunk.decode('utf-8')[6:]
|
||||||
response_content = json.loads(response_content)
|
response_content = json.loads(response_content)
|
||||||
content = (response_content.get('content', '') or '')
|
content = response_content.get('content', '')
|
||||||
runtime_node_id = response_content.get('runtime_node_id', '')
|
runtime_node_id = response_content.get('runtime_node_id', '')
|
||||||
chat_record_id = response_content.get('chat_record_id', '')
|
chat_record_id = response_content.get('chat_record_id', '')
|
||||||
child_node = response_content.get('child_node')
|
child_node = response_content.get('child_node')
|
||||||
|
|
@ -63,7 +63,7 @@ def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INo
|
||||||
node_type = response_content.get('node_type')
|
node_type = response_content.get('node_type')
|
||||||
real_node_id = response_content.get('real_node_id')
|
real_node_id = response_content.get('real_node_id')
|
||||||
node_is_end = response_content.get('node_is_end', False)
|
node_is_end = response_content.get('node_is_end', False)
|
||||||
_reasoning_content = (response_content.get('reasoning_content', '') or '')
|
_reasoning_content = response_content.get('reasoning_content', '')
|
||||||
if node_type == 'form-node':
|
if node_type == 'form-node':
|
||||||
is_interrupt_exec = True
|
is_interrupt_exec = True
|
||||||
answer += content
|
answer += content
|
||||||
|
|
@ -171,30 +171,16 @@ class BaseApplicationNode(IApplicationNode):
|
||||||
if self.node_params.get('is_result', False):
|
if self.node_params.get('is_result', False):
|
||||||
self.answer_text = details.get('answer')
|
self.answer_text = details.get('answer')
|
||||||
|
|
||||||
def get_chat_asker(self, kwargs):
|
def execute(self, application_id, message, chat_id, chat_record_id, stream, re_chat, client_id, client_type,
|
||||||
asker = kwargs.get('asker')
|
|
||||||
if asker:
|
|
||||||
if isinstance(asker, dict):
|
|
||||||
return asker
|
|
||||||
return {'username': asker}
|
|
||||||
return self.workflow_manage.work_flow_post_handler.chat_info.get_chat_user()
|
|
||||||
|
|
||||||
def execute(self, application_id, message, chat_id, chat_record_id, stream, re_chat,
|
|
||||||
chat_user_id,
|
|
||||||
chat_user_type,
|
|
||||||
app_document_list=None, app_image_list=None, app_audio_list=None, child_node=None, node_data=None,
|
app_document_list=None, app_image_list=None, app_audio_list=None, child_node=None, node_data=None,
|
||||||
**kwargs) -> NodeResult:
|
**kwargs) -> NodeResult:
|
||||||
from chat.serializers.chat import ChatSerializers
|
from application.serializers.chat_message_serializers import ChatMessageSerializer
|
||||||
if application_id == self.workflow_manage.get_body().get('application_id'):
|
|
||||||
raise Exception(_("The sub application cannot use the current node"))
|
|
||||||
# 生成嵌入应用的chat_id
|
# 生成嵌入应用的chat_id
|
||||||
current_chat_id = string_to_uuid(chat_id + application_id)
|
current_chat_id = string_to_uuid(chat_id + application_id)
|
||||||
Chat.objects.get_or_create(id=current_chat_id, defaults={
|
Chat.objects.get_or_create(id=current_chat_id, defaults={
|
||||||
'application_id': application_id,
|
'application_id': application_id,
|
||||||
'abstract': message[0:1024],
|
'abstract': message[0:1024],
|
||||||
'chat_user_id': chat_user_id,
|
'client_id': client_id,
|
||||||
'chat_user_type': chat_user_type,
|
|
||||||
'asker': self.get_chat_asker(kwargs)
|
|
||||||
})
|
})
|
||||||
if app_document_list is None:
|
if app_document_list is None:
|
||||||
app_document_list = []
|
app_document_list = []
|
||||||
|
|
@ -211,26 +197,22 @@ class BaseApplicationNode(IApplicationNode):
|
||||||
child_node_value = child_node.get('child_node')
|
child_node_value = child_node.get('child_node')
|
||||||
application_node_dict = self.context.get('application_node_dict')
|
application_node_dict = self.context.get('application_node_dict')
|
||||||
reset_application_node_dict(application_node_dict, runtime_node_id, node_data)
|
reset_application_node_dict(application_node_dict, runtime_node_id, node_data)
|
||||||
response = ChatSerializers(data={
|
|
||||||
"chat_id": current_chat_id,
|
|
||||||
"chat_user_id": chat_user_id,
|
|
||||||
'chat_user_type': chat_user_type,
|
|
||||||
'application_id': application_id,
|
|
||||||
'debug': False
|
|
||||||
}).chat(instance=
|
|
||||||
{'message': message,
|
|
||||||
're_chat': re_chat,
|
|
||||||
'stream': stream,
|
|
||||||
'document_list': app_document_list,
|
|
||||||
'image_list': app_image_list,
|
|
||||||
'audio_list': app_audio_list,
|
|
||||||
'runtime_node_id': runtime_node_id,
|
|
||||||
'chat_record_id': record_id,
|
|
||||||
'child_node': child_node_value,
|
|
||||||
'node_data': node_data,
|
|
||||||
'form_data': kwargs}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
response = ChatMessageSerializer(
|
||||||
|
data={'chat_id': current_chat_id, 'message': message,
|
||||||
|
're_chat': re_chat,
|
||||||
|
'stream': stream,
|
||||||
|
'application_id': application_id,
|
||||||
|
'client_id': client_id,
|
||||||
|
'client_type': client_type,
|
||||||
|
'document_list': app_document_list,
|
||||||
|
'image_list': app_image_list,
|
||||||
|
'audio_list': app_audio_list,
|
||||||
|
'runtime_node_id': runtime_node_id,
|
||||||
|
'chat_record_id': record_id,
|
||||||
|
'child_node': child_node_value,
|
||||||
|
'node_data': node_data,
|
||||||
|
'form_data': kwargs}).chat()
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
if stream:
|
if stream:
|
||||||
content_generator = response.streaming_content
|
content_generator = response.streaming_content
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,6 @@
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .contain_compare import *
|
from .contain_compare import *
|
||||||
from .end_with import EndWithCompare
|
|
||||||
from .equal_compare import *
|
from .equal_compare import *
|
||||||
from .ge_compare import *
|
from .ge_compare import *
|
||||||
from .gt_compare import *
|
from .gt_compare import *
|
||||||
|
|
@ -24,10 +23,8 @@ from .len_le_compare import *
|
||||||
from .len_lt_compare import *
|
from .len_lt_compare import *
|
||||||
from .lt_compare import *
|
from .lt_compare import *
|
||||||
from .not_contain_compare import *
|
from .not_contain_compare import *
|
||||||
from .start_with import StartWithCompare
|
|
||||||
|
|
||||||
compare_handle_list = [GECompare(), GTCompare(), ContainCompare(), EqualCompare(), LTCompare(), LECompare(),
|
compare_handle_list = [GECompare(), GTCompare(), ContainCompare(), EqualCompare(), LTCompare(), LECompare(),
|
||||||
LenLECompare(), LenGECompare(), LenEqualCompare(), LenGTCompare(), LenLTCompare(),
|
LenLECompare(), LenGECompare(), LenEqualCompare(), LenGTCompare(), LenLTCompare(),
|
||||||
IsNullCompare(),
|
IsNullCompare(),
|
||||||
IsNotNullCompare(), NotContainCompare(), IsTrueCompare(), IsNotTrueCompare(), StartWithCompare(),
|
IsNotNullCompare(), NotContainCompare(), IsTrueCompare(), IsNotTrueCompare()]
|
||||||
EndWithCompare()]
|
|
||||||
|
|
@ -8,7 +8,7 @@
|
||||||
"""
|
"""
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from application.flow.compare.compare import Compare
|
from application.flow.step_node.condition_node.compare.compare import Compare
|
||||||
|
|
||||||
|
|
||||||
class ContainCompare(Compare):
|
class ContainCompare(Compare):
|
||||||
|
|
@ -20,7 +20,4 @@ class ContainCompare(Compare):
|
||||||
def compare(self, source_value, compare, target_value):
|
def compare(self, source_value, compare, target_value):
|
||||||
if isinstance(source_value, str):
|
if isinstance(source_value, str):
|
||||||
return str(target_value) in source_value
|
return str(target_value) in source_value
|
||||||
elif isinstance(source_value, list):
|
return any([str(item) == str(target_value) for item in source_value])
|
||||||
return any([str(item) == str(target_value) for item in source_value])
|
|
||||||
else:
|
|
||||||
return str(target_value) in str(source_value)
|
|
||||||
|
|
@ -8,7 +8,7 @@
|
||||||
"""
|
"""
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from application.flow.compare import Compare
|
from application.flow.step_node.condition_node.compare.compare import Compare
|
||||||
|
|
||||||
|
|
||||||
class EqualCompare(Compare):
|
class EqualCompare(Compare):
|
||||||
|
|
@ -8,7 +8,7 @@
|
||||||
"""
|
"""
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from application.flow.compare import Compare
|
from application.flow.step_node.condition_node.compare.compare import Compare
|
||||||
|
|
||||||
|
|
||||||
class GECompare(Compare):
|
class GECompare(Compare):
|
||||||
|
|
@ -21,8 +21,4 @@ class GECompare(Compare):
|
||||||
try:
|
try:
|
||||||
return float(source_value) >= float(target_value)
|
return float(source_value) >= float(target_value)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
try:
|
|
||||||
return str(source_value) >= str(target_value)
|
|
||||||
except Exception as _:
|
|
||||||
pass
|
|
||||||
return False
|
return False
|
||||||
|
|
@ -8,7 +8,7 @@
|
||||||
"""
|
"""
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from application.flow.compare import Compare
|
from application.flow.step_node.condition_node.compare.compare import Compare
|
||||||
|
|
||||||
|
|
||||||
class GTCompare(Compare):
|
class GTCompare(Compare):
|
||||||
|
|
@ -21,8 +21,4 @@ class GTCompare(Compare):
|
||||||
try:
|
try:
|
||||||
return float(source_value) > float(target_value)
|
return float(source_value) > float(target_value)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
try:
|
|
||||||
return str(source_value) > str(target_value)
|
|
||||||
except Exception as _:
|
|
||||||
pass
|
|
||||||
return False
|
return False
|
||||||
|
|
@ -8,7 +8,7 @@
|
||||||
"""
|
"""
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from application.flow.compare import Compare
|
from application.flow.step_node.condition_node.compare import Compare
|
||||||
|
|
||||||
|
|
||||||
class IsNotNullCompare(Compare):
|
class IsNotNullCompare(Compare):
|
||||||
|
|
@ -8,7 +8,7 @@
|
||||||
"""
|
"""
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from application.flow.compare import Compare
|
from application.flow.step_node.condition_node.compare import Compare
|
||||||
|
|
||||||
|
|
||||||
class IsNotTrueCompare(Compare):
|
class IsNotTrueCompare(Compare):
|
||||||
|
|
@ -8,7 +8,7 @@
|
||||||
"""
|
"""
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from application.flow.compare import Compare
|
from application.flow.step_node.condition_node.compare import Compare
|
||||||
|
|
||||||
|
|
||||||
class IsNullCompare(Compare):
|
class IsNullCompare(Compare):
|
||||||
|
|
@ -18,7 +18,4 @@ class IsNullCompare(Compare):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def compare(self, source_value, compare, target_value):
|
def compare(self, source_value, compare, target_value):
|
||||||
try:
|
return source_value is None or len(source_value) == 0
|
||||||
return source_value is None or len(source_value) == 0
|
|
||||||
except Exception as e:
|
|
||||||
return False
|
|
||||||
|
|
@ -8,7 +8,7 @@
|
||||||
"""
|
"""
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from application.flow.compare import Compare
|
from application.flow.step_node.condition_node.compare import Compare
|
||||||
|
|
||||||
|
|
||||||
class IsTrueCompare(Compare):
|
class IsTrueCompare(Compare):
|
||||||
|
|
@ -8,7 +8,7 @@
|
||||||
"""
|
"""
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from application.flow.compare import Compare
|
from application.flow.step_node.condition_node.compare.compare import Compare
|
||||||
|
|
||||||
|
|
||||||
class LECompare(Compare):
|
class LECompare(Compare):
|
||||||
|
|
@ -21,8 +21,4 @@ class LECompare(Compare):
|
||||||
try:
|
try:
|
||||||
return float(source_value) <= float(target_value)
|
return float(source_value) <= float(target_value)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
try:
|
|
||||||
return str(source_value) <= str(target_value)
|
|
||||||
except Exception as _:
|
|
||||||
pass
|
|
||||||
return False
|
return False
|
||||||
|
|
@ -8,7 +8,7 @@
|
||||||
"""
|
"""
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from application.flow.compare import Compare
|
from application.flow.step_node.condition_node.compare.compare import Compare
|
||||||
|
|
||||||
|
|
||||||
class LenEqualCompare(Compare):
|
class LenEqualCompare(Compare):
|
||||||
|
|
@ -8,7 +8,7 @@
|
||||||
"""
|
"""
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from application.flow.compare import Compare
|
from application.flow.step_node.condition_node.compare.compare import Compare
|
||||||
|
|
||||||
|
|
||||||
class LenGECompare(Compare):
|
class LenGECompare(Compare):
|
||||||
|
|
@ -8,7 +8,7 @@
|
||||||
"""
|
"""
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from application.flow.compare import Compare
|
from application.flow.step_node.condition_node.compare.compare import Compare
|
||||||
|
|
||||||
|
|
||||||
class LenGTCompare(Compare):
|
class LenGTCompare(Compare):
|
||||||
|
|
@ -8,7 +8,7 @@
|
||||||
"""
|
"""
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from application.flow.compare import Compare
|
from application.flow.step_node.condition_node.compare.compare import Compare
|
||||||
|
|
||||||
|
|
||||||
class LenLECompare(Compare):
|
class LenLECompare(Compare):
|
||||||
|
|
@ -8,7 +8,7 @@
|
||||||
"""
|
"""
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from application.flow.compare import Compare
|
from application.flow.step_node.condition_node.compare.compare import Compare
|
||||||
|
|
||||||
|
|
||||||
class LenLTCompare(Compare):
|
class LenLTCompare(Compare):
|
||||||
|
|
@ -8,7 +8,7 @@
|
||||||
"""
|
"""
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from application.flow.compare import Compare
|
from application.flow.step_node.condition_node.compare.compare import Compare
|
||||||
|
|
||||||
|
|
||||||
class LTCompare(Compare):
|
class LTCompare(Compare):
|
||||||
|
|
@ -21,8 +21,4 @@ class LTCompare(Compare):
|
||||||
try:
|
try:
|
||||||
return float(source_value) < float(target_value)
|
return float(source_value) < float(target_value)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
try:
|
|
||||||
return str(source_value) < str(target_value)
|
|
||||||
except Exception as _:
|
|
||||||
pass
|
|
||||||
return False
|
return False
|
||||||
|
|
@ -8,7 +8,7 @@
|
||||||
"""
|
"""
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from application.flow.compare import Compare
|
from application.flow.step_node.condition_node.compare.compare import Compare
|
||||||
|
|
||||||
|
|
||||||
class NotContainCompare(Compare):
|
class NotContainCompare(Compare):
|
||||||
|
|
@ -20,7 +20,4 @@ class NotContainCompare(Compare):
|
||||||
def compare(self, source_value, compare, target_value):
|
def compare(self, source_value, compare, target_value):
|
||||||
if isinstance(source_value, str):
|
if isinstance(source_value, str):
|
||||||
return str(target_value) not in source_value
|
return str(target_value) not in source_value
|
||||||
elif isinstance(self, list):
|
return not any([str(item) == str(target_value) for item in source_value])
|
||||||
return not any([str(item) == str(target_value) for item in source_value])
|
|
||||||
else:
|
|
||||||
return str(target_value) not in str(source_value)
|
|
||||||
|
|
@ -11,20 +11,20 @@ from typing import Type
|
||||||
from django.utils.translation import gettext_lazy as _
|
from django.utils.translation import gettext_lazy as _
|
||||||
from rest_framework import serializers
|
from rest_framework import serializers
|
||||||
|
|
||||||
from application.flow.common import WorkflowMode
|
|
||||||
from application.flow.i_step_node import INode
|
from application.flow.i_step_node import INode
|
||||||
|
from common.util.field_message import ErrMessage
|
||||||
|
|
||||||
|
|
||||||
class ConditionSerializer(serializers.Serializer):
|
class ConditionSerializer(serializers.Serializer):
|
||||||
compare = serializers.CharField(required=True, label=_("Comparator"))
|
compare = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Comparator")))
|
||||||
value = serializers.CharField(required=True, label=_("value"))
|
value = serializers.CharField(required=True, error_messages=ErrMessage.char(_("value")))
|
||||||
field = serializers.ListField(required=True, label=_("Fields"))
|
field = serializers.ListField(required=True, error_messages=ErrMessage.char(_("Fields")))
|
||||||
|
|
||||||
|
|
||||||
class ConditionBranchSerializer(serializers.Serializer):
|
class ConditionBranchSerializer(serializers.Serializer):
|
||||||
id = serializers.CharField(required=True, label=_("Branch id"))
|
id = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Branch id")))
|
||||||
type = serializers.CharField(required=True, label=_("Branch Type"))
|
type = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Branch Type")))
|
||||||
condition = serializers.CharField(required=True, label=_("Condition or|and"))
|
condition = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Condition or|and")))
|
||||||
conditions = ConditionSerializer(many=True)
|
conditions = ConditionSerializer(many=True)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -37,5 +37,3 @@ class IConditionNode(INode):
|
||||||
return ConditionNodeParamsSerializer
|
return ConditionNodeParamsSerializer
|
||||||
|
|
||||||
type = 'condition-node'
|
type = 'condition-node'
|
||||||
|
|
||||||
support = [WorkflowMode.APPLICATION, WorkflowMode.APPLICATION_LOOP, WorkflowMode.KNOWLEDGE, WorkflowMode.KNOWLEDGE_LOOP]
|
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from application.flow.i_step_node import NodeResult
|
from application.flow.i_step_node import NodeResult
|
||||||
from application.flow.compare import compare_handle_list
|
from application.flow.step_node.condition_node.compare import compare_handle_list
|
||||||
from application.flow.step_node.condition_node.i_condition_node import IConditionNode
|
from application.flow.step_node.condition_node.i_condition_node import IConditionNode
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,8 +0,0 @@
|
||||||
# coding=utf-8
|
|
||||||
"""
|
|
||||||
@project: MaxKB
|
|
||||||
@Author:虎虎
|
|
||||||
@file: __init__.py.py
|
|
||||||
@date:2025/11/11 10:06
|
|
||||||
@desc:
|
|
||||||
"""
|
|
||||||
|
|
@ -1,42 +0,0 @@
|
||||||
# coding=utf-8
|
|
||||||
"""
|
|
||||||
@project: MaxKB
|
|
||||||
@Author:虎虎
|
|
||||||
@file: i_data_source_local_node.py
|
|
||||||
@date:2025/11/11 10:06
|
|
||||||
@desc:
|
|
||||||
"""
|
|
||||||
from abc import abstractmethod
|
|
||||||
from typing import Type
|
|
||||||
|
|
||||||
from django.utils.translation import gettext_lazy as _
|
|
||||||
from rest_framework import serializers
|
|
||||||
|
|
||||||
from application.flow.common import WorkflowMode
|
|
||||||
from application.flow.i_step_node import INode, NodeResult
|
|
||||||
|
|
||||||
|
|
||||||
class DataSourceLocalNodeParamsSerializer(serializers.Serializer):
|
|
||||||
file_type_list = serializers.ListField(child=serializers.CharField(label=('')), label='')
|
|
||||||
file_size_limit = serializers.IntegerField(required=True, label=_("Number of uploaded files"))
|
|
||||||
file_count_limit = serializers.IntegerField(required=True, label=_("Upload file size"))
|
|
||||||
|
|
||||||
|
|
||||||
class IDataSourceLocalNode(INode):
|
|
||||||
type = 'data-source-local-node'
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@abstractmethod
|
|
||||||
def get_form_list(node):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
|
|
||||||
return DataSourceLocalNodeParamsSerializer
|
|
||||||
|
|
||||||
def _run(self):
|
|
||||||
return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data)
|
|
||||||
|
|
||||||
def execute(self, file_type_list, file_size_limit, file_count_limit, **kwargs) -> NodeResult:
|
|
||||||
pass
|
|
||||||
|
|
||||||
support = [WorkflowMode.KNOWLEDGE]
|
|
||||||
|
|
@ -1,8 +0,0 @@
|
||||||
# coding=utf-8
|
|
||||||
"""
|
|
||||||
@project: MaxKB
|
|
||||||
@Author:虎虎
|
|
||||||
@file: __init__.py.py
|
|
||||||
@date:2025/11/11 10:08
|
|
||||||
@desc:
|
|
||||||
"""
|
|
||||||
|
|
@ -1,51 +0,0 @@
|
||||||
# coding=utf-8
|
|
||||||
"""
|
|
||||||
@project: MaxKB
|
|
||||||
@Author:虎虎
|
|
||||||
@file: base_data_source_local_node.py
|
|
||||||
@date:2025/11/11 10:30
|
|
||||||
@desc:
|
|
||||||
"""
|
|
||||||
from application.flow.i_step_node import NodeResult
|
|
||||||
from application.flow.step_node.data_source_local_node.i_data_source_local_node import IDataSourceLocalNode
|
|
||||||
from common import forms
|
|
||||||
from common.forms import BaseForm
|
|
||||||
|
|
||||||
|
|
||||||
class BaseDataSourceLocalNodeForm(BaseForm):
|
|
||||||
api_key = forms.PasswordInputField('API Key', required=True)
|
|
||||||
|
|
||||||
|
|
||||||
class BaseDataSourceLocalNode(IDataSourceLocalNode):
|
|
||||||
def save_context(self, details, workflow_manage):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_form_list(node):
|
|
||||||
node_data = node.get('properties').get('node_data')
|
|
||||||
return [{
|
|
||||||
'field': 'file_list',
|
|
||||||
'input_type': 'LocalFileUpload',
|
|
||||||
'attrs': {
|
|
||||||
'file_count_limit': node_data.get('file_count_limit') or 10,
|
|
||||||
'file_size_limit': node_data.get('file_size_limit') or 100,
|
|
||||||
'file_type_list': node_data.get('file_type_list'),
|
|
||||||
},
|
|
||||||
'label': '',
|
|
||||||
}]
|
|
||||||
|
|
||||||
def execute(self, file_type_list, file_size_limit, file_count_limit, **kwargs) -> NodeResult:
|
|
||||||
return NodeResult({'file_list': self.workflow_manage.params.get('data_source', {}).get('file_list')},
|
|
||||||
self.workflow_manage.params.get('knowledge_base') or {})
|
|
||||||
|
|
||||||
def get_details(self, index: int, **kwargs):
|
|
||||||
return {
|
|
||||||
'name': self.node.properties.get('stepName'),
|
|
||||||
"index": index,
|
|
||||||
'run_time': self.context.get('run_time'),
|
|
||||||
'type': self.node.type,
|
|
||||||
'file_list': self.context.get('file_list'),
|
|
||||||
'knowledge_base': self.workflow_params.get('knowledge_base'),
|
|
||||||
'status': self.status,
|
|
||||||
'err_message': self.err_message
|
|
||||||
}
|
|
||||||
|
|
@ -1,28 +0,0 @@
|
||||||
# coding=utf-8
|
|
||||||
"""
|
|
||||||
@project: MaxKB
|
|
||||||
@Author:niu
|
|
||||||
@file: i_data_source_web_node.py
|
|
||||||
@date:2025/11/12 13:47
|
|
||||||
@desc:
|
|
||||||
"""
|
|
||||||
from abc import abstractmethod
|
|
||||||
|
|
||||||
from application.flow.common import WorkflowMode
|
|
||||||
from application.flow.i_step_node import INode, NodeResult
|
|
||||||
|
|
||||||
|
|
||||||
class IDataSourceWebNode(INode):
|
|
||||||
type = 'data-source-web-node'
|
|
||||||
support = [WorkflowMode.KNOWLEDGE]
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@abstractmethod
|
|
||||||
def get_form_list(node):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def _run(self):
|
|
||||||
return self.execute(**self.flow_params_serializer.data)
|
|
||||||
|
|
||||||
def execute(self, **kwargs) -> NodeResult:
|
|
||||||
pass
|
|
||||||
|
|
@ -1,8 +0,0 @@
|
||||||
# coding=utf-8
|
|
||||||
"""
|
|
||||||
@project: MaxKB
|
|
||||||
@Author:niu
|
|
||||||
@file: __init__.py
|
|
||||||
@date:2025/11/12 13:44
|
|
||||||
@desc:
|
|
||||||
"""
|
|
||||||
|
|
@ -1,86 +0,0 @@
|
||||||
# coding=utf-8
|
|
||||||
"""
|
|
||||||
@project: MaxKB
|
|
||||||
@Author:niu
|
|
||||||
@file: base_data_source_web_node.py
|
|
||||||
@date:2025/11/12 13:47
|
|
||||||
@desc:
|
|
||||||
"""
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
from django.utils.translation import gettext_lazy as _
|
|
||||||
|
|
||||||
from application.flow.i_step_node import NodeResult
|
|
||||||
from application.flow.step_node.data_source_web_node.i_data_source_web_node import IDataSourceWebNode
|
|
||||||
from common import forms
|
|
||||||
from common.forms import BaseForm
|
|
||||||
from common.utils.fork import ForkManage, Fork, ChildLink
|
|
||||||
from common.utils.logger import maxkb_logger
|
|
||||||
|
|
||||||
|
|
||||||
class BaseDataSourceWebNodeForm(BaseForm):
|
|
||||||
source_url = forms.TextInputField(_('Web source url'), required=True, attrs={
|
|
||||||
'placeholder': _('Please enter the Web root address')})
|
|
||||||
selector = forms.TextInputField(_('Web knowledge selector'), required=False, attrs={
|
|
||||||
'placeholder': _('The default is body, you can enter .classname/#idname/tagname')})
|
|
||||||
|
|
||||||
|
|
||||||
def get_collect_handler():
|
|
||||||
results = []
|
|
||||||
|
|
||||||
def handler(child_link: ChildLink, response: Fork.Response):
|
|
||||||
if response.status == 200:
|
|
||||||
try:
|
|
||||||
document_name = child_link.tag.text if child_link.tag is not None and len(
|
|
||||||
child_link.tag.text.strip()) > 0 else child_link.url
|
|
||||||
results.append({
|
|
||||||
"name": document_name.strip(),
|
|
||||||
"content": response.content,
|
|
||||||
})
|
|
||||||
except Exception as e:
|
|
||||||
maxkb_logger.error(f'{str(e)}:{traceback.format_exc()}')
|
|
||||||
|
|
||||||
return handler, results
|
|
||||||
|
|
||||||
|
|
||||||
class BaseDataSourceWebNode(IDataSourceWebNode):
|
|
||||||
def save_context(self, details, workflow_manage):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_form_list(node):
|
|
||||||
return BaseDataSourceWebNodeForm().to_form_list()
|
|
||||||
|
|
||||||
def execute(self, **kwargs) -> NodeResult:
|
|
||||||
BaseDataSourceWebNodeForm().valid_form(self.workflow_params.get("data_source"))
|
|
||||||
|
|
||||||
data_source = self.workflow_params.get("data_source")
|
|
||||||
|
|
||||||
node_id = data_source.get("node_id")
|
|
||||||
source_url = data_source.get("source_url")
|
|
||||||
selector = data_source.get("selector") or "body"
|
|
||||||
|
|
||||||
collect_handler, document_list = get_collect_handler()
|
|
||||||
|
|
||||||
try:
|
|
||||||
ForkManage(source_url, selector.split(" ") if selector is not None else []).fork(3, set(), collect_handler)
|
|
||||||
|
|
||||||
return NodeResult({'document_list': document_list,'source_url': source_url, 'selector': selector},
|
|
||||||
self.workflow_manage.params.get('knowledge_base') or {})
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
maxkb_logger.error(_('data source web node:{node_id} error{error}{traceback}').format(
|
|
||||||
knowledge_id=node_id, error=str(e), traceback=traceback.format_exc()))
|
|
||||||
|
|
||||||
def get_details(self, index: int, **kwargs):
|
|
||||||
return {
|
|
||||||
'name': self.node.properties.get('stepName'),
|
|
||||||
"index": index,
|
|
||||||
'run_time': self.context.get('run_time'),
|
|
||||||
'type': self.node.type,
|
|
||||||
'input_params': {"source_url": self.context.get("source_url"), "selector": self.context.get('selector')},
|
|
||||||
'output_params': self.context.get('document_list'),
|
|
||||||
'knowledge_base': self.workflow_params.get('knowledge_base'),
|
|
||||||
'status': self.status,
|
|
||||||
'err_message': self.err_message
|
|
||||||
}
|
|
||||||
|
|
@ -10,20 +10,18 @@ from typing import Type
|
||||||
|
|
||||||
from rest_framework import serializers
|
from rest_framework import serializers
|
||||||
|
|
||||||
from application.flow.common import WorkflowMode
|
|
||||||
from application.flow.i_step_node import INode, NodeResult
|
from application.flow.i_step_node import INode, NodeResult
|
||||||
from common.exception.app_exception import AppApiException
|
from common.exception.app_exception import AppApiException
|
||||||
|
from common.util.field_message import ErrMessage
|
||||||
from django.utils.translation import gettext_lazy as _
|
from django.utils.translation import gettext_lazy as _
|
||||||
|
|
||||||
|
|
||||||
class ReplyNodeParamsSerializer(serializers.Serializer):
|
class ReplyNodeParamsSerializer(serializers.Serializer):
|
||||||
reply_type = serializers.CharField(required=True, label=_("Response Type"))
|
reply_type = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Response Type")))
|
||||||
fields = serializers.ListField(required=False, label=_("Reference Field"))
|
fields = serializers.ListField(required=False, error_messages=ErrMessage.list(_("Reference Field")))
|
||||||
content = serializers.CharField(required=False, allow_blank=True, allow_null=True,
|
content = serializers.CharField(required=False, allow_blank=True, allow_null=True,
|
||||||
label=_("Direct answer content"))
|
error_messages=ErrMessage.char(_("Direct answer content")))
|
||||||
is_result = serializers.BooleanField(required=False,
|
is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean(_('Whether to return content')))
|
||||||
label=_('Whether to return content'))
|
|
||||||
|
|
||||||
def is_valid(self, *, raise_exception=False):
|
def is_valid(self, *, raise_exception=False):
|
||||||
super().is_valid(raise_exception=True)
|
super().is_valid(raise_exception=True)
|
||||||
|
|
@ -39,19 +37,12 @@ class ReplyNodeParamsSerializer(serializers.Serializer):
|
||||||
|
|
||||||
class IReplyNode(INode):
|
class IReplyNode(INode):
|
||||||
type = 'reply-node'
|
type = 'reply-node'
|
||||||
support = [WorkflowMode.APPLICATION, WorkflowMode.APPLICATION_LOOP, WorkflowMode.KNOWLEDGE_LOOP,
|
|
||||||
WorkflowMode.KNOWLEDGE]
|
|
||||||
|
|
||||||
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
|
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
|
||||||
return ReplyNodeParamsSerializer
|
return ReplyNodeParamsSerializer
|
||||||
|
|
||||||
def _run(self):
|
def _run(self):
|
||||||
if [WorkflowMode.KNOWLEDGE, WorkflowMode.KNOWLEDGE_LOOP].__contains__(
|
return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data)
|
||||||
self.workflow_manage.flow.workflow_mode):
|
|
||||||
return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data,
|
|
||||||
**{'stream': True})
|
|
||||||
else:
|
|
||||||
return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data)
|
|
||||||
|
|
||||||
def execute(self, reply_type, stream, fields=None, content=None, **kwargs) -> NodeResult:
|
def execute(self, reply_type, stream, fields=None, content=None, **kwargs) -> NodeResult:
|
||||||
pass
|
pass
|
||||||
|
|
|
||||||
|
|
@ -5,18 +5,17 @@ from typing import Type
|
||||||
from django.utils.translation import gettext_lazy as _
|
from django.utils.translation import gettext_lazy as _
|
||||||
from rest_framework import serializers
|
from rest_framework import serializers
|
||||||
|
|
||||||
from application.flow.common import WorkflowMode
|
|
||||||
from application.flow.i_step_node import INode, NodeResult
|
from application.flow.i_step_node import INode, NodeResult
|
||||||
|
from common.util.field_message import ErrMessage
|
||||||
|
|
||||||
|
|
||||||
class DocumentExtractNodeSerializer(serializers.Serializer):
|
class DocumentExtractNodeSerializer(serializers.Serializer):
|
||||||
document_list = serializers.ListField(required=False, label=_("document"))
|
document_list = serializers.ListField(required=False, error_messages=ErrMessage.list(_("document")))
|
||||||
|
|
||||||
|
|
||||||
class IDocumentExtractNode(INode):
|
class IDocumentExtractNode(INode):
|
||||||
type = 'document-extract-node'
|
type = 'document-extract-node'
|
||||||
support = [WorkflowMode.APPLICATION, WorkflowMode.APPLICATION_LOOP, WorkflowMode.KNOWLEDGE_LOOP,
|
|
||||||
WorkflowMode.KNOWLEDGE]
|
|
||||||
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
|
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
|
||||||
return DocumentExtractNodeSerializer
|
return DocumentExtractNodeSerializer
|
||||||
|
|
||||||
|
|
@ -25,5 +24,5 @@ class IDocumentExtractNode(INode):
|
||||||
self.node_params_serializer.data.get('document_list')[1:])
|
self.node_params_serializer.data.get('document_list')[1:])
|
||||||
return self.execute(document=res, **self.flow_params_serializer.data)
|
return self.execute(document=res, **self.flow_params_serializer.data)
|
||||||
|
|
||||||
def execute(self, document, chat_id=None, **kwargs) -> NodeResult:
|
def execute(self, document, chat_id, **kwargs) -> NodeResult:
|
||||||
pass
|
pass
|
||||||
|
|
|
||||||
|
|
@ -1,64 +1,72 @@
|
||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
import ast
|
|
||||||
import io
|
import io
|
||||||
|
import mimetypes
|
||||||
|
|
||||||
import uuid_utils.compat as uuid
|
from django.core.files.uploadedfile import InMemoryUploadedFile
|
||||||
from django.db.models import QuerySet
|
from django.db.models import QuerySet
|
||||||
|
|
||||||
from application.flow.i_step_node import NodeResult
|
from application.flow.i_step_node import NodeResult
|
||||||
from application.flow.step_node.document_extract_node.i_document_extract_node import IDocumentExtractNode
|
from application.flow.step_node.document_extract_node.i_document_extract_node import IDocumentExtractNode
|
||||||
from knowledge.models import File, FileSourceType
|
from dataset.models import File
|
||||||
from knowledge.serializers.document import split_handles, parse_table_handle_list, FileBufferHandle
|
from dataset.serializers.document_serializers import split_handles, parse_table_handle_list, FileBufferHandle
|
||||||
|
from dataset.serializers.file_serializers import FileSerializer
|
||||||
|
|
||||||
|
|
||||||
|
def bytes_to_uploaded_file(file_bytes, file_name="file.txt"):
|
||||||
|
content_type, _ = mimetypes.guess_type(file_name)
|
||||||
|
if content_type is None:
|
||||||
|
# 如果未能识别,设置为默认的二进制文件类型
|
||||||
|
content_type = "application/octet-stream"
|
||||||
|
# 创建一个内存中的字节流对象
|
||||||
|
file_stream = io.BytesIO(file_bytes)
|
||||||
|
|
||||||
|
# 获取文件大小
|
||||||
|
file_size = len(file_bytes)
|
||||||
|
|
||||||
|
# 创建 InMemoryUploadedFile 对象
|
||||||
|
uploaded_file = InMemoryUploadedFile(
|
||||||
|
file=file_stream,
|
||||||
|
field_name=None,
|
||||||
|
name=file_name,
|
||||||
|
content_type=content_type,
|
||||||
|
size=file_size,
|
||||||
|
charset=None,
|
||||||
|
)
|
||||||
|
return uploaded_file
|
||||||
|
|
||||||
|
|
||||||
splitter = '\n`-----------------------------------`\n'
|
splitter = '\n`-----------------------------------`\n'
|
||||||
|
|
||||||
|
|
||||||
class BaseDocumentExtractNode(IDocumentExtractNode):
|
class BaseDocumentExtractNode(IDocumentExtractNode):
|
||||||
def save_context(self, details, workflow_manage):
|
def save_context(self, details, workflow_manage):
|
||||||
self.context['content'] = details.get('content')
|
self.context['content'] = details.get('content')
|
||||||
|
|
||||||
def execute(self, document, chat_id=None, **kwargs):
|
|
||||||
|
def execute(self, document, chat_id, **kwargs):
|
||||||
get_buffer = FileBufferHandle().get_buffer
|
get_buffer = FileBufferHandle().get_buffer
|
||||||
|
|
||||||
self.context['document_list'] = document
|
self.context['document_list'] = document
|
||||||
content = []
|
content = []
|
||||||
if document is None or not isinstance(document, list):
|
if document is None or not isinstance(document, list):
|
||||||
return NodeResult({'content': '', 'document_list': []}, {})
|
return NodeResult({'content': ''}, {})
|
||||||
|
|
||||||
# 安全获取 application
|
application = self.workflow_manage.work_flow_post_handler.chat_info.application
|
||||||
application_id = None
|
|
||||||
if (self.workflow_manage and
|
|
||||||
self.workflow_manage.work_flow_post_handler and
|
|
||||||
self.workflow_manage.work_flow_post_handler.chat_info):
|
|
||||||
application_id = self.workflow_manage.work_flow_post_handler.chat_info.application.id
|
|
||||||
knowledge_id = self.workflow_params.get('knowledge_id')
|
|
||||||
|
|
||||||
# doc文件中的图片保存
|
# doc文件中的图片保存
|
||||||
def save_image(image_list):
|
def save_image(image_list):
|
||||||
for image in image_list:
|
for image in image_list:
|
||||||
meta = {
|
meta = {
|
||||||
'debug': False if (application_id or knowledge_id) else True,
|
'debug': False if application.id else True,
|
||||||
'chat_id': chat_id,
|
'chat_id': chat_id,
|
||||||
'application_id': str(application_id) if application_id else None,
|
'application_id': str(application.id) if application.id else None,
|
||||||
'knowledge_id': str(knowledge_id) if knowledge_id else None,
|
|
||||||
'file_id': str(image.id)
|
'file_id': str(image.id)
|
||||||
}
|
}
|
||||||
file_bytes = image.meta.pop('content')
|
file = bytes_to_uploaded_file(image.image, image.image_name)
|
||||||
new_file = File(
|
FileSerializer(data={'file': file, 'meta': meta}).upload()
|
||||||
id=meta['file_id'],
|
|
||||||
file_name=image.file_name,
|
|
||||||
file_size=len(file_bytes),
|
|
||||||
source_type=FileSourceType.APPLICATION.value if meta[
|
|
||||||
'application_id'] else FileSourceType.KNOWLEDGE.value,
|
|
||||||
source_id=meta['application_id'] if meta['application_id'] else meta['knowledge_id'],
|
|
||||||
meta=meta
|
|
||||||
)
|
|
||||||
new_file.save(file_bytes)
|
|
||||||
|
|
||||||
document_list = []
|
|
||||||
for doc in document:
|
for doc in document:
|
||||||
file = QuerySet(File).filter(id=doc['file_id']).first()
|
file = QuerySet(File).filter(id=doc['file_id']).first()
|
||||||
buffer = io.BytesIO(file.get_bytes())
|
buffer = io.BytesIO(file.get_byte().tobytes())
|
||||||
buffer.name = doc['name'] # this is the important line
|
buffer.name = doc['name'] # this is the important line
|
||||||
|
|
||||||
for split_handle in (parse_table_handle_list + split_handles):
|
for split_handle in (parse_table_handle_list + split_handles):
|
||||||
|
|
@ -67,10 +75,9 @@ class BaseDocumentExtractNode(IDocumentExtractNode):
|
||||||
buffer.seek(0)
|
buffer.seek(0)
|
||||||
file_content = split_handle.get_content(buffer, save_image)
|
file_content = split_handle.get_content(buffer, save_image)
|
||||||
content.append('### ' + doc['name'] + '\n' + file_content)
|
content.append('### ' + doc['name'] + '\n' + file_content)
|
||||||
document_list.append({'id': str(file.id), 'name': doc['name'], 'content': file_content})
|
|
||||||
break
|
break
|
||||||
|
|
||||||
return NodeResult({'content': splitter.join(content), 'document_list': document_list}, {})
|
return NodeResult({'content': splitter.join(content)}, {})
|
||||||
|
|
||||||
def get_details(self, index: int, **kwargs):
|
def get_details(self, index: int, **kwargs):
|
||||||
content = self.context.get('content', '').split(splitter)
|
content = self.context.get('content', '').split(splitter)
|
||||||
|
|
|
||||||
|
|
@ -1 +0,0 @@
|
||||||
from .impl import *
|
|
||||||
|
|
@ -1,90 +0,0 @@
|
||||||
# coding=utf-8
|
|
||||||
|
|
||||||
from typing import Type
|
|
||||||
|
|
||||||
from django.utils.translation import gettext_lazy as _
|
|
||||||
from rest_framework import serializers
|
|
||||||
|
|
||||||
from application.flow.common import WorkflowMode
|
|
||||||
from application.flow.i_step_node import INode, NodeResult
|
|
||||||
|
|
||||||
|
|
||||||
class DocumentSplitNodeSerializer(serializers.Serializer):
|
|
||||||
document_list = serializers.ListField(required=False, label=_("document list"))
|
|
||||||
split_strategy = serializers.ChoiceField(
|
|
||||||
choices=['auto', 'custom', 'qa'], required=False, label=_("split strategy"), default='auto'
|
|
||||||
)
|
|
||||||
paragraph_title_relate_problem_type = serializers.ChoiceField(
|
|
||||||
choices=['custom', 'referencing'], required=False, label=_("paragraph title relate problem type"),
|
|
||||||
default='custom'
|
|
||||||
)
|
|
||||||
paragraph_title_relate_problem = serializers.BooleanField(
|
|
||||||
required=False, label=_("paragraph title relate problem"), default=False
|
|
||||||
)
|
|
||||||
paragraph_title_relate_problem_reference = serializers.ListField(
|
|
||||||
required=False, label=_("paragraph title relate problem reference"), child=serializers.CharField(), default=[]
|
|
||||||
)
|
|
||||||
document_name_relate_problem_type = serializers.ChoiceField(
|
|
||||||
choices=['custom', 'referencing'], required=False, label=_("document name relate problem type"),
|
|
||||||
default='custom'
|
|
||||||
)
|
|
||||||
document_name_relate_problem = serializers.BooleanField(
|
|
||||||
required=False, label=_("document name relate problem"), default=False
|
|
||||||
)
|
|
||||||
document_name_relate_problem_reference = serializers.ListField(
|
|
||||||
required=False, label=_("document name relate problem reference"), child=serializers.CharField(), default=[]
|
|
||||||
)
|
|
||||||
limit = serializers.IntegerField(required=False, label=_("limit"), default=4096)
|
|
||||||
limit_type = serializers.ChoiceField(
|
|
||||||
choices=['custom', 'referencing'], required=False, label=_("document name relate problem type"),
|
|
||||||
default='custom'
|
|
||||||
)
|
|
||||||
limit_reference = serializers.ListField(
|
|
||||||
required=False, label=_("limit reference"), child=serializers.CharField(), default=[]
|
|
||||||
)
|
|
||||||
chunk_size = serializers.IntegerField(required=False, label=_("chunk size"), default=256)
|
|
||||||
chunk_size_type = serializers.ChoiceField(
|
|
||||||
choices=['custom', 'referencing'], required=False, label=_("chunk size type"), default='custom'
|
|
||||||
)
|
|
||||||
chunk_size_reference = serializers.ListField(
|
|
||||||
required=False, label=_("chunk size reference"), child=serializers.CharField(), default=[]
|
|
||||||
)
|
|
||||||
patterns = serializers.ListField(
|
|
||||||
required=False, label=_("patterns"), child=serializers.CharField(), default=[]
|
|
||||||
)
|
|
||||||
patterns_type = serializers.ChoiceField(
|
|
||||||
choices=['custom', 'referencing'], required=False, label=_("patterns type"), default='custom'
|
|
||||||
)
|
|
||||||
patterns_reference = serializers.ListField(
|
|
||||||
required=False, label=_("patterns reference"), child=serializers.CharField(), default=[]
|
|
||||||
)
|
|
||||||
with_filter = serializers.BooleanField(
|
|
||||||
required=False, label=_("with filter"), default=False
|
|
||||||
)
|
|
||||||
with_filter_type = serializers.ChoiceField(
|
|
||||||
choices=['custom', 'referencing'], required=False, label=_("with filter type"), default='custom'
|
|
||||||
)
|
|
||||||
with_filter_reference = serializers.ListField(
|
|
||||||
required=False, label=_("with filter reference"), child=serializers.CharField(), default=[]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class IDocumentSplitNode(INode):
|
|
||||||
type = 'document-split-node'
|
|
||||||
support = [
|
|
||||||
WorkflowMode.APPLICATION, WorkflowMode.APPLICATION_LOOP, WorkflowMode.KNOWLEDGE_LOOP, WorkflowMode.KNOWLEDGE
|
|
||||||
]
|
|
||||||
|
|
||||||
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
|
|
||||||
return DocumentSplitNodeSerializer
|
|
||||||
|
|
||||||
def _run(self):
|
|
||||||
return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data)
|
|
||||||
|
|
||||||
def execute(self, document_list, knowledge_id, split_strategy, paragraph_title_relate_problem_type,
|
|
||||||
paragraph_title_relate_problem, paragraph_title_relate_problem_reference,
|
|
||||||
document_name_relate_problem_type, document_name_relate_problem,
|
|
||||||
document_name_relate_problem_reference, limit, limit_type, limit_reference, chunk_size, chunk_size_type,
|
|
||||||
chunk_size_reference, patterns, patterns_type, patterns_reference, with_filter, with_filter_type,
|
|
||||||
with_filter_reference, **kwargs) -> NodeResult:
|
|
||||||
pass
|
|
||||||
|
|
@ -1 +0,0 @@
|
||||||
from .base_document_split_node import BaseDocumentSplitNode
|
|
||||||
|
|
@ -1,178 +0,0 @@
|
||||||
# coding=utf-8
|
|
||||||
import io
|
|
||||||
import mimetypes
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from django.core.files.uploadedfile import InMemoryUploadedFile
|
|
||||||
|
|
||||||
from application.flow.i_step_node import NodeResult
|
|
||||||
from application.flow.step_node.document_split_node.i_document_split_node import IDocumentSplitNode
|
|
||||||
from common.chunk import text_to_chunk
|
|
||||||
from knowledge.serializers.document import default_split_handle, FileBufferHandle, md_qa_split_handle
|
|
||||||
|
|
||||||
|
|
||||||
def bytes_to_uploaded_file(file_bytes, file_name="file.txt"):
|
|
||||||
if file_name.startswith("http"):
|
|
||||||
file_name = "file.txt"
|
|
||||||
content_type, _ = mimetypes.guess_type(file_name)
|
|
||||||
if content_type is None:
|
|
||||||
# 如果未能识别,设置为默认的二进制文件类型
|
|
||||||
content_type = "application/octet-stream"
|
|
||||||
# 创建一个内存中的字节流对象
|
|
||||||
file_stream = io.BytesIO(file_bytes)
|
|
||||||
|
|
||||||
# 获取文件大小
|
|
||||||
file_size = len(file_bytes)
|
|
||||||
|
|
||||||
# 创建 InMemoryUploadedFile 对象
|
|
||||||
uploaded_file = InMemoryUploadedFile(
|
|
||||||
file=file_stream,
|
|
||||||
field_name=None,
|
|
||||||
name=file_name,
|
|
||||||
content_type=content_type,
|
|
||||||
size=file_size,
|
|
||||||
charset=None,
|
|
||||||
)
|
|
||||||
return uploaded_file
|
|
||||||
|
|
||||||
|
|
||||||
class BaseDocumentSplitNode(IDocumentSplitNode):
|
|
||||||
def save_context(self, details, workflow_manage):
|
|
||||||
self.context['content'] = details.get('content')
|
|
||||||
|
|
||||||
def get_reference_content(self, fields: List[str]):
|
|
||||||
return self.workflow_manage.get_reference_field(fields[0], fields[1:])
|
|
||||||
|
|
||||||
def execute(self, document_list, knowledge_id, split_strategy, paragraph_title_relate_problem_type,
|
|
||||||
paragraph_title_relate_problem, paragraph_title_relate_problem_reference,
|
|
||||||
document_name_relate_problem_type, document_name_relate_problem,
|
|
||||||
document_name_relate_problem_reference, limit, limit_type, limit_reference, chunk_size, chunk_size_type,
|
|
||||||
chunk_size_reference, patterns, patterns_type, patterns_reference, with_filter, with_filter_type,
|
|
||||||
with_filter_reference, **kwargs) -> NodeResult:
|
|
||||||
self.context['knowledge_id'] = knowledge_id
|
|
||||||
file_list = self.get_reference_content(document_list)
|
|
||||||
|
|
||||||
# 处理引用类型的参数
|
|
||||||
if patterns_type == 'referencing':
|
|
||||||
patterns = self.get_reference_content(patterns_reference)
|
|
||||||
if limit_type == 'referencing':
|
|
||||||
limit = self.get_reference_content(limit_reference)
|
|
||||||
if chunk_size_type == 'referencing':
|
|
||||||
chunk_size = self.get_reference_content(chunk_size_reference)
|
|
||||||
if with_filter_type == 'referencing':
|
|
||||||
with_filter = self.get_reference_content(with_filter_reference)
|
|
||||||
|
|
||||||
paragraph_list = []
|
|
||||||
for doc in file_list:
|
|
||||||
get_buffer = FileBufferHandle().get_buffer
|
|
||||||
|
|
||||||
file_mem = bytes_to_uploaded_file(doc['content'].encode('utf-8'), doc['name'])
|
|
||||||
if split_strategy == 'qa':
|
|
||||||
result = md_qa_split_handle.handle(file_mem, get_buffer, self._save_image)
|
|
||||||
else:
|
|
||||||
result = default_split_handle.handle(file_mem, patterns, with_filter, limit, get_buffer,
|
|
||||||
self._save_image)
|
|
||||||
# 统一处理结果为列表
|
|
||||||
results = result if isinstance(result, list) else [result]
|
|
||||||
|
|
||||||
for item in results:
|
|
||||||
self._process_split_result(
|
|
||||||
item, knowledge_id, doc.get('id'), doc.get('name'),
|
|
||||||
split_strategy, paragraph_title_relate_problem_type,
|
|
||||||
paragraph_title_relate_problem, paragraph_title_relate_problem_reference,
|
|
||||||
document_name_relate_problem_type, document_name_relate_problem,
|
|
||||||
document_name_relate_problem_reference, chunk_size
|
|
||||||
)
|
|
||||||
|
|
||||||
paragraph_list += results
|
|
||||||
|
|
||||||
self.context['paragraph_list'] = paragraph_list
|
|
||||||
self.context['document_list'] = file_list
|
|
||||||
self.context['limit'] = limit
|
|
||||||
self.context['chunk_size'] = chunk_size
|
|
||||||
self.context['with_filter'] = with_filter
|
|
||||||
self.context['patterns'] = patterns
|
|
||||||
self.context['split_strategy'] = split_strategy
|
|
||||||
|
|
||||||
return NodeResult({'paragraph_list': paragraph_list}, {})
|
|
||||||
|
|
||||||
def _save_image(self, image_list):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def _process_split_result(
|
|
||||||
self, item, knowledge_id, source_file_id, file_name,
|
|
||||||
split_strategy, paragraph_title_relate_problem_type,
|
|
||||||
paragraph_title_relate_problem, paragraph_title_relate_problem_reference,
|
|
||||||
document_name_relate_problem_type, document_name_relate_problem,
|
|
||||||
document_name_relate_problem_reference, chunk_size
|
|
||||||
):
|
|
||||||
"""处理文档分割结果"""
|
|
||||||
item['meta'] = {
|
|
||||||
'knowledge_id': knowledge_id,
|
|
||||||
'source_file_id': source_file_id,
|
|
||||||
'source_url': file_name,
|
|
||||||
}
|
|
||||||
if item.get('name', 'file.txt') == 'file.txt':
|
|
||||||
item['name'] = file_name
|
|
||||||
item['source_file_id'] = source_file_id
|
|
||||||
item['paragraphs'] = item.pop('content', item.get('paragraphs', []))
|
|
||||||
|
|
||||||
for paragraph in item['paragraphs']:
|
|
||||||
paragraph['problem_list'] = self._generate_problem_list(
|
|
||||||
paragraph, file_name,
|
|
||||||
split_strategy, paragraph_title_relate_problem_type,
|
|
||||||
paragraph_title_relate_problem, paragraph_title_relate_problem_reference,
|
|
||||||
document_name_relate_problem_type, document_name_relate_problem,
|
|
||||||
document_name_relate_problem_reference
|
|
||||||
)
|
|
||||||
paragraph['is_active'] = True
|
|
||||||
paragraph['chunks'] = text_to_chunk(paragraph['content'], chunk_size)
|
|
||||||
|
|
||||||
def _generate_problem_list(
|
|
||||||
self, paragraph, document_name, split_strategy, paragraph_title_relate_problem_type,
|
|
||||||
paragraph_title_relate_problem, paragraph_title_relate_problem_reference,
|
|
||||||
document_name_relate_problem_type, document_name_relate_problem,
|
|
||||||
document_name_relate_problem_reference
|
|
||||||
):
|
|
||||||
if paragraph_title_relate_problem_type == 'referencing':
|
|
||||||
paragraph_title_relate_problem = self.get_reference_content(paragraph_title_relate_problem_reference)
|
|
||||||
if document_name_relate_problem_type == 'referencing':
|
|
||||||
document_name_relate_problem = self.get_reference_content(document_name_relate_problem_reference)
|
|
||||||
|
|
||||||
problem_list = [
|
|
||||||
item for p in paragraph.get('problem_list', []) for item in p.get('content', '').split('<br>')
|
|
||||||
if item.strip()
|
|
||||||
]
|
|
||||||
|
|
||||||
if split_strategy == 'auto':
|
|
||||||
if paragraph_title_relate_problem and paragraph.get('title'):
|
|
||||||
problem_list.append(paragraph.get('title'))
|
|
||||||
if document_name_relate_problem and document_name:
|
|
||||||
problem_list.append(document_name)
|
|
||||||
elif split_strategy == 'custom':
|
|
||||||
if paragraph_title_relate_problem and paragraph.get('title'):
|
|
||||||
problem_list.append(paragraph.get('title'))
|
|
||||||
if document_name_relate_problem and document_name:
|
|
||||||
problem_list.append(document_name)
|
|
||||||
elif split_strategy == 'qa':
|
|
||||||
if document_name_relate_problem and document_name:
|
|
||||||
problem_list.append(document_name)
|
|
||||||
|
|
||||||
return list(set(problem_list))
|
|
||||||
|
|
||||||
def get_details(self, index: int, **kwargs):
|
|
||||||
return {
|
|
||||||
'name': self.node.properties.get('stepName'),
|
|
||||||
"index": index,
|
|
||||||
'run_time': self.context.get('run_time'),
|
|
||||||
'type': self.node.type,
|
|
||||||
'status': self.status,
|
|
||||||
'err_message': self.err_message,
|
|
||||||
'paragraph_list': self.context.get('paragraph_list', []),
|
|
||||||
'limit': self.context.get('limit'),
|
|
||||||
'chunk_size': self.context.get('chunk_size'),
|
|
||||||
'with_filter': self.context.get('with_filter'),
|
|
||||||
'patterns': self.context.get('patterns'),
|
|
||||||
'split_strategy': self.context.get('split_strategy'),
|
|
||||||
'document_list': self.context.get('document_list', []),
|
|
||||||
}
|
|
||||||
|
|
@ -10,22 +10,20 @@ from typing import Type
|
||||||
|
|
||||||
from rest_framework import serializers
|
from rest_framework import serializers
|
||||||
|
|
||||||
from application.flow.common import WorkflowMode
|
|
||||||
from application.flow.i_step_node import INode, NodeResult
|
from application.flow.i_step_node import INode, NodeResult
|
||||||
|
from common.util.field_message import ErrMessage
|
||||||
from django.utils.translation import gettext_lazy as _
|
from django.utils.translation import gettext_lazy as _
|
||||||
|
|
||||||
|
|
||||||
class FormNodeParamsSerializer(serializers.Serializer):
|
class FormNodeParamsSerializer(serializers.Serializer):
|
||||||
form_field_list = serializers.ListField(required=True, label=_("Form Configuration"))
|
form_field_list = serializers.ListField(required=True, error_messages=ErrMessage.list(_("Form Configuration")))
|
||||||
form_content_format = serializers.CharField(required=True, label=_('Form output content'))
|
form_content_format = serializers.CharField(required=True, error_messages=ErrMessage.char(_('Form output content')))
|
||||||
form_data = serializers.DictField(required=False, allow_null=True, label=_("Form Data"))
|
form_data = serializers.DictField(required=False, allow_null=True, error_messages=ErrMessage.dict(_("Form Data")))
|
||||||
|
|
||||||
|
|
||||||
class IFormNode(INode):
|
class IFormNode(INode):
|
||||||
type = 'form-node'
|
type = 'form-node'
|
||||||
view_type = 'single_view'
|
view_type = 'single_view'
|
||||||
support = [WorkflowMode.APPLICATION, WorkflowMode.APPLICATION_LOOP]
|
|
||||||
|
|
||||||
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
|
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
|
||||||
return FormNodeParamsSerializer
|
return FormNodeParamsSerializer
|
||||||
|
|
|
||||||
|
|
@ -16,29 +16,6 @@ from application.flow.common import Answer
|
||||||
from application.flow.i_step_node import NodeResult
|
from application.flow.i_step_node import NodeResult
|
||||||
from application.flow.step_node.form_node.i_form_node import IFormNode
|
from application.flow.step_node.form_node.i_form_node import IFormNode
|
||||||
|
|
||||||
multi_select_list = [
|
|
||||||
'MultiSelect',
|
|
||||||
'MultiRow'
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def get_default_option(option_list, _type, value_field):
|
|
||||||
try:
|
|
||||||
if option_list is not None and isinstance(option_list, list) and len(option_list) > 0:
|
|
||||||
default_value_list = [o.get(value_field) for o in option_list if o.get('default')]
|
|
||||||
if len(default_value_list) == 0:
|
|
||||||
return [option_list[0].get(
|
|
||||||
value_field)] if multi_select_list.__contains__(_type) else option_list[0].get(
|
|
||||||
value_field)
|
|
||||||
else:
|
|
||||||
if multi_select_list.__contains__(_type):
|
|
||||||
return default_value_list
|
|
||||||
else:
|
|
||||||
return default_value_list[0]
|
|
||||||
except Exception as _:
|
|
||||||
pass
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
def write_context(step_variable: Dict, global_variable: Dict, node, workflow):
|
def write_context(step_variable: Dict, global_variable: Dict, node, workflow):
|
||||||
if step_variable is not None:
|
if step_variable is not None:
|
||||||
|
|
@ -51,13 +28,6 @@ def write_context(step_variable: Dict, global_variable: Dict, node, workflow):
|
||||||
node.context['run_time'] = time.time() - node.context['start_time']
|
node.context['run_time'] = time.time() - node.context['start_time']
|
||||||
|
|
||||||
|
|
||||||
def generate_prompt(workflow_manage, _value):
|
|
||||||
try:
|
|
||||||
return workflow_manage.generate_prompt(_value)
|
|
||||||
except Exception as e:
|
|
||||||
return _value
|
|
||||||
|
|
||||||
|
|
||||||
class BaseFormNode(IFormNode):
|
class BaseFormNode(IFormNode):
|
||||||
def save_context(self, details, workflow_manage):
|
def save_context(self, details, workflow_manage):
|
||||||
form_data = details.get('form_data', None)
|
form_data = details.get('form_data', None)
|
||||||
|
|
@ -74,37 +44,6 @@ class BaseFormNode(IFormNode):
|
||||||
for key in form_data:
|
for key in form_data:
|
||||||
self.context[key] = form_data[key]
|
self.context[key] = form_data[key]
|
||||||
|
|
||||||
def reset_field(self, field):
|
|
||||||
reset_field = ['field', 'label', 'default_value']
|
|
||||||
for f in reset_field:
|
|
||||||
_value = field[f]
|
|
||||||
if _value is None:
|
|
||||||
continue
|
|
||||||
if isinstance(_value, str):
|
|
||||||
field[f] = generate_prompt(self.workflow_manage, _value)
|
|
||||||
elif f == 'label':
|
|
||||||
_label_value = _value.get('label')
|
|
||||||
_value['label'] = generate_prompt(self.workflow_manage, _label_value)
|
|
||||||
tooltip = _value.get('attrs').get('tooltip')
|
|
||||||
if tooltip is not None:
|
|
||||||
_value.get('attrs')['tooltip'] = generate_prompt(self.workflow_manage, tooltip)
|
|
||||||
|
|
||||||
if ['SingleSelect', 'MultiSelect', 'RadioCard', 'RadioRow', 'MultiRow'].__contains__(field.get('input_type')):
|
|
||||||
if field.get('assignment_method') == 'ref_variables':
|
|
||||||
option_list = self.workflow_manage.get_reference_field(field.get('option_list')[0],
|
|
||||||
field.get('option_list')[1:])
|
|
||||||
option_list = option_list if isinstance(option_list, list) else []
|
|
||||||
field['option_list'] = option_list
|
|
||||||
field['default_value'] = get_default_option(option_list, field.get('input_type'),
|
|
||||||
field.get('value_field'))
|
|
||||||
|
|
||||||
if ['JsonInput'].__contains__(field.get('input_type')):
|
|
||||||
if field.get('default_value_assignment_method') == 'ref_variables':
|
|
||||||
field['default_value'] = self.workflow_manage.get_reference_field(field.get('default_value')[0],
|
|
||||||
field.get('default_value')[1:])
|
|
||||||
|
|
||||||
return field
|
|
||||||
|
|
||||||
def execute(self, form_field_list, form_content_format, form_data, **kwargs) -> NodeResult:
|
def execute(self, form_field_list, form_content_format, form_data, **kwargs) -> NodeResult:
|
||||||
if form_data is not None:
|
if form_data is not None:
|
||||||
self.context['is_submit'] = True
|
self.context['is_submit'] = True
|
||||||
|
|
@ -113,7 +52,6 @@ class BaseFormNode(IFormNode):
|
||||||
self.context[key] = form_data.get(key)
|
self.context[key] = form_data.get(key)
|
||||||
else:
|
else:
|
||||||
self.context['is_submit'] = False
|
self.context['is_submit'] = False
|
||||||
form_field_list = [self.reset_field(field) for field in form_field_list]
|
|
||||||
form_setting = {"form_field_list": form_field_list, "runtime_node_id": self.runtime_node_id,
|
form_setting = {"form_field_list": form_field_list, "runtime_node_id": self.runtime_node_id,
|
||||||
"chat_record_id": self.flow_params_serializer.data.get("chat_record_id"),
|
"chat_record_id": self.flow_params_serializer.data.get("chat_record_id"),
|
||||||
"is_submit": self.context.get("is_submit", False)}
|
"is_submit": self.context.get("is_submit", False)}
|
||||||
|
|
@ -121,10 +59,7 @@ class BaseFormNode(IFormNode):
|
||||||
context = self.workflow_manage.get_workflow_content()
|
context = self.workflow_manage.get_workflow_content()
|
||||||
form_content_format = self.workflow_manage.reset_prompt(form_content_format)
|
form_content_format = self.workflow_manage.reset_prompt(form_content_format)
|
||||||
prompt_template = PromptTemplate.from_template(form_content_format, template_format='jinja2')
|
prompt_template = PromptTemplate.from_template(form_content_format, template_format='jinja2')
|
||||||
value = prompt_template.format(form=form, context=context, runtime_node_id=self.runtime_node_id,
|
value = prompt_template.format(form=form, context=context)
|
||||||
chat_record_id=self.flow_params_serializer.data.get("chat_record_id"),
|
|
||||||
form_field_list=form_field_list)
|
|
||||||
|
|
||||||
return NodeResult(
|
return NodeResult(
|
||||||
{'result': value, 'form_field_list': form_field_list, 'form_content_format': form_content_format}, {},
|
{'result': value, 'form_field_list': form_field_list, 'form_content_format': form_content_format}, {},
|
||||||
_write_context=write_context)
|
_write_context=write_context)
|
||||||
|
|
@ -140,9 +75,7 @@ class BaseFormNode(IFormNode):
|
||||||
context = self.workflow_manage.get_workflow_content()
|
context = self.workflow_manage.get_workflow_content()
|
||||||
form_content_format = self.workflow_manage.reset_prompt(form_content_format)
|
form_content_format = self.workflow_manage.reset_prompt(form_content_format)
|
||||||
prompt_template = PromptTemplate.from_template(form_content_format, template_format='jinja2')
|
prompt_template = PromptTemplate.from_template(form_content_format, template_format='jinja2')
|
||||||
value = prompt_template.format(form=form, context=context, runtime_node_id=self.runtime_node_id,
|
value = prompt_template.format(form=form, context=context)
|
||||||
chat_record_id=self.flow_params_serializer.data.get("chat_record_id"),
|
|
||||||
form_field_list=form_field_list)
|
|
||||||
return [Answer(value, self.view_type, self.runtime_node_id, self.workflow_params['chat_record_id'], None,
|
return [Answer(value, self.view_type, self.runtime_node_id, self.workflow_params['chat_record_id'], None,
|
||||||
self.runtime_node_id, '')]
|
self.runtime_node_id, '')]
|
||||||
|
|
||||||
|
|
@ -157,9 +90,7 @@ class BaseFormNode(IFormNode):
|
||||||
context = self.workflow_manage.get_workflow_content()
|
context = self.workflow_manage.get_workflow_content()
|
||||||
form_content_format = self.workflow_manage.reset_prompt(form_content_format)
|
form_content_format = self.workflow_manage.reset_prompt(form_content_format)
|
||||||
prompt_template = PromptTemplate.from_template(form_content_format, template_format='jinja2')
|
prompt_template = PromptTemplate.from_template(form_content_format, template_format='jinja2')
|
||||||
value = prompt_template.format(form=form, context=context, runtime_node_id=self.runtime_node_id,
|
value = prompt_template.format(form=form, context=context)
|
||||||
chat_record_id=self.flow_params_serializer.data.get("chat_record_id"),
|
|
||||||
form_field_list=form_field_list)
|
|
||||||
return {
|
return {
|
||||||
'name': self.node.properties.get('stepName'),
|
'name': self.node.properties.get('stepName'),
|
||||||
"index": index,
|
"index": index,
|
||||||
|
|
|
||||||
|
|
@ -8,41 +8,35 @@
|
||||||
"""
|
"""
|
||||||
from typing import Type
|
from typing import Type
|
||||||
|
|
||||||
from django.db import connection
|
|
||||||
from django.db.models import QuerySet
|
from django.db.models import QuerySet
|
||||||
from django.utils.translation import gettext_lazy as _
|
|
||||||
from rest_framework import serializers
|
from rest_framework import serializers
|
||||||
|
|
||||||
from application.flow.common import WorkflowMode
|
|
||||||
from application.flow.i_step_node import INode, NodeResult
|
from application.flow.i_step_node import INode, NodeResult
|
||||||
from common.field.common import ObjectField
|
from common.field.common import ObjectField
|
||||||
from tools.models.tool import Tool
|
from common.util.field_message import ErrMessage
|
||||||
|
from function_lib.models.function import FunctionLib
|
||||||
|
from django.utils.translation import gettext_lazy as _
|
||||||
|
|
||||||
|
|
||||||
class InputField(serializers.Serializer):
|
class InputField(serializers.Serializer):
|
||||||
name = serializers.CharField(required=True, label=_('Variable Name'))
|
name = serializers.CharField(required=True, error_messages=ErrMessage.char(_('Variable Name')))
|
||||||
value = ObjectField(required=True, label=_("Variable Value"), model_type_list=[str, list])
|
value = ObjectField(required=True, error_messages=ErrMessage.char(_("Variable Value")), model_type_list=[str, list])
|
||||||
|
|
||||||
|
|
||||||
class FunctionLibNodeParamsSerializer(serializers.Serializer):
|
class FunctionLibNodeParamsSerializer(serializers.Serializer):
|
||||||
tool_lib_id = serializers.UUIDField(required=True, label=_('Library ID'))
|
function_lib_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid(_('Library ID')))
|
||||||
input_field_list = InputField(required=True, many=True)
|
input_field_list = InputField(required=True, many=True)
|
||||||
is_result = serializers.BooleanField(required=False,
|
is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean(_('Whether to return content')))
|
||||||
label=_('Whether to return content'))
|
|
||||||
|
|
||||||
def is_valid(self, *, raise_exception=False):
|
def is_valid(self, *, raise_exception=False):
|
||||||
super().is_valid(raise_exception=True)
|
super().is_valid(raise_exception=True)
|
||||||
f_lib = QuerySet(Tool).filter(id=self.data.get('tool_lib_id')).first()
|
f_lib = QuerySet(FunctionLib).filter(id=self.data.get('function_lib_id')).first()
|
||||||
# 归还链接到连接池
|
|
||||||
connection.close()
|
|
||||||
if f_lib is None:
|
if f_lib is None:
|
||||||
raise Exception(_('The function has been deleted'))
|
raise Exception(_('The function has been deleted'))
|
||||||
|
|
||||||
|
|
||||||
class IToolLibNode(INode):
|
class IFunctionLibNode(INode):
|
||||||
type = 'tool-lib-node'
|
type = 'function-lib-node'
|
||||||
support = [WorkflowMode.APPLICATION, WorkflowMode.APPLICATION_LOOP, WorkflowMode.KNOWLEDGE,
|
|
||||||
WorkflowMode.KNOWLEDGE_LOOP]
|
|
||||||
|
|
||||||
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
|
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
|
||||||
return FunctionLibNodeParamsSerializer
|
return FunctionLibNodeParamsSerializer
|
||||||
|
|
@ -50,5 +44,5 @@ class IToolLibNode(INode):
|
||||||
def _run(self):
|
def _run(self):
|
||||||
return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data)
|
return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data)
|
||||||
|
|
||||||
def execute(self, tool_lib_id, input_field_list, **kwargs) -> NodeResult:
|
def execute(self, function_lib_id, input_field_list, **kwargs) -> NodeResult:
|
||||||
pass
|
pass
|
||||||
|
|
@ -6,4 +6,4 @@
|
||||||
@date:2024/8/8 17:48
|
@date:2024/8/8 17:48
|
||||||
@desc:
|
@desc:
|
||||||
"""
|
"""
|
||||||
from .base_tool_lib_node import BaseToolLibNodeNode
|
from .base_function_lib_node import BaseFunctionLibNodeNode
|
||||||
|
|
@ -0,0 +1,150 @@
|
||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: MaxKB
|
||||||
|
@Author:虎
|
||||||
|
@file: base_function_lib_node.py
|
||||||
|
@date:2024/8/8 17:49
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from django.db.models import QuerySet
|
||||||
|
from django.utils.translation import gettext as _
|
||||||
|
|
||||||
|
from application.flow.i_step_node import NodeResult
|
||||||
|
from application.flow.step_node.function_lib_node.i_function_lib_node import IFunctionLibNode
|
||||||
|
from common.exception.app_exception import AppApiException
|
||||||
|
from common.util.function_code import FunctionExecutor
|
||||||
|
from common.util.rsa_util import rsa_long_decrypt
|
||||||
|
from function_lib.models.function import FunctionLib
|
||||||
|
from smartdoc.const import CONFIG
|
||||||
|
|
||||||
|
function_executor = FunctionExecutor(CONFIG.get('SANDBOX'))
|
||||||
|
|
||||||
|
|
||||||
|
def write_context(step_variable: Dict, global_variable: Dict, node, workflow):
|
||||||
|
if step_variable is not None:
|
||||||
|
for key in step_variable:
|
||||||
|
node.context[key] = step_variable[key]
|
||||||
|
if workflow.is_result(node, NodeResult(step_variable, global_variable)) and 'result' in step_variable:
|
||||||
|
result = str(step_variable['result']) + '\n'
|
||||||
|
yield result
|
||||||
|
node.answer_text = result
|
||||||
|
node.context['run_time'] = time.time() - node.context['start_time']
|
||||||
|
|
||||||
|
|
||||||
|
def get_field_value(debug_field_list, name, is_required):
|
||||||
|
result = [field for field in debug_field_list if field.get('name') == name]
|
||||||
|
if len(result) > 0:
|
||||||
|
return result[-1]['value']
|
||||||
|
if is_required:
|
||||||
|
raise AppApiException(500, _('Field: {name} No value set').format(name=name))
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def valid_reference_value(_type, value, name):
|
||||||
|
if _type == 'int':
|
||||||
|
instance_type = int | float
|
||||||
|
elif _type == 'float':
|
||||||
|
instance_type = float | int
|
||||||
|
elif _type == 'dict':
|
||||||
|
instance_type = dict
|
||||||
|
elif _type == 'array':
|
||||||
|
instance_type = list
|
||||||
|
elif _type == 'string':
|
||||||
|
instance_type = str
|
||||||
|
else:
|
||||||
|
raise Exception(_('Field: {name} Type: {_type} Value: {value} Unsupported types').format(name=name,
|
||||||
|
_type=_type))
|
||||||
|
if not isinstance(value, instance_type):
|
||||||
|
raise Exception(
|
||||||
|
_('Field: {name} Type: {_type} Value: {value} Type error').format(name=name, _type=_type,
|
||||||
|
value=value))
|
||||||
|
|
||||||
|
|
||||||
|
def convert_value(name: str, value, _type, is_required, source, node):
|
||||||
|
if not is_required and (value is None or (isinstance(value, str) and len(value) == 0)):
|
||||||
|
return None
|
||||||
|
if not is_required and source == 'reference' and (value is None or len(value) == 0):
|
||||||
|
return None
|
||||||
|
if source == 'reference':
|
||||||
|
value = node.workflow_manage.get_reference_field(
|
||||||
|
value[0],
|
||||||
|
value[1:])
|
||||||
|
valid_reference_value(_type, value, name)
|
||||||
|
if _type == 'int':
|
||||||
|
return int(value)
|
||||||
|
if _type == 'float':
|
||||||
|
return float(value)
|
||||||
|
return value
|
||||||
|
try:
|
||||||
|
if _type == 'int':
|
||||||
|
return int(value)
|
||||||
|
if _type == 'float':
|
||||||
|
return float(value)
|
||||||
|
if _type == 'dict':
|
||||||
|
v = json.loads(value)
|
||||||
|
if isinstance(v, dict):
|
||||||
|
return v
|
||||||
|
raise Exception(_('type error'))
|
||||||
|
if _type == 'array':
|
||||||
|
v = json.loads(value)
|
||||||
|
if isinstance(v, list):
|
||||||
|
return v
|
||||||
|
raise Exception(_('type error'))
|
||||||
|
return value
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(
|
||||||
|
_('Field: {name} Type: {_type} Value: {value} Type error').format(name=name, _type=_type,
|
||||||
|
value=value))
|
||||||
|
|
||||||
|
|
||||||
|
def valid_function(function_lib, user_id):
|
||||||
|
if function_lib is None:
|
||||||
|
raise Exception(_('Function does not exist'))
|
||||||
|
if function_lib.permission_type == 'PRIVATE' and str(function_lib.user_id) != str(user_id):
|
||||||
|
raise Exception(_('No permission to use this function {name}').format(name=function_lib.name))
|
||||||
|
if not function_lib.is_active:
|
||||||
|
raise Exception(_('Function {name} is unavailable').format(name=function_lib.name))
|
||||||
|
|
||||||
|
|
||||||
|
class BaseFunctionLibNodeNode(IFunctionLibNode):
|
||||||
|
def save_context(self, details, workflow_manage):
|
||||||
|
self.context['result'] = details.get('result')
|
||||||
|
if self.node_params.get('is_result'):
|
||||||
|
self.answer_text = str(details.get('result'))
|
||||||
|
|
||||||
|
def execute(self, function_lib_id, input_field_list, **kwargs) -> NodeResult:
|
||||||
|
function_lib = QuerySet(FunctionLib).filter(id=function_lib_id).first()
|
||||||
|
valid_function(function_lib, self.flow_params_serializer.data.get('user_id'))
|
||||||
|
params = {field.get('name'): convert_value(field.get('name'), field.get('value'), field.get('type'),
|
||||||
|
field.get('is_required'),
|
||||||
|
field.get('source'), self)
|
||||||
|
for field in
|
||||||
|
[{'value': get_field_value(input_field_list, field.get('name'), field.get('is_required'),
|
||||||
|
), **field}
|
||||||
|
for field in
|
||||||
|
function_lib.input_field_list]}
|
||||||
|
|
||||||
|
self.context['params'] = params
|
||||||
|
# 合并初始化参数
|
||||||
|
if function_lib.init_params is not None:
|
||||||
|
all_params = json.loads(rsa_long_decrypt(function_lib.init_params)) | params
|
||||||
|
else:
|
||||||
|
all_params = params
|
||||||
|
result = function_executor.exec_code(function_lib.code, all_params)
|
||||||
|
return NodeResult({'result': result}, {}, _write_context=write_context)
|
||||||
|
|
||||||
|
def get_details(self, index: int, **kwargs):
|
||||||
|
return {
|
||||||
|
'name': self.node.properties.get('stepName'),
|
||||||
|
"index": index,
|
||||||
|
"result": self.context.get('result'),
|
||||||
|
"params": self.context.get('params'),
|
||||||
|
'run_time': self.context.get('run_time'),
|
||||||
|
'type': self.node.type,
|
||||||
|
'status': self.status,
|
||||||
|
'err_message': self.err_message
|
||||||
|
}
|
||||||
|
|
@ -10,28 +10,28 @@ import re
|
||||||
from typing import Type
|
from typing import Type
|
||||||
|
|
||||||
from django.core import validators
|
from django.core import validators
|
||||||
from django.utils.translation import gettext_lazy as _
|
|
||||||
from rest_framework import serializers
|
from rest_framework import serializers
|
||||||
from rest_framework.utils.formatting import lazy_format
|
|
||||||
|
|
||||||
from application.flow.common import WorkflowMode
|
|
||||||
from application.flow.i_step_node import INode, NodeResult
|
from application.flow.i_step_node import INode, NodeResult
|
||||||
from common.exception.app_exception import AppApiException
|
from common.exception.app_exception import AppApiException
|
||||||
from common.field.common import ObjectField
|
from common.field.common import ObjectField
|
||||||
|
from common.util.field_message import ErrMessage
|
||||||
|
from django.utils.translation import gettext_lazy as _
|
||||||
|
from rest_framework.utils.formatting import lazy_format
|
||||||
|
|
||||||
|
|
||||||
class InputField(serializers.Serializer):
|
class InputField(serializers.Serializer):
|
||||||
name = serializers.CharField(required=True, label=_('Variable Name'))
|
name = serializers.CharField(required=True, error_messages=ErrMessage.char(_('Variable Name')))
|
||||||
is_required = serializers.BooleanField(required=True, label=_("Is this field required"))
|
is_required = serializers.BooleanField(required=True, error_messages=ErrMessage.boolean(_("Is this field required")))
|
||||||
type = serializers.CharField(required=True, label=_("type"), validators=[
|
type = serializers.CharField(required=True, error_messages=ErrMessage.char(_("type")), validators=[
|
||||||
validators.RegexValidator(regex=re.compile("^string|int|dict|array|float$"),
|
validators.RegexValidator(regex=re.compile("^string|int|dict|array|float$"),
|
||||||
message=_("The field only supports string|int|dict|array|float"), code=500)
|
message=_("The field only supports string|int|dict|array|float"), code=500)
|
||||||
])
|
])
|
||||||
source = serializers.CharField(required=True, label=_("source"), validators=[
|
source = serializers.CharField(required=True, error_messages=ErrMessage.char(_("source")), validators=[
|
||||||
validators.RegexValidator(regex=re.compile("^custom|reference$"),
|
validators.RegexValidator(regex=re.compile("^custom|reference$"),
|
||||||
message=_("The field only supports custom|reference"), code=500)
|
message=_("The field only supports custom|reference"), code=500)
|
||||||
])
|
])
|
||||||
value = ObjectField(required=True, label=_("Variable Value"), model_type_list=[str, list])
|
value = ObjectField(required=True, error_messages=ErrMessage.char(_("Variable Value")), model_type_list=[str, list])
|
||||||
|
|
||||||
def is_valid(self, *, raise_exception=False):
|
def is_valid(self, *, raise_exception=False):
|
||||||
super().is_valid(raise_exception=True)
|
super().is_valid(raise_exception=True)
|
||||||
|
|
@ -43,18 +43,15 @@ class InputField(serializers.Serializer):
|
||||||
|
|
||||||
class FunctionNodeParamsSerializer(serializers.Serializer):
|
class FunctionNodeParamsSerializer(serializers.Serializer):
|
||||||
input_field_list = InputField(required=True, many=True)
|
input_field_list = InputField(required=True, many=True)
|
||||||
code = serializers.CharField(required=True, label=_("function"))
|
code = serializers.CharField(required=True, error_messages=ErrMessage.char(_("function")))
|
||||||
is_result = serializers.BooleanField(required=False,
|
is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean(_('Whether to return content')))
|
||||||
label=_('Whether to return content'))
|
|
||||||
|
|
||||||
def is_valid(self, *, raise_exception=False):
|
def is_valid(self, *, raise_exception=False):
|
||||||
super().is_valid(raise_exception=True)
|
super().is_valid(raise_exception=True)
|
||||||
|
|
||||||
|
|
||||||
class IToolNode(INode):
|
class IFunctionNode(INode):
|
||||||
type = 'tool-node'
|
type = 'function-node'
|
||||||
support = [WorkflowMode.APPLICATION, WorkflowMode.APPLICATION_LOOP, WorkflowMode.KNOWLEDGE,
|
|
||||||
WorkflowMode.KNOWLEDGE_LOOP]
|
|
||||||
|
|
||||||
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
|
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
|
||||||
return FunctionNodeParamsSerializer
|
return FunctionNodeParamsSerializer
|
||||||
|
|
@ -6,4 +6,4 @@
|
||||||
@date:2024/8/13 11:19
|
@date:2024/8/13 11:19
|
||||||
@desc:
|
@desc:
|
||||||
"""
|
"""
|
||||||
from .base_tool_node import BaseToolNodeNode
|
from .base_function_node import BaseFunctionNodeNode
|
||||||
|
|
@ -8,16 +8,16 @@
|
||||||
"""
|
"""
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
from django.utils.translation import gettext as _
|
|
||||||
|
|
||||||
from application.flow.i_step_node import NodeResult
|
from application.flow.i_step_node import NodeResult
|
||||||
from application.flow.step_node.tool_node.i_tool_node import IToolNode
|
from application.flow.step_node.function_node.i_function_node import IFunctionNode
|
||||||
from common.utils.tool_code import ToolExecutor
|
from common.exception.app_exception import AppApiException
|
||||||
from maxkb.const import CONFIG
|
from common.util.function_code import FunctionExecutor
|
||||||
|
from smartdoc.const import CONFIG
|
||||||
|
|
||||||
function_executor = ToolExecutor()
|
function_executor = FunctionExecutor(CONFIG.get('SANDBOX'))
|
||||||
|
|
||||||
|
|
||||||
def write_context(step_variable: Dict, global_variable: Dict, node, workflow):
|
def write_context(step_variable: Dict, global_variable: Dict, node, workflow):
|
||||||
|
|
@ -32,54 +32,36 @@ def write_context(step_variable: Dict, global_variable: Dict, node, workflow):
|
||||||
|
|
||||||
|
|
||||||
def valid_reference_value(_type, value, name):
|
def valid_reference_value(_type, value, name):
|
||||||
try:
|
if _type == 'int':
|
||||||
if _type == 'int':
|
instance_type = int | float
|
||||||
instance_type = int | float
|
elif _type == 'float':
|
||||||
elif _type == 'float':
|
instance_type = float | int
|
||||||
instance_type = float | int
|
elif _type == 'dict':
|
||||||
elif _type == 'dict':
|
instance_type = dict
|
||||||
value = json.loads(value) if isinstance(value, str) else value
|
elif _type == 'array':
|
||||||
instance_type = dict
|
instance_type = list
|
||||||
elif _type == 'array':
|
elif _type == 'string':
|
||||||
value = json.loads(value) if isinstance(value, str) else value
|
instance_type = str
|
||||||
instance_type = list
|
else:
|
||||||
elif _type == 'string':
|
raise Exception(500, f'字段:{name}类型:{_type} 不支持的类型')
|
||||||
instance_type = str
|
|
||||||
else:
|
|
||||||
raise Exception(_(
|
|
||||||
'Field: {name} Type: {_type} Value: {value} Unsupported types'
|
|
||||||
).format(name=name, _type=_type))
|
|
||||||
except:
|
|
||||||
return value
|
|
||||||
if not isinstance(value, instance_type):
|
if not isinstance(value, instance_type):
|
||||||
raise Exception(_(
|
raise Exception(f'字段:{name}类型:{_type}值:{value}类型错误')
|
||||||
'Field: {name} Type: {_type} Value: {value} Type error'
|
|
||||||
).format(name=name, _type=_type, value=value))
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
def convert_value(name: str, value, _type, is_required, source, node):
|
def convert_value(name: str, value, _type, is_required, source, node):
|
||||||
if not is_required and (value is None or ((isinstance(value, str) or isinstance(value, list)) and len(value) == 0)):
|
if not is_required and (value is None or (isinstance(value, str) and len(value) == 0)):
|
||||||
return None
|
return None
|
||||||
if source == 'reference':
|
if source == 'reference':
|
||||||
value = node.workflow_manage.get_reference_field(
|
value = node.workflow_manage.get_reference_field(
|
||||||
value[0],
|
value[0],
|
||||||
value[1:])
|
value[1:])
|
||||||
if value is None:
|
valid_reference_value(_type, value, name)
|
||||||
if not is_required:
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
raise Exception(_(
|
|
||||||
'Field: {name} Type: {_type} is required'
|
|
||||||
).format(name=name, _type=_type))
|
|
||||||
value = valid_reference_value(_type, value, name)
|
|
||||||
if _type == 'int':
|
if _type == 'int':
|
||||||
return int(value)
|
return int(value)
|
||||||
if _type == 'float':
|
if _type == 'float':
|
||||||
return float(value)
|
return float(value)
|
||||||
return value
|
return value
|
||||||
try:
|
try:
|
||||||
value = node.workflow_manage.generate_prompt(value)
|
|
||||||
if _type == 'int':
|
if _type == 'int':
|
||||||
return int(value)
|
return int(value)
|
||||||
if _type == 'float':
|
if _type == 'float':
|
||||||
|
|
@ -88,20 +70,18 @@ def convert_value(name: str, value, _type, is_required, source, node):
|
||||||
v = json.loads(value)
|
v = json.loads(value)
|
||||||
if isinstance(v, dict):
|
if isinstance(v, dict):
|
||||||
return v
|
return v
|
||||||
raise Exception(_('type error'))
|
raise Exception("类型错误")
|
||||||
if _type == 'array':
|
if _type == 'array':
|
||||||
v = json.loads(value)
|
v = json.loads(value)
|
||||||
if isinstance(v, list):
|
if isinstance(v, list):
|
||||||
return v
|
return v
|
||||||
raise Exception(_('type error'))
|
raise Exception("类型错误")
|
||||||
return value
|
return value
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise Exception(
|
raise Exception(f'字段:{name}类型:{_type}值:{value}类型错误')
|
||||||
_('Field: {name} Type: {_type} Value: {value} Type error').format(name=name, _type=_type,
|
|
||||||
value=value))
|
|
||||||
|
|
||||||
|
|
||||||
class BaseToolNodeNode(IToolNode):
|
class BaseFunctionNodeNode(IFunctionNode):
|
||||||
def save_context(self, details, workflow_manage):
|
def save_context(self, details, workflow_manage):
|
||||||
self.context['result'] = details.get('result')
|
self.context['result'] = details.get('result')
|
||||||
if self.node_params.get('is_result', False):
|
if self.node_params.get('is_result', False):
|
||||||
|
|
@ -2,51 +2,43 @@
|
||||||
|
|
||||||
from typing import Type
|
from typing import Type
|
||||||
|
|
||||||
from django.utils.translation import gettext_lazy as _
|
|
||||||
from rest_framework import serializers
|
from rest_framework import serializers
|
||||||
|
|
||||||
from application.flow.common import WorkflowMode
|
|
||||||
from application.flow.i_step_node import INode, NodeResult
|
from application.flow.i_step_node import INode, NodeResult
|
||||||
|
from common.util.field_message import ErrMessage
|
||||||
|
from django.utils.translation import gettext_lazy as _
|
||||||
|
|
||||||
|
|
||||||
class ImageGenerateNodeSerializer(serializers.Serializer):
|
class ImageGenerateNodeSerializer(serializers.Serializer):
|
||||||
model_id = serializers.CharField(required=True, label=_("Model id"))
|
model_id = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Model id")))
|
||||||
|
|
||||||
prompt = serializers.CharField(required=True, label=_("Prompt word (positive)"))
|
prompt = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Prompt word (positive)")))
|
||||||
|
|
||||||
negative_prompt = serializers.CharField(required=False, label=_("Prompt word (negative)"),
|
negative_prompt = serializers.CharField(required=False, error_messages=ErrMessage.char(_("Prompt word (negative)")),
|
||||||
allow_null=True, allow_blank=True, )
|
allow_null=True, allow_blank=True, )
|
||||||
# 多轮对话数量
|
# 多轮对话数量
|
||||||
dialogue_number = serializers.IntegerField(required=False, default=0,
|
dialogue_number = serializers.IntegerField(required=False, default=0,
|
||||||
label=_("Number of multi-round conversations"))
|
error_messages=ErrMessage.integer(_("Number of multi-round conversations")))
|
||||||
|
|
||||||
dialogue_type = serializers.CharField(required=False, default='NODE',
|
dialogue_type = serializers.CharField(required=False, default='NODE',
|
||||||
label=_("Conversation storage type"))
|
error_messages=ErrMessage.char(_("Conversation storage type")))
|
||||||
|
|
||||||
is_result = serializers.BooleanField(required=False,
|
is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean(_('Whether to return content')))
|
||||||
label=_('Whether to return content'))
|
|
||||||
|
|
||||||
model_params_setting = serializers.JSONField(required=False, default=dict,
|
model_params_setting = serializers.JSONField(required=False, default=dict,
|
||||||
label=_("Model parameter settings"))
|
error_messages=ErrMessage.json(_("Model parameter settings")))
|
||||||
|
|
||||||
|
|
||||||
class IImageGenerateNode(INode):
|
class IImageGenerateNode(INode):
|
||||||
type = 'image-generate-node'
|
type = 'image-generate-node'
|
||||||
support = [WorkflowMode.APPLICATION, WorkflowMode.APPLICATION_LOOP, WorkflowMode.KNOWLEDGE,
|
|
||||||
WorkflowMode.KNOWLEDGE_LOOP]
|
|
||||||
|
|
||||||
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
|
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
|
||||||
return ImageGenerateNodeSerializer
|
return ImageGenerateNodeSerializer
|
||||||
|
|
||||||
def _run(self):
|
def _run(self):
|
||||||
if [WorkflowMode.KNOWLEDGE, WorkflowMode.KNOWLEDGE_LOOP].__contains__(
|
return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data)
|
||||||
self.workflow_manage.flow.workflow_mode):
|
|
||||||
return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data,
|
|
||||||
**{'history_chat_record': [], 'stream': True, 'chat_id': None, 'chat_record_id': None})
|
|
||||||
else:
|
|
||||||
return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data)
|
|
||||||
|
|
||||||
def execute(self, model_id, prompt, negative_prompt, dialogue_number, dialogue_type, history_chat_record,
|
def execute(self, model_id, prompt, negative_prompt, dialogue_number, dialogue_type, history_chat_record, chat_id,
|
||||||
model_params_setting,
|
model_params_setting,
|
||||||
chat_record_id,
|
chat_record_id,
|
||||||
**kwargs) -> NodeResult:
|
**kwargs) -> NodeResult:
|
||||||
|
|
|
||||||
|
|
@ -5,13 +5,11 @@ from typing import List
|
||||||
import requests
|
import requests
|
||||||
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
|
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
|
||||||
|
|
||||||
from application.flow.common import WorkflowMode
|
|
||||||
from application.flow.i_step_node import NodeResult
|
from application.flow.i_step_node import NodeResult
|
||||||
from application.flow.step_node.image_generate_step_node.i_image_generate_node import IImageGenerateNode
|
from application.flow.step_node.image_generate_step_node.i_image_generate_node import IImageGenerateNode
|
||||||
from common.utils.common import bytes_to_uploaded_file
|
from common.util.common import bytes_to_uploaded_file
|
||||||
from knowledge.models import FileSourceType
|
from dataset.serializers.file_serializers import FileSerializer
|
||||||
from oss.serializers.file import FileSerializer
|
from setting.models_provider.tools import get_model_instance_by_model_user_id
|
||||||
from models_provider.tools import get_model_instance_by_model_workspace_id
|
|
||||||
|
|
||||||
|
|
||||||
class BaseImageGenerateNode(IImageGenerateNode):
|
class BaseImageGenerateNode(IImageGenerateNode):
|
||||||
|
|
@ -21,13 +19,14 @@ class BaseImageGenerateNode(IImageGenerateNode):
|
||||||
if self.node_params.get('is_result', False):
|
if self.node_params.get('is_result', False):
|
||||||
self.answer_text = details.get('answer')
|
self.answer_text = details.get('answer')
|
||||||
|
|
||||||
def execute(self, model_id, prompt, negative_prompt, dialogue_number, dialogue_type, history_chat_record,
|
def execute(self, model_id, prompt, negative_prompt, dialogue_number, dialogue_type, history_chat_record, chat_id,
|
||||||
model_params_setting,
|
model_params_setting,
|
||||||
chat_record_id,
|
chat_record_id,
|
||||||
**kwargs) -> NodeResult:
|
**kwargs) -> NodeResult:
|
||||||
workspace_id = self.workflow_manage.get_body().get('workspace_id')
|
print(model_params_setting)
|
||||||
tti_model = get_model_instance_by_model_workspace_id(model_id, workspace_id,
|
application = self.workflow_manage.work_flow_post_handler.chat_info.application
|
||||||
**model_params_setting)
|
tti_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'),
|
||||||
|
**model_params_setting)
|
||||||
history_message = self.get_history_message(history_chat_record, dialogue_number)
|
history_message = self.get_history_message(history_chat_record, dialogue_number)
|
||||||
self.context['history_message'] = history_message
|
self.context['history_message'] = history_message
|
||||||
question = self.generate_prompt_question(prompt)
|
question = self.generate_prompt_question(prompt)
|
||||||
|
|
@ -35,16 +34,19 @@ class BaseImageGenerateNode(IImageGenerateNode):
|
||||||
message_list = self.generate_message_list(question, history_message)
|
message_list = self.generate_message_list(question, history_message)
|
||||||
self.context['message_list'] = message_list
|
self.context['message_list'] = message_list
|
||||||
self.context['dialogue_type'] = dialogue_type
|
self.context['dialogue_type'] = dialogue_type
|
||||||
self.context['negative_prompt'] = self.generate_prompt_question(negative_prompt)
|
print(message_list)
|
||||||
image_urls = tti_model.generate_image(question, negative_prompt)
|
image_urls = tti_model.generate_image(question, negative_prompt)
|
||||||
# 保存图片
|
# 保存图片
|
||||||
file_urls = []
|
file_urls = []
|
||||||
for image_url in image_urls:
|
for image_url in image_urls:
|
||||||
file_name = 'generated_image.png'
|
file_name = 'generated_image.png'
|
||||||
if isinstance(image_url, str) and image_url.startswith('http'):
|
file = bytes_to_uploaded_file(requests.get(image_url).content, file_name)
|
||||||
image_url = requests.get(image_url).content
|
meta = {
|
||||||
file = bytes_to_uploaded_file(image_url, file_name)
|
'debug': False if application.id else True,
|
||||||
file_url = self.upload_file(file)
|
'chat_id': chat_id,
|
||||||
|
'application_id': str(application.id) if application.id else None,
|
||||||
|
}
|
||||||
|
file_url = FileSerializer(data={'file': file, 'meta': meta}).upload()
|
||||||
file_urls.append(file_url)
|
file_urls.append(file_url)
|
||||||
self.context['image_list'] = [{'file_id': path.split('/')[-1], 'url': path} for path in file_urls]
|
self.context['image_list'] = [{'file_id': path.split('/')[-1], 'url': path} for path in file_urls]
|
||||||
answer = ' '.join([f"" for path in file_urls])
|
answer = ' '.join([f"" for path in file_urls])
|
||||||
|
|
@ -91,42 +93,6 @@ class BaseImageGenerateNode(IImageGenerateNode):
|
||||||
question
|
question
|
||||||
]
|
]
|
||||||
|
|
||||||
def upload_file(self, file):
|
|
||||||
if [WorkflowMode.KNOWLEDGE, WorkflowMode.KNOWLEDGE_LOOP].__contains__(
|
|
||||||
self.workflow_manage.flow.workflow_mode):
|
|
||||||
return self.upload_knowledge_file(file)
|
|
||||||
return self.upload_application_file(file)
|
|
||||||
|
|
||||||
def upload_knowledge_file(self, file):
|
|
||||||
knowledge_id = self.workflow_params.get('knowledge_id')
|
|
||||||
meta = {
|
|
||||||
'debug': False,
|
|
||||||
'knowledge_id': knowledge_id,
|
|
||||||
}
|
|
||||||
file_url = FileSerializer(data={
|
|
||||||
'file': file,
|
|
||||||
'meta': meta,
|
|
||||||
'source_id': knowledge_id,
|
|
||||||
'source_type': FileSourceType.KNOWLEDGE.value
|
|
||||||
}).upload()
|
|
||||||
return file_url
|
|
||||||
|
|
||||||
def upload_application_file(self, file):
|
|
||||||
application = self.workflow_manage.work_flow_post_handler.chat_info.application
|
|
||||||
chat_id = self.workflow_params.get('chat_id')
|
|
||||||
meta = {
|
|
||||||
'debug': False if application.id else True,
|
|
||||||
'chat_id': chat_id,
|
|
||||||
'application_id': str(application.id) if application.id else None,
|
|
||||||
}
|
|
||||||
file_url = FileSerializer(data={
|
|
||||||
'file': file,
|
|
||||||
'meta': meta,
|
|
||||||
'source_id': meta['application_id'],
|
|
||||||
'source_type': FileSourceType.APPLICATION.value
|
|
||||||
}).upload()
|
|
||||||
return file_url
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def reset_message_list(message_list: List[BaseMessage], answer_text):
|
def reset_message_list(message_list: List[BaseMessage], answer_text):
|
||||||
result = [{'role': 'user' if isinstance(message, HumanMessage) else 'ai', 'content': message.content} for
|
result = [{'role': 'user' if isinstance(message, HumanMessage) else 'ai', 'content': message.content} for
|
||||||
|
|
@ -152,6 +118,5 @@ class BaseImageGenerateNode(IImageGenerateNode):
|
||||||
'status': self.status,
|
'status': self.status,
|
||||||
'err_message': self.err_message,
|
'err_message': self.err_message,
|
||||||
'image_list': self.context.get('image_list'),
|
'image_list': self.context.get('image_list'),
|
||||||
'dialogue_type': self.context.get('dialogue_type'),
|
'dialogue_type': self.context.get('dialogue_type')
|
||||||
'negative_prompt': self.context.get('negative_prompt'),
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,3 +0,0 @@
|
||||||
# coding=utf-8
|
|
||||||
|
|
||||||
from .impl import *
|
|
||||||
|
|
@ -1,71 +0,0 @@
|
||||||
# coding=utf-8
|
|
||||||
|
|
||||||
from typing import Type
|
|
||||||
|
|
||||||
from django.utils.translation import gettext_lazy as _
|
|
||||||
from rest_framework import serializers
|
|
||||||
|
|
||||||
from application.flow.common import WorkflowMode
|
|
||||||
from application.flow.i_step_node import INode, NodeResult
|
|
||||||
|
|
||||||
|
|
||||||
class ImageToVideoNodeSerializer(serializers.Serializer):
|
|
||||||
model_id = serializers.CharField(required=True, label=_("Model id"))
|
|
||||||
|
|
||||||
prompt = serializers.CharField(required=True, label=_("Prompt word (positive)"))
|
|
||||||
|
|
||||||
negative_prompt = serializers.CharField(required=False, label=_("Prompt word (negative)"),
|
|
||||||
allow_null=True, allow_blank=True, )
|
|
||||||
# 多轮对话数量
|
|
||||||
dialogue_number = serializers.IntegerField(required=False, default=0,
|
|
||||||
label=_("Number of multi-round conversations"))
|
|
||||||
|
|
||||||
dialogue_type = serializers.CharField(required=False, default='NODE',
|
|
||||||
label=_("Conversation storage type"))
|
|
||||||
|
|
||||||
is_result = serializers.BooleanField(required=False,
|
|
||||||
label=_('Whether to return content'))
|
|
||||||
|
|
||||||
model_params_setting = serializers.JSONField(required=False, default=dict,
|
|
||||||
label=_("Model parameter settings"))
|
|
||||||
|
|
||||||
first_frame_url = serializers.ListField(required=True, label=_("First frame url"))
|
|
||||||
last_frame_url = serializers.ListField(required=False, label=_("Last frame url"))
|
|
||||||
|
|
||||||
|
|
||||||
class IImageToVideoNode(INode):
|
|
||||||
type = 'image-to-video-node'
|
|
||||||
support = [WorkflowMode.APPLICATION, WorkflowMode.APPLICATION_LOOP, WorkflowMode.KNOWLEDGE,
|
|
||||||
WorkflowMode.KNOWLEDGE_LOOP]
|
|
||||||
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
|
|
||||||
return ImageToVideoNodeSerializer
|
|
||||||
|
|
||||||
def _run(self):
|
|
||||||
first_frame_url = self.workflow_manage.get_reference_field(
|
|
||||||
self.node_params_serializer.data.get('first_frame_url')[0],
|
|
||||||
self.node_params_serializer.data.get('first_frame_url')[1:])
|
|
||||||
if first_frame_url is []:
|
|
||||||
raise ValueError(
|
|
||||||
_("First frame url cannot be empty"))
|
|
||||||
last_frame_url = None
|
|
||||||
if self.node_params_serializer.data.get('last_frame_url') is not None and self.node_params_serializer.data.get(
|
|
||||||
'last_frame_url') != []:
|
|
||||||
last_frame_url = self.workflow_manage.get_reference_field(
|
|
||||||
self.node_params_serializer.data.get('last_frame_url')[0],
|
|
||||||
self.node_params_serializer.data.get('last_frame_url')[1:])
|
|
||||||
node_params_data = {k: v for k, v in self.node_params_serializer.data.items()
|
|
||||||
if k not in ['first_frame_url', 'last_frame_url']}
|
|
||||||
if [WorkflowMode.KNOWLEDGE, WorkflowMode.KNOWLEDGE_LOOP].__contains__(
|
|
||||||
self.workflow_manage.flow.workflow_mode):
|
|
||||||
return self.execute(first_frame_url=first_frame_url, last_frame_url=last_frame_url, **node_params_data, **self.flow_params_serializer.data,
|
|
||||||
**{'history_chat_record': [], 'stream': True, 'chat_id': None, 'chat_record_id': None})
|
|
||||||
else:
|
|
||||||
return self.execute(first_frame_url=first_frame_url, last_frame_url=last_frame_url,
|
|
||||||
**node_params_data, **self.flow_params_serializer.data)
|
|
||||||
|
|
||||||
def execute(self, model_id, prompt, negative_prompt, dialogue_number, dialogue_type, history_chat_record,
|
|
||||||
model_params_setting,
|
|
||||||
chat_record_id,
|
|
||||||
first_frame_url, last_frame_url,
|
|
||||||
**kwargs) -> NodeResult:
|
|
||||||
pass
|
|
||||||
|
|
@ -1,3 +0,0 @@
|
||||||
# coding=utf-8
|
|
||||||
|
|
||||||
from .base_image_to_video_node import BaseImageToVideoNode
|
|
||||||
|
|
@ -1,184 +0,0 @@
|
||||||
# coding=utf-8
|
|
||||||
import base64
|
|
||||||
from functools import reduce
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
import requests
|
|
||||||
from django.db.models import QuerySet
|
|
||||||
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
|
|
||||||
|
|
||||||
from application.flow.common import WorkflowMode
|
|
||||||
from application.flow.i_step_node import NodeResult
|
|
||||||
from application.flow.step_node.image_to_video_step_node.i_image_to_video_node import IImageToVideoNode
|
|
||||||
from common.utils.common import bytes_to_uploaded_file
|
|
||||||
from knowledge.models import FileSourceType, File
|
|
||||||
from oss.serializers.file import FileSerializer, mime_types
|
|
||||||
from models_provider.tools import get_model_instance_by_model_workspace_id
|
|
||||||
from django.utils.translation import gettext
|
|
||||||
|
|
||||||
|
|
||||||
class BaseImageToVideoNode(IImageToVideoNode):
|
|
||||||
def save_context(self, details, workflow_manage):
|
|
||||||
self.context['answer'] = details.get('answer')
|
|
||||||
self.context['question'] = details.get('question')
|
|
||||||
if self.node_params.get('is_result', False):
|
|
||||||
self.answer_text = details.get('answer')
|
|
||||||
|
|
||||||
def execute(self, model_id, prompt, negative_prompt, dialogue_number, dialogue_type, history_chat_record,
|
|
||||||
model_params_setting,
|
|
||||||
chat_record_id,
|
|
||||||
first_frame_url, last_frame_url=None,
|
|
||||||
**kwargs) -> NodeResult:
|
|
||||||
workspace_id = self.workflow_manage.get_body().get('workspace_id')
|
|
||||||
ttv_model = get_model_instance_by_model_workspace_id(model_id, workspace_id,
|
|
||||||
**model_params_setting)
|
|
||||||
history_message = self.get_history_message(history_chat_record, dialogue_number)
|
|
||||||
self.context['history_message'] = history_message
|
|
||||||
question = self.generate_prompt_question(prompt)
|
|
||||||
self.context['question'] = question
|
|
||||||
message_list = self.generate_message_list(question, history_message)
|
|
||||||
self.context['message_list'] = message_list
|
|
||||||
self.context['dialogue_type'] = dialogue_type
|
|
||||||
self.context['negative_prompt'] = self.generate_prompt_question(negative_prompt)
|
|
||||||
self.context['first_frame_url'] = first_frame_url
|
|
||||||
self.context['last_frame_url'] = last_frame_url
|
|
||||||
# 处理首尾帧图片 这块可以是url 也可以是file_id 如果是url 可以直接传递给模型 如果是file_id 需要传base64
|
|
||||||
# 判断是不是 url
|
|
||||||
first_frame_url = self.get_file_base64(first_frame_url)
|
|
||||||
last_frame_url = self.get_file_base64(last_frame_url)
|
|
||||||
video_urls = ttv_model.generate_video(question, negative_prompt, first_frame_url, last_frame_url)
|
|
||||||
# 保存图片
|
|
||||||
if video_urls is None or video_urls == '':
|
|
||||||
return NodeResult({'answer': gettext('Failed to generate video')}, {})
|
|
||||||
file_name = 'generated_video.mp4'
|
|
||||||
if isinstance(video_urls, str) and video_urls.startswith('http'):
|
|
||||||
video_urls = requests.get(video_urls).content
|
|
||||||
file = bytes_to_uploaded_file(video_urls, file_name)
|
|
||||||
file_url = self.upload_file(file)
|
|
||||||
video_label = f'<video src="{file_url}" controls style="max-width: 100%; width: 100%; height: auto; max-height: 60vh;"></video>'
|
|
||||||
video_list = [{'file_id': file_url.split('/')[-1], 'file_name': file_name, 'url': file_url}]
|
|
||||||
return NodeResult({'answer': video_label, 'chat_model': ttv_model, 'message_list': message_list,
|
|
||||||
'video': video_list,
|
|
||||||
'history_message': history_message, 'question': question}, {})
|
|
||||||
|
|
||||||
def get_file_base64(self, image_url):
|
|
||||||
try:
|
|
||||||
if isinstance(image_url, list):
|
|
||||||
image_url = image_url[0].get('file_id') if 'file_id' in image_url[0] else image_url[0].get('url')
|
|
||||||
if isinstance(image_url, str) and not image_url.startswith('http'):
|
|
||||||
file = QuerySet(File).filter(id=image_url).first()
|
|
||||||
file_bytes = file.get_bytes()
|
|
||||||
# 如果我不知道content_type 可以用 magic 库去检测
|
|
||||||
file_type = file.file_name.split(".")[-1].lower()
|
|
||||||
content_type = mime_types.get(file_type, 'application/octet-stream')
|
|
||||||
encoded_bytes = base64.b64encode(file_bytes)
|
|
||||||
return f'data:{content_type};base64,{encoded_bytes.decode()}'
|
|
||||||
return image_url
|
|
||||||
except Exception as e:
|
|
||||||
raise ValueError(
|
|
||||||
gettext("Failed to obtain the image"))
|
|
||||||
|
|
||||||
def upload_file(self, file):
|
|
||||||
if [WorkflowMode.KNOWLEDGE, WorkflowMode.KNOWLEDGE_LOOP].__contains__(
|
|
||||||
self.workflow_manage.flow.workflow_mode):
|
|
||||||
return self.upload_knowledge_file(file)
|
|
||||||
return self.upload_application_file(file)
|
|
||||||
|
|
||||||
def upload_knowledge_file(self, file):
|
|
||||||
knowledge_id = self.workflow_params.get('knowledge_id')
|
|
||||||
meta = {
|
|
||||||
'debug': False,
|
|
||||||
'knowledge_id': knowledge_id
|
|
||||||
}
|
|
||||||
file_url = FileSerializer(data={
|
|
||||||
'file': file,
|
|
||||||
'meta': meta,
|
|
||||||
'source_id': knowledge_id,
|
|
||||||
'source_type': FileSourceType.KNOWLEDGE.value
|
|
||||||
}).upload()
|
|
||||||
return file_url
|
|
||||||
|
|
||||||
def upload_application_file(self, file):
|
|
||||||
application = self.workflow_manage.work_flow_post_handler.chat_info.application
|
|
||||||
chat_id = self.workflow_params.get('chat_id')
|
|
||||||
meta = {
|
|
||||||
'debug': False if application.id else True,
|
|
||||||
'chat_id': chat_id,
|
|
||||||
'application_id': str(application.id) if application.id else None,
|
|
||||||
}
|
|
||||||
file_url = FileSerializer(data={
|
|
||||||
'file': file,
|
|
||||||
'meta': meta,
|
|
||||||
'source_id': meta['application_id'],
|
|
||||||
'source_type': FileSourceType.APPLICATION.value
|
|
||||||
}).upload()
|
|
||||||
return file_url
|
|
||||||
|
|
||||||
def generate_history_ai_message(self, chat_record):
|
|
||||||
for val in chat_record.details.values():
|
|
||||||
if self.node.id == val['node_id'] and 'image_list' in val:
|
|
||||||
if val['dialogue_type'] == 'WORKFLOW':
|
|
||||||
return chat_record.get_ai_message()
|
|
||||||
image_list = val['image_list']
|
|
||||||
return AIMessage(content=[
|
|
||||||
*[{'type': 'image_url', 'image_url': {'url': f'{file_url}'}} for file_url in image_list]
|
|
||||||
])
|
|
||||||
return chat_record.get_ai_message()
|
|
||||||
|
|
||||||
def get_history_message(self, history_chat_record, dialogue_number):
|
|
||||||
start_index = len(history_chat_record) - dialogue_number
|
|
||||||
history_message = reduce(lambda x, y: [*x, *y], [
|
|
||||||
[self.generate_history_human_message(history_chat_record[index]),
|
|
||||||
self.generate_history_ai_message(history_chat_record[index])]
|
|
||||||
for index in
|
|
||||||
range(start_index if start_index > 0 else 0, len(history_chat_record))], [])
|
|
||||||
return history_message
|
|
||||||
|
|
||||||
def generate_history_human_message(self, chat_record):
|
|
||||||
|
|
||||||
for data in chat_record.details.values():
|
|
||||||
if self.node.id == data['node_id'] and 'image_list' in data:
|
|
||||||
image_list = data['image_list']
|
|
||||||
if len(image_list) == 0 or data['dialogue_type'] == 'WORKFLOW':
|
|
||||||
return HumanMessage(content=chat_record.problem_text)
|
|
||||||
return HumanMessage(content=data['question'])
|
|
||||||
return HumanMessage(content=chat_record.problem_text)
|
|
||||||
|
|
||||||
def generate_prompt_question(self, prompt):
|
|
||||||
return self.workflow_manage.generate_prompt(prompt)
|
|
||||||
|
|
||||||
def generate_message_list(self, question: str, history_message):
|
|
||||||
return [
|
|
||||||
*history_message,
|
|
||||||
question
|
|
||||||
]
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def reset_message_list(message_list: List[BaseMessage], answer_text):
|
|
||||||
result = [{'role': 'user' if isinstance(message, HumanMessage) else 'ai', 'content': message.content} for
|
|
||||||
message
|
|
||||||
in
|
|
||||||
message_list]
|
|
||||||
result.append({'role': 'ai', 'content': answer_text})
|
|
||||||
return result
|
|
||||||
|
|
||||||
def get_details(self, index: int, **kwargs):
|
|
||||||
return {
|
|
||||||
'name': self.node.properties.get('stepName'),
|
|
||||||
"index": index,
|
|
||||||
'run_time': self.context.get('run_time'),
|
|
||||||
'history_message': [{'content': message.content, 'role': message.type} for message in
|
|
||||||
(self.context.get('history_message') if self.context.get(
|
|
||||||
'history_message') is not None else [])],
|
|
||||||
'question': self.context.get('question'),
|
|
||||||
'answer': self.context.get('answer'),
|
|
||||||
'type': self.node.type,
|
|
||||||
'message_tokens': self.context.get('message_tokens'),
|
|
||||||
'answer_tokens': self.context.get('answer_tokens'),
|
|
||||||
'status': self.status,
|
|
||||||
'err_message': self.err_message,
|
|
||||||
'first_frame_url': self.context.get('first_frame_url'),
|
|
||||||
'last_frame_url': self.context.get('last_frame_url'),
|
|
||||||
'dialogue_type': self.context.get('dialogue_type'),
|
|
||||||
'negative_prompt': self.context.get('negative_prompt'),
|
|
||||||
}
|
|
||||||
|
|
@ -4,34 +4,31 @@ from typing import Type
|
||||||
|
|
||||||
from rest_framework import serializers
|
from rest_framework import serializers
|
||||||
|
|
||||||
from application.flow.common import WorkflowMode
|
|
||||||
from application.flow.i_step_node import INode, NodeResult
|
from application.flow.i_step_node import INode, NodeResult
|
||||||
|
from common.util.field_message import ErrMessage
|
||||||
from django.utils.translation import gettext_lazy as _
|
from django.utils.translation import gettext_lazy as _
|
||||||
|
|
||||||
|
|
||||||
class ImageUnderstandNodeSerializer(serializers.Serializer):
|
class ImageUnderstandNodeSerializer(serializers.Serializer):
|
||||||
model_id = serializers.CharField(required=True, label=_("Model id"))
|
model_id = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Model id")))
|
||||||
system = serializers.CharField(required=False, allow_blank=True, allow_null=True,
|
system = serializers.CharField(required=False, allow_blank=True, allow_null=True,
|
||||||
label=_("Role Setting"))
|
error_messages=ErrMessage.char(_("Role Setting")))
|
||||||
prompt = serializers.CharField(required=True, label=_("Prompt word"))
|
prompt = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Prompt word")))
|
||||||
# 多轮对话数量
|
# 多轮对话数量
|
||||||
dialogue_number = serializers.IntegerField(required=True, label=_("Number of multi-round conversations"))
|
dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer(_("Number of multi-round conversations")))
|
||||||
|
|
||||||
dialogue_type = serializers.CharField(required=True, label=_("Conversation storage type"))
|
dialogue_type = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Conversation storage type")))
|
||||||
|
|
||||||
is_result = serializers.BooleanField(required=False,
|
is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean(_('Whether to return content')))
|
||||||
label=_('Whether to return content'))
|
|
||||||
|
|
||||||
image_list = serializers.ListField(required=False, label=_("picture"))
|
image_list = serializers.ListField(required=False, error_messages=ErrMessage.list(_("picture")))
|
||||||
|
|
||||||
model_params_setting = serializers.JSONField(required=False, default=dict,
|
model_params_setting = serializers.JSONField(required=False, default=dict,
|
||||||
label=_("Model parameter settings"))
|
error_messages=ErrMessage.json(_("Model parameter settings")))
|
||||||
|
|
||||||
|
|
||||||
class IImageUnderstandNode(INode):
|
class IImageUnderstandNode(INode):
|
||||||
type = 'image-understand-node'
|
type = 'image-understand-node'
|
||||||
support = [WorkflowMode.APPLICATION, WorkflowMode.APPLICATION_LOOP]
|
|
||||||
|
|
||||||
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
|
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
|
||||||
return ImageUnderstandNodeSerializer
|
return ImageUnderstandNodeSerializer
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,8 @@
|
||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
import base64
|
import base64
|
||||||
|
import os
|
||||||
import time
|
import time
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
from imghdr import what
|
|
||||||
from typing import List, Dict
|
from typing import List, Dict
|
||||||
|
|
||||||
from django.db.models import QuerySet
|
from django.db.models import QuerySet
|
||||||
|
|
@ -10,8 +10,9 @@ from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage, AI
|
||||||
|
|
||||||
from application.flow.i_step_node import NodeResult, INode
|
from application.flow.i_step_node import NodeResult, INode
|
||||||
from application.flow.step_node.image_understand_step_node.i_image_understand_node import IImageUnderstandNode
|
from application.flow.step_node.image_understand_step_node.i_image_understand_node import IImageUnderstandNode
|
||||||
from knowledge.models import File
|
from dataset.models import File
|
||||||
from models_provider.tools import get_model_instance_by_model_workspace_id
|
from setting.models_provider.tools import get_model_instance_by_model_user_id
|
||||||
|
from imghdr import what
|
||||||
|
|
||||||
|
|
||||||
def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str):
|
def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str):
|
||||||
|
|
@ -59,9 +60,9 @@ def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wor
|
||||||
|
|
||||||
def file_id_to_base64(file_id: str):
|
def file_id_to_base64(file_id: str):
|
||||||
file = QuerySet(File).filter(id=file_id).first()
|
file = QuerySet(File).filter(id=file_id).first()
|
||||||
file_bytes = file.get_bytes()
|
file_bytes = file.get_byte()
|
||||||
base64_image = base64.b64encode(file_bytes).decode("utf-8")
|
base64_image = base64.b64encode(file_bytes).decode("utf-8")
|
||||||
return [base64_image, what(None, file_bytes)]
|
return [base64_image, what(None, file_bytes.tobytes())]
|
||||||
|
|
||||||
|
|
||||||
class BaseImageUnderstandNode(IImageUnderstandNode):
|
class BaseImageUnderstandNode(IImageUnderstandNode):
|
||||||
|
|
@ -77,9 +78,10 @@ class BaseImageUnderstandNode(IImageUnderstandNode):
|
||||||
image,
|
image,
|
||||||
**kwargs) -> NodeResult:
|
**kwargs) -> NodeResult:
|
||||||
# 处理不正确的参数
|
# 处理不正确的参数
|
||||||
workspace_id = self.workflow_manage.get_body().get('workspace_id')
|
if image is None or not isinstance(image, list):
|
||||||
image_model = get_model_instance_by_model_workspace_id(model_id, workspace_id,
|
image = []
|
||||||
**model_params_setting)
|
print(model_params_setting)
|
||||||
|
image_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'), **model_params_setting)
|
||||||
# 执行详情中的历史消息不需要图片内容
|
# 执行详情中的历史消息不需要图片内容
|
||||||
history_message = self.get_history_message_for_details(history_chat_record, dialogue_number)
|
history_message = self.get_history_message_for_details(history_chat_record, dialogue_number)
|
||||||
self.context['history_message'] = history_message
|
self.context['history_message'] = history_message
|
||||||
|
|
@ -89,7 +91,7 @@ class BaseImageUnderstandNode(IImageUnderstandNode):
|
||||||
message_list = self.generate_message_list(image_model, system, prompt,
|
message_list = self.generate_message_list(image_model, system, prompt,
|
||||||
self.get_history_message(history_chat_record, dialogue_number), image)
|
self.get_history_message(history_chat_record, dialogue_number), image)
|
||||||
self.context['message_list'] = message_list
|
self.context['message_list'] = message_list
|
||||||
self.generate_context_image(image)
|
self.context['image_list'] = image
|
||||||
self.context['dialogue_type'] = dialogue_type
|
self.context['dialogue_type'] = dialogue_type
|
||||||
if stream:
|
if stream:
|
||||||
r = image_model.stream(message_list)
|
r = image_model.stream(message_list)
|
||||||
|
|
@ -102,12 +104,6 @@ class BaseImageUnderstandNode(IImageUnderstandNode):
|
||||||
'history_message': history_message, 'question': question.content}, {},
|
'history_message': history_message, 'question': question.content}, {},
|
||||||
_write_context=write_context)
|
_write_context=write_context)
|
||||||
|
|
||||||
def generate_context_image(self, image):
|
|
||||||
if isinstance(image, str) and image.startswith('http'):
|
|
||||||
self.context['image_list'] = [{'url': image}]
|
|
||||||
elif image is not None and len(image) > 0:
|
|
||||||
self.context['image_list'] = image
|
|
||||||
|
|
||||||
def get_history_message_for_details(self, history_chat_record, dialogue_number):
|
def get_history_message_for_details(self, history_chat_record, dialogue_number):
|
||||||
start_index = len(history_chat_record) - dialogue_number
|
start_index = len(history_chat_record) - dialogue_number
|
||||||
history_message = reduce(lambda x, y: [*x, *y], [
|
history_message = reduce(lambda x, y: [*x, *y], [
|
||||||
|
|
@ -131,18 +127,11 @@ class BaseImageUnderstandNode(IImageUnderstandNode):
|
||||||
image_list = data['image_list']
|
image_list = data['image_list']
|
||||||
if len(image_list) == 0 or data['dialogue_type'] == 'WORKFLOW':
|
if len(image_list) == 0 or data['dialogue_type'] == 'WORKFLOW':
|
||||||
return HumanMessage(content=chat_record.problem_text)
|
return HumanMessage(content=chat_record.problem_text)
|
||||||
|
file_id_list = [image.get('file_id') for image in image_list]
|
||||||
file_id_list = []
|
|
||||||
url_list = []
|
|
||||||
for image in image_list:
|
|
||||||
if 'file_id' in image:
|
|
||||||
file_id_list.append(image.get('file_id'))
|
|
||||||
elif 'url' in image:
|
|
||||||
url_list.append(image.get('url'))
|
|
||||||
return HumanMessage(content=[
|
return HumanMessage(content=[
|
||||||
{'type': 'text', 'text': data['question']},
|
{'type': 'text', 'text': data['question']},
|
||||||
*[{'type': 'image_url', 'image_url': {'url': f'./oss/file/{file_id}'}} for file_id in file_id_list],
|
*[{'type': 'image_url', 'image_url': {'url': f'/api/file/{file_id}'}} for file_id in file_id_list]
|
||||||
*[{'type': 'image_url', 'image_url': {'url': url}} for url in url_list]
|
|
||||||
])
|
])
|
||||||
return HumanMessage(content=chat_record.problem_text)
|
return HumanMessage(content=chat_record.problem_text)
|
||||||
|
|
||||||
|
|
@ -162,58 +151,36 @@ class BaseImageUnderstandNode(IImageUnderstandNode):
|
||||||
image_list = data['image_list']
|
image_list = data['image_list']
|
||||||
if len(image_list) == 0 or data['dialogue_type'] == 'WORKFLOW':
|
if len(image_list) == 0 or data['dialogue_type'] == 'WORKFLOW':
|
||||||
return HumanMessage(content=chat_record.problem_text)
|
return HumanMessage(content=chat_record.problem_text)
|
||||||
file_id_list = []
|
image_base64_list = [file_id_to_base64(image.get('file_id')) for image in image_list]
|
||||||
url_list = []
|
|
||||||
for image in image_list:
|
|
||||||
if 'file_id' in image:
|
|
||||||
file_id_list.append(image.get('file_id'))
|
|
||||||
elif 'url' in image:
|
|
||||||
url_list.append(image.get('url'))
|
|
||||||
image_base64_list = [file_id_to_base64(file_id) for file_id in file_id_list]
|
|
||||||
|
|
||||||
return HumanMessage(
|
return HumanMessage(
|
||||||
content=[
|
content=[
|
||||||
{'type': 'text', 'text': data['question']},
|
{'type': 'text', 'text': data['question']},
|
||||||
*[{'type': 'image_url',
|
*[{'type': 'image_url', 'image_url': {'url': f'data:image/{base64_image[1]};base64,{base64_image[0]}'}} for
|
||||||
'image_url': {'url': f'data:image/{base64_image[1]};base64,{base64_image[0]}'}} for
|
base64_image in image_base64_list]
|
||||||
base64_image in image_base64_list],
|
|
||||||
*[{'type': 'image_url', 'image_url': url} for url in url_list]
|
|
||||||
])
|
])
|
||||||
return HumanMessage(content=chat_record.problem_text)
|
return HumanMessage(content=chat_record.problem_text)
|
||||||
|
|
||||||
def generate_prompt_question(self, prompt):
|
def generate_prompt_question(self, prompt):
|
||||||
return HumanMessage(self.workflow_manage.generate_prompt(prompt))
|
return HumanMessage(self.workflow_manage.generate_prompt(prompt))
|
||||||
|
|
||||||
def _process_images(self, image):
|
|
||||||
"""
|
|
||||||
处理图像数据,转换为模型可识别的格式
|
|
||||||
"""
|
|
||||||
images = []
|
|
||||||
if isinstance(image, str) and image.startswith('http'):
|
|
||||||
images.append({'type': 'image_url', 'image_url': {'url': image}})
|
|
||||||
elif image is not None and len(image) > 0:
|
|
||||||
for img in image:
|
|
||||||
if 'file_id' in img:
|
|
||||||
file_id = img['file_id']
|
|
||||||
file = QuerySet(File).filter(id=file_id).first()
|
|
||||||
image_bytes = file.get_bytes()
|
|
||||||
base64_image = base64.b64encode(image_bytes).decode("utf-8")
|
|
||||||
image_format = what(None, image_bytes)
|
|
||||||
images.append(
|
|
||||||
{'type': 'image_url', 'image_url': {'url': f'data:image/{image_format};base64,{base64_image}'}})
|
|
||||||
elif 'url' in img and img['url'].startswith('http'):
|
|
||||||
images.append(
|
|
||||||
{'type': 'image_url', 'image_url': {'url': img["url"]}})
|
|
||||||
return images
|
|
||||||
|
|
||||||
def generate_message_list(self, image_model, system: str, prompt: str, history_message, image):
|
def generate_message_list(self, image_model, system: str, prompt: str, history_message, image):
|
||||||
prompt_text = self.workflow_manage.generate_prompt(prompt)
|
if image is not None and len(image) > 0:
|
||||||
images = self._process_images(image)
|
# 处理多张图片
|
||||||
|
images = []
|
||||||
if images:
|
for img in image:
|
||||||
messages = [HumanMessage(content=[{'type': 'text', 'text': prompt_text}, *images])]
|
file_id = img['file_id']
|
||||||
|
file = QuerySet(File).filter(id=file_id).first()
|
||||||
|
image_bytes = file.get_byte()
|
||||||
|
base64_image = base64.b64encode(image_bytes).decode("utf-8")
|
||||||
|
image_format = what(None, image_bytes.tobytes())
|
||||||
|
images.append({'type': 'image_url', 'image_url': {'url': f'data:image/{image_format};base64,{base64_image}'}})
|
||||||
|
messages = [HumanMessage(
|
||||||
|
content=[
|
||||||
|
{'type': 'text', 'text': self.workflow_manage.generate_prompt(prompt)},
|
||||||
|
*images
|
||||||
|
])]
|
||||||
else:
|
else:
|
||||||
messages = [HumanMessage(prompt_text)]
|
messages = [HumanMessage(self.workflow_manage.generate_prompt(prompt))]
|
||||||
|
|
||||||
if system is not None and len(system) > 0:
|
if system is not None and len(system) > 0:
|
||||||
return [
|
return [
|
||||||
|
|
|
||||||
|
|
@ -1,6 +0,0 @@
|
||||||
# coding=utf-8
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
from .impl import *
|
|
||||||
|
|
@ -1,55 +0,0 @@
|
||||||
# coding=utf-8
|
|
||||||
|
|
||||||
from typing import Type
|
|
||||||
|
|
||||||
from django.utils.translation import gettext_lazy as _
|
|
||||||
from rest_framework import serializers
|
|
||||||
|
|
||||||
from application.flow.common import WorkflowMode
|
|
||||||
from application.flow.i_step_node import INode, NodeResult
|
|
||||||
|
|
||||||
|
|
||||||
class IntentBranchSerializer(serializers.Serializer):
|
|
||||||
id = serializers.CharField(required=True, label=_("Branch id"))
|
|
||||||
content = serializers.CharField(required=True, label=_("content"))
|
|
||||||
isOther = serializers.BooleanField(required=True, label=_("Branch Type"))
|
|
||||||
|
|
||||||
|
|
||||||
class IntentNodeSerializer(serializers.Serializer):
|
|
||||||
model_id = serializers.CharField(required=True, label=_("Model id"))
|
|
||||||
content_list = serializers.ListField(required=True, label=_("Text content"))
|
|
||||||
dialogue_number = serializers.IntegerField(required=True, label=
|
|
||||||
_("Number of multi-round conversations"))
|
|
||||||
model_params_setting = serializers.DictField(required=False,
|
|
||||||
label=_("Model parameter settings"))
|
|
||||||
branch = IntentBranchSerializer(many=True)
|
|
||||||
|
|
||||||
|
|
||||||
class IIntentNode(INode):
|
|
||||||
type = 'intent-node'
|
|
||||||
support = [WorkflowMode.APPLICATION, WorkflowMode.APPLICATION_LOOP, WorkflowMode.KNOWLEDGE,
|
|
||||||
WorkflowMode.KNOWLEDGE_LOOP]
|
|
||||||
|
|
||||||
def save_context(self, details, workflow_manage):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
|
|
||||||
return IntentNodeSerializer
|
|
||||||
|
|
||||||
def _run(self):
|
|
||||||
question = self.workflow_manage.get_reference_field(
|
|
||||||
self.node_params_serializer.data.get('content_list')[0],
|
|
||||||
self.node_params_serializer.data.get('content_list')[1:],
|
|
||||||
)
|
|
||||||
if [WorkflowMode.KNOWLEDGE, WorkflowMode.KNOWLEDGE_LOOP].__contains__(
|
|
||||||
self.workflow_manage.flow.workflow_mode):
|
|
||||||
return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data,
|
|
||||||
**{'history_chat_record': [], 'stream': True, 'chat_id': None, 'chat_record_id': None,
|
|
||||||
'user_input': str(question)})
|
|
||||||
else:
|
|
||||||
return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data,
|
|
||||||
user_input=str(question))
|
|
||||||
|
|
||||||
def execute(self, model_id, dialogue_number, history_chat_record, user_input, branch,
|
|
||||||
model_params_setting=None, **kwargs) -> NodeResult:
|
|
||||||
pass
|
|
||||||
|
|
@ -1,3 +0,0 @@
|
||||||
|
|
||||||
|
|
||||||
from .base_intent_node import BaseIntentNode
|
|
||||||
|
|
@ -1,261 +0,0 @@
|
||||||
# coding=utf-8
|
|
||||||
import json
|
|
||||||
import re
|
|
||||||
import time
|
|
||||||
from typing import List, Dict, Any
|
|
||||||
from functools import reduce
|
|
||||||
|
|
||||||
from django.db.models import QuerySet
|
|
||||||
from langchain.schema import HumanMessage, SystemMessage
|
|
||||||
|
|
||||||
from application.flow.i_step_node import INode, NodeResult
|
|
||||||
from application.flow.step_node.intent_node.i_intent_node import IIntentNode
|
|
||||||
from models_provider.models import Model
|
|
||||||
from models_provider.tools import get_model_instance_by_model_workspace_id, get_model_credential
|
|
||||||
from .prompt_template import PROMPT_TEMPLATE
|
|
||||||
|
|
||||||
def get_default_model_params_setting(model_id):
|
|
||||||
|
|
||||||
model = QuerySet(Model).filter(id=model_id).first()
|
|
||||||
credential = get_model_credential(model.provider, model.model_type, model.model_name)
|
|
||||||
model_params_setting = credential.get_model_params_setting_form(
|
|
||||||
model.model_name).get_default_form_data()
|
|
||||||
return model_params_setting
|
|
||||||
|
|
||||||
|
|
||||||
def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str):
|
|
||||||
|
|
||||||
chat_model = node_variable.get('chat_model')
|
|
||||||
message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list'))
|
|
||||||
answer_tokens = chat_model.get_num_tokens(answer)
|
|
||||||
|
|
||||||
node.context['message_tokens'] = message_tokens
|
|
||||||
node.context['answer_tokens'] = answer_tokens
|
|
||||||
node.context['answer'] = answer
|
|
||||||
node.context['history_message'] = node_variable['history_message']
|
|
||||||
node.context['user_input'] = node_variable['user_input']
|
|
||||||
node.context['branch_id'] = node_variable.get('branch_id')
|
|
||||||
node.context['reason'] = node_variable.get('reason')
|
|
||||||
node.context['category'] = node_variable.get('category')
|
|
||||||
node.context['run_time'] = time.time() - node.context['start_time']
|
|
||||||
|
|
||||||
|
|
||||||
def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
|
|
||||||
|
|
||||||
response = node_variable.get('result')
|
|
||||||
answer = response.content
|
|
||||||
_write_context(node_variable, workflow_variable, node, workflow, answer)
|
|
||||||
|
|
||||||
|
|
||||||
class BaseIntentNode(IIntentNode):
|
|
||||||
|
|
||||||
|
|
||||||
def save_context(self, details, workflow_manage):
|
|
||||||
|
|
||||||
self.context['branch_id'] = details.get('branch_id')
|
|
||||||
self.context['category'] = details.get('category')
|
|
||||||
|
|
||||||
|
|
||||||
def execute(self, model_id, dialogue_number, history_chat_record, user_input, branch,
|
|
||||||
model_params_setting=None, **kwargs) -> NodeResult:
|
|
||||||
|
|
||||||
# 设置默认模型参数
|
|
||||||
if model_params_setting is None:
|
|
||||||
model_params_setting = get_default_model_params_setting(model_id)
|
|
||||||
|
|
||||||
# 获取模型实例
|
|
||||||
workspace_id = self.workflow_manage.get_body().get('workspace_id')
|
|
||||||
chat_model = get_model_instance_by_model_workspace_id(
|
|
||||||
model_id, workspace_id, **model_params_setting
|
|
||||||
)
|
|
||||||
|
|
||||||
# 获取历史对话
|
|
||||||
history_message = self.get_history_message(history_chat_record, dialogue_number)
|
|
||||||
self.context['history_message'] = history_message
|
|
||||||
|
|
||||||
# 保存问题到上下文
|
|
||||||
self.context['user_input'] = user_input
|
|
||||||
|
|
||||||
# 构建分类提示词
|
|
||||||
prompt = self.build_classification_prompt(user_input, branch)
|
|
||||||
|
|
||||||
|
|
||||||
# 生成消息列表
|
|
||||||
system = self.build_system_prompt()
|
|
||||||
message_list = self.generate_message_list(system, prompt, history_message)
|
|
||||||
self.context['message_list'] = message_list
|
|
||||||
|
|
||||||
# 调用模型进行分类
|
|
||||||
try:
|
|
||||||
r = chat_model.invoke(message_list)
|
|
||||||
classification_result = r.content.strip()
|
|
||||||
# 解析分类结果获取分支信息
|
|
||||||
matched_branch = self.parse_classification_result(classification_result, branch)
|
|
||||||
|
|
||||||
# 返回结果
|
|
||||||
return NodeResult({
|
|
||||||
'result': r,
|
|
||||||
'chat_model': chat_model,
|
|
||||||
'message_list': message_list,
|
|
||||||
'history_message': history_message,
|
|
||||||
'user_input': user_input,
|
|
||||||
'branch_id': matched_branch['id'],
|
|
||||||
'reason': self.parse_result_reason(r.content),
|
|
||||||
'category': matched_branch.get('content', matched_branch['id'])
|
|
||||||
}, {}, _write_context=write_context)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
# 错误处理:返回"其他"分支
|
|
||||||
other_branch = self.find_other_branch(branch)
|
|
||||||
if other_branch:
|
|
||||||
return NodeResult({
|
|
||||||
'branch_id': other_branch['id'],
|
|
||||||
'category': other_branch.get('content', other_branch['id']),
|
|
||||||
'error': str(e)
|
|
||||||
}, {})
|
|
||||||
else:
|
|
||||||
raise Exception(f"error: {str(e)}")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_history_message(history_chat_record, dialogue_number):
|
|
||||||
"""获取历史消息"""
|
|
||||||
start_index = len(history_chat_record) - dialogue_number
|
|
||||||
history_message = reduce(lambda x, y: [*x, *y], [
|
|
||||||
[history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()]
|
|
||||||
for index in
|
|
||||||
range(start_index if start_index > 0 else 0, len(history_chat_record))], [])
|
|
||||||
|
|
||||||
for message in history_message:
|
|
||||||
if isinstance(message.content, str):
|
|
||||||
message.content = re.sub('<form_rander>[\d\D]*?<\/form_rander>', '', message.content)
|
|
||||||
return history_message
|
|
||||||
|
|
||||||
|
|
||||||
def build_system_prompt(self) -> str:
|
|
||||||
"""构建系统提示词"""
|
|
||||||
return "你是一个专业的意图识别助手,请根据用户输入和意图选项,准确识别用户的真实意图。"
|
|
||||||
|
|
||||||
def build_classification_prompt(self, user_input: str, branch: List[Dict]) -> str:
|
|
||||||
"""构建分类提示词"""
|
|
||||||
|
|
||||||
classification_list = []
|
|
||||||
|
|
||||||
other_branch = self.find_other_branch(branch)
|
|
||||||
# 添加其他分支
|
|
||||||
if other_branch:
|
|
||||||
classification_list.append({
|
|
||||||
"classificationId": 0,
|
|
||||||
"content": other_branch.get('content')
|
|
||||||
})
|
|
||||||
# 添加正常分支
|
|
||||||
classification_id = 1
|
|
||||||
for b in branch:
|
|
||||||
if not b.get('isOther'):
|
|
||||||
classification_list.append({
|
|
||||||
"classificationId": classification_id,
|
|
||||||
"content": b['content']
|
|
||||||
})
|
|
||||||
classification_id += 1
|
|
||||||
|
|
||||||
return PROMPT_TEMPLATE.format(
|
|
||||||
classification_list=classification_list,
|
|
||||||
user_input=user_input
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def generate_message_list(self, system: str, prompt: str, history_message):
|
|
||||||
"""生成消息列表"""
|
|
||||||
if system is None or len(system) == 0:
|
|
||||||
return [*history_message, HumanMessage(self.workflow_manage.generate_prompt(prompt))]
|
|
||||||
else:
|
|
||||||
return [SystemMessage(self.workflow_manage.generate_prompt(system)), *history_message,
|
|
||||||
HumanMessage(self.workflow_manage.generate_prompt(prompt))]
|
|
||||||
|
|
||||||
def parse_classification_result(self, result: str, branch: List[Dict]) -> Dict[str, Any]:
|
|
||||||
"""解析分类结果"""
|
|
||||||
|
|
||||||
other_branch = self.find_other_branch(branch)
|
|
||||||
normal_intents = [
|
|
||||||
b
|
|
||||||
for b in branch
|
|
||||||
if not b.get('isOther')
|
|
||||||
]
|
|
||||||
|
|
||||||
def get_branch_by_id(category_id: int):
|
|
||||||
if category_id == 0:
|
|
||||||
return other_branch
|
|
||||||
elif 1 <= category_id <= len(normal_intents):
|
|
||||||
return normal_intents[category_id - 1]
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
result_json = json.loads(result)
|
|
||||||
classification_id = result_json.get('classificationId')
|
|
||||||
# 如果是 0 ,返回其他分支
|
|
||||||
matched_branch = get_branch_by_id(classification_id)
|
|
||||||
if matched_branch:
|
|
||||||
return matched_branch
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
# json 解析失败,re 提取
|
|
||||||
numbers = re.findall(r'"classificationId":\s*(\d+)', result)
|
|
||||||
if numbers:
|
|
||||||
classification_id = int(numbers[0])
|
|
||||||
|
|
||||||
matched_branch = get_branch_by_id(classification_id)
|
|
||||||
if matched_branch:
|
|
||||||
return matched_branch
|
|
||||||
|
|
||||||
# 如果都解析失败,返回“other”
|
|
||||||
return other_branch or (normal_intents[0] if normal_intents else {'id': 'unknown', 'content': 'unknown'})
|
|
||||||
|
|
||||||
def parse_result_reason(self, result: str):
|
|
||||||
"""解析分类的原因"""
|
|
||||||
try:
|
|
||||||
result_json = json.loads(result)
|
|
||||||
return result_json.get('reason', '')
|
|
||||||
except Exception as e:
|
|
||||||
reason_patterns = [
|
|
||||||
r'"reason":\s*"([^"]*)"', # 标准格式
|
|
||||||
r'"reason":\s*"([^"]*)', # 缺少结束引号
|
|
||||||
r'"reason":\s*([^,}\n]*)', # 没有引号包围的内容
|
|
||||||
]
|
|
||||||
for pattern in reason_patterns:
|
|
||||||
match = re.search(pattern, result, re.DOTALL)
|
|
||||||
if match:
|
|
||||||
reason = match.group(1).strip()
|
|
||||||
# 清理可能的尾部字符
|
|
||||||
reason = re.sub(r'["\s]*$', '', reason)
|
|
||||||
return reason
|
|
||||||
|
|
||||||
return ''
|
|
||||||
|
|
||||||
def find_other_branch(self, branch: List[Dict]) -> Dict[str, Any] | None:
|
|
||||||
"""查找其他分支"""
|
|
||||||
for b in branch:
|
|
||||||
if b.get('isOther'):
|
|
||||||
return b
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def get_details(self, index: int, **kwargs):
|
|
||||||
"""获取节点执行详情"""
|
|
||||||
return {
|
|
||||||
'name': self.node.properties.get('stepName'),
|
|
||||||
'index': index,
|
|
||||||
'run_time': self.context.get('run_time'),
|
|
||||||
'system': self.context.get('system'),
|
|
||||||
'history_message': [
|
|
||||||
{'content': message.content, 'role': message.type}
|
|
||||||
for message in (self.context.get('history_message') or [])
|
|
||||||
],
|
|
||||||
'user_input': self.context.get('user_input'),
|
|
||||||
'answer': self.context.get('answer'),
|
|
||||||
'branch_id': self.context.get('branch_id'),
|
|
||||||
'category': self.context.get('category'),
|
|
||||||
'type': self.node.type,
|
|
||||||
'message_tokens': self.context.get('message_tokens'),
|
|
||||||
'answer_tokens': self.context.get('answer_tokens'),
|
|
||||||
'status': self.status,
|
|
||||||
'err_message': self.err_message
|
|
||||||
}
|
|
||||||
|
|
@ -1,32 +0,0 @@
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
PROMPT_TEMPLATE = """
|
|
||||||
# Role
|
|
||||||
You are an intention classification expert, good at being able to judge which classification the user's input belongs to.
|
|
||||||
|
|
||||||
## Skills
|
|
||||||
Skill 1: Clearly determine which of the following intention classifications the user's input belongs to.
|
|
||||||
Intention classification list:
|
|
||||||
{classification_list}
|
|
||||||
|
|
||||||
Note:
|
|
||||||
- Please determine the match only between the user's input content and the Intention classification list content, without judging or categorizing the match with the classification ID.
|
|
||||||
- **When classifying, you must give higher weight to the context and intent continuity shown in the historical conversation. Do not rely solely on the literal meaning of the current input; instead, prioritize the most consistent classification with the previous dialogue flow.**
|
|
||||||
|
|
||||||
## User Input
|
|
||||||
{user_input}
|
|
||||||
|
|
||||||
## Reply requirements
|
|
||||||
- The answer must be returned in JSON format.
|
|
||||||
- Strictly ensure that the output is in a valid JSON format.
|
|
||||||
- Do not add prefix ```json or suffix ```
|
|
||||||
- The answer needs to include the following fields such as:
|
|
||||||
{{
|
|
||||||
"classificationId": 0,
|
|
||||||
"reason": ""
|
|
||||||
}}
|
|
||||||
|
|
||||||
## Limit
|
|
||||||
- Please do not reply in text."""
|
|
||||||
|
|
@ -1,8 +0,0 @@
|
||||||
# coding=utf-8
|
|
||||||
"""
|
|
||||||
@project: MaxKB
|
|
||||||
@Author:niu
|
|
||||||
@file: __init__.py.py
|
|
||||||
@date:2025/11/13 11:17
|
|
||||||
@desc:
|
|
||||||
"""
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue