feat: 命中率测试接口

This commit is contained in:
shaohuzhang1 2023-12-25 17:10:59 +08:00
parent 861db4ad11
commit a7704c3a8a
12 changed files with 252 additions and 2 deletions

View File

@ -9,6 +9,7 @@
import hashlib
import os
import uuid
from functools import reduce
from typing import Dict
from django.contrib.postgres.fields import ArrayField
@ -20,12 +21,14 @@ from rest_framework import serializers
from application.models import Application, ApplicationDatasetMapping
from application.models.api_key_model import ApplicationAccessToken, ApplicationApiKey
from common.config.embedding_config import VectorStore, EmbeddingModel
from common.constants.authentication_type import AuthenticationType
from common.db.search import get_dynamics_model, native_search, native_page_search
from common.db.sql_execute import select_list
from common.exception.app_exception import AppApiException, NotFound404, AppAuthenticationFailed
from common.exception.app_exception import AppApiException, NotFound404
from common.util.file_util import get_file_content
from dataset.models import DataSet
from dataset.serializers.common_serializers import list_paragraph
from setting.models import AuthOperate
from setting.models.model_management import Model
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
@ -174,6 +177,33 @@ class ApplicationSerializer(serializers.Serializer):
def to_application_dateset_mapping(application_id: str, dataset_id: str):
return ApplicationDatasetMapping(id=uuid.uuid1(), application_id=application_id, dataset_id=dataset_id)
class HitTest(serializers.Serializer):
id = serializers.CharField(required=True)
user_id = serializers.UUIDField(required=False)
query_text = serializers.CharField(required=True)
top_number = serializers.IntegerField(required=True, max_value=10, min_value=1)
similarity = serializers.FloatField(required=True, max_value=1, min_value=0)
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
if not QuerySet(Application).filter(id=self.data.get('id')).exists():
raise AppApiException(500, '不存在的应用id')
def hit_test(self):
self.is_valid()
vector = VectorStore.get_embedding_vector()
# 向量库检索
hit_list = vector.hit_test(self.data.get('query_text'), [ad.dataset_id for ad in
QuerySet(ApplicationDatasetMapping).filter(
application_id=self.data.get('id'))],
self.data.get('top_number'),
self.data.get('similarity'),
EmbeddingModel.get_embedding_model())
hit_dict = reduce(lambda x, y: {**x, **y}, [{hit.get('paragraph_id'): hit} for hit in hit_list], {})
p_list = list_paragraph([h.get('paragraph_id') for h in hit_list])
return [{**p, 'similarity': hit_dict.get(p.get('id')).get('similarity'),
'comprehensive_score': hit_dict.get(p.get('id')).get('comprehensive_score')} for p in p_list]
class Query(serializers.Serializer):
name = serializers.CharField(required=False)

View File

@ -7,6 +7,7 @@ urlpatterns = [
path('application', views.Application.as_view(), name="application"),
path('application/profile', views.Application.Profile.as_view()),
path('application/authentication', views.Application.Authentication.as_view()),
path('application/<str:application_id>/hit_test', views.Application.HitTest.as_view()),
path('application/<str:application_id>/api_key', views.Application.ApplicationKey.as_view()),
path("application/<str:application_id>/api_key/<str:api_key_id>",
views.Application.ApplicationKey.Operate.as_view()),

View File

@ -19,6 +19,7 @@ from common.constants.permission_constants import CompareConstants, PermissionCo
ViewPermission, RoleConstants
from common.exception.app_exception import AppAuthenticationFailed
from common.response import result
from common.swagger_api.common_api import CommonApi
from common.util.common import query_params_to_single_dict
from dataset.serializers.dataset_serializers import DataSetSerializers
@ -180,6 +181,28 @@ class Application(APIView):
ApplicationSerializer.Query(
data={**query_params_to_single_dict(request.query_params), 'user_id': request.user.id}).list())
class HitTest(APIView):
authentication_classes = [TokenAuth]
@action(methods="GET", detail=False)
@swagger_auto_schema(operation_summary="命中测试列表", operation_id="命中测试列表",
manual_parameters=CommonApi.HitTestApi.get_request_params_api(),
responses=result.get_api_array_response(CommonApi.HitTestApi.get_response_body_api()),
tags=["应用"])
@has_permissions(ViewPermission(
[RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_ACCESS_TOKEN,
RoleConstants.APPLICATION_KEY],
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE,
dynamic_tag=keywords.get('application_id'))],
compare=CompareConstants.AND))
def get(self, request: Request, application_id: str):
return result.success(
ApplicationSerializer.HitTest(data={'id': application_id, 'user_id': request.user.id,
"query_text": request.query_params.get("query_text"),
"top_number": request.query_params.get("top_number"),
'similarity': request.query_params.get('similarity')}).hit_test(
))
class Operate(APIView):
authentication_classes = [TokenAuth]

