mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
feat: 命中率测试接口
This commit is contained in:
parent
861db4ad11
commit
a7704c3a8a
|
|
@ -9,6 +9,7 @@
|
|||
import hashlib
|
||||
import os
|
||||
import uuid
|
||||
from functools import reduce
|
||||
from typing import Dict
|
||||
|
||||
from django.contrib.postgres.fields import ArrayField
|
||||
|
|
@ -20,12 +21,14 @@ from rest_framework import serializers
|
|||
|
||||
from application.models import Application, ApplicationDatasetMapping
|
||||
from application.models.api_key_model import ApplicationAccessToken, ApplicationApiKey
|
||||
from common.config.embedding_config import VectorStore, EmbeddingModel
|
||||
from common.constants.authentication_type import AuthenticationType
|
||||
from common.db.search import get_dynamics_model, native_search, native_page_search
|
||||
from common.db.sql_execute import select_list
|
||||
from common.exception.app_exception import AppApiException, NotFound404, AppAuthenticationFailed
|
||||
from common.exception.app_exception import AppApiException, NotFound404
|
||||
from common.util.file_util import get_file_content
|
||||
from dataset.models import DataSet
|
||||
from dataset.serializers.common_serializers import list_paragraph
|
||||
from setting.models import AuthOperate
|
||||
from setting.models.model_management import Model
|
||||
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
|
||||
|
|
@ -174,6 +177,33 @@ class ApplicationSerializer(serializers.Serializer):
|
|||
def to_application_dateset_mapping(application_id: str, dataset_id: str):
|
||||
return ApplicationDatasetMapping(id=uuid.uuid1(), application_id=application_id, dataset_id=dataset_id)
|
||||
|
||||
class HitTest(serializers.Serializer):
|
||||
id = serializers.CharField(required=True)
|
||||
user_id = serializers.UUIDField(required=False)
|
||||
query_text = serializers.CharField(required=True)
|
||||
top_number = serializers.IntegerField(required=True, max_value=10, min_value=1)
|
||||
similarity = serializers.FloatField(required=True, max_value=1, min_value=0)
|
||||
|
||||
def is_valid(self, *, raise_exception=False):
|
||||
super().is_valid(raise_exception=True)
|
||||
if not QuerySet(Application).filter(id=self.data.get('id')).exists():
|
||||
raise AppApiException(500, '不存在的应用id')
|
||||
|
||||
def hit_test(self):
|
||||
self.is_valid()
|
||||
vector = VectorStore.get_embedding_vector()
|
||||
# 向量库检索
|
||||
hit_list = vector.hit_test(self.data.get('query_text'), [ad.dataset_id for ad in
|
||||
QuerySet(ApplicationDatasetMapping).filter(
|
||||
application_id=self.data.get('id'))],
|
||||
self.data.get('top_number'),
|
||||
self.data.get('similarity'),
|
||||
EmbeddingModel.get_embedding_model())
|
||||
hit_dict = reduce(lambda x, y: {**x, **y}, [{hit.get('paragraph_id'): hit} for hit in hit_list], {})
|
||||
p_list = list_paragraph([h.get('paragraph_id') for h in hit_list])
|
||||
return [{**p, 'similarity': hit_dict.get(p.get('id')).get('similarity'),
|
||||
'comprehensive_score': hit_dict.get(p.get('id')).get('comprehensive_score')} for p in p_list]
|
||||
|
||||
class Query(serializers.Serializer):
|
||||
name = serializers.CharField(required=False)
|
||||
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ urlpatterns = [
|
|||
path('application', views.Application.as_view(), name="application"),
|
||||
path('application/profile', views.Application.Profile.as_view()),
|
||||
path('application/authentication', views.Application.Authentication.as_view()),
|
||||
path('application/<str:application_id>/hit_test', views.Application.HitTest.as_view()),
|
||||
path('application/<str:application_id>/api_key', views.Application.ApplicationKey.as_view()),
|
||||
path("application/<str:application_id>/api_key/<str:api_key_id>",
|
||||
views.Application.ApplicationKey.Operate.as_view()),
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ from common.constants.permission_constants import CompareConstants, PermissionCo
|
|||
ViewPermission, RoleConstants
|
||||
from common.exception.app_exception import AppAuthenticationFailed
|
||||
from common.response import result
|
||||
from common.swagger_api.common_api import CommonApi
|
||||
from common.util.common import query_params_to_single_dict
|
||||
from dataset.serializers.dataset_serializers import DataSetSerializers
|
||||
|
||||
|
|
@ -180,6 +181,28 @@ class Application(APIView):
|
|||
ApplicationSerializer.Query(
|
||||
data={**query_params_to_single_dict(request.query_params), 'user_id': request.user.id}).list())
|
||||
|
||||
class HitTest(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
@action(methods="GET", detail=False)
|
||||
@swagger_auto_schema(operation_summary="命中测试列表", operation_id="命中测试列表",
|
||||
manual_parameters=CommonApi.HitTestApi.get_request_params_api(),
|
||||
responses=result.get_api_array_response(CommonApi.HitTestApi.get_response_body_api()),
|
||||
tags=["应用"])
|
||||
@has_permissions(ViewPermission(
|
||||
[RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_ACCESS_TOKEN,
|
||||
RoleConstants.APPLICATION_KEY],
|
||||
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE,
|
||||
dynamic_tag=keywords.get('application_id'))],
|
||||
compare=CompareConstants.AND))
|
||||
def get(self, request: Request, application_id: str):
|
||||
return result.success(
|
||||
ApplicationSerializer.HitTest(data={'id': application_id, 'user_id': request.user.id,
|
||||
"query_text": request.query_params.get("query_text"),
|
||||
"top_number": request.query_params.get("top_number"),
|
||||
'similarity': request.query_params.get('similarity')}).hit_test(
|
||||
))
|
||||
|
||||
class Operate(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,78 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: common.py
|
||||
@date:2023/12/25 16:17
|
||||
@desc:
|
||||
"""
|
||||
from drf_yasg import openapi
|
||||
|
||||
from common.mixins.api_mixin import ApiMixin
|
||||
|
||||
|
||||
class CommonApi:
|
||||
class HitTestApi(ApiMixin):
|
||||
@staticmethod
|
||||
def get_request_params_api():
|
||||
return [
|
||||
openapi.Parameter(name='query_text',
|
||||
in_=openapi.IN_QUERY,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=True,
|
||||
description='问题文本'),
|
||||
openapi.Parameter(name='top_number',
|
||||
in_=openapi.IN_QUERY,
|
||||
type=openapi.TYPE_NUMBER,
|
||||
default=10,
|
||||
required=True,
|
||||
description='topN'),
|
||||
openapi.Parameter(name='similarity',
|
||||
in_=openapi.IN_QUERY,
|
||||
type=openapi.TYPE_NUMBER,
|
||||
default=0.6,
|
||||
required=True,
|
||||
description='相关性'),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_response_body_api():
|
||||
return openapi.Schema(
|
||||
type=openapi.TYPE_OBJECT,
|
||||
required=['id', 'content', 'hit_num', 'star_num', 'trample_num', 'is_active', 'dataset_id',
|
||||
'document_id', 'title',
|
||||
'similarity', 'comprehensive_score',
|
||||
'create_time', 'update_time'],
|
||||
properties={
|
||||
'id': openapi.Schema(type=openapi.TYPE_STRING, title="id",
|
||||
description="id", default="xx"),
|
||||
'content': openapi.Schema(type=openapi.TYPE_STRING, title="段落内容",
|
||||
description="段落内容", default='段落内容'),
|
||||
'title': openapi.Schema(type=openapi.TYPE_STRING, title="标题",
|
||||
description="标题", default="xxx的描述"),
|
||||
'hit_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="命中数量", description="命中数量",
|
||||
default=1),
|
||||
'star_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="点赞数量",
|
||||
description="点赞数量", default=1),
|
||||
'trample_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="点踩数量",
|
||||
description="点踩数", default=1),
|
||||
'dataset_id': openapi.Schema(type=openapi.TYPE_STRING, title="知识库id",
|
||||
description="知识库id", default='xxx'),
|
||||
'document_id': openapi.Schema(type=openapi.TYPE_STRING, title="文档id",
|
||||
description="文档id", default='xxx'),
|
||||
'is_active': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否可用",
|
||||
description="是否可用", default=True),
|
||||
'similarity': openapi.Schema(type=openapi.TYPE_NUMBER, title="相关性得分",
|
||||
description="相关性得分", default=True),
|
||||
'comprehensive_score': openapi.Schema(type=openapi.TYPE_NUMBER, title="综合得分,用于排序",
|
||||
description="综合得分,用于排序", default=True),
|
||||
'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间",
|
||||
description="修改时间",
|
||||
default="1970-01-01 00:00:00"),
|
||||
'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间",
|
||||
description="创建时间",
|
||||
default="1970-01-01 00:00:00"
|
||||
),
|
||||
|
||||
}
|
||||
)
|
||||
|
|
@ -7,9 +7,14 @@
|
|||
@desc:
|
||||
"""
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
from django.db.models import QuerySet
|
||||
|
||||
from common.db.search import native_search
|
||||
from common.db.sql_execute import update_execute
|
||||
from common.util.file_util import get_file_content
|
||||
from dataset.models import Paragraph
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
|
||||
|
|
@ -17,3 +22,10 @@ def update_document_char_length(document_id: str):
|
|||
update_execute(get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'update_document_char_length.sql')),
|
||||
(document_id, document_id))
|
||||
|
||||
|
||||
def list_paragraph(paragraph_list: List[str]):
|
||||
if paragraph_list is None or len(paragraph_list) == 0:
|
||||
return []
|
||||
return native_search(QuerySet(Paragraph).filter(id__in=paragraph_list), get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_paragraph.sql')))
|
||||
|
|
|
|||
|
|
@ -8,6 +8,8 @@
|
|||
"""
|
||||
import os.path
|
||||
import uuid
|
||||
from functools import reduce
|
||||
from itertools import groupby
|
||||
from typing import Dict
|
||||
|
||||
from django.contrib.postgres.fields import ArrayField
|
||||
|
|
@ -18,6 +20,7 @@ from drf_yasg import openapi
|
|||
from rest_framework import serializers
|
||||
|
||||
from application.models import ApplicationDatasetMapping
|
||||
from common.config.embedding_config import VectorStore, EmbeddingModel
|
||||
from common.db.search import get_dynamics_model, native_page_search, native_search
|
||||
from common.db.sql_execute import select_list
|
||||
from common.event.listener_manage import ListenerManagement
|
||||
|
|
@ -26,6 +29,7 @@ from common.mixins.api_mixin import ApiMixin
|
|||
from common.util.common import post
|
||||
from common.util.file_util import get_file_content
|
||||
from dataset.models.data_set import DataSet, Document, Paragraph, Problem
|
||||
from dataset.serializers.common_serializers import list_paragraph
|
||||
from dataset.serializers.document_serializers import DocumentSerializers, DocumentInstanceSerializer
|
||||
from setting.models import AuthOperate
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
|
@ -300,6 +304,30 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||
desc = serializers.CharField(required=False)
|
||||
application_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True))
|
||||
|
||||
class HitTest(ApiMixin, serializers.Serializer):
|
||||
id = serializers.CharField(required=True)
|
||||
user_id = serializers.UUIDField(required=False)
|
||||
query_text = serializers.CharField(required=True)
|
||||
top_number = serializers.IntegerField(required=True, max_value=10, min_value=1)
|
||||
similarity = serializers.FloatField(required=True, max_value=1, min_value=0)
|
||||
|
||||
def is_valid(self, *, raise_exception=True):
|
||||
super().is_valid(raise_exception=True)
|
||||
if not QuerySet(DataSet).filter(id=self.data.get("id")).exists():
|
||||
raise AppApiException(300, "id不存在")
|
||||
|
||||
def hit_test(self):
|
||||
self.is_valid()
|
||||
vector = VectorStore.get_embedding_vector()
|
||||
# 向量库检索
|
||||
hit_list = vector.hit_test(self.data.get('query_text'), [self.data.get('id')], self.data.get('top_number'),
|
||||
self.data.get('similarity'),
|
||||
EmbeddingModel.get_embedding_model())
|
||||
hit_dict = reduce(lambda x, y: {**x, **y}, [{hit.get('paragraph_id'): hit} for hit in hit_list], {})
|
||||
p_list = list_paragraph([h.get('paragraph_id') for h in hit_list])
|
||||
return [{**p, 'similarity': hit_dict.get(p.get('id')).get('similarity'),
|
||||
'comprehensive_score': hit_dict.get(p.get('id')).get('comprehensive_score')} for p in p_list]
|
||||
|
||||
class Operate(ApiMixin, serializers.Serializer):
|
||||
id = serializers.CharField(required=True)
|
||||
user_id = serializers.UUIDField(required=False)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,6 @@
|
|||
SELECT
|
||||
(SELECT "name" FROM "document" WHERE "id"=document_id) as document_name,
|
||||
(SELECT "name" FROM "dataset" WHERE "id"=dataset_id) as dataset_name,
|
||||
*
|
||||
FROM
|
||||
"paragraph"
|
||||
|
|
@ -8,6 +8,7 @@ urlpatterns = [
|
|||
path('dataset/<str:dataset_id>', views.Dataset.Operate.as_view(), name="dataset_key"),
|
||||
path('dataset/<str:dataset_id>/application', views.Dataset.Application.as_view()),
|
||||
path('dataset/<int:current_page>/<int:page_size>', views.Dataset.Page.as_view(), name="dataset"),
|
||||
path('dataset/<str:dataset_id>/hit_test', views.Dataset.HitTest.as_view()),
|
||||
path('dataset/<str:dataset_id>/document', views.Document.as_view(), name='document'),
|
||||
path('dataset/<str:dataset_id>/document/_bach', views.Document.Batch.as_view()),
|
||||
path('dataset/<str:dataset_id>/document/<int:current_page>/<int:page_size>', views.Document.Page.as_view()),
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from common.auth import TokenAuth, has_permissions
|
|||
from common.constants.permission_constants import PermissionConstants, CompareConstants, Permission, Group, Operate
|
||||
from common.response import result
|
||||
from common.response.result import get_page_request_params, get_page_api_response, get_api_response
|
||||
from common.swagger_api.common_api import CommonApi
|
||||
from dataset.serializers.dataset_serializers import DataSetSerializers
|
||||
|
||||
|
||||
|
|
@ -61,6 +62,24 @@ class Dataset(APIView):
|
|||
s.is_valid(raise_exception=True)
|
||||
return result.success(s.save(request.user))
|
||||
|
||||
class HitTest(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
@action(methods="GET", detail=False)
|
||||
@swagger_auto_schema(operation_summary="命中测试列表", operation_id="命中测试列表",
|
||||
manual_parameters=CommonApi.HitTestApi.get_request_params_api(),
|
||||
responses=result.get_api_array_response(CommonApi.HitTestApi.get_response_body_api()),
|
||||
tags=["知识库"])
|
||||
@has_permissions(lambda r, keywords: Permission(group=Group.DATASET, operate=Operate.USE,
|
||||
dynamic_tag=keywords.get('dataset_id')))
|
||||
def get(self, request: Request, dataset_id: str):
|
||||
return result.success(
|
||||
DataSetSerializers.HitTest(data={'id': dataset_id, 'user_id': request.user.id,
|
||||
"query_text": request.query_params.get("query_text"),
|
||||
"top_number": request.query_params.get("top_number"),
|
||||
'similarity': request.query_params.get('similarity')}).hit_test(
|
||||
))
|
||||
|
||||
class Operate(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,34 @@
|
|||
SELECT
|
||||
similarity,
|
||||
paragraph_id,
|
||||
comprehensive_score
|
||||
FROM
|
||||
(
|
||||
SELECT DISTINCT ON
|
||||
( "paragraph_id" ) ( similarity + score ),*,
|
||||
( similarity + score ) AS comprehensive_score
|
||||
FROM
|
||||
(
|
||||
SELECT
|
||||
*,
|
||||
( 1 - ( embedding.embedding <=> %s ) ) AS similarity,
|
||||
CASE
|
||||
|
||||
WHEN embedding.star_num - embedding.trample_num = 0 THEN
|
||||
0 ELSE ( ( ( embedding.star_num - embedding.trample_num ) - aggs.min_value ) / ( aggs.max_value - aggs.min_value ) )
|
||||
END AS score
|
||||
FROM
|
||||
embedding,
|
||||
( SELECT MIN ( star_num - trample_num ) AS min_value, MAX ( star_num - trample_num ) AS max_value FROM embedding ${embedding_query}) aggs
|
||||
${embedding_query}
|
||||
) TEMP
|
||||
WHERE
|
||||
similarity > %s
|
||||
ORDER BY
|
||||
paragraph_id,
|
||||
( similarity + score )
|
||||
DESC
|
||||
) ss
|
||||
ORDER BY
|
||||
comprehensive_score DESC
|
||||
LIMIT %s
|
||||
|
|
@ -115,6 +115,11 @@ class BaseVectorStore(ABC):
|
|||
embedding: HuggingFaceEmbeddings):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def hit_test(self, query_text, dataset_id: list[str], top_number: int, similarity: float,
|
||||
embedding: HuggingFaceEmbeddings):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_by_paragraph_id(self, paragraph_id: str, instance: Dict):
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from django.db.models import QuerySet
|
|||
from langchain.embeddings import HuggingFaceEmbeddings
|
||||
|
||||
from common.db.search import native_search, generate_sql_by_query_dict
|
||||
from common.db.sql_execute import select_one
|
||||
from common.db.sql_execute import select_one, select_list
|
||||
from common.util.file_util import get_file_content
|
||||
from embedding.models import Embedding, SourceType
|
||||
from embedding.vector.base_vector import BaseVectorStore
|
||||
|
|
@ -68,6 +68,19 @@ class PGVector(BaseVectorStore):
|
|||
QuerySet(Embedding).bulk_create(embedding_list) if len(embedding_list) > 0 else None
|
||||
return True
|
||||
|
||||
def hit_test(self, query_text, dataset_id_list: list[str], top_number: int, similarity: float,
|
||||
embedding: HuggingFaceEmbeddings):
|
||||
embedding_query = embedding.embed_query(query_text)
|
||||
query_set = QuerySet(Embedding).filter(dataset_id__in=dataset_id_list, is_active=True)
|
||||
exec_sql, exec_params = generate_sql_by_query_dict({'embedding_query': query_set},
|
||||
select_string=get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", "embedding", 'sql',
|
||||
'hit_test.sql')),
|
||||
with_table_name=True)
|
||||
embedding_model = select_list(exec_sql,
|
||||
[json.dumps(embedding_query), *exec_params, *exec_params, similarity, top_number])
|
||||
return embedding_model
|
||||
|
||||
def search(self, query_text, dataset_id_list: list[str], exclude_document_id_list: list[str],
|
||||
exclude_id_list: list[str],
|
||||
is_active: bool,
|
||||
|
|
|
|||
Loading…
Reference in New Issue