From 4da8b1b0d86ad9d480baea0a73b60287c3767964 Mon Sep 17 00:00:00 2001
From: shaohuzhang1 <80892890+shaohuzhang1@users.noreply.github.com>
Date: Wed, 8 May 2024 17:31:56 +0800
Subject: [PATCH] =?UTF-8?q?feat:=20=E7=9B=B4=E6=8E=A5=E5=9B=9E=E7=AD=94?=
=?UTF-8?q?=E6=94=AF=E6=8C=81=E8=AE=BE=E7=BD=AE=E7=9B=B8=E4=BC=BC=E5=BA=A6?=
=?UTF-8?q?=E5=80=BC(#371)?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
.../chat_pipeline/I_base_chat_pipeline.py | 10 +-
.../step/chat_step/impl/base_chat_step.py | 4 +-
.../impl/base_search_dataset_step.py | 6 +-
...list_dataset_paragraph_by_paragraph_id.sql | 3 +-
...004_document_directly_return_similarity.py | 18 ++++
apps/dataset/models/data_set.py | 1 +
.../serializers/document_serializers.py | 18 +++-
apps/dataset/swagger_api/document_api.py | 3 +-
.../component/BatchEditDocumentDialog.vue | 96 -------------------
.../component/ImportDocumentDialog.vue | 80 ++++++++++++----
ui/src/views/document/index.vue | 10 +-
11 files changed, 120 insertions(+), 129 deletions(-)
create mode 100644 apps/dataset/migrations/0004_document_directly_return_similarity.py
delete mode 100644 ui/src/views/document/component/BatchEditDocumentDialog.vue
diff --git a/apps/application/chat_pipeline/I_base_chat_pipeline.py b/apps/application/chat_pipeline/I_base_chat_pipeline.py
index 91effa82c..4c894ddbd 100644
--- a/apps/application/chat_pipeline/I_base_chat_pipeline.py
+++ b/apps/application/chat_pipeline/I_base_chat_pipeline.py
@@ -19,7 +19,7 @@ class ParagraphPipelineModel:
def __init__(self, _id: str, document_id: str, dataset_id: str, content: str, title: str, status: str,
is_active: bool, comprehensive_score: float, similarity: float, dataset_name: str, document_name: str,
- hit_handling_method: str):
+ hit_handling_method: str, directly_return_similarity: float):
self.id = _id
self.document_id = document_id
self.dataset_id = dataset_id
@@ -32,6 +32,7 @@ class ParagraphPipelineModel:
self.dataset_name = dataset_name
self.document_name = document_name
self.hit_handling_method = hit_handling_method
+ self.directly_return_similarity = directly_return_similarity
def to_dict(self):
return {
@@ -56,6 +57,7 @@ class ParagraphPipelineModel:
self.document_name = None
self.dataset_name = None
self.hit_handling_method = None
+ self.directly_return_similarity = 0.9
def add_paragraph(self, paragraph):
if isinstance(paragraph, Paragraph):
@@ -83,6 +85,10 @@ class ParagraphPipelineModel:
self.hit_handling_method = hit_handling_method
return self
+ def add_directly_return_similarity(self, directly_return_similarity):
+ self.directly_return_similarity = directly_return_similarity
+ return self
+
def add_comprehensive_score(self, comprehensive_score: float):
self.comprehensive_score = comprehensive_score
return self
@@ -98,7 +104,7 @@ class ParagraphPipelineModel:
self.paragraph.get('status'),
self.paragraph.get('is_active'),
self.comprehensive_score, self.similarity, self.dataset_name,
- self.document_name, self.hit_handling_method)
+ self.document_name, self.hit_handling_method, self.directly_return_similarity)
class IBaseChatPipelineStep:
diff --git a/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py b/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py
index 276d01947..8d7b9d35e 100644
--- a/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py
+++ b/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py
@@ -138,8 +138,8 @@ class BaseChatStep(IChatStep):
if paragraph_list is None:
paragraph_list = []
directly_return_chunk_list = [AIMessageChunk(content=paragraph.content)
- for paragraph in paragraph_list if
- paragraph.hit_handling_method == 'directly_return']
+ for paragraph in paragraph_list if (
+ paragraph.hit_handling_method == 'directly_return' and paragraph.similarity >= paragraph.directly_return_similarity)]
if directly_return_chunk_list is not None and len(directly_return_chunk_list) > 0:
return iter(directly_return_chunk_list), False
elif len(paragraph_list) == 0 and no_references_setting.get(
diff --git a/apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py b/apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py
index bfc1118e8..3dd9f8300 100644
--- a/apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py
+++ b/apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py
@@ -52,6 +52,7 @@ class BaseSearchDatasetStep(ISearchDatasetStep):
.add_dataset_name(paragraph.get('dataset_name'))
.add_document_name(paragraph.get('document_name'))
.add_hit_handling_method(paragraph.get('hit_handling_method'))
+ .add_directly_return_similarity(paragraph.get('directly_return_similarity'))
.build())
@staticmethod
@@ -81,7 +82,10 @@ class BaseSearchDatasetStep(ISearchDatasetStep):
vector.delete_by_paragraph_id(paragraph_id)
# 如果存在直接返回的则取直接返回段落
hit_handling_method_paragraph = [paragraph for paragraph in paragraph_list if
- paragraph.get('hit_handling_method') == 'directly_return']
+ (paragraph.get(
+ 'hit_handling_method') == 'directly_return' and BaseSearchDatasetStep.get_similarity(
+ paragraph, embedding_list) >= paragraph.get(
+ 'directly_return_similarity'))]
if len(hit_handling_method_paragraph) > 0:
# 找到评分最高的
return [sorted(hit_handling_method_paragraph,
diff --git a/apps/application/sql/list_dataset_paragraph_by_paragraph_id.sql b/apps/application/sql/list_dataset_paragraph_by_paragraph_id.sql
index 813d4f090..2bacd53e1 100644
--- a/apps/application/sql/list_dataset_paragraph_by_paragraph_id.sql
+++ b/apps/application/sql/list_dataset_paragraph_by_paragraph_id.sql
@@ -2,7 +2,8 @@ SELECT
paragraph.*,
dataset."name" AS "dataset_name",
"document"."name" AS "document_name",
- "document"."hit_handling_method" AS "hit_handling_method"
+ "document"."hit_handling_method" AS "hit_handling_method",
+ "document"."directly_return_similarity" as "directly_return_similarity"
FROM
paragraph paragraph
LEFT JOIN dataset dataset ON dataset."id" = paragraph.dataset_id
diff --git a/apps/dataset/migrations/0004_document_directly_return_similarity.py b/apps/dataset/migrations/0004_document_directly_return_similarity.py
new file mode 100644
index 000000000..cddf38ca3
--- /dev/null
+++ b/apps/dataset/migrations/0004_document_directly_return_similarity.py
@@ -0,0 +1,18 @@
+# Generated by Django 4.1.13 on 2024-05-08 16:43
+
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ('dataset', '0003_document_hit_handling_method'),
+ ]
+
+ operations = [
+ migrations.AddField(
+ model_name='document',
+ name='directly_return_similarity',
+ field=models.FloatField(default=0.9, verbose_name='直接回答相似度'),
+ ),
+ ]
diff --git a/apps/dataset/models/data_set.py b/apps/dataset/models/data_set.py
index ab1cfa6f1..d0f56a017 100644
--- a/apps/dataset/models/data_set.py
+++ b/apps/dataset/models/data_set.py
@@ -66,6 +66,7 @@ class Document(AppModelMixin):
hit_handling_method = models.CharField(verbose_name='命中处理方式', max_length=20,
choices=HitHandlingMethod.choices,
default=HitHandlingMethod.optimization)
+ directly_return_similarity = models.FloatField(verbose_name='直接回答相似度', default=0.9)
meta = models.JSONField(verbose_name="元数据", default=dict)
diff --git a/apps/dataset/serializers/document_serializers.py b/apps/dataset/serializers/document_serializers.py
index bef62fb9e..f07c5f4e4 100644
--- a/apps/dataset/serializers/document_serializers.py
+++ b/apps/dataset/serializers/document_serializers.py
@@ -50,7 +50,13 @@ class DocumentEditInstanceSerializer(ApiMixin, serializers.Serializer):
code=500)
], error_messages=ErrMessage.char("命中处理方式"))
- is_active = serializers.BooleanField(required=False, error_messages=ErrMessage.char(
+ directly_return_similarity = serializers.FloatField(required=False,
+ max_value=2,
+ min_value=0,
+ error_messages=ErrMessage.float(
+ "直接返回分数"))
+
+ is_active = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean(
"文档是否可用"))
@staticmethod
@@ -371,7 +377,7 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
_document = QuerySet(Document).get(id=self.data.get("document_id"))
if with_valid:
DocumentEditInstanceSerializer(data=instance).is_valid(document=_document)
- update_keys = ['name', 'is_active', 'hit_handling_method', 'meta']
+ update_keys = ['name', 'is_active', 'hit_handling_method', 'directly_return_similarity', 'meta']
for update_key in update_keys:
if update_key in instance and instance.get(update_key) is not None:
_document.__setattr__(update_key, instance.get(update_key))
@@ -444,6 +450,8 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
'is_active': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否可用", description="是否可用"),
'hit_handling_method': openapi.Schema(type=openapi.TYPE_STRING, title="命中处理方式",
description="ai优化:optimization,直接返回:directly_return"),
+ 'directly_return_similarity': openapi.Schema(type=openapi.TYPE_NUMBER, title="直接返回分数",
+ default=0.9),
'meta': openapi.Schema(type=openapi.TYPE_OBJECT, title="文档元数据",
description="文档元数据->web:{source_url:xxx,selector:'xxx'},base:{}"),
}
@@ -731,7 +739,11 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
self.is_valid(raise_exception=True)
document_id_list = instance.get("id_list")
hit_handling_method = instance.get('hit_handling_method')
- QuerySet(Document).filter(id__in=document_id_list).update(hit_handling_method=hit_handling_method)
+ directly_return_similarity = instance.get('directly_return_similarity')
+ update_dict = {'hit_handling_method': hit_handling_method}
+ if directly_return_similarity is not None:
+ update_dict['directly_return_similarity'] = directly_return_similarity
+ QuerySet(Document).filter(id__in=document_id_list).update(**update_dict)
class FileBufferHandle:
diff --git a/apps/dataset/swagger_api/document_api.py b/apps/dataset/swagger_api/document_api.py
index 1463a61c2..637a7e509 100644
--- a/apps/dataset/swagger_api/document_api.py
+++ b/apps/dataset/swagger_api/document_api.py
@@ -22,6 +22,7 @@ class DocumentApi(ApiMixin):
title="主键id列表",
description="主键id列表"),
'hit_handling_method': openapi.Schema(type=openapi.TYPE_STRING, title="命中处理方式",
- description="directly_return|optimization")
+ description="directly_return|optimization"),
+ 'directly_return_similarity': openapi.Schema(type=openapi.TYPE_NUMBER, title="直接返回相似度")
}
)
diff --git a/ui/src/views/document/component/BatchEditDocumentDialog.vue b/ui/src/views/document/component/BatchEditDocumentDialog.vue
deleted file mode 100644
index 55c0de0a5..000000000
--- a/ui/src/views/document/component/BatchEditDocumentDialog.vue
+++ /dev/null
@@ -1,96 +0,0 @@
-
-
-
-
-
-
-
-
-
- {{ value }}
-
-
-
-
-
-
-
-
-
-
-
diff --git a/ui/src/views/document/component/ImportDocumentDialog.vue b/ui/src/views/document/component/ImportDocumentDialog.vue
index f0ddc29eb..e0ab41f1e 100644
--- a/ui/src/views/document/component/ImportDocumentDialog.vue
+++ b/ui/src/views/document/component/ImportDocumentDialog.vue
@@ -43,11 +43,28 @@
-
+
- {{ value }}
+ {{ value }}
+
+ 相似度高于
+ 直接返回分段内容
+
@@ -78,11 +95,17 @@ const isImport = ref(false)
const form = ref({
source_url: '',
selector: '',
- hit_handling_method: ''
+ hit_handling_method: 'optimization',
+ directly_return_similarity: 0.9
})
+
+// 文档设置
const documentId = ref('')
const documentType = ref('') //文档类型:1: web文档;0:普通文档
+// 批量设置
+const documentList = ref>([])
+
const rules = reactive({
source_url: [{ required: true, message: '请输入文档地址', trigger: 'blur' }]
})
@@ -94,20 +117,30 @@ watch(dialogVisible, (bool) => {
form.value = {
source_url: '',
selector: '',
- hit_handling_method: ''
+ hit_handling_method: 'optimization',
+ directly_return_similarity: 0.9
}
isImport.value = false
documentType.value = ''
+ documentList.value = []
}
})
-const open = (row: any) => {
+const open = (row: any, list: Array) => {
if (row) {
documentType.value = row.type
documentId.value = row.id
- form.value = { hit_handling_method: row.hit_handling_method, ...row.meta }
+ form.value = {
+ hit_handling_method: row.hit_handling_method,
+ directly_return_similarity: row.directly_return_similarity,
+ ...row.meta
+ }
isImport.value = false
+ } else if (list) {
+ // 批量设置
+ documentList.value = list
} else {
+ // 导入
isImport.value = true
}
dialogVisible.value = true
@@ -128,18 +161,33 @@ const submit = async (formEl: FormInstance | undefined) => {
dialogVisible.value = false
})
} else {
- const obj = {
- hit_handling_method: form.value.hit_handling_method,
- meta: {
- source_url: form.value.source_url,
- selector: form.value.selector
+ if (documentId.value) {
+ const obj = {
+ hit_handling_method: form.value.hit_handling_method,
+ directly_return_similarity: form.value.directly_return_similarity,
+ meta: {
+ source_url: form.value.source_url,
+ selector: form.value.selector
+ }
}
+ documentApi.putDocument(id, documentId.value, obj, loading).then((res) => {
+ MsgSuccess('设置成功')
+ emit('refresh')
+ dialogVisible.value = false
+ })
+ } else if (documentList.value.length > 0) {
+ // 批量设置
+ const obj = {
+ hit_handling_method: form.value.hit_handling_method,
+ directly_return_similarity: form.value.directly_return_similarity,
+ id_list: documentList.value
+ }
+ documentApi.batchEditHitHandling(id, obj, loading).then((res: any) => {
+ MsgSuccess('设置成功')
+ emit('refresh')
+ dialogVisible.value = false
+ })
}
- documentApi.putDocument(id, documentId.value, obj, loading).then((res) => {
- MsgSuccess('设置成功')
- emit('refresh')
- dialogVisible.value = false
- })
}
}
})
diff --git a/ui/src/views/document/index.vue b/ui/src/views/document/index.vue
index 9b7cffdd1..402d8bdf5 100644
--- a/ui/src/views/document/index.vue
+++ b/ui/src/views/document/index.vue
@@ -215,10 +215,6 @@
-
@@ -232,7 +228,6 @@ import documentApi from '@/api/document'
import ImportDocumentDialog from './component/ImportDocumentDialog.vue'
import SyncWebDialog from '@/views/dataset/component/SyncWebDialog.vue'
import SelectDatasetDialog from './component/SelectDatasetDialog.vue'
-import BatchEditDocumentDialog from './component/BatchEditDocumentDialog.vue'
import { numberFormat } from '@/utils/utils'
import { datetimeFormat } from '@/utils/time'
import { hitHandlingMethod } from './utils'
@@ -265,7 +260,7 @@ onBeforeRouteLeave((to: any, from: any) => {
})
const beforePagination = computed(() => common.paginationConfig[storeKey])
const beforeSearch = computed(() => common.search[storeKey])
-const batchEditDocumentDialogRef = ref>()
+
const SyncWebDialogRef = ref()
const loading = ref(false)
let interval: any
@@ -326,8 +321,9 @@ const handleSelectionChange = (val: any[]) => {
}
function openBatchEditDocument() {
+ title.value = '设置'
const arr: string[] = multipleSelection.value.map((v) => v.id)
- batchEditDocumentDialogRef?.value?.open(arr)
+ ImportDocumentDialogRef.value.open(null, arr)
}
/**