View File

@ -0,0 +1,78 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file common.py
@date2023/12/25 16:17
@desc:
"""
from drf_yasg import openapi
from common.mixins.api_mixin import ApiMixin
class CommonApi:
class HitTestApi(ApiMixin):
@staticmethod
def get_request_params_api():
return [
openapi.Parameter(name='query_text',
in_=openapi.IN_QUERY,
type=openapi.TYPE_STRING,
required=True,
description='问题文本'),
openapi.Parameter(name='top_number',
in_=openapi.IN_QUERY,
type=openapi.TYPE_NUMBER,
default=10,
required=True,
description='topN'),
openapi.Parameter(name='similarity',
in_=openapi.IN_QUERY,
type=openapi.TYPE_NUMBER,
default=0.6,
required=True,
description='相关性'),
]
@staticmethod
def get_response_body_api():
return openapi.Schema(
type=openapi.TYPE_OBJECT,
required=['id', 'content', 'hit_num', 'star_num', 'trample_num', 'is_active', 'dataset_id',
'document_id', 'title',
'similarity', 'comprehensive_score',
'create_time', 'update_time'],
properties={
'id': openapi.Schema(type=openapi.TYPE_STRING, title="id",
description="id", default="xx"),
'content': openapi.Schema(type=openapi.TYPE_STRING, title="段落内容",
description="段落内容", default='段落内容'),
'title': openapi.Schema(type=openapi.TYPE_STRING, title="标题",
description="标题", default="xxx的描述"),
'hit_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="命中数量", description="命中数量",
default=1),
'star_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="点赞数量",
description="点赞数量", default=1),
'trample_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="点踩数量",
description="点踩数", default=1),
'dataset_id': openapi.Schema(type=openapi.TYPE_STRING, title="知识库id",
description="知识库id", default='xxx'),
'document_id': openapi.Schema(type=openapi.TYPE_STRING, title="文档id",
description="文档id", default='xxx'),
'is_active': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否可用",
description="是否可用", default=True),
'similarity': openapi.Schema(type=openapi.TYPE_NUMBER, title="相关性得分",
description="相关性得分", default=True),
'comprehensive_score': openapi.Schema(type=openapi.TYPE_NUMBER, title="综合得分,用于排序",
description="综合得分,用于排序", default=True),
'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间",
description="修改时间",
default="1970-01-01 00:00:00"),
'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间",
description="创建时间",
default="1970-01-01 00:00:00"
),
}
)

View File

@ -7,9 +7,14 @@
@desc:
"""
import os
from typing import List
from django.db.models import QuerySet
from common.db.search import native_search
from common.db.sql_execute import update_execute
from common.util.file_util import get_file_content
from dataset.models import Paragraph
from smartdoc.conf import PROJECT_DIR
@ -17,3 +22,10 @@ def update_document_char_length(document_id: str):
update_execute(get_file_content(
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'update_document_char_length.sql')),
(document_id, document_id))
def list_paragraph(paragraph_list: List[str]):
if paragraph_list is None or len(paragraph_list) == 0:
return []
return native_search(QuerySet(Paragraph).filter(id__in=paragraph_list), get_file_content(
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_paragraph.sql')))

View File

@ -8,6 +8,8 @@
"""
import os.path
import uuid
from functools import reduce
from itertools import groupby
from typing import Dict
from django.contrib.postgres.fields import ArrayField
@ -18,6 +20,7 @@ from drf_yasg import openapi
from rest_framework import serializers
from application.models import ApplicationDatasetMapping
from common.config.embedding_config import VectorStore, EmbeddingModel
from common.db.search import get_dynamics_model, native_page_search, native_search
from common.db.sql_execute import select_list
from common.event.listener_manage import ListenerManagement
@ -26,6 +29,7 @@ from common.mixins.api_mixin import ApiMixin
from common.util.common import post
from common.util.file_util import get_file_content
from dataset.models.data_set import DataSet, Document, Paragraph, Problem
from dataset.serializers.common_serializers import list_paragraph
from dataset.serializers.document_serializers import DocumentSerializers, DocumentInstanceSerializer
from setting.models import AuthOperate
from smartdoc.conf import PROJECT_DIR
@ -300,6 +304,30 @@ class DataSetSerializers(serializers.ModelSerializer):
desc = serializers.CharField(required=False)
application_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True))
class HitTest(ApiMixin, serializers.Serializer):
id = serializers.CharField(required=True)
user_id = serializers.UUIDField(required=False)
query_text = serializers.CharField(required=True)
top_number = serializers.IntegerField(required=True, max_value=10, min_value=1)
similarity = serializers.FloatField(required=True, max_value=1, min_value=0)
def is_valid(self, *, raise_exception=True):
super().is_valid(raise_exception=True)
if not QuerySet(DataSet).filter(id=self.data.get("id")).exists():
raise AppApiException(300, "id不存在")
def hit_test(self):
self.is_valid()
vector = VectorStore.get_embedding_vector()
# 向量库检索
hit_list = vector.hit_test(self.data.get('query_text'), [self.data.get('id')], self.data.get('top_number'),
self.data.get('similarity'),
EmbeddingModel.get_embedding_model())
hit_dict = reduce(lambda x, y: {**x, **y}, [{hit.get('paragraph_id'): hit} for hit in hit_list], {})
p_list = list_paragraph([h.get('paragraph_id') for h in hit_list])
return [{**p, 'similarity': hit_dict.get(p.get('id')).get('similarity'),
'comprehensive_score': hit_dict.get(p.get('id')).get('comprehensive_score')} for p in p_list]
class Operate(ApiMixin, serializers.Serializer):
id = serializers.CharField(required=True)
user_id = serializers.UUIDField(required=False)

View File

@ -0,0 +1,6 @@
SELECT
(SELECT "name" FROM "document" WHERE "id"=document_id) as document_name,
(SELECT "name" FROM "dataset" WHERE "id"=dataset_id) as dataset_name,
*
FROM
"paragraph"

View File

@ -8,6 +8,7 @@ urlpatterns = [
path('dataset/<str:dataset_id>', views.Dataset.Operate.as_view(), name="dataset_key"),
path('dataset/<str:dataset_id>/application', views.Dataset.Application.as_view()),
path('dataset/<int:current_page>/<int:page_size>', views.Dataset.Page.as_view(), name="dataset"),
path('dataset/<str:dataset_id>/hit_test', views.Dataset.HitTest.as_view()),
path('dataset/<str:dataset_id>/document', views.Document.as_view(), name='document'),
path('dataset/<str:dataset_id>/document/_bach', views.Document.Batch.as_view()),
path('dataset/<str:dataset_id>/document/<int:current_page>/<int:page_size>', views.Document.Page.as_view()),

View File

@ -16,6 +16,7 @@ from common.auth import TokenAuth, has_permissions
from common.constants.permission_constants import PermissionConstants, CompareConstants, Permission, Group, Operate
from common.response import result
from common.response.result import get_page_request_params, get_page_api_response, get_api_response
from common.swagger_api.common_api import CommonApi
from dataset.serializers.dataset_serializers import DataSetSerializers
@ -61,6 +62,24 @@ class Dataset(APIView):
s.is_valid(raise_exception=True)
return result.success(s.save(request.user))
class HitTest(APIView):
authentication_classes = [TokenAuth]
@action(methods="GET", detail=False)
@swagger_auto_schema(operation_summary="命中测试列表", operation_id="命中测试列表",
manual_parameters=CommonApi.HitTestApi.get_request_params_api(),
responses=result.get_api_array_response(CommonApi.HitTestApi.get_response_body_api()),
tags=["知识库"])
@has_permissions(lambda r, keywords: Permission(group=Group.DATASET, operate=Operate.USE,
dynamic_tag=keywords.get('dataset_id')))
def get(self, request: Request, dataset_id: str):
return result.success(
DataSetSerializers.HitTest(data={'id': dataset_id, 'user_id': request.user.id,
"query_text": request.query_params.get("query_text"),
"top_number": request.query_params.get("top_number"),
'similarity': request.query_params.get('similarity')}).hit_test(
))
class Operate(APIView):
authentication_classes = [TokenAuth]

View File

@ -0,0 +1,34 @@
SELECT
similarity,
paragraph_id,
comprehensive_score
FROM
(
SELECT DISTINCT ON
( "paragraph_id" ) ( similarity + score ),*,
( similarity + score ) AS comprehensive_score
FROM
(
SELECT
*,
( 1 - ( embedding.embedding <=> %s ) ) AS similarity,
CASE
WHEN embedding.star_num - embedding.trample_num = 0 THEN
0 ELSE ( ( ( embedding.star_num - embedding.trample_num ) - aggs.min_value ) / ( aggs.max_value - aggs.min_value ) )
END AS score
FROM
embedding,
( SELECT MIN ( star_num - trample_num ) AS min_value, MAX ( star_num - trample_num ) AS max_value FROM embedding ${embedding_query}) aggs
${embedding_query}
) TEMP
WHERE
similarity > %s
ORDER BY
paragraph_id,
( similarity + score )
DESC
) ss
ORDER BY
comprehensive_score DESC
LIMIT %s

View File

@ -115,6 +115,11 @@ class BaseVectorStore(ABC):
embedding: HuggingFaceEmbeddings):
pass
@abstractmethod
def hit_test(self, query_text, dataset_id: list[str], top_number: int, similarity: float,
embedding: HuggingFaceEmbeddings):
pass
@abstractmethod
def update_by_paragraph_id(self, paragraph_id: str, instance: Dict):
pass

View File

@ -15,7 +15,7 @@ from django.db.models import QuerySet
from langchain.embeddings import HuggingFaceEmbeddings
from common.db.search import native_search, generate_sql_by_query_dict
from common.db.sql_execute import select_one
from common.db.sql_execute import select_one, select_list
from common.util.file_util import get_file_content
from embedding.models import Embedding, SourceType
from embedding.vector.base_vector import BaseVectorStore
@ -68,6 +68,19 @@ class PGVector(BaseVectorStore):
QuerySet(Embedding).bulk_create(embedding_list) if len(embedding_list) > 0 else None
return True
def hit_test(self, query_text, dataset_id_list: list[str], top_number: int, similarity: float,
embedding: HuggingFaceEmbeddings):
embedding_query = embedding.embed_query(query_text)
query_set = QuerySet(Embedding).filter(dataset_id__in=dataset_id_list, is_active=True)
exec_sql, exec_params = generate_sql_by_query_dict({'embedding_query': query_set},
select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "embedding", 'sql',
'hit_test.sql')),
with_table_name=True)
embedding_model = select_list(exec_sql,
[json.dumps(embedding_query), *exec_params, *exec_params, similarity, top_number])
return embedding_model
def search(self, query_text, dataset_id_list: list[str], exclude_document_id_list: list[str],
exclude_id_list: list[str],
is_active: bool,