feat: 文档支持设置命中处理方式(#56)

This commit is contained in:
shaohuzhang1 2024-04-25 10:44:14 +08:00 committed by GitHub
parent 6980704cce
commit 25aa4cd2b7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 210 additions and 32 deletions

View File

@ -18,7 +18,8 @@ from dataset.models import Paragraph
class ParagraphPipelineModel:
def __init__(self, _id: str, document_id: str, dataset_id: str, content: str, title: str, status: str,
is_active: bool, comprehensive_score: float, similarity: float, dataset_name: str, document_name: str):
is_active: bool, comprehensive_score: float, similarity: float, dataset_name: str, document_name: str,
hit_handling_method: str):
self.id = _id
self.document_id = document_id
self.dataset_id = dataset_id
@ -30,6 +31,7 @@ class ParagraphPipelineModel:
self.similarity = similarity
self.dataset_name = dataset_name
self.document_name = document_name
self.hit_handling_method = hit_handling_method
def to_dict(self):
return {
@ -53,6 +55,7 @@ class ParagraphPipelineModel:
self.comprehensive_score = None
self.document_name = None
self.dataset_name = None
self.hit_handling_method = None
def add_paragraph(self, paragraph):
if isinstance(paragraph, Paragraph):
@ -76,6 +79,10 @@ class ParagraphPipelineModel:
self.document_name = document_name
return self
def add_hit_handling_method(self, hit_handling_method):
self.hit_handling_method = hit_handling_method
return self
def add_comprehensive_score(self, comprehensive_score: float):
self.comprehensive_score = comprehensive_score
return self
@ -91,7 +98,7 @@ class ParagraphPipelineModel:
self.paragraph.get('status'),
self.paragraph.get('is_active'),
self.comprehensive_score, self.similarity, self.dataset_name,
self.document_name)
self.document_name, self.hit_handling_method)
class IBaseChatPipelineStep:

View File

