From c89ae294295f130acb143bf52173bf8175750cab Mon Sep 17 00:00:00 2001 From: shaohuzhang1 <80892890+shaohuzhang1@users.noreply.github.com> Date: Mon, 22 Apr 2024 11:21:24 +0800 Subject: [PATCH] =?UTF-8?q?=20feat:=20=E5=A2=9E=E5=8A=A0=E5=85=A8=E6=96=87?= =?UTF-8?q?=E6=A3=80=E7=B4=A2=E5=92=8C=E6=B7=B7=E5=90=88=E6=A3=80=E7=B4=A2?= =?UTF-8?q?=E6=96=B9=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../i_search_dataset_step.py | 9 +- .../impl/base_search_dataset_step.py | 6 +- apps/application/models/application.py | 2 +- .../serializers/application_serializers.py | 13 +- .../serializers/chat_message_serializers.py | 8 +- .../swagger_api/application_api.py | 2 + apps/application/views/application_views.py | 3 +- apps/common/swagger_api/common_api.py | 7 + apps/common/util/ts_vecto_util.py | 107 ++++++++++++ .../serializers/dataset_serializers.py | 6 + apps/dataset/views/dataset.py | 3 +- .../0002_embedding_search_vector.py | 54 +++++++ apps/embedding/models/embedding.py | 9 ++ apps/embedding/sql/blend_search.sql | 26 +++ apps/embedding/sql/keywords_search.sql | 17 ++ apps/embedding/vector/base_vector.py | 9 +- apps/embedding/vector/pg_vector.py | 117 +++++++++++--- ui/src/views/application/CreateAndSetting.vue | 122 +++----------- .../components/ParamSettingDialog.vue | 152 ++++++++++++++++++ ui/src/views/hit-test/index.vue | 48 +++++- 20 files changed, 580 insertions(+), 140 deletions(-) create mode 100644 apps/common/util/ts_vecto_util.py create mode 100644 apps/embedding/migrations/0002_embedding_search_vector.py create mode 100644 apps/embedding/sql/blend_search.sql create mode 100644 apps/embedding/sql/keywords_search.sql create mode 100644 ui/src/views/application/components/ParamSettingDialog.vue diff --git a/apps/application/chat_pipeline/step/search_dataset_step/i_search_dataset_step.py b/apps/application/chat_pipeline/step/search_dataset_step/i_search_dataset_step.py index f4e9296af..549abfaf2 100644 --- a/apps/application/chat_pipeline/step/search_dataset_step/i_search_dataset_step.py +++ b/apps/application/chat_pipeline/step/search_dataset_step/i_search_dataset_step.py @@ -6,15 +6,16 @@ @date:2024/1/9 18:10 @desc: 检索知识库 """ +import re from abc import abstractmethod from typing import List, Type +from django.core import validators from rest_framework import serializers from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep, ParagraphPipelineModel from application.chat_pipeline.pipeline_manage import PipelineManage from common.util.field_message import ErrMessage -from dataset.models import Paragraph class ISearchDatasetStep(IBaseChatPipelineStep): @@ -38,6 +39,10 @@ class ISearchDatasetStep(IBaseChatPipelineStep): # 相似度 0-1之间 similarity = serializers.FloatField(required=True, max_value=1, min_value=0, error_messages=ErrMessage.float("引用分段数")) + search_mode = serializers.CharField(required=True, validators=[ + validators.RegexValidator(regex=re.compile("^embedding|keywords|blend$"), + message="类型只支持register|reset_password", code=500) + ], error_messages=ErrMessage.char("检索模式")) def get_step_serializer(self, manage: PipelineManage) -> Type[InstanceSerializer]: return self.InstanceSerializer @@ -50,6 +55,7 @@ class ISearchDatasetStep(IBaseChatPipelineStep): @abstractmethod def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str], exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None, + search_mode: str = None, **kwargs) -> List[ParagraphPipelineModel]: """ 关于 用户和补全问题 说明: 补全问题如果有就使用补全问题去查询 反之就用用户原始问题查询 @@ -60,6 +66,7 @@ class ISearchDatasetStep(IBaseChatPipelineStep): :param exclude_document_id_list: 需要排除的文档id :param exclude_paragraph_id_list: 需要排除段落id :param padding_problem_text 补全问题 + :param search_mode 检索模式 :return: 段落列表 """ pass diff --git a/apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py b/apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py index 1781d4f3d..dcd375ce4 100644 --- a/apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py +++ b/apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py @@ -17,6 +17,7 @@ from common.config.embedding_config import VectorStore, EmbeddingModel from common.db.search import native_search from common.util.file_util import get_file_content from dataset.models import Paragraph +from embedding.models import SearchMode from smartdoc.conf import PROJECT_DIR @@ -24,13 +25,14 @@ class BaseSearchDatasetStep(ISearchDatasetStep): def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str], exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None, + search_mode: str = None, **kwargs) -> List[ParagraphPipelineModel]: exec_problem_text = padding_problem_text if padding_problem_text is not None else problem_text embedding_model = EmbeddingModel.get_embedding_model() embedding_value = embedding_model.embed_query(exec_problem_text) vector = VectorStore.get_embedding_vector() - embedding_list = vector.query(embedding_value, dataset_id_list, exclude_document_id_list, - exclude_paragraph_id_list, True, top_n, similarity) + embedding_list = vector.query(exec_problem_text, embedding_value, dataset_id_list, exclude_document_id_list, + exclude_paragraph_id_list, True, top_n, similarity, SearchMode(search_mode)) if embedding_list is None: return [] paragraph_list = self.list_paragraph([row.get('paragraph_id') for row in embedding_list], vector) diff --git a/apps/application/models/application.py b/apps/application/models/application.py index e20e92385..e7c99e404 100644 --- a/apps/application/models/application.py +++ b/apps/application/models/application.py @@ -19,7 +19,7 @@ from users.models import User def get_dataset_setting_dict(): - return {'top_n': 3, 'similarity': 0.6, 'max_paragraph_char_number': 5000} + return {'top_n': 3, 'similarity': 0.6, 'max_paragraph_char_number': 5000, 'search_mode': 'embedding'} def get_model_setting_dict(): diff --git a/apps/application/serializers/application_serializers.py b/apps/application/serializers/application_serializers.py index 3b20a1b39..43d706d59 100644 --- a/apps/application/serializers/application_serializers.py +++ b/apps/application/serializers/application_serializers.py @@ -8,12 +8,13 @@ """ import hashlib import os +import re import uuid from functools import reduce from typing import Dict from django.contrib.postgres.fields import ArrayField -from django.core import cache +from django.core import cache, validators from django.core import signing from django.db import transaction, models from django.db.models import QuerySet @@ -32,6 +33,7 @@ from common.util.field_message import ErrMessage from common.util.file_util import get_file_content from dataset.models import DataSet, Document from dataset.serializers.common_serializers import list_paragraph +from embedding.models import SearchMode from setting.models import AuthOperate from setting.models.model_management import Model from setting.models_provider.constants.model_provider_constants import ModelProvideConstants @@ -77,6 +79,10 @@ class DatasetSettingSerializer(serializers.Serializer): error_messages=ErrMessage.float("相识度")) max_paragraph_char_number = serializers.IntegerField(required=True, min_value=500, max_value=10000, error_messages=ErrMessage.integer("最多引用字符数")) + search_mode = serializers.CharField(required=True, validators=[ + validators.RegexValidator(regex=re.compile("^embedding|keywords|blend$"), + message="类型只支持register|reset_password", code=500) + ], error_messages=ErrMessage.char("检索模式")) class ModelSettingSerializer(serializers.Serializer): @@ -291,6 +297,10 @@ class ApplicationSerializer(serializers.Serializer): error_messages=ErrMessage.integer("topN")) similarity = serializers.FloatField(required=True, max_value=1, min_value=0, error_messages=ErrMessage.float("相关度")) + search_mode = serializers.CharField(required=True, validators=[ + validators.RegexValidator(regex=re.compile("^embedding|keywords|blend$"), + message="类型只支持register|reset_password", code=500) + ], error_messages=ErrMessage.char("检索模式")) def is_valid(self, *, raise_exception=False): super().is_valid(raise_exception=True) @@ -312,6 +322,7 @@ class ApplicationSerializer(serializers.Serializer): hit_list = vector.hit_test(self.data.get('query_text'), dataset_id_list, exclude_document_id_list, self.data.get('top_number'), self.data.get('similarity'), + SearchMode(self.data.get('search_mode')), EmbeddingModel.get_embedding_model()) hit_dict = reduce(lambda x, y: {**x, **y}, [{hit.get('paragraph_id'): hit} for hit in hit_list], {}) p_list = list_paragraph([h.get('paragraph_id') for h in hit_list]) diff --git a/apps/application/serializers/chat_message_serializers.py b/apps/application/serializers/chat_message_serializers.py index dbbffb2a2..daf850e76 100644 --- a/apps/application/serializers/chat_message_serializers.py +++ b/apps/application/serializers/chat_message_serializers.py @@ -77,6 +77,8 @@ class ChatInfo: 'model_id': self.application.model.id if self.application.model is not None else None, 'problem_optimization': self.application.problem_optimization, 'stream': True, + 'search_mode': self.application.dataset_setting.get( + 'search_mode') if 'search_mode' in self.application.dataset_setting else 'embedding' } @@ -184,9 +186,9 @@ class ChatMessageSerializer(serializers.Serializer): pipeline_manage_builder.append_step(BaseResetProblemStep) # 构建流水线管理器 pipeline_message = (pipeline_manage_builder.append_step(BaseSearchDatasetStep) - .append_step(BaseGenerateHumanMessageStep) - .append_step(BaseChatStep) - .build()) + .append_step(BaseGenerateHumanMessageStep) + .append_step(BaseChatStep) + .build()) exclude_paragraph_id_list = [] # 相同问题是否需要排除已经查询到的段落 if re_chat: diff --git a/apps/application/swagger_api/application_api.py b/apps/application/swagger_api/application_api.py index 38d747bb1..7af1c437a 100644 --- a/apps/application/swagger_api/application_api.py +++ b/apps/application/swagger_api/application_api.py @@ -161,6 +161,8 @@ class ApplicationApi(ApiMixin): default=0.6), 'max_paragraph_char_number': openapi.Schema(type=openapi.TYPE_NUMBER, title='最多引用字符数', description="最多引用字符数", default=3000), + 'search_mode': openapi.Schema(type=openapi.TYPE_STRING, title='检索模式', + description="embedding|keywords|blend", default='embedding'), } ) diff --git a/apps/application/views/application_views.py b/apps/application/views/application_views.py index 12e47f854..7953220b4 100644 --- a/apps/application/views/application_views.py +++ b/apps/application/views/application_views.py @@ -343,7 +343,8 @@ class Application(APIView): ApplicationSerializer.HitTest(data={'id': application_id, 'user_id': request.user.id, "query_text": request.query_params.get("query_text"), "top_number": request.query_params.get("top_number"), - 'similarity': request.query_params.get('similarity')}).hit_test( + 'similarity': request.query_params.get('similarity'), + 'search_mode': request.query_params.get('search_mode')}).hit_test( )) class Operate(APIView): diff --git a/apps/common/swagger_api/common_api.py b/apps/common/swagger_api/common_api.py index 71a876ae0..c3d8be6ca 100644 --- a/apps/common/swagger_api/common_api.py +++ b/apps/common/swagger_api/common_api.py @@ -33,6 +33,13 @@ class CommonApi: default=0.6, required=True, description='相关性'), + openapi.Parameter(name='search_mode', + in_=openapi.IN_QUERY, + type=openapi.TYPE_STRING, + default="embedding", + required=True, + description='检索模式embedding|keywords|blend' + ) ] @staticmethod diff --git a/apps/common/util/ts_vecto_util.py b/apps/common/util/ts_vecto_util.py new file mode 100644 index 000000000..b5d4de3fd --- /dev/null +++ b/apps/common/util/ts_vecto_util.py @@ -0,0 +1,107 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: ts_vecto_util.py + @date:2024/4/16 15:26 + @desc: +""" +import re +import uuid +from typing import List + +import jieba +from jieba import analyse + +from common.util.split_model import group_by + +jieba_word_list_cache = [chr(item) for item in range(38, 84)] + +for jieba_word in jieba_word_list_cache: + jieba.add_word('#' + jieba_word + '#') +# r"(?i)\b(?:https?|ftp|tcp|file)://[^\s]+\b", +# 某些不分词数据 +# r'"([^"]*)"' +word_pattern_list = [r"v\d+.\d+.\d+", + r"[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}"] + +remove_chars = '\n , :\'<>!@#¥%……&*()!@#$%^&*(): ;,/"./-' + + +def get_word_list(text: str): + result = [] + for pattern in word_pattern_list: + word_list = re.findall(pattern, text) + for child_list in word_list: + for word in child_list if isinstance(child_list, tuple) else [child_list]: + # 不能有: 所以再使用: 进行分割 + if word.__contains__(':'): + item_list = word.split(":") + for w in item_list: + result.append(w) + else: + result.append(word) + return result + + +def replace_word(word_dict, text: str): + for key in word_dict: + text = re.sub('(?= 0]) + + +def to_query(text: str): + # 获取不分词的数据 + word_list = get_word_list(text) + # 获取关键词关系 + word_dict = to_word_dict(word_list, text) + # 替换字符串 + text = replace_word(word_dict, text) + extract_tags = analyse.extract_tags(text, topK=5, withWeight=True, allowPOS=('ns', 'n', 'vn', 'v', 'eng')) + result = " ".join([get_key_by_word_dict(word, word_dict) for word, score in extract_tags if + not remove_chars.__contains__(word)]) + # 删除词库 + for word in word_list: + jieba.del_word(word) + return result diff --git a/apps/dataset/serializers/dataset_serializers.py b/apps/dataset/serializers/dataset_serializers.py index 6f03eca2a..5d03def4a 100644 --- a/apps/dataset/serializers/dataset_serializers.py +++ b/apps/dataset/serializers/dataset_serializers.py @@ -37,6 +37,7 @@ from common.util.split_model import get_split_model from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, ProblemParagraphMapping from dataset.serializers.common_serializers import list_paragraph, MetaSerializer from dataset.serializers.document_serializers import DocumentSerializers, DocumentInstanceSerializer +from embedding.models import SearchMode from setting.models import AuthOperate from smartdoc.conf import PROJECT_DIR @@ -457,6 +458,10 @@ class DataSetSerializers(serializers.ModelSerializer): error_messages=ErrMessage.char("响应Top")) similarity = serializers.FloatField(required=True, max_value=1, min_value=0, error_messages=ErrMessage.char("相似度")) + search_mode = serializers.CharField(required=True, validators=[ + validators.RegexValidator(regex=re.compile("^embedding|keywords|blend$"), + message="类型只支持register|reset_password", code=500) + ], error_messages=ErrMessage.char("检索模式")) def is_valid(self, *, raise_exception=True): super().is_valid(raise_exception=True) @@ -474,6 +479,7 @@ class DataSetSerializers(serializers.ModelSerializer): hit_list = vector.hit_test(self.data.get('query_text'), [self.data.get('id')], exclude_document_id_list, self.data.get('top_number'), self.data.get('similarity'), + SearchMode(self.data.get('search_mode')), EmbeddingModel.get_embedding_model()) hit_dict = reduce(lambda x, y: {**x, **y}, [{hit.get('paragraph_id'): hit} for hit in hit_list], {}) p_list = list_paragraph([h.get('paragraph_id') for h in hit_list]) diff --git a/apps/dataset/views/dataset.py b/apps/dataset/views/dataset.py index 96106ff86..d3720977b 100644 --- a/apps/dataset/views/dataset.py +++ b/apps/dataset/views/dataset.py @@ -111,7 +111,8 @@ class Dataset(APIView): DataSetSerializers.HitTest(data={'id': dataset_id, 'user_id': request.user.id, "query_text": request.query_params.get("query_text"), "top_number": request.query_params.get("top_number"), - 'similarity': request.query_params.get('similarity')}).hit_test( + 'similarity': request.query_params.get('similarity'), + 'search_mode': request.query_params.get('search_mode')}).hit_test( )) class Operate(APIView): diff --git a/apps/embedding/migrations/0002_embedding_search_vector.py b/apps/embedding/migrations/0002_embedding_search_vector.py new file mode 100644 index 000000000..e0cbe733b --- /dev/null +++ b/apps/embedding/migrations/0002_embedding_search_vector.py @@ -0,0 +1,54 @@ +# Generated by Django 4.1.13 on 2024-04-16 11:43 + +import django.contrib.postgres.search +from django.db import migrations + +from common.util.common import sub_array +from common.util.ts_vecto_util import to_ts_vector +from dataset.models import Status +from embedding.models import Embedding + + +def update_embedding_search_vector(embedding, paragraph_list): + paragraphs = [paragraph for paragraph in paragraph_list if paragraph.id == embedding.get('paragraph')] + if len(paragraphs) > 0: + content = paragraphs[0].title + paragraphs[0].content + return Embedding(id=embedding.get('id'), search_vector=to_ts_vector(content)) + return Embedding(id=embedding.get('id'), search_vector="") + + +def save_keywords(apps, schema_editor): + document = apps.get_model("dataset", "Document") + embedding = apps.get_model("embedding", "Embedding") + paragraph = apps.get_model('dataset', 'Paragraph') + db_alias = schema_editor.connection.alias + document_list = document.objects.using(db_alias).all() + for document in document_list: + document.status = Status.embedding + document.save() + paragraph_list = paragraph.objects.using(db_alias).filter(document=document).all() + embedding_list = embedding.objects.using(db_alias).filter(document=document).values('id', 'search_vector', + 'paragraph') + embedding_update_list = [update_embedding_search_vector(embedding, paragraph_list) for embedding + in embedding_list] + child_array = sub_array(embedding_update_list, 20) + for c in child_array: + try: + embedding.objects.using(db_alias).bulk_update(c, ['search_vector']) + except Exception as e: + print(e) + + +class Migration(migrations.Migration): + dependencies = [ + ('embedding', '0001_initial'), + ] + + operations = [ + migrations.AddField( + model_name='embedding', + name='search_vector', + field=django.contrib.postgres.search.SearchVectorField(default='', verbose_name='分词'), + ), + migrations.RunPython(save_keywords) + ] diff --git a/apps/embedding/models/embedding.py b/apps/embedding/models/embedding.py index f7cc6bf31..24c78f41f 100644 --- a/apps/embedding/models/embedding.py +++ b/apps/embedding/models/embedding.py @@ -10,6 +10,7 @@ from django.db import models from common.field.vector_field import VectorField from dataset.models.data_set import Document, Paragraph, DataSet +from django.contrib.postgres.search import SearchVectorField class SourceType(models.TextChoices): @@ -19,6 +20,12 @@ class SourceType(models.TextChoices): TITLE = 2, '标题' +class SearchMode(models.TextChoices): + embedding = 'embedding' + keywords = 'keywords' + blend = 'blend' + + class Embedding(models.Model): id = models.CharField(max_length=128, primary_key=True, verbose_name="主键id") @@ -37,6 +44,8 @@ class Embedding(models.Model): embedding = VectorField(verbose_name="向量") + search_vector = SearchVectorField(verbose_name="分词", default="") + meta = models.JSONField(verbose_name="元数据", default=dict) class Meta: diff --git a/apps/embedding/sql/blend_search.sql b/apps/embedding/sql/blend_search.sql new file mode 100644 index 000000000..afb1f0040 --- /dev/null +++ b/apps/embedding/sql/blend_search.sql @@ -0,0 +1,26 @@ +SELECT + paragraph_id, + comprehensive_score, + comprehensive_score AS similarity +FROM + ( + SELECT DISTINCT ON + ( "paragraph_id" ) ( similarity ),* , + similarity AS comprehensive_score + FROM + ( + SELECT + *, + (( 1 - ( embedding.embedding <=> %s ) )+ts_rank_cd( embedding.search_vector, websearch_to_tsquery('simple', %s ), 32 )) AS similarity + FROM + embedding ${embedding_query} + ) TEMP + ORDER BY + paragraph_id, + similarity DESC + ) DISTINCT_TEMP +WHERE + comprehensive_score >%s +ORDER BY + comprehensive_score DESC + LIMIT %s \ No newline at end of file diff --git a/apps/embedding/sql/keywords_search.sql b/apps/embedding/sql/keywords_search.sql new file mode 100644 index 000000000..a27d0a694 --- /dev/null +++ b/apps/embedding/sql/keywords_search.sql @@ -0,0 +1,17 @@ +SELECT + paragraph_id, + comprehensive_score, + comprehensive_score as similarity +FROM + ( + SELECT DISTINCT ON + ("paragraph_id") ( similarity ),* ,similarity AS comprehensive_score + FROM + ( SELECT *,ts_rank_cd(embedding.search_vector,websearch_to_tsquery('simple',%s),32) AS similarity FROM embedding ${keywords_query}) TEMP + ORDER BY + paragraph_id, + similarity DESC + ) DISTINCT_TEMP +WHERE comprehensive_score>%s +ORDER BY comprehensive_score DESC +LIMIT %s \ No newline at end of file diff --git a/apps/embedding/vector/base_vector.py b/apps/embedding/vector/base_vector.py index c031e1143..496150dda 100644 --- a/apps/embedding/vector/base_vector.py +++ b/apps/embedding/vector/base_vector.py @@ -14,7 +14,7 @@ from langchain_community.embeddings import HuggingFaceEmbeddings from common.config.embedding_config import EmbeddingModel from common.util.common import sub_array -from embedding.models import SourceType +from embedding.models import SourceType, SearchMode lock = threading.Lock() @@ -113,13 +113,16 @@ class BaseVectorStore(ABC): return result[0] @abstractmethod - def query(self, query_embedding: List[float], dataset_id_list: list[str], exclude_document_id_list: list[str], - exclude_paragraph_list: list[str], is_active: bool, top_n: int, similarity: float): + def query(self, query_text:str,query_embedding: List[float], dataset_id_list: list[str], + exclude_document_id_list: list[str], + exclude_paragraph_list: list[str], is_active: bool, top_n: int, similarity: float, + search_mode: SearchMode): pass @abstractmethod def hit_test(self, query_text, dataset_id: list[str], exclude_document_id_list: list[str], top_number: int, similarity: float, + search_mode: SearchMode, embedding: HuggingFaceEmbeddings): pass diff --git a/apps/embedding/vector/pg_vector.py b/apps/embedding/vector/pg_vector.py index 6f9d25b2a..d5e5d125e 100644 --- a/apps/embedding/vector/pg_vector.py +++ b/apps/embedding/vector/pg_vector.py @@ -9,6 +9,7 @@ import json import os import uuid +from abc import ABC, abstractmethod from typing import Dict, List from django.db.models import QuerySet @@ -18,7 +19,8 @@ from common.config.embedding_config import EmbeddingModel from common.db.search import generate_sql_by_query_dict from common.db.sql_execute import select_list from common.util.file_util import get_file_content -from embedding.models import Embedding, SourceType +from common.util.ts_vecto_util import to_ts_vector, to_query +from embedding.models import Embedding, SourceType, SearchMode from embedding.vector.base_vector import BaseVectorStore from smartdoc.conf import PROJECT_DIR @@ -57,7 +59,8 @@ class PGVector(BaseVectorStore): paragraph_id=paragraph_id, source_id=source_id, embedding=text_embedding, - source_type=source_type) + source_type=source_type, + search_vector=to_ts_vector(text)) embedding.save() return True @@ -71,13 +74,15 @@ class PGVector(BaseVectorStore): is_active=text_list[index].get('is_active', True), source_id=text_list[index].get('source_id'), source_type=text_list[index].get('source_type'), - embedding=embeddings[index]) for index in + embedding=embeddings[index], + search_vector=to_ts_vector(text_list[index]['text'])) for index in range(0, len(text_list))] QuerySet(Embedding).bulk_create(embedding_list) if len(embedding_list) > 0 else None return True def hit_test(self, query_text, dataset_id_list: list[str], exclude_document_id_list: list[str], top_number: int, similarity: float, + search_mode: SearchMode, embedding: HuggingFaceEmbeddings): if dataset_id_list is None or len(dataset_id_list) == 0: return [] @@ -87,17 +92,14 @@ class PGVector(BaseVectorStore): if exclude_document_id_list is not None and len(exclude_document_id_list) > 0: exclude_dict.__setitem__('document_id__in', exclude_document_id_list) query_set = query_set.exclude(**exclude_dict) - exec_sql, exec_params = generate_sql_by_query_dict({'embedding_query': query_set}, - select_string=get_file_content( - os.path.join(PROJECT_DIR, "apps", "embedding", 'sql', - 'hit_test.sql')), - with_table_name=True) - embedding_model = select_list(exec_sql, - [json.dumps(embedding_query), *exec_params, similarity, top_number]) - return embedding_model + for search_handle in search_handle_list: + if search_handle.support(search_mode): + return search_handle.handle(query_set, query_text, embedding_query, top_number, similarity, search_mode) - def query(self, query_embedding: List[float], dataset_id_list: list[str], exclude_document_id_list: list[str], - exclude_paragraph_list: list[str], is_active: bool, top_n: int, similarity: float): + def query(self, query_text: str, query_embedding: List[float], dataset_id_list: list[str], + exclude_document_id_list: list[str], + exclude_paragraph_list: list[str], is_active: bool, top_n: int, similarity: float, + search_mode: SearchMode): exclude_dict = {} if dataset_id_list is None or len(dataset_id_list) == 0: return [] @@ -107,14 +109,9 @@ class PGVector(BaseVectorStore): if exclude_paragraph_list is not None and len(exclude_paragraph_list) > 0: exclude_dict.__setitem__('paragraph_id__in', exclude_paragraph_list) query_set = query_set.exclude(**exclude_dict) - exec_sql, exec_params = generate_sql_by_query_dict({'embedding_query': query_set}, - select_string=get_file_content( - os.path.join(PROJECT_DIR, "apps", "embedding", 'sql', - 'embedding_search.sql')), - with_table_name=True) - embedding_model = select_list(exec_sql, - [json.dumps(query_embedding), *exec_params, similarity, top_n]) - return embedding_model + for search_handle in search_handle_list: + if search_handle.support(search_mode): + return search_handle.handle(query_set, query_text, query_embedding, top_n, similarity, search_mode) def update_by_source_id(self, source_id: str, instance: Dict): QuerySet(Embedding).filter(source_id=source_id).update(**instance) @@ -141,3 +138,81 @@ class PGVector(BaseVectorStore): def delete_by_paragraph_id(self, paragraph_id: str): QuerySet(Embedding).filter(paragraph_id=paragraph_id).delete() + + +class ISearch(ABC): + @abstractmethod + def support(self, search_mode: SearchMode): + pass + + @abstractmethod + def handle(self, query_set, query_text, query_embedding, top_number: int, + similarity: float, search_mode: SearchMode): + pass + + +class EmbeddingSearch(ISearch): + def handle(self, + query_set, + query_text, + query_embedding, + top_number: int, + similarity: float, + search_mode: SearchMode): + exec_sql, exec_params = generate_sql_by_query_dict({'embedding_query': query_set}, + select_string=get_file_content( + os.path.join(PROJECT_DIR, "apps", "embedding", 'sql', + 'embedding_search.sql')), + with_table_name=True) + embedding_model = select_list(exec_sql, + [json.dumps(query_embedding), *exec_params, similarity, top_number]) + return embedding_model + + def support(self, search_mode: SearchMode): + return search_mode.value == SearchMode.embedding.value + + +class KeywordsSearch(ISearch): + def handle(self, + query_set, + query_text, + query_embedding, + top_number: int, + similarity: float, + search_mode: SearchMode): + exec_sql, exec_params = generate_sql_by_query_dict({'keywords_query': query_set}, + select_string=get_file_content( + os.path.join(PROJECT_DIR, "apps", "embedding", 'sql', + 'keywords_search.sql')), + with_table_name=True) + embedding_model = select_list(exec_sql, + [to_query(query_text), *exec_params, similarity, top_number]) + return embedding_model + + def support(self, search_mode: SearchMode): + return search_mode.value == SearchMode.keywords.value + + +class BlendSearch(ISearch): + def handle(self, + query_set, + query_text, + query_embedding, + top_number: int, + similarity: float, + search_mode: SearchMode): + exec_sql, exec_params = generate_sql_by_query_dict({'embedding_query': query_set}, + select_string=get_file_content( + os.path.join(PROJECT_DIR, "apps", "embedding", 'sql', + 'blend_search.sql')), + with_table_name=True) + embedding_model = select_list(exec_sql, + [json.dumps(query_embedding), to_query(query_text), *exec_params, similarity, + top_number]) + return embedding_model + + def support(self, search_mode: SearchMode): + return search_mode.value == SearchMode.blend.value + + +search_handle_list = [EmbeddingSearch(), KeywordsSearch(), BlendSearch()] diff --git a/ui/src/views/application/CreateAndSetting.vue b/ui/src/views/application/CreateAndSetting.vue index 7ae4c7ff7..cb09daab3 100644 --- a/ui/src/views/application/CreateAndSetting.vue +++ b/ui/src/views/application/CreateAndSetting.vue @@ -145,74 +145,12 @@
关联知识库
- - -
-
-
- 相似度高于 - - - -
-
- -
-
-
-
引用分段数 TOP
-
- -
-
- -
-
最多引用字符数
-
- -
-
-
-
- 取消 - 确认 -
-
- 添加 + + 参数设置 + + + 添加 +
@@ -221,13 +159,6 @@ >关联的知识库展示在这里 - + + + import { reactive, ref, watch, onMounted } from 'vue' import { useRouter, useRoute } from 'vue-router' -import { groupBy, cloneDeep } from 'lodash' +import { groupBy } from 'lodash' +import ParamSettingDialog from './components/ParamSettingDialog.vue' import AddDatasetDialog from './components/AddDatasetDialog.vue' import CreateModelDialog from '@/views/template/component/CreateModelDialog.vue' import SelectProviderDialog from '@/views/template/component/SelectProviderDialog.vue' @@ -363,6 +298,8 @@ const defaultPrompt = `已知信息: 问题: {question} ` + +const ParamSettingDialogRef = ref>() const createModelRef = ref>() const selectProviderRef = ref>() @@ -384,7 +321,8 @@ const applicationForm = ref({ dataset_setting: { top_n: 3, similarity: 0.6, - max_paragraph_char_number: 5000 + max_paragraph_char_number: 5000, + search_mode: 'embedding' }, model_setting: { prompt: defaultPrompt @@ -392,8 +330,6 @@ const applicationForm = ref({ problem_optimization: false }) -const popoverVisible = ref(false) - const rules = reactive>({ name: [{ required: true, message: '请输入应用名称', trigger: 'blur' }], model_id: [ @@ -408,17 +344,6 @@ const rules = reactive>({ const modelOptions = ref(null) const providerOptions = ref>([]) const datasetList = ref([]) -const dataset_setting = ref({}) - -function datasetSettingChange(val: string) { - if (val === 'open') { - popoverVisible.value = true - dataset_setting.value = cloneDeep(applicationForm.value.dataset_setting) - } else if (val === 'close') { - popoverVisible.value = false - applicationForm.value.dataset_setting = cloneDeep(dataset_setting.value) - } -} const submit = async (formEl: FormInstance | undefined) => { if (!formEl) return @@ -438,6 +363,14 @@ const submit = async (formEl: FormInstance | undefined) => { }) } +const openParamSettingDialog = () => { + ParamSettingDialogRef.value?.open(applicationForm.value.dataset_setting) +} + +function refreshParam(data: any) { + applicationForm.value.dataset_setting = data +} + const openCreateModel = (provider?: Provider) => { if (provider && provider.provider) { createModelRef.value?.open(provider) @@ -560,13 +493,4 @@ onMounted(() => { .prologue-md-editor { height: 150px; } -.dataset_setting { - color: var(--el-text-color-regular); - font-weight: 400; -} -.custom-slider { - :deep(.el-input-number.is-without-controls .el-input__wrapper) { - padding: 0 !important; - } -} diff --git a/ui/src/views/application/components/ParamSettingDialog.vue b/ui/src/views/application/components/ParamSettingDialog.vue new file mode 100644 index 000000000..4d05f2f50 --- /dev/null +++ b/ui/src/views/application/components/ParamSettingDialog.vue @@ -0,0 +1,152 @@ + + + diff --git a/ui/src/views/hit-test/index.vue b/ui/src/views/hit-test/index.vue index ad4d4ac62..ffc7bb240 100644 --- a/ui/src/views/hit-test/index.vue +++ b/ui/src/views/hit-test/index.vue @@ -78,11 +78,47 @@
- + - +
+
检索模式
+ + + +

向量检索

+ 通过向量距离计算与用户问题最相似的文本分段 +
+
+ + +

全文检索

+ 通过关键词检索,返回包含关键词最多的文本分段 +
+
+ + +

混合检索

+ 同时执行全文检索和向量检索,再进行重排序,从两类查询结果中选择匹配用户问题的最佳结果 +
+
+
+
相似度高于
@@ -103,7 +138,6 @@ :min="1" :max="10" controls-position="right" - style="width: 145px" />
@@ -161,7 +195,8 @@ const title = ref('') const inputValue = ref('') const formInline = ref({ similarity: 0.6, - top_number: 5 + top_number: 5, + search_mode: 'embedding' }) // 第一次加载 @@ -213,8 +248,7 @@ function sendChatHandle(event: any) { function getHitTestList() { const obj = { query_text: inputValue.value, - similarity: formInline.value.similarity, - top_number: formInline.value.top_number + ...formInline.value } if (isDataset.value) { datasetApi.getDatasetHitTest(id, obj, loading).then((res) => {