From a7704c3a8af8255cb222e912c13b189f20038d8a Mon Sep 17 00:00:00 2001 From: shaohuzhang1 Date: Mon, 25 Dec 2023 17:10:59 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=91=BD=E4=B8=AD=E7=8E=87=E6=B5=8B?= =?UTF-8?q?=E8=AF=95=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../serializers/application_serializers.py | 32 +++++++- apps/application/urls.py | 1 + apps/application/views/application_views.py | 23 ++++++ apps/common/swagger_api/common_api.py | 78 +++++++++++++++++++ .../dataset/serializers/common_serializers.py | 12 +++ .../serializers/dataset_serializers.py | 28 +++++++ apps/dataset/sql/list_paragraph.sql | 6 ++ apps/dataset/urls.py | 1 + apps/dataset/views/dataset.py | 19 +++++ apps/embedding/sql/hit_test.sql | 34 ++++++++ apps/embedding/vector/base_vector.py | 5 ++ apps/embedding/vector/pg_vector.py | 15 +++- 12 files changed, 252 insertions(+), 2 deletions(-) create mode 100644 apps/common/swagger_api/common_api.py create mode 100644 apps/dataset/sql/list_paragraph.sql create mode 100644 apps/embedding/sql/hit_test.sql diff --git a/apps/application/serializers/application_serializers.py b/apps/application/serializers/application_serializers.py index cb91e465d..8bd858358 100644 --- a/apps/application/serializers/application_serializers.py +++ b/apps/application/serializers/application_serializers.py @@ -9,6 +9,7 @@ import hashlib import os import uuid +from functools import reduce from typing import Dict from django.contrib.postgres.fields import ArrayField @@ -20,12 +21,14 @@ from rest_framework import serializers from application.models import Application, ApplicationDatasetMapping from application.models.api_key_model import ApplicationAccessToken, ApplicationApiKey +from common.config.embedding_config import VectorStore, EmbeddingModel from common.constants.authentication_type import AuthenticationType from common.db.search import get_dynamics_model, native_search, native_page_search from common.db.sql_execute import select_list -from common.exception.app_exception import AppApiException, NotFound404, AppAuthenticationFailed +from common.exception.app_exception import AppApiException, NotFound404 from common.util.file_util import get_file_content from dataset.models import DataSet +from dataset.serializers.common_serializers import list_paragraph from setting.models import AuthOperate from setting.models.model_management import Model from setting.models_provider.constants.model_provider_constants import ModelProvideConstants @@ -174,6 +177,33 @@ class ApplicationSerializer(serializers.Serializer): def to_application_dateset_mapping(application_id: str, dataset_id: str): return ApplicationDatasetMapping(id=uuid.uuid1(), application_id=application_id, dataset_id=dataset_id) + class HitTest(serializers.Serializer): + id = serializers.CharField(required=True) + user_id = serializers.UUIDField(required=False) + query_text = serializers.CharField(required=True) + top_number = serializers.IntegerField(required=True, max_value=10, min_value=1) + similarity = serializers.FloatField(required=True, max_value=1, min_value=0) + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + if not QuerySet(Application).filter(id=self.data.get('id')).exists(): + raise AppApiException(500, '不存在的应用id') + + def hit_test(self): + self.is_valid() + vector = VectorStore.get_embedding_vector() + # 向量库检索 + hit_list = vector.hit_test(self.data.get('query_text'), [ad.dataset_id for ad in + QuerySet(ApplicationDatasetMapping).filter( + application_id=self.data.get('id'))], + self.data.get('top_number'), + self.data.get('similarity'), + EmbeddingModel.get_embedding_model()) + hit_dict = reduce(lambda x, y: {**x, **y}, [{hit.get('paragraph_id'): hit} for hit in hit_list], {}) + p_list = list_paragraph([h.get('paragraph_id') for h in hit_list]) + return [{**p, 'similarity': hit_dict.get(p.get('id')).get('similarity'), + 'comprehensive_score': hit_dict.get(p.get('id')).get('comprehensive_score')} for p in p_list] + class Query(serializers.Serializer): name = serializers.CharField(required=False) diff --git a/apps/application/urls.py b/apps/application/urls.py index 89da2788e..ef74a8830 100644 --- a/apps/application/urls.py +++ b/apps/application/urls.py @@ -7,6 +7,7 @@ urlpatterns = [ path('application', views.Application.as_view(), name="application"), path('application/profile', views.Application.Profile.as_view()), path('application/authentication', views.Application.Authentication.as_view()), + path('application//hit_test', views.Application.HitTest.as_view()), path('application//api_key', views.Application.ApplicationKey.as_view()), path("application//api_key/", views.Application.ApplicationKey.Operate.as_view()), diff --git a/apps/application/views/application_views.py b/apps/application/views/application_views.py index afe802a5e..fb234c665 100644 --- a/apps/application/views/application_views.py +++ b/apps/application/views/application_views.py @@ -19,6 +19,7 @@ from common.constants.permission_constants import CompareConstants, PermissionCo ViewPermission, RoleConstants from common.exception.app_exception import AppAuthenticationFailed from common.response import result +from common.swagger_api.common_api import CommonApi from common.util.common import query_params_to_single_dict from dataset.serializers.dataset_serializers import DataSetSerializers @@ -180,6 +181,28 @@ class Application(APIView): ApplicationSerializer.Query( data={**query_params_to_single_dict(request.query_params), 'user_id': request.user.id}).list()) + class HitTest(APIView): + authentication_classes = [TokenAuth] + + @action(methods="GET", detail=False) + @swagger_auto_schema(operation_summary="命中测试列表", operation_id="命中测试列表", + manual_parameters=CommonApi.HitTestApi.get_request_params_api(), + responses=result.get_api_array_response(CommonApi.HitTestApi.get_response_body_api()), + tags=["应用"]) + @has_permissions(ViewPermission( + [RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_ACCESS_TOKEN, + RoleConstants.APPLICATION_KEY], + [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, + dynamic_tag=keywords.get('application_id'))], + compare=CompareConstants.AND)) + def get(self, request: Request, application_id: str): + return result.success( + ApplicationSerializer.HitTest(data={'id': application_id, 'user_id': request.user.id, + "query_text": request.query_params.get("query_text"), + "top_number": request.query_params.get("top_number"), + 'similarity': request.query_params.get('similarity')}).hit_test( + )) + class Operate(APIView): authentication_classes = [TokenAuth] diff --git a/apps/common/swagger_api/common_api.py b/apps/common/swagger_api/common_api.py new file mode 100644 index 000000000..71a876ae0 --- /dev/null +++ b/apps/common/swagger_api/common_api.py @@ -0,0 +1,78 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: common.py + @date:2023/12/25 16:17 + @desc: +""" +from drf_yasg import openapi + +from common.mixins.api_mixin import ApiMixin + + +class CommonApi: + class HitTestApi(ApiMixin): + @staticmethod + def get_request_params_api(): + return [ + openapi.Parameter(name='query_text', + in_=openapi.IN_QUERY, + type=openapi.TYPE_STRING, + required=True, + description='问题文本'), + openapi.Parameter(name='top_number', + in_=openapi.IN_QUERY, + type=openapi.TYPE_NUMBER, + default=10, + required=True, + description='topN'), + openapi.Parameter(name='similarity', + in_=openapi.IN_QUERY, + type=openapi.TYPE_NUMBER, + default=0.6, + required=True, + description='相关性'), + ] + + @staticmethod + def get_response_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['id', 'content', 'hit_num', 'star_num', 'trample_num', 'is_active', 'dataset_id', + 'document_id', 'title', + 'similarity', 'comprehensive_score', + 'create_time', 'update_time'], + properties={ + 'id': openapi.Schema(type=openapi.TYPE_STRING, title="id", + description="id", default="xx"), + 'content': openapi.Schema(type=openapi.TYPE_STRING, title="段落内容", + description="段落内容", default='段落内容'), + 'title': openapi.Schema(type=openapi.TYPE_STRING, title="标题", + description="标题", default="xxx的描述"), + 'hit_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="命中数量", description="命中数量", + default=1), + 'star_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="点赞数量", + description="点赞数量", default=1), + 'trample_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="点踩数量", + description="点踩数", default=1), + 'dataset_id': openapi.Schema(type=openapi.TYPE_STRING, title="知识库id", + description="知识库id", default='xxx'), + 'document_id': openapi.Schema(type=openapi.TYPE_STRING, title="文档id", + description="文档id", default='xxx'), + 'is_active': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否可用", + description="是否可用", default=True), + 'similarity': openapi.Schema(type=openapi.TYPE_NUMBER, title="相关性得分", + description="相关性得分", default=True), + 'comprehensive_score': openapi.Schema(type=openapi.TYPE_NUMBER, title="综合得分,用于排序", + description="综合得分,用于排序", default=True), + 'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间", + description="修改时间", + default="1970-01-01 00:00:00"), + 'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间", + description="创建时间", + default="1970-01-01 00:00:00" + ), + + } + ) diff --git a/apps/dataset/serializers/common_serializers.py b/apps/dataset/serializers/common_serializers.py index 7d888fdb0..77f3bca85 100644 --- a/apps/dataset/serializers/common_serializers.py +++ b/apps/dataset/serializers/common_serializers.py @@ -7,9 +7,14 @@ @desc: """ import os +from typing import List +from django.db.models import QuerySet + +from common.db.search import native_search from common.db.sql_execute import update_execute from common.util.file_util import get_file_content +from dataset.models import Paragraph from smartdoc.conf import PROJECT_DIR @@ -17,3 +22,10 @@ def update_document_char_length(document_id: str): update_execute(get_file_content( os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'update_document_char_length.sql')), (document_id, document_id)) + + +def list_paragraph(paragraph_list: List[str]): + if paragraph_list is None or len(paragraph_list) == 0: + return [] + return native_search(QuerySet(Paragraph).filter(id__in=paragraph_list), get_file_content( + os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_paragraph.sql'))) diff --git a/apps/dataset/serializers/dataset_serializers.py b/apps/dataset/serializers/dataset_serializers.py index fad4a1a45..5c0b5d4f9 100644 --- a/apps/dataset/serializers/dataset_serializers.py +++ b/apps/dataset/serializers/dataset_serializers.py @@ -8,6 +8,8 @@ """ import os.path import uuid +from functools import reduce +from itertools import groupby from typing import Dict from django.contrib.postgres.fields import ArrayField @@ -18,6 +20,7 @@ from drf_yasg import openapi from rest_framework import serializers from application.models import ApplicationDatasetMapping +from common.config.embedding_config import VectorStore, EmbeddingModel from common.db.search import get_dynamics_model, native_page_search, native_search from common.db.sql_execute import select_list from common.event.listener_manage import ListenerManagement @@ -26,6 +29,7 @@ from common.mixins.api_mixin import ApiMixin from common.util.common import post from common.util.file_util import get_file_content from dataset.models.data_set import DataSet, Document, Paragraph, Problem +from dataset.serializers.common_serializers import list_paragraph from dataset.serializers.document_serializers import DocumentSerializers, DocumentInstanceSerializer from setting.models import AuthOperate from smartdoc.conf import PROJECT_DIR @@ -300,6 +304,30 @@ class DataSetSerializers(serializers.ModelSerializer): desc = serializers.CharField(required=False) application_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True)) + class HitTest(ApiMixin, serializers.Serializer): + id = serializers.CharField(required=True) + user_id = serializers.UUIDField(required=False) + query_text = serializers.CharField(required=True) + top_number = serializers.IntegerField(required=True, max_value=10, min_value=1) + similarity = serializers.FloatField(required=True, max_value=1, min_value=0) + + def is_valid(self, *, raise_exception=True): + super().is_valid(raise_exception=True) + if not QuerySet(DataSet).filter(id=self.data.get("id")).exists(): + raise AppApiException(300, "id不存在") + + def hit_test(self): + self.is_valid() + vector = VectorStore.get_embedding_vector() + # 向量库检索 + hit_list = vector.hit_test(self.data.get('query_text'), [self.data.get('id')], self.data.get('top_number'), + self.data.get('similarity'), + EmbeddingModel.get_embedding_model()) + hit_dict = reduce(lambda x, y: {**x, **y}, [{hit.get('paragraph_id'): hit} for hit in hit_list], {}) + p_list = list_paragraph([h.get('paragraph_id') for h in hit_list]) + return [{**p, 'similarity': hit_dict.get(p.get('id')).get('similarity'), + 'comprehensive_score': hit_dict.get(p.get('id')).get('comprehensive_score')} for p in p_list] + class Operate(ApiMixin, serializers.Serializer): id = serializers.CharField(required=True) user_id = serializers.UUIDField(required=False) diff --git a/apps/dataset/sql/list_paragraph.sql b/apps/dataset/sql/list_paragraph.sql new file mode 100644 index 000000000..2256f3f92 --- /dev/null +++ b/apps/dataset/sql/list_paragraph.sql @@ -0,0 +1,6 @@ +SELECT + (SELECT "name" FROM "document" WHERE "id"=document_id) as document_name, + (SELECT "name" FROM "dataset" WHERE "id"=dataset_id) as dataset_name, + * +FROM + "paragraph" diff --git a/apps/dataset/urls.py b/apps/dataset/urls.py index bceccf274..16078842e 100644 --- a/apps/dataset/urls.py +++ b/apps/dataset/urls.py @@ -8,6 +8,7 @@ urlpatterns = [ path('dataset/', views.Dataset.Operate.as_view(), name="dataset_key"), path('dataset//application', views.Dataset.Application.as_view()), path('dataset//', views.Dataset.Page.as_view(), name="dataset"), + path('dataset//hit_test', views.Dataset.HitTest.as_view()), path('dataset//document', views.Document.as_view(), name='document'), path('dataset//document/_bach', views.Document.Batch.as_view()), path('dataset//document//', views.Document.Page.as_view()), diff --git a/apps/dataset/views/dataset.py b/apps/dataset/views/dataset.py index 4db678ea7..0721ee025 100644 --- a/apps/dataset/views/dataset.py +++ b/apps/dataset/views/dataset.py @@ -16,6 +16,7 @@ from common.auth import TokenAuth, has_permissions from common.constants.permission_constants import PermissionConstants, CompareConstants, Permission, Group, Operate from common.response import result from common.response.result import get_page_request_params, get_page_api_response, get_api_response +from common.swagger_api.common_api import CommonApi from dataset.serializers.dataset_serializers import DataSetSerializers @@ -61,6 +62,24 @@ class Dataset(APIView): s.is_valid(raise_exception=True) return result.success(s.save(request.user)) + class HitTest(APIView): + authentication_classes = [TokenAuth] + + @action(methods="GET", detail=False) + @swagger_auto_schema(operation_summary="命中测试列表", operation_id="命中测试列表", + manual_parameters=CommonApi.HitTestApi.get_request_params_api(), + responses=result.get_api_array_response(CommonApi.HitTestApi.get_response_body_api()), + tags=["知识库"]) + @has_permissions(lambda r, keywords: Permission(group=Group.DATASET, operate=Operate.USE, + dynamic_tag=keywords.get('dataset_id'))) + def get(self, request: Request, dataset_id: str): + return result.success( + DataSetSerializers.HitTest(data={'id': dataset_id, 'user_id': request.user.id, + "query_text": request.query_params.get("query_text"), + "top_number": request.query_params.get("top_number"), + 'similarity': request.query_params.get('similarity')}).hit_test( + )) + class Operate(APIView): authentication_classes = [TokenAuth] diff --git a/apps/embedding/sql/hit_test.sql b/apps/embedding/sql/hit_test.sql new file mode 100644 index 000000000..141feee4e --- /dev/null +++ b/apps/embedding/sql/hit_test.sql @@ -0,0 +1,34 @@ +SELECT + similarity, + paragraph_id, + comprehensive_score +FROM + ( + SELECT DISTINCT ON + ( "paragraph_id" ) ( similarity + score ),*, + ( similarity + score ) AS comprehensive_score + FROM + ( + SELECT + *, + ( 1 - ( embedding.embedding <=> %s ) ) AS similarity, + CASE + + WHEN embedding.star_num - embedding.trample_num = 0 THEN + 0 ELSE ( ( ( embedding.star_num - embedding.trample_num ) - aggs.min_value ) / ( aggs.max_value - aggs.min_value ) ) + END AS score + FROM + embedding, + ( SELECT MIN ( star_num - trample_num ) AS min_value, MAX ( star_num - trample_num ) AS max_value FROM embedding ${embedding_query}) aggs + ${embedding_query} + ) TEMP + WHERE + similarity > %s + ORDER BY + paragraph_id, + ( similarity + score ) + DESC + ) ss +ORDER BY + comprehensive_score DESC + LIMIT %s \ No newline at end of file diff --git a/apps/embedding/vector/base_vector.py b/apps/embedding/vector/base_vector.py index b1dda12a8..a5c14f413 100644 --- a/apps/embedding/vector/base_vector.py +++ b/apps/embedding/vector/base_vector.py @@ -115,6 +115,11 @@ class BaseVectorStore(ABC): embedding: HuggingFaceEmbeddings): pass + @abstractmethod + def hit_test(self, query_text, dataset_id: list[str], top_number: int, similarity: float, + embedding: HuggingFaceEmbeddings): + pass + @abstractmethod def update_by_paragraph_id(self, paragraph_id: str, instance: Dict): pass diff --git a/apps/embedding/vector/pg_vector.py b/apps/embedding/vector/pg_vector.py index b58e964d5..79f8be9bf 100644 --- a/apps/embedding/vector/pg_vector.py +++ b/apps/embedding/vector/pg_vector.py @@ -15,7 +15,7 @@ from django.db.models import QuerySet from langchain.embeddings import HuggingFaceEmbeddings from common.db.search import native_search, generate_sql_by_query_dict -from common.db.sql_execute import select_one +from common.db.sql_execute import select_one, select_list from common.util.file_util import get_file_content from embedding.models import Embedding, SourceType from embedding.vector.base_vector import BaseVectorStore @@ -68,6 +68,19 @@ class PGVector(BaseVectorStore): QuerySet(Embedding).bulk_create(embedding_list) if len(embedding_list) > 0 else None return True + def hit_test(self, query_text, dataset_id_list: list[str], top_number: int, similarity: float, + embedding: HuggingFaceEmbeddings): + embedding_query = embedding.embed_query(query_text) + query_set = QuerySet(Embedding).filter(dataset_id__in=dataset_id_list, is_active=True) + exec_sql, exec_params = generate_sql_by_query_dict({'embedding_query': query_set}, + select_string=get_file_content( + os.path.join(PROJECT_DIR, "apps", "embedding", 'sql', + 'hit_test.sql')), + with_table_name=True) + embedding_model = select_list(exec_sql, + [json.dumps(embedding_query), *exec_params, *exec_params, similarity, top_number]) + return embedding_model + def search(self, query_text, dataset_id_list: list[str], exclude_document_id_list: list[str], exclude_id_list: list[str], is_active: bool,