@ -146,8 +146,17 @@ class BaseChatStep(IChatStep):
'status') == 'designated_answer':
chat_result = iter([AIMessageChunk(content=no_references_setting.get('value'))])
else:
chat_result = chat_model.stream(message_list)
is_ai_chat = True
if paragraph_list is not None and len(paragraph_list) > 0:
directly_return_chunk_list = [AIMessageChunk(content=paragraph.title + "\n" + paragraph.content)
for paragraph in paragraph_list if
paragraph.hit_handling_method == 'directly_return']
if directly_return_chunk_list is not None and len(directly_return_chunk_list) > 0:
chat_result = iter(directly_return_chunk_list)
else:
chat_result = chat_model.stream(message_list)
else:
chat_result = chat_model.stream(message_list)
is_ai_chat = True
chat_record_id = uuid.uuid1()
r = StreamingHttpResponse(

View File

@ -35,8 +35,9 @@ class BaseSearchDatasetStep(ISearchDatasetStep):
exclude_paragraph_id_list, True, top_n, similarity, SearchMode(search_mode))
if embedding_list is None:
return []
paragraph_list = self.list_paragraph([row.get('paragraph_id') for row in embedding_list], vector)
return [self.reset_paragraph(paragraph, embedding_list) for paragraph in paragraph_list]
paragraph_list = self.list_paragraph(embedding_list, vector)
result = [self.reset_paragraph(paragraph, embedding_list) for paragraph in paragraph_list]
return result
@staticmethod
def reset_paragraph(paragraph: Dict, embedding_list: List) -> ParagraphPipelineModel:
@ -50,10 +51,21 @@ class BaseSearchDatasetStep(ISearchDatasetStep):
.add_comprehensive_score(find_embedding.get('comprehensive_score'))
.add_dataset_name(paragraph.get('dataset_name'))
.add_document_name(paragraph.get('document_name'))
.add_hit_handling_method(paragraph.get('hit_handling_method'))
.build())
@staticmethod
def list_paragraph(paragraph_id_list: List, vector):
def get_similarity(paragraph, embedding_list: List):
filter_embedding_list = [embedding for embedding in embedding_list if
str(embedding.get('paragraph_id')) == str(paragraph.get('id'))]
if filter_embedding_list is not None and len(filter_embedding_list) > 0:
find_embedding = filter_embedding_list[-1]
return find_embedding.get('comprehensive_score')
return 0
@staticmethod
def list_paragraph(embedding_list: List, vector):
paragraph_id_list = [row.get('paragraph_id') for row in embedding_list]
if paragraph_id_list is None or len(paragraph_id_list) == 0:
return []
paragraph_list = native_search(QuerySet(Paragraph).filter(id__in=paragraph_id_list),
@ -67,6 +79,13 @@ class BaseSearchDatasetStep(ISearchDatasetStep):
for paragraph_id in paragraph_id_list:
if not exist_paragraph_list.__contains__(paragraph_id):
vector.delete_by_paragraph_id(paragraph_id)
# 如果存在直接返回的则取直接返回段落
hit_handling_method_paragraph = [paragraph for paragraph in paragraph_list if
paragraph.get('hit_handling_method') == 'directly_return']
if len(hit_handling_method_paragraph) > 0:
# 找到评分最高的
return [sorted(hit_handling_method_paragraph,
key=lambda p: BaseSearchDatasetStep.get_similarity(p, embedding_list))[-1]]
return paragraph_list
def get_details(self, manage, **kwargs):

View File

@ -196,9 +196,11 @@ class ChatMessageSerializer(serializers.Serializer):
exclude_paragraph_id_list = []
# 相同问题是否需要排除已经查询到的段落
if re_chat:
paragraph_id_list = flat_map([row.paragraph_id_list for row in
filter(lambda chat_record: chat_record == message,
chat_info.chat_record_list)])
paragraph_id_list = flat_map(
[[paragraph.get('id') for paragraph in chat_record.details['search_step']['paragraph_list']] for
chat_record in chat_info.chat_record_list if
chat_record.problem_text == message and 'search_step' in chat_record.details and 'paragraph_list' in
chat_record.details['search_step']])
exclude_paragraph_id_list = list(set(paragraph_id_list))
# 构建运行参数
params = chat_info.to_pipeline_manage_params(message, get_post_handler(chat_info), exclude_paragraph_id_list,

View File

@ -1,7 +1,8 @@
SELECT
paragraph.*,
dataset."name" AS "dataset_name",
"document"."name" AS "document_name"
"document"."name" AS "document_name",
"document"."hit_handling_method" AS "hit_handling_method"
FROM
paragraph paragraph
LEFT JOIN dataset dataset ON dataset."id" = paragraph.dataset_id

View File

@ -0,0 +1,18 @@
# Generated by Django 4.1.13 on 2024-04-24 15:36
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('dataset', '0002_image'),
]
operations = [
migrations.AddField(
model_name='document',
name='hit_handling_method',
field=models.CharField(choices=[('optimization', '模型优化'), ('directly_return', '直接返回')], default='optimization', max_length=20, verbose_name='命中处理方式'),
),
]

View File

@ -27,6 +27,11 @@ class Type(models.TextChoices):
web = 1, 'web站点类型'
class HitHandlingMethod(models.TextChoices):
optimization = 'optimization', '模型优化'
directly_return = 'directly_return', '直接返回'
class DataSet(AppModelMixin):
"""
数据集表
@ -58,6 +63,9 @@ class Document(AppModelMixin):
type = models.CharField(verbose_name='类型', max_length=1, choices=Type.choices,
default=Type.base)
hit_handling_method = models.CharField(verbose_name='命中处理方式', max_length=20,
choices=HitHandlingMethod.choices,
default=HitHandlingMethod.optimization)
meta = models.JSONField(verbose_name="元数据", default=dict)

View File

@ -8,11 +8,13 @@
"""
import logging
import os
import re
import traceback
import uuid
from functools import reduce
from typing import List, Dict
from django.core import validators
from django.db import transaction
from django.db.models import QuerySet
from drf_yasg import openapi
@ -42,6 +44,12 @@ class DocumentEditInstanceSerializer(ApiMixin, serializers.Serializer):
name = serializers.CharField(required=False, max_length=128, min_length=1,
error_messages=ErrMessage.char(
"文档名称"))
hit_handling_method = serializers.CharField(required=False, validators=[
validators.RegexValidator(regex=re.compile("^optimization|directly_return$"),
message="类型只支持optimization|directly_return",
code=500)
], error_messages=ErrMessage.char("命中处理方式"))
is_active = serializers.BooleanField(required=False, error_messages=ErrMessage.char(
"文档是否可用"))
@ -116,12 +124,15 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
min_length=1,
error_messages=ErrMessage.char(
"文档名称"))
hit_handling_method = serializers.CharField(required=False, error_messages=ErrMessage.char("命中处理方式"))
def get_query_set(self):
query_set = QuerySet(model=Document)
query_set = query_set.filter(**{'dataset_id': self.data.get("dataset_id")})
if 'name' in self.data and self.data.get('name') is not None:
query_set = query_set.filter(**{'name__icontains': self.data.get('name')})
if 'hit_handling_method' in self.data and self.data.get('hit_handling_method') is not None:
query_set = query_set.filter(**{'hit_handling_method': self.data.get('hit_handling_method')})
query_set = query_set.order_by('-create_time')
return query_set
@ -143,7 +154,11 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
in_=openapi.IN_QUERY,
type=openapi.TYPE_STRING,
required=False,
description='文档名称')]
description='文档名称'),
openapi.Parameter(name='hit_handling_method', in_=openapi.IN_QUERY,
type=openapi.TYPE_STRING,
required=False,
description='文档命中处理方式')]
@staticmethod
def get_response_body_api():
@ -252,7 +267,7 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
_document = QuerySet(Document).get(id=self.data.get("document_id"))
if with_valid:
DocumentEditInstanceSerializer(data=instance).is_valid(document=_document)
update_keys = ['name', 'is_active', 'meta']
update_keys = ['name', 'is_active', 'hit_handling_method', 'meta']
for update_key in update_keys:
if update_key in instance and instance.get(update_key) is not None:
_document.__setattr__(update_key, instance.get(update_key))
@ -320,6 +335,8 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
properties={
'name': openapi.Schema(type=openapi.TYPE_STRING, title="文档名称", description="文档名称"),
'is_active': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否可用", description="是否可用"),
'hit_handling_method': openapi.Schema(type=openapi.TYPE_STRING, title="命中处理方式",
description="ai优化:optimization,直接返回:directly_return"),
'meta': openapi.Schema(type=openapi.TYPE_OBJECT, title="文档元数据",
description="文档元数据->web:{source_url:xxx,selector:'xxx'},base:{}"),
}

