mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-27 20:42:52 +00:00
feat: add HitTest API for knowledge base query testing and update SyncWeb API to use knowledge_id
This commit is contained in:
parent
7dcd1a71e8
commit
b6f5a8378a
|
|
@ -5,7 +5,7 @@ from common.mixins.api_mixin import APIMixin
|
|||
from common.result import ResultSerializer, DefaultResultSerializer
|
||||
from knowledge.serializers.common import GenerateRelatedSerializer
|
||||
from knowledge.serializers.knowledge import KnowledgeBaseCreateRequest, KnowledgeModelSerializer, KnowledgeEditRequest, \
|
||||
KnowledgeWebCreateRequest
|
||||
KnowledgeWebCreateRequest, HitTestSerializer
|
||||
|
||||
|
||||
class KnowledgeCreateResponse(ResultSerializer):
|
||||
|
|
@ -238,3 +238,9 @@ class GenerateRelatedAPI(SyncWebAPI):
|
|||
@staticmethod
|
||||
def get_request():
|
||||
return GenerateRelatedSerializer
|
||||
|
||||
|
||||
class HitTestAPI(SyncWebAPI):
|
||||
@staticmethod
|
||||
def get_request():
|
||||
return HitTestSerializer
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from django.db.models.functions import Reverse, Substr
|
|||
from django.utils.translation import gettext_lazy as _
|
||||
from rest_framework import serializers
|
||||
|
||||
from common.config.embedding_config import VectorStore
|
||||
from common.db.search import native_search, get_dynamics_model, native_page_search
|
||||
from common.db.sql_execute import select_list
|
||||
from common.event import ListenerManagement
|
||||
|
|
@ -22,9 +23,9 @@ from common.utils.common import valid_license, post, get_file_content
|
|||
from common.utils.fork import Fork, ChildLink
|
||||
from common.utils.split_model import get_split_model
|
||||
from knowledge.models import Knowledge, KnowledgeScope, KnowledgeType, Document, Paragraph, Problem, \
|
||||
ProblemParagraphMapping, ApplicationKnowledgeMapping, TaskType, State
|
||||
ProblemParagraphMapping, ApplicationKnowledgeMapping, TaskType, State, SearchMode
|
||||
from knowledge.serializers.common import ProblemParagraphManage, get_embedding_model_id_by_knowledge_id, MetaSerializer, \
|
||||
GenerateRelatedSerializer
|
||||
GenerateRelatedSerializer, get_embedding_model_by_knowledge_id, list_paragraph
|
||||
from knowledge.serializers.document import DocumentSerializers
|
||||
from knowledge.task.embedding import embedding_by_knowledge, delete_embedding_by_knowledge
|
||||
from knowledge.task.generate import generate_related_by_knowledge_id
|
||||
|
|
@ -79,6 +80,14 @@ class KnowledgeEditRequest(serializers.Serializer):
|
|||
valid_class = knowledge_meta_valid_map.get(knowledge.type)
|
||||
valid_class(data=self.data.get('meta')).is_valid(raise_exception=True)
|
||||
|
||||
class HitTestSerializer(serializers.Serializer):
|
||||
query_text = serializers.CharField(required=True, label=_('query text'))
|
||||
top_number = serializers.IntegerField(required=True, max_value=10000, min_value=1, label=_("top number"))
|
||||
similarity = serializers.FloatField(required=True, max_value=2, min_value=0, label=_('similarity'))
|
||||
search_mode = serializers.CharField(required=True, label=_('search mode'), validators=[
|
||||
validators.RegexValidator(regex=re.compile("^embedding|keywords|blend$"),
|
||||
message=_('The type only supports embedding|keywords|blend'), code=500)
|
||||
])
|
||||
|
||||
class KnowledgeSerializer(serializers.Serializer):
|
||||
class Query(serializers.Serializer):
|
||||
|
|
@ -152,7 +161,7 @@ class KnowledgeSerializer(serializers.Serializer):
|
|||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
GenerateRelatedSerializer(data=instance).is_valid(raise_exception=True)
|
||||
knowledge_id = self.data.get('id')
|
||||
knowledge_id = self.data.get('knowledge_id')
|
||||
model_id = instance.get("model_id")
|
||||
prompt = instance.get("prompt")
|
||||
state_list = instance.get('state_list')
|
||||
|
|
@ -382,7 +391,8 @@ class KnowledgeSerializer(serializers.Serializer):
|
|||
return {**KnowledgeModelSerializer(knowledge).data, 'document_list': []}
|
||||
|
||||
class SyncWeb(serializers.Serializer):
|
||||
id = serializers.CharField(required=True, label=_('knowledge id'))
|
||||
workspace_id = serializers.CharField(required=True, label=_('workspace id'))
|
||||
knowledge_id = serializers.CharField(required=True, label=_('knowledge id'))
|
||||
user_id = serializers.UUIDField(required=False, label=_('user id'))
|
||||
sync_type = serializers.CharField(required=True, label=_('sync type'), validators=[
|
||||
validators.RegexValidator(regex=re.compile("^replace|complete$"),
|
||||
|
|
@ -390,7 +400,7 @@ class KnowledgeSerializer(serializers.Serializer):
|
|||
|
||||
def is_valid(self, *, raise_exception=False):
|
||||
super().is_valid(raise_exception=True)
|
||||
first = QuerySet(Knowledge).filter(id=self.data.get("id")).first()
|
||||
first = QuerySet(Knowledge).filter(id=self.data.get("knowledge_id")).first()
|
||||
if first is None:
|
||||
raise AppApiException(300, _('id does not exist'))
|
||||
if first.type != KnowledgeType.WEB:
|
||||
|
|
@ -400,7 +410,7 @@ class KnowledgeSerializer(serializers.Serializer):
|
|||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
sync_type = self.data.get('sync_type')
|
||||
knowledge_id = self.data.get('id')
|
||||
knowledge_id = self.data.get('knowledge_id')
|
||||
knowledge = QuerySet(Knowledge).get(id=knowledge_id)
|
||||
self.__getattribute__(sync_type + '_sync')(knowledge)
|
||||
return True
|
||||
|
|
@ -454,6 +464,52 @@ class KnowledgeSerializer(serializers.Serializer):
|
|||
# 删除段落
|
||||
QuerySet(Paragraph).filter(knowledge=knowledge).delete()
|
||||
# 删除向量
|
||||
delete_embedding_by_knowledge(self.data.get('id'))
|
||||
delete_embedding_by_knowledge(self.data.get('knowledge_id'))
|
||||
# 同步
|
||||
self.replace_sync(knowledge)
|
||||
|
||||
class HitTest(serializers.Serializer):
|
||||
workspace_id = serializers.CharField(required=True, label=_('workspace id'))
|
||||
knowledge_id = serializers.UUIDField(required=True, label=_("id"))
|
||||
user_id = serializers.UUIDField(required=False, label=_('user id'))
|
||||
query_text = serializers.CharField(required=True, label=_('query text'))
|
||||
top_number = serializers.IntegerField(required=True, max_value=10000, min_value=1, label=_("top number"))
|
||||
similarity = serializers.FloatField(required=True, max_value=2, min_value=0, label=_('similarity'))
|
||||
search_mode = serializers.CharField(required=True, label=_('search mode'), validators=[
|
||||
validators.RegexValidator(regex=re.compile("^embedding|keywords|blend$"),
|
||||
message=_('The type only supports embedding|keywords|blend'), code=500)
|
||||
])
|
||||
|
||||
def is_valid(self, *, raise_exception=True):
|
||||
super().is_valid(raise_exception=True)
|
||||
if not QuerySet(Knowledge).filter(id=self.data.get("knowledge_id")).exists():
|
||||
raise AppApiException(300, _('id does not exist'))
|
||||
|
||||
def hit_test(self):
|
||||
self.is_valid()
|
||||
vector = VectorStore.get_embedding_vector()
|
||||
exclude_document_id_list = [
|
||||
str(
|
||||
document.id
|
||||
) for document in QuerySet(Document).filter(knowledge_id=self.data.get('knowledge_id'), is_active=False)
|
||||
]
|
||||
model = get_embedding_model_by_knowledge_id(self.data.get('knowledge_id'))
|
||||
# 向量库检索
|
||||
hit_list = vector.hit_test(
|
||||
self.data.get('query_text'),
|
||||
[self.data.get('knowledge_id')],
|
||||
exclude_document_id_list,
|
||||
self.data.get('top_number'),
|
||||
self.data.get('similarity'),
|
||||
SearchMode(self.data.get('search_mode')),
|
||||
model
|
||||
)
|
||||
hit_dict = reduce(lambda x, y: {**x, **y}, [{hit.get('paragraph_id'): hit} for hit in hit_list], {})
|
||||
p_list = list_paragraph([h.get('paragraph_id') for h in hit_list])
|
||||
return [
|
||||
{
|
||||
**p,
|
||||
'similarity': hit_dict.get(p.get('id')).get('similarity'),
|
||||
'comprehensive_score': hit_dict.get(p.get('id')).get('comprehensive_score')
|
||||
} for p in p_list
|
||||
]
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ urlpatterns = [
|
|||
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>', views.KnowledgeView.Operate.as_view()),
|
||||
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/sync', views.KnowledgeView.SyncWeb.as_view()),
|
||||
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/generate_related', views.KnowledgeView.GenerateRelated.as_view()),
|
||||
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/hit_test', views.KnowledgeView.HitTest.as_view()),
|
||||
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document', views.DocumentView.as_view()),
|
||||
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/split', views.DocumentView.Split.as_view()),
|
||||
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/split_pattern', views.DocumentView.SplitPattern.as_view()),
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from common.auth.authentication import has_permissions
|
|||
from common.constants.permission_constants import PermissionConstants
|
||||
from common.result import result
|
||||
from knowledge.api.knowledge import KnowledgeBaseCreateAPI, KnowledgeWebCreateAPI, KnowledgeTreeReadAPI, \
|
||||
KnowledgeEditAPI, KnowledgeReadAPI, KnowledgePageAPI, SyncWebAPI, GenerateRelatedAPI
|
||||
KnowledgeEditAPI, KnowledgeReadAPI, KnowledgePageAPI, SyncWebAPI, GenerateRelatedAPI, HitTestAPI
|
||||
from knowledge.serializers.knowledge import KnowledgeSerializer
|
||||
|
||||
|
||||
|
|
@ -128,11 +128,38 @@ class KnowledgeView(APIView):
|
|||
data={
|
||||
'workspace_id': workspace_id,
|
||||
'sync_type': request.query_params.get('sync_type'),
|
||||
'id': knowledge_id,
|
||||
'knowledge_id': knowledge_id,
|
||||
'user_id': str(request.user.id)
|
||||
}
|
||||
).sync())
|
||||
|
||||
class HitTest(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
@extend_schema(
|
||||
methods=['PUT'],
|
||||
summary=_('Hit test list'),
|
||||
description=_('Hit test list'),
|
||||
operation_id=_('Hit test list'),
|
||||
parameters=HitTestAPI.get_parameters(),
|
||||
request=HitTestAPI.get_request(),
|
||||
responses=HitTestAPI.get_response(),
|
||||
tags=[_('Knowledge Base')]
|
||||
)
|
||||
@has_permissions(PermissionConstants.KNOWLEDGE_EDIT.get_workspace_permission())
|
||||
def put(self, request: Request, workspace_id: str, knowledge_id: str):
|
||||
return result.success(KnowledgeSerializer.HitTest(
|
||||
data={
|
||||
'workspace_id': workspace_id,
|
||||
'knowledge_id': knowledge_id,
|
||||
'user_id': request.user.id,
|
||||
"query_text": request.query_params.get("query_text"),
|
||||
"top_number": request.query_params.get("top_number"),
|
||||
'similarity': request.query_params.get('similarity'),
|
||||
'search_mode': request.query_params.get('search_mode')
|
||||
}
|
||||
).hit_test())
|
||||
|
||||
class GenerateRelated(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue