feat: 增加全文检索和混合检索方式

This commit is contained in:
shaohuzhang1 2024-04-22 11:21:24 +08:00 committed by GitHub
parent 8fe1a147ff
commit c89ae29429
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 580 additions and 140 deletions

View File

@ -6,15 +6,16 @@
@date2024/1/9 18:10
@desc: 检索知识库
"""
import re
from abc import abstractmethod
from typing import List, Type
from django.core import validators
from rest_framework import serializers
from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep, ParagraphPipelineModel
from application.chat_pipeline.pipeline_manage import PipelineManage
from common.util.field_message import ErrMessage
from dataset.models import Paragraph
class ISearchDatasetStep(IBaseChatPipelineStep):
@ -38,6 +39,10 @@ class ISearchDatasetStep(IBaseChatPipelineStep):
# 相似度 0-1之间
similarity = serializers.FloatField(required=True, max_value=1, min_value=0,
error_messages=ErrMessage.float("引用分段数"))
search_mode = serializers.CharField(required=True, validators=[
validators.RegexValidator(regex=re.compile("^embedding|keywords|blend$"),
message="类型只支持register|reset_password", code=500)
], error_messages=ErrMessage.char("检索模式"))
def get_step_serializer(self, manage: PipelineManage) -> Type[InstanceSerializer]:
return self.InstanceSerializer
@ -50,6 +55,7 @@ class ISearchDatasetStep(IBaseChatPipelineStep):
@abstractmethod
def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str],
exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None,
search_mode: str = None,
**kwargs) -> List[ParagraphPipelineModel]:
"""
关于 用户和补全问题 说明: 补全问题如果有就使用补全问题去查询 反之就用用户原始问题查询
@ -60,6 +66,7 @@ class ISearchDatasetStep(IBaseChatPipelineStep):
:param exclude_document_id_list: 需要排除的文档id
:param exclude_paragraph_id_list: 需要排除段落id
:param padding_problem_text 补全问题
:param search_mode 检索模式
:return: 段落列表
"""
pass

View File

@ -17,6 +17,7 @@ from common.config.embedding_config import VectorStore, EmbeddingModel
from common.db.search import native_search
from common.util.file_util import get_file_content
from dataset.models import Paragraph
from embedding.models import SearchMode
from smartdoc.conf import PROJECT_DIR
@ -24,13 +25,14 @@ class BaseSearchDatasetStep(ISearchDatasetStep):
def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str],
exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None,
search_mode: str = None,
**kwargs) -> List[ParagraphPipelineModel]:
exec_problem_text = padding_problem_text if padding_problem_text is not None else problem_text
embedding_model = EmbeddingModel.get_embedding_model()
embedding_value = embedding_model.embed_query(exec_problem_text)
vector = VectorStore.get_embedding_vector()
embedding_list = vector.query(embedding_value, dataset_id_list, exclude_document_id_list,
exclude_paragraph_id_list, True, top_n, similarity)
embedding_list = vector.query(exec_problem_text, embedding_value, dataset_id_list, exclude_document_id_list,
exclude_paragraph_id_list, True, top_n, similarity, SearchMode(search_mode))
if embedding_list is None:
return []
paragraph_list = self.list_paragraph([row.get('paragraph_id') for row in embedding_list], vector)

View File

@ -19,7 +19,7 @@ from users.models import User
def get_dataset_setting_dict():
return {'top_n': 3, 'similarity': 0.6, 'max_paragraph_char_number': 5000}
return {'top_n': 3, 'similarity': 0.6, 'max_paragraph_char_number': 5000, 'search_mode': 'embedding'}
def get_model_setting_dict():

View File

@ -8,12 +8,13 @@
"""
import hashlib
import os
import re
import uuid
from functools import reduce
from typing import Dict
from django.contrib.postgres.fields import ArrayField
from django.core import cache
from django.core import cache, validators
from django.core import signing
from django.db import transaction, models
from django.db.models import QuerySet
@ -32,6 +33,7 @@ from common.util.field_message import ErrMessage
from common.util.file_util import get_file_content
from dataset.models import DataSet, Document
from dataset.serializers.common_serializers import list_paragraph
from embedding.models import SearchMode
from setting.models import AuthOperate
from setting.models.model_management import Model
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
@ -77,6 +79,10 @@ class DatasetSettingSerializer(serializers.Serializer):
error_messages=ErrMessage.float("相识度"))
max_paragraph_char_number = serializers.IntegerField(required=True, min_value=500, max_value=10000,
error_messages=ErrMessage.integer("最多引用字符数"))
search_mode = serializers.CharField(required=True, validators=[
validators.RegexValidator(regex=re.compile("^embedding|keywords|blend$"),
message="类型只支持register|reset_password", code=500)
], error_messages=ErrMessage.char("检索模式"))
class ModelSettingSerializer(serializers.Serializer):
@ -291,6 +297,10 @@ class ApplicationSerializer(serializers.Serializer):
error_messages=ErrMessage.integer("topN"))
similarity = serializers.FloatField(required=True, max_value=1, min_value=0,
error_messages=ErrMessage.float("相关度"))
search_mode = serializers.CharField(required=True, validators=[
validators.RegexValidator(regex=re.compile("^embedding|keywords|blend$"),
message="类型只支持register|reset_password", code=500)
], error_messages=ErrMessage.char("检索模式"))
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
@ -312,6 +322,7 @@ class ApplicationSerializer(serializers.Serializer):
hit_list = vector.hit_test(self.data.get('query_text'), dataset_id_list, exclude_document_id_list,
self.data.get('top_number'),
self.data.get('similarity'),
SearchMode(self.data.get('search_mode')),
EmbeddingModel.get_embedding_model())
hit_dict = reduce(lambda x, y: {**x, **y}, [{hit.get('paragraph_id'): hit} for hit in hit_list], {})
p_list = list_paragraph([h.get('paragraph_id') for h in hit_list])