View File

@ -189,8 +189,8 @@ h4 {
padding-bottom: 0;
}
.float-right{
float:right;
.float-right {
float: right;
}
.flex {
@ -217,6 +217,10 @@ h4 {
align-items: baseline;
}
.justify-center {
justify-content: center;
}
.text-left {
text-align: left;
}
@ -565,4 +569,4 @@ h4 {
.title {
color: var(--app-text-color);
}
}
}

View File

@ -24,6 +24,10 @@
background-color: var(--el-button-bg-color);
border-color: var(--el-button-border-color);
}
&.is-link:focus {
background: none;
border: none;
}
}
.el-button--large {
font-size: 16px;
@ -137,10 +141,16 @@
color: var(--app-text-color);
font-weight: 400;
padding: 5px 11px;
&:not(.is-disabled):focus {
&:not(.is-disabled):focus,
&:not(.is-active):focus {
background-color: var(--app-text-color-light-1);
color: var(--app-text-color);
}
&.is-active,
&.is-active:hover {
color: var(--el-menu-active-color);
background: var(--el-color-primary-light-9);
}
}
.el-tag {

View File

@ -21,15 +21,34 @@
type="textarea"
/>
</el-form-item>
<el-form-item v-else label="文档地址" prop="source_url">
<el-form-item v-else-if="documentType === '1'" label="文档地址" prop="source_url">
<el-input v-model="form.source_url" placeholder="请输入文档地址" />
</el-form-item>
<el-form-item label="选择器">
<el-form-item label="选择器" v-if="documentType === '1'">
<el-input
v-model="form.selector"
placeholder="默认为 body可输入 .classname/#idname/tagname"
/>
</el-form-item>
<el-form-item v-if="!isImport">
<template #label>
<div class="flex align-center">
<span class="mr-4">命中处理方式</span>
<el-tooltip
effect="dark"
content="用户提问时,命中文档下的分段时按照设置的方式进行处理。"
placement="right"
>
<AppIcon iconName="app-warning" class="app-warning-icon"></AppIcon>
</el-tooltip>
</div>
</template>
<el-radio-group v-model="form.hit_handling_method">
<template v-for="(value, key) of hitHandlingMethod" :key="key">
<el-radio :value="key">{{ value }}</el-radio>
</template>
</el-radio-group>
</el-form-item>
</el-form>
<template #footer>
<span class="dialog-footer">
@ -45,6 +64,7 @@ import { useRoute } from 'vue-router'
import type { FormInstance, FormRules } from 'element-plus'
import documentApi from '@/api/document'
import { MsgSuccess } from '@/utils/message'
import { hitHandlingMethod } from '../utils'
const route = useRoute()
const {
@ -57,9 +77,11 @@ const loading = ref<boolean>(false)
const isImport = ref<boolean>(false)
const form = ref<any>({
source_url: '',
selector: ''
selector: '',
hit_handling_method: ''
})
const documentId = ref('')
const documentType = ref<string | number>('') //1: web0:
const rules = reactive({
source_url: [{ required: true, message: '请输入文档地址', trigger: 'blur' }]
@ -71,16 +93,19 @@ watch(dialogVisible, (bool) => {
if (!bool) {
form.value = {
source_url: '',
selector: ''
selector: '',
hit_handling_method: ''
}
isImport.value = false
documentType.value = ''
}
})
const open = (row: any) => {
if (row) {
documentType.value = row.type
documentId.value = row.id
form.value = row.meta
form.value = { hit_handling_method: row.hit_handling_method, ...row.meta }
isImport.value = false
} else {
isImport.value = true
@ -104,7 +129,11 @@ const submit = async (formEl: FormInstance | undefined) => {
})
} else {
const obj = {
meta: form.value
hit_handling_method: form.value.hit_handling_method,
meta: {
source_url: form.value.source_url,
selector: form.value.selector
}
}
documentApi.putDocument(id, documentId.value, obj, loading).then((res) => {
MsgSuccess('设置成功')

View File

@ -90,6 +90,39 @@
</div>
</template>
</el-table-column>
<el-table-column width="150">
<template #header>
<div>
<span>命中处理方式</span>
<el-dropdown trigger="click" @command="dropdownHandle">
<el-button style="margin-top: 1px" link :type="filterMethod ? 'primary' : ''">
<el-icon><Filter /></el-icon>
</el-button>
<template #dropdown>
<el-dropdown-menu style="width: 100px">
<el-dropdown-item
:class="filterMethod ? '' : 'is-active'"
command=""
class="justify-center"
>全部</el-dropdown-item
>
<template v-for="(value, key) of hitHandlingMethod" :key="key">
<el-dropdown-item
:class="filterMethod === key ? 'is-active' : ''"
class="justify-center"
:command="key"
>{{ value }}</el-dropdown-item
>
</template>
</el-dropdown-menu>
</template>
</el-dropdown>
</div>
</template>
<template #default="{ row }">
{{ hitHandlingMethod[row.hit_handling_method] }}
</template>
</el-table-column>
<el-table-column prop="create_time" label="创建时间" width="170">
<template #default="{ row }">
{{ datetimeFormat(row.create_time) }}
@ -100,7 +133,7 @@
{{ datetimeFormat(row.update_time) }}
</template>
</el-table-column>
<el-table-column label="操作" align="left">
<el-table-column label="操作" align="left" width="110">
<template #default="{ row }">
<div v-if="datasetDetail.type === '0'">
<span v-if="row.status === '2'" class="mr-4">
@ -110,6 +143,13 @@
</el-button>
</el-tooltip>
</span>
<span class="mr-4">
<el-tooltip effect="dark" content="设置" placement="top">
<el-button type="primary" text @click.stop="settingDoc(row)">
<el-icon><Setting /></el-icon>
</el-button>
</el-tooltip>
</span>
<span>
<el-tooltip effect="dark" content="删除" placement="top">
<el-button type="primary" text @click.stop="deleteDocument(row)">
@ -165,6 +205,7 @@ import ImportDocumentDialog from './component/ImportDocumentDialog.vue'
import SyncWebDialog from '@/views/dataset/component/SyncWebDialog.vue'
import { numberFormat } from '@/utils/utils'
import { datetimeFormat } from '@/utils/time'
import { hitHandlingMethod } from './utils'
import { MsgSuccess, MsgConfirm, MsgError } from '@/utils/message'
import useStore from '@/stores'
const router = useRouter()
@ -186,7 +227,10 @@ onBeforeRouteLeave((to: any, from: any) => {
common.savePage(storeKey, null)
common.saveCondition(storeKey, null)
} else {
common.saveCondition(storeKey, filterText.value)
common.saveCondition(storeKey, {
filterText: filterText.value,
filterMethod: filterMethod.value
})
}
})
const beforePagination = computed(() => common.paginationConfig[storeKey])
@ -196,6 +240,7 @@ const SyncWebDialogRef = ref()
const loading = ref(false)
let interval: any
const filterText = ref('')
const filterMethod = ref<string | number>('')
const documentData = ref<any[]>([])
const currentMouseId = ref(null)
const datasetDetail = ref<any>({})
@ -211,6 +256,11 @@ const multipleTableRef = ref<InstanceType<typeof ElTable>>()
const multipleSelection = ref<any[]>([])
const title = ref('')
function dropdownHandle(val: string) {
filterMethod.value = val
getList()
}
function syncDataset() {
SyncWebDialogRef.value.open(id)
}
@ -378,13 +428,12 @@ function handleSizeChange() {
}
function getList(bool?: boolean) {
const param = {
...(filterText.value && { name: filterText.value }),
...(filterMethod.value && { hit_handling_method: filterMethod.value })
}
documentApi
.getDocument(
id as string,
paginationConfig.value,
filterText.value && { name: filterText.value },
bool ? undefined : loading
)
.getDocument(id as string, paginationConfig.value, param, bool ? undefined : loading)
.then((res) => {
documentData.value = res.data.records
paginationConfig.value.total = res.data.total
@ -408,7 +457,8 @@ onMounted(() => {
paginationConfig.value = beforePagination.value
}
if (beforeSearch.value) {
filterText.value = beforeSearch.value
filterText.value = beforeSearch.value['filterText']
filterMethod.value = beforeSearch.value['filterMethod']
}
getList()
//

View File

@ -0,0 +1,4 @@
export const hitHandlingMethod: any = {
optimization: '模型优化',
directly_return: '直接回答'
}