View File

@ -77,6 +77,8 @@ class ChatInfo:
'model_id': self.application.model.id if self.application.model is not None else None,
'problem_optimization': self.application.problem_optimization,
'stream': True,
'search_mode': self.application.dataset_setting.get(
'search_mode') if 'search_mode' in self.application.dataset_setting else 'embedding'
}
@ -184,9 +186,9 @@ class ChatMessageSerializer(serializers.Serializer):
pipeline_manage_builder.append_step(BaseResetProblemStep)
# 构建流水线管理器
pipeline_message = (pipeline_manage_builder.append_step(BaseSearchDatasetStep)
.append_step(BaseGenerateHumanMessageStep)
.append_step(BaseChatStep)
.build())
.append_step(BaseGenerateHumanMessageStep)
.append_step(BaseChatStep)
.build())
exclude_paragraph_id_list = []
# 相同问题是否需要排除已经查询到的段落
if re_chat:

View File

@ -161,6 +161,8 @@ class ApplicationApi(ApiMixin):
default=0.6),
'max_paragraph_char_number': openapi.Schema(type=openapi.TYPE_NUMBER, title='最多引用字符数',
description="最多引用字符数", default=3000),
'search_mode': openapi.Schema(type=openapi.TYPE_STRING, title='检索模式',
description="embedding|keywords|blend", default='embedding'),
}
)

View File

@ -343,7 +343,8 @@ class Application(APIView):
ApplicationSerializer.HitTest(data={'id': application_id, 'user_id': request.user.id,
"query_text": request.query_params.get("query_text"),
"top_number": request.query_params.get("top_number"),
'similarity': request.query_params.get('similarity')}).hit_test(
'similarity': request.query_params.get('similarity'),
'search_mode': request.query_params.get('search_mode')}).hit_test(
))
class Operate(APIView):

View File

@ -33,6 +33,13 @@ class CommonApi:
default=0.6,
required=True,
description='相关性'),
openapi.Parameter(name='search_mode',
in_=openapi.IN_QUERY,
type=openapi.TYPE_STRING,
default="embedding",
required=True,
description='检索模式embedding|keywords|blend'
)
]
@staticmethod

View File

@ -0,0 +1,107 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file ts_vecto_util.py
@date2024/4/16 15:26
@desc:
"""
import re
import uuid
from typing import List
import jieba
from jieba import analyse
from common.util.split_model import group_by
jieba_word_list_cache = [chr(item) for item in range(38, 84)]
for jieba_word in jieba_word_list_cache:
jieba.add_word('#' + jieba_word + '#')
# r"(?i)\b(?:https?|ftp|tcp|file)://[^\s]+\b",
# 某些不分词数据
# r'"([^"]*)"'
word_pattern_list = [r"v\d+.\d+.\d+",
r"[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}"]
remove_chars = '\n , :\'<>@#¥%……&*!@#$%^&*() /"./-'
def get_word_list(text: str):
result = []
for pattern in word_pattern_list:
word_list = re.findall(pattern, text)
for child_list in word_list:
for word in child_list if isinstance(child_list, tuple) else [child_list]:
# 不能有: 所以再使用: 进行分割
if word.__contains__(':'):
item_list = word.split(":")
for w in item_list:
result.append(w)
else:
result.append(word)
return result
def replace_word(word_dict, text: str):
for key in word_dict:
text = re.sub('(?<!#)' + word_dict[key] + '(?!#)', key, text)
return text
def get_word_key(text: str, use_word_list):
for j_word in jieba_word_list_cache:
if not text.__contains__(j_word) and not use_word_list.__contains__(j_word):
return j_word
j_word = str(uuid.uuid1())
jieba.add_word(j_word)
return j_word
def to_word_dict(word_list: List, text: str):
word_dict = {}
for word in word_list:
key = get_word_key(text, set(word_dict))
word_dict['#' + key + '#'] = word
return word_dict
def get_key_by_word_dict(key, word_dict):
v = word_dict.get(key)
if v is None:
return key
return v
def to_ts_vector(text: str):
# 获取不分词的数据
word_list = get_word_list(text)
# 获取关键词关系
word_dict = to_word_dict(word_list, text)
# 替换字符串
text = replace_word(word_dict, text)
# 分词
result = jieba.tokenize(text, mode='search')
result_ = [{'word': get_key_by_word_dict(item[0], word_dict), 'index': item[1]} for item in result]
result_group = group_by(result_, lambda r: r['word'])
return " ".join(
[f"{key.lower()}:{','.join([str(item['index'] + 1) for item in result_group[key]][:20])}" for key in
result_group if
not remove_chars.__contains__(key) and len(key.strip()) >= 0])
def to_query(text: str):
# 获取不分词的数据
word_list = get_word_list(text)
# 获取关键词关系
word_dict = to_word_dict(word_list, text)
# 替换字符串
text = replace_word(word_dict, text)
extract_tags = analyse.extract_tags(text, topK=5, withWeight=True, allowPOS=('ns', 'n', 'vn', 'v', 'eng'))
result = " ".join([get_key_by_word_dict(word, word_dict) for word, score in extract_tags if
not remove_chars.__contains__(word)])
# 删除词库
for word in word_list:
jieba.del_word(word)
return result

View File

@ -37,6 +37,7 @@ from common.util.split_model import get_split_model
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, ProblemParagraphMapping
from dataset.serializers.common_serializers import list_paragraph, MetaSerializer
from dataset.serializers.document_serializers import DocumentSerializers, DocumentInstanceSerializer
from embedding.models import SearchMode
from setting.models import AuthOperate
from smartdoc.conf import PROJECT_DIR
@ -457,6 +458,10 @@ class DataSetSerializers(serializers.ModelSerializer):
error_messages=ErrMessage.char("响应Top"))
similarity = serializers.FloatField(required=True, max_value=1, min_value=0,
error_messages=ErrMessage.char("相似度"))
search_mode = serializers.CharField(required=True, validators=[
validators.RegexValidator(regex=re.compile("^embedding|keywords|blend$"),
message="类型只支持register|reset_password", code=500)
], error_messages=ErrMessage.char("检索模式"))
def is_valid(self, *, raise_exception=True):
super().is_valid(raise_exception=True)
@ -474,6 +479,7 @@ class DataSetSerializers(serializers.ModelSerializer):
hit_list = vector.hit_test(self.data.get('query_text'), [self.data.get('id')], exclude_document_id_list,
self.data.get('top_number'),
self.data.get('similarity'),
SearchMode(self.data.get('search_mode')),
EmbeddingModel.get_embedding_model())
hit_dict = reduce(lambda x, y: {**x, **y}, [{hit.get('paragraph_id'): hit} for hit in hit_list], {})
p_list = list_paragraph([h.get('paragraph_id') for h in hit_list])

View File

@ -111,7 +111,8 @@ class Dataset(APIView):
DataSetSerializers.HitTest(data={'id': dataset_id, 'user_id': request.user.id,
"query_text": request.query_params.get("query_text"),
"top_number": request.query_params.get("top_number"),
'similarity': request.query_params.get('similarity')}).hit_test(
'similarity': request.query_params.get('similarity'),
'search_mode': request.query_params.get('search_mode')}).hit_test(
))
class Operate(APIView):

View File

@ -0,0 +1,54 @@
# Generated by Django 4.1.13 on 2024-04-16 11:43
import django.contrib.postgres.search
from django.db import migrations
from common.util.common import sub_array
from common.util.ts_vecto_util import to_ts_vector
from dataset.models import Status
from embedding.models import Embedding
def update_embedding_search_vector(embedding, paragraph_list):
paragraphs = [paragraph for paragraph in paragraph_list if paragraph.id == embedding.get('paragraph')]
if len(paragraphs) > 0:
content = paragraphs[0].title + paragraphs[0].content
return Embedding(id=embedding.get('id'), search_vector=to_ts_vector(content))
return Embedding(id=embedding.get('id'), search_vector="")
def save_keywords(apps, schema_editor):
document = apps.get_model("dataset", "Document")
embedding = apps.get_model("embedding", "Embedding")
paragraph = apps.get_model('dataset', 'Paragraph')
db_alias = schema_editor.connection.alias
document_list = document.objects.using(db_alias).all()
for document in document_list:
document.status = Status.embedding
document.save()
paragraph_list = paragraph.objects.using(db_alias).filter(document=document).all()
embedding_list = embedding.objects.using(db_alias).filter(document=document).values('id', 'search_vector',
'paragraph')
embedding_update_list = [update_embedding_search_vector(embedding, paragraph_list) for embedding
in embedding_list]
child_array = sub_array(embedding_update_list, 20)
for c in child_array:
try:
embedding.objects.using(db_alias).bulk_update(c, ['search_vector'])
except Exception as e:
print(e)
class Migration(migrations.Migration):
dependencies = [
('embedding', '0001_initial'),
]
operations = [
migrations.AddField(
model_name='embedding',
name='search_vector',
field=django.contrib.postgres.search.SearchVectorField(default='', verbose_name='分词'),
),
migrations.RunPython(save_keywords)
]

View File

@ -10,6 +10,7 @@ from django.db import models
from common.field.vector_field import VectorField
from dataset.models.data_set import Document, Paragraph, DataSet
from django.contrib.postgres.search import SearchVectorField
class SourceType(models.TextChoices):
@ -19,6 +20,12 @@ class SourceType(models.TextChoices):
TITLE = 2, '标题'
class SearchMode(models.TextChoices):
embedding = 'embedding'
keywords = 'keywords'
blend = 'blend'
class Embedding(models.Model):
id = models.CharField(max_length=128, primary_key=True, verbose_name="主键id")
@ -37,6 +44,8 @@ class Embedding(models.Model):
embedding = VectorField(verbose_name="向量")
search_vector = SearchVectorField(verbose_name="分词", default="")
meta = models.JSONField(verbose_name="元数据", default=dict)
class Meta:

View File

@ -0,0 +1,26 @@
SELECT
paragraph_id,
comprehensive_score,
comprehensive_score AS similarity
FROM
(
SELECT DISTINCT ON
( "paragraph_id" ) ( similarity ),* ,
similarity AS comprehensive_score
FROM
(
SELECT
*,
(( 1 - ( embedding.embedding <=> %s ) )+ts_rank_cd( embedding.search_vector, websearch_to_tsquery('simple', %s ), 32 )) AS similarity
FROM
embedding ${embedding_query}
) TEMP
ORDER BY
paragraph_id,
similarity DESC
) DISTINCT_TEMP
WHERE
comprehensive_score >%s
ORDER BY
comprehensive_score DESC
LIMIT %s

View File

@ -0,0 +1,17 @@
SELECT
paragraph_id,
comprehensive_score,
comprehensive_score as similarity
FROM
(
SELECT DISTINCT ON
("paragraph_id") ( similarity ),* ,similarity AS comprehensive_score
FROM
( SELECT *,ts_rank_cd(embedding.search_vector,websearch_to_tsquery('simple',%s),32) AS similarity FROM embedding ${keywords_query}) TEMP
ORDER BY
paragraph_id,
similarity DESC
) DISTINCT_TEMP
WHERE comprehensive_score>%s
ORDER BY comprehensive_score DESC
LIMIT %s

View File

@ -14,7 +14,7 @@ from langchain_community.embeddings import HuggingFaceEmbeddings
from common.config.embedding_config import EmbeddingModel
from common.util.common import sub_array
from embedding.models import SourceType
from embedding.models import SourceType, SearchMode
lock = threading.Lock()
@ -113,13 +113,16 @@ class BaseVectorStore(ABC):
return result[0]
@abstractmethod
def query(self, query_embedding: List[float], dataset_id_list: list[str], exclude_document_id_list: list[str],
exclude_paragraph_list: list[str], is_active: bool, top_n: int, similarity: float):
def query(self, query_text:str,query_embedding: List[float], dataset_id_list: list[str],
exclude_document_id_list: list[str],
exclude_paragraph_list: list[str], is_active: bool, top_n: int, similarity: float,
search_mode: SearchMode):
pass
@abstractmethod
def hit_test(self, query_text, dataset_id: list[str], exclude_document_id_list: list[str], top_number: int,
similarity: float,
search_mode: SearchMode,
embedding: HuggingFaceEmbeddings):
pass

View File

@ -9,6 +9,7 @@
import json
import os
import uuid
from abc import ABC, abstractmethod
from typing import Dict, List
from django.db.models import QuerySet
@ -18,7 +19,8 @@ from common.config.embedding_config import EmbeddingModel
from common.db.search import generate_sql_by_query_dict
from common.db.sql_execute import select_list
from common.util.file_util import get_file_content
from embedding.models import Embedding, SourceType
from common.util.ts_vecto_util import to_ts_vector, to_query
from embedding.models import Embedding, SourceType, SearchMode
from embedding.vector.base_vector import BaseVectorStore
from smartdoc.conf import PROJECT_DIR
@ -57,7 +59,8 @@ class PGVector(BaseVectorStore):
paragraph_id=paragraph_id,
source_id=source_id,
embedding=text_embedding,
source_type=source_type)
source_type=source_type,
search_vector=to_ts_vector(text))
embedding.save()
return True
@ -71,13 +74,15 @@ class PGVector(BaseVectorStore):
is_active=text_list[index].get('is_active', True),
source_id=text_list[index].get('source_id'),
source_type=text_list[index].get('source_type'),
embedding=embeddings[index]) for index in
embedding=embeddings[index],
search_vector=to_ts_vector(text_list[index]['text'])) for index in
range(0, len(text_list))]
QuerySet(Embedding).bulk_create(embedding_list) if len(embedding_list) > 0 else None
return True
def hit_test(self, query_text, dataset_id_list: list[str], exclude_document_id_list: list[str], top_number: int,
similarity: float,
search_mode: SearchMode,
embedding: HuggingFaceEmbeddings):
if dataset_id_list is None or len(dataset_id_list) == 0:
return []
@ -87,17 +92,14 @@ class PGVector(BaseVectorStore):
if exclude_document_id_list is not None and len(exclude_document_id_list) > 0:
exclude_dict.__setitem__('document_id__in', exclude_document_id_list)
query_set = query_set.exclude(**exclude_dict)
exec_sql, exec_params = generate_sql_by_query_dict({'embedding_query': query_set},
select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "embedding", 'sql',
'hit_test.sql')),
with_table_name=True)
embedding_model = select_list(exec_sql,
[json.dumps(embedding_query), *exec_params, similarity, top_number])
return embedding_model
for search_handle in search_handle_list:
if search_handle.support(search_mode):
return search_handle.handle(query_set, query_text, embedding_query, top_number, similarity, search_mode)
def query(self, query_embedding: List[float], dataset_id_list: list[str], exclude_document_id_list: list[str],
exclude_paragraph_list: list[str], is_active: bool, top_n: int, similarity: float):
def query(self, query_text: str, query_embedding: List[float], dataset_id_list: list[str],
exclude_document_id_list: list[str],
exclude_paragraph_list: list[str], is_active: bool, top_n: int, similarity: float,
search_mode: SearchMode):
exclude_dict = {}
if dataset_id_list is None or len(dataset_id_list) == 0:
return []
@ -107,14 +109,9 @@ class PGVector(BaseVectorStore):
if exclude_paragraph_list is not None and len(exclude_paragraph_list) > 0:
exclude_dict.__setitem__('paragraph_id__in', exclude_paragraph_list)
query_set = query_set.exclude(**exclude_dict)
exec_sql, exec_params = generate_sql_by_query_dict({'embedding_query': query_set},
select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "embedding", 'sql',
'embedding_search.sql')),
with_table_name=True)
embedding_model = select_list(exec_sql,
[json.dumps(query_embedding), *exec_params, similarity, top_n])
return embedding_model
for search_handle in search_handle_list:
if search_handle.support(search_mode):
return search_handle.handle(query_set, query_text, query_embedding, top_n, similarity, search_mode)
def update_by_source_id(self, source_id: str, instance: Dict):
QuerySet(Embedding).filter(source_id=source_id).update(**instance)
@ -141,3 +138,81 @@ class PGVector(BaseVectorStore):
def delete_by_paragraph_id(self, paragraph_id: str):
QuerySet(Embedding).filter(paragraph_id=paragraph_id).delete()
class ISearch(ABC):
@abstractmethod
def support(self, search_mode: SearchMode):
pass
@abstractmethod
def handle(self, query_set, query_text, query_embedding, top_number: int,
similarity: float, search_mode: SearchMode):
pass
class EmbeddingSearch(ISearch):
def handle(self,
query_set,
query_text,
query_embedding,
top_number: int,
similarity: float,
search_mode: SearchMode):
exec_sql, exec_params = generate_sql_by_query_dict({'embedding_query': query_set},
select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "embedding", 'sql',
'embedding_search.sql')),
with_table_name=True)
embedding_model = select_list(exec_sql,
[json.dumps(query_embedding), *exec_params, similarity, top_number])
return embedding_model
def support(self, search_mode: SearchMode):
return search_mode.value == SearchMode.embedding.value
class KeywordsSearch(ISearch):
def handle(self,
query_set,
query_text,
query_embedding,
top_number: int,
similarity: float,
search_mode: SearchMode):
exec_sql, exec_params = generate_sql_by_query_dict({'keywords_query': query_set},
select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "embedding", 'sql',
'keywords_search.sql')),
with_table_name=True)
embedding_model = select_list(exec_sql,
[to_query(query_text), *exec_params, similarity, top_number])
return embedding_model
def support(self, search_mode: SearchMode):
return search_mode.value == SearchMode.keywords.value
class BlendSearch(ISearch):
def handle(self,
query_set,
query_text,
query_embedding,
top_number: int,
similarity: float,
search_mode: SearchMode):
exec_sql, exec_params = generate_sql_by_query_dict({'embedding_query': query_set},
select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "embedding", 'sql',
'blend_search.sql')),
with_table_name=True)
embedding_model = select_list(exec_sql,
[json.dumps(query_embedding), to_query(query_text), *exec_params, similarity,
top_number])
return embedding_model
def support(self, search_mode: SearchMode):
return search_mode.value == SearchMode.blend.value
search_handle_list = [EmbeddingSearch(), KeywordsSearch(), BlendSearch()]

View File

@ -145,74 +145,12 @@
<div class="flex-between">
<span>关联知识库</span>
<div>
<el-popover :visible="popoverVisible" :width="214" trigger="click">
<template #reference>
<el-button type="primary" link @click="datasetSettingChange('open')">
<AppIcon iconName="app-operation" class="mr-4"></AppIcon>参数设置
</el-button>
</template>
<div class="dataset_setting">
<div class="form-item mb-16">
<div class="title flex align-center mb-8">
<span style="margin-right: 4px">相似度高于</span>
<el-tooltip
effect="dark"
content="相似度越高相关性越强。"
placement="right"
>
<AppIcon iconName="app-warning" class="app-warning-icon"></AppIcon>
</el-tooltip>
</div>
<div @click.stop>
<el-input-number
v-model="dataset_setting.similarity"
:min="0"
:max="1"
:precision="3"
:step="0.1"
controls-position="right"
style="width: 180px"
/>
</div>
</div>
<div class="form-item mb-16">
<div class="title mb-8">引用分段数 TOP</div>
<div @click.stop>
<el-input-number
v-model="dataset_setting.top_n"
:min="1"
:max="10"
controls-position="right"
style="width: 180px"
/>
</div>
</div>
<div class="form-item mb-16">
<div class="title mb-8">最多引用字符数</div>
<div class="flex align-center">
<el-slider
v-model="dataset_setting.max_paragraph_char_number"
show-input
:show-input-controls="false"
:min="500"
:max="10000"
style="width: 180px"
class="custom-slider"
/>
</div>
</div>
</div>
<div class="text-right">
<el-button @click="popoverVisible = false">取消</el-button>
<el-button type="primary" @click="datasetSettingChange('close')"
>确认</el-button
>
</div>
</el-popover>
<el-button type="primary" link @click="openDatasetDialog"
><el-icon class="mr-4"><Plus /></el-icon></el-button
>
<el-button type="primary" link @click="openParamSettingDialog">
<AppIcon iconName="app-operation" class="mr-4"></AppIcon>参数设置
</el-button>
<el-button type="primary" link @click="openDatasetDialog">
<el-icon class="mr-4"><Plus /></el-icon>
</el-button>
</div>
</div>
</template>
@ -221,13 +159,6 @@
>关联的知识库展示在这里</el-text
>
<el-row :gutter="12" v-else>
<!-- <el-col :xs="24" :sm="24" :md="12" :lg="12" :xl="12" class="mb-8">
<CardAdd
title="关联知识库"
@click="openDatasetDialog"
style="min-height: 50px; font-size: 14px"
/>
</el-col> -->
<el-col
:xs="24"
:sm="24"
@ -311,6 +242,7 @@
</el-col>
</el-row>
<ParamSettingDialog ref="ParamSettingDialogRef" @refresh="refreshParam" />
<AddDatasetDialog
ref="AddDatasetDialogRef"
@addData="addDataset"
@ -318,6 +250,8 @@
@refresh="refresh"
:loading="datasetLoading"
/>
<!-- 添加模版 -->
<CreateModelDialog
ref="createModelRef"
@submit="getModel"
@ -329,7 +263,8 @@
<script setup lang="ts">
import { reactive, ref, watch, onMounted } from 'vue'
import { useRouter, useRoute } from 'vue-router'
import { groupBy, cloneDeep } from 'lodash'
import { groupBy } from 'lodash'
import ParamSettingDialog from './components/ParamSettingDialog.vue'
import AddDatasetDialog from './components/AddDatasetDialog.vue'
import CreateModelDialog from '@/views/template/component/CreateModelDialog.vue'
import SelectProviderDialog from '@/views/template/component/SelectProviderDialog.vue'
@ -363,6 +298,8 @@ const defaultPrompt = `已知信息:
问题
{question}
`
const ParamSettingDialogRef = ref<InstanceType<typeof ParamSettingDialog>>()
const createModelRef = ref<InstanceType<typeof CreateModelDialog>>()
const selectProviderRef = ref<InstanceType<typeof SelectProviderDialog>>()
@ -384,7 +321,8 @@ const applicationForm = ref<ApplicationFormType>({
dataset_setting: {
top_n: 3,
similarity: 0.6,
max_paragraph_char_number: 5000
max_paragraph_char_number: 5000,
search_mode: 'embedding'
},
model_setting: {
prompt: defaultPrompt
@ -392,8 +330,6 @@ const applicationForm = ref<ApplicationFormType>({
problem_optimization: false
})
const popoverVisible = ref(false)
const rules = reactive<FormRules<ApplicationFormType>>({
name: [{ required: true, message: '请输入应用名称', trigger: 'blur' }],
model_id: [
@ -408,17 +344,6 @@ const rules = reactive<FormRules<ApplicationFormType>>({
const modelOptions = ref<any>(null)
const providerOptions = ref<Array<Provider>>([])
const datasetList = ref([])
const dataset_setting = ref<any>({})
function datasetSettingChange(val: string) {
if (val === 'open') {
popoverVisible.value = true
dataset_setting.value = cloneDeep(applicationForm.value.dataset_setting)
} else if (val === 'close') {
popoverVisible.value = false
applicationForm.value.dataset_setting = cloneDeep(dataset_setting.value)
}
}
const submit = async (formEl: FormInstance | undefined) => {
if (!formEl) return
@ -438,6 +363,14 @@ const submit = async (formEl: FormInstance | undefined) => {
})
}
const openParamSettingDialog = () => {
ParamSettingDialogRef.value?.open(applicationForm.value.dataset_setting)
}
function refreshParam(data: any) {
applicationForm.value.dataset_setting = data
}
const openCreateModel = (provider?: Provider) => {
if (provider && provider.provider) {
createModelRef.value?.open(provider)
@ -560,13 +493,4 @@ onMounted(() => {
.prologue-md-editor {
height: 150px;
}
.dataset_setting {
color: var(--el-text-color-regular);
font-weight: 400;
}
.custom-slider {
:deep(.el-input-number.is-without-controls .el-input__wrapper) {
padding: 0 !important;
}
}
</style>

View File

@ -0,0 +1,152 @@
<template>
<el-dialog title="参数设置" class="param-dialog" v-model="dialogVisible" style="width: 550px">
<div class="dialog-max-height">
<el-scrollbar>
<div class="p-16">
<el-form label-position="top" ref="paramFormRef" :model="form">
<el-form-item label="检索模式">
<el-radio-group v-model="form.search_mode" class="card__radio">
<el-card
shadow="never"
class="mb-16"
:class="form.search_mode === 'embedding' ? 'active' : ''"
>
<el-radio value="embedding" size="large">
<p class="mb-4">向量检索</p>
<el-text type="info">通过向量距离计算与用户问题最相似的文本分段</el-text>
</el-radio>
</el-card>
<el-card
shadow="never"
class="mb-16"
:class="form.search_mode === 'keywords' ? 'active' : ''"
>
<el-radio value="keywords" size="large">
<p class="mb-4">全文检索</p>
<el-text type="info">通过关键词检索返回包含关键词最多的文本分段</el-text>
</el-radio>
</el-card>
<el-card
shadow="never"
class="mb-16"
:class="form.search_mode === 'blend' ? 'active' : ''"
>
<el-radio value="blend" size="large">
<p class="mb-4">混合检索</p>
<el-text type="info"
>同时执行全文检索和向量检索再进行重排序从两类查询结果中选择匹配用户问题的最佳结果</el-text
>
</el-radio>
</el-card>
</el-radio-group>
</el-form-item>
<el-form-item>
<template #label>
<div class="flex align-center">
<span class="mr-4">相似度高于</span>
<el-tooltip effect="dark" content="相似度越高相关性越强。" placement="right">
<AppIcon iconName="app-warning" class="app-warning-icon"></AppIcon>
</el-tooltip>
</div>
</template>
<el-input-number
v-model="form.similarity"
:min="0"
:max="1"
:precision="3"
:step="0.1"
controls-position="right"
/>
</el-form-item>
<el-form-item label="引用分段数 TOP">
<el-input-number v-model="form.top_n" :min="1" :max="10" controls-position="right" />
</el-form-item>
<el-form-item label="最多引用字符数">
<el-slider
v-model="form.max_paragraph_char_number"
show-input
:show-input-controls="false"
:min="500"
:max="10000"
class="custom-slider"
/>
</el-form-item>
</el-form>
</div>
</el-scrollbar>
</div>
<template #footer>
<span class="dialog-footer">
<el-button @click.prevent="dialogVisible = false"> 取消 </el-button>
<el-button type="primary" @click="submit(paramFormRef)" :loading="loading">
保存
</el-button>
</span>
</template>
</el-dialog>
</template>
<script setup lang="ts">
import { ref, watch } from 'vue'
import type { FormInstance, FormRules } from 'element-plus'
const emit = defineEmits(['refresh'])
const paramFormRef = ref()
const form = ref<any>({
search_mode: 'embedding',
top_n: 3,
similarity: 0.6,
max_paragraph_char_number: 5000
})
const dialogVisible = ref<boolean>(false)
const loading = ref(false)
watch(dialogVisible, (bool) => {
if (!bool) {
form.value = {
search_mode: 'embedding',
top_n: 3,
similarity: 0.6,
max_paragraph_char_number: 5000
}
}
})
const open = (data: any) => {
form.value = { ...form.value, ...data }
dialogVisible.value = true
}
const submit = async (formEl: FormInstance | undefined) => {
if (!formEl) return
await formEl.validate((valid, fields) => {
if (valid) {
emit('refresh', form.value)
dialogVisible.value = false
}
})
}
defineExpose({ open })
</script>
<style lang="scss" scope>
.param-dialog {
padding: 8px;
.el-dialog__header {
padding: 16px 16px 0 16px;
}
.el-dialog__body {
padding: 0 !important;
}
.dialog-max-height {
height: calc(100vh - 260px);
}
.custom-slider {
.el-input-number.is-without-controls .el-input__wrapper {
padding: 0 !important;
}
}
}
</style>

View File

@ -78,11 +78,47 @@
<ParagraphDialog ref="ParagraphDialogRef" :title="title" @refresh="refresh" />
</LayoutContainer>
<div class="hit-test__operate p-24 pt-0">
<el-popover :visible="popoverVisible" placement="right-end" :width="180" trigger="click">
<el-popover :visible="popoverVisible" placement="right-end" :width="500" trigger="click">
<template #reference>
<el-button icon="Setting" class="mb-8" @click="settingChange('open')">参数设置</el-button>
</template>
<div class="mb-16">
<div class="title mb-8">检索模式</div>
<el-radio-group v-model="cloneForm.search_mode" class="card__radio">
<el-card
shadow="never"
class="mb-16"
:class="cloneForm.search_mode === 'embedding' ? 'active' : ''"
>
<el-radio value="embedding" size="large">
<p class="mb-4">向量检索</p>
<el-text type="info">通过向量距离计算与用户问题最相似的文本分段</el-text>
</el-radio>
</el-card>
<el-card
shadow="never"
class="mb-16"
:class="cloneForm.search_mode === 'keywords' ? 'active' : ''"
>
<el-radio value="keywords" size="large">
<p class="mb-4">全文检索</p>
<el-text type="info">通过关键词检索返回包含关键词最多的文本分段</el-text>
</el-radio>
</el-card>
<el-card
shadow="never"
class="mb-16"
:class="cloneForm.search_mode === 'blend' ? 'active' : ''"
>
<el-radio value="blend" size="large">
<p class="mb-4">混合检索</p>
<el-text type="info"
>同时执行全文检索和向量检索再进行重排序从两类查询结果中选择匹配用户问题的最佳结果</el-text
>
</el-radio>
</el-card>
</el-radio-group>
</div>
<div class="mb-16">
<div class="title mb-8">相似度高于</div>
<el-input-number
@ -92,7 +128,6 @@
:precision="3"
:step="0.1"
controls-position="right"
style="width: 145px"
/>
</div>
<div class="mb-16">
@ -103,7 +138,6 @@
:min="1"
:max="10"
controls-position="right"
style="width: 145px"
/>
</div>
<div class="text-right">
@ -161,7 +195,8 @@ const title = ref('')
const inputValue = ref('')
const formInline = ref({
similarity: 0.6,
top_number: 5
top_number: 5,
search_mode: 'embedding'
})
//
@ -213,8 +248,7 @@ function sendChatHandle(event: any) {
function getHitTestList() {
const obj = {
query_text: inputValue.value,
similarity: formInline.value.similarity,
top_number: formInline.value.top_number
...formInline.value
}
if (isDataset.value) {
datasetApi.getDatasetHitTest(id, obj, loading).then((res) => {