feat: add Embedding API for re-vectorization of knowledge objects

This commit is contained in:
CaptainB 2025-05-08 17:17:09 +08:00
parent b6f5a8378a
commit 54836d549c
4 changed files with 57 additions and 1 deletions

View File

@ -244,3 +244,7 @@ class HitTestAPI(SyncWebAPI):
@staticmethod
def get_request():
return HitTestSerializer
class EmbeddingAPI(SyncWebAPI):
pass

View File

@ -31,6 +31,7 @@ from knowledge.task.embedding import embedding_by_knowledge, delete_embedding_by
from knowledge.task.generate import generate_related_by_knowledge_id
from knowledge.task.sync import sync_web_knowledge, sync_replace_web_knowledge
from maxkb.conf import PROJECT_DIR
from models_provider.models import Model
class KnowledgeModelSerializer(serializers.ModelSerializer):
@ -80,6 +81,7 @@ class KnowledgeEditRequest(serializers.Serializer):
valid_class = knowledge_meta_valid_map.get(knowledge.type)
valid_class(data=self.data.get('meta')).is_valid(raise_exception=True)
class HitTestSerializer(serializers.Serializer):
query_text = serializers.CharField(required=True, label=_('query text'))
top_number = serializers.IntegerField(required=True, max_value=10000, min_value=1, label=_("top number"))
@ -89,6 +91,7 @@ class HitTestSerializer(serializers.Serializer):
message=_('The type only supports embedding|keywords|blend'), code=500)
])
class KnowledgeSerializer(serializers.Serializer):
class Query(serializers.Serializer):
workspace_id = serializers.CharField(required=True)
@ -157,6 +160,36 @@ class KnowledgeSerializer(serializers.Serializer):
workspace_id = serializers.CharField(required=True, label=_('workspace id'))
knowledge_id = serializers.UUIDField(required=True, label=_('knowledge id'))
@transaction.atomic
def embedding(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
knowledge_id = self.data.get('knowledge_id')
knowledge = QuerySet(Knowledge).filter(id=knowledge_id).first()
embedding_model_id = knowledge.embedding_mode_id
knowledge_user_id = knowledge.user_id
embedding_model = QuerySet(Model).filter(id=embedding_model_id).first()
if embedding_model is None:
raise AppApiException(500, _('Model does not exist'))
if embedding_model.permission_type == 'PRIVATE' and knowledge_user_id != embedding_model.user_id:
raise AppApiException(500, _('No permission to use this model') + f"{embedding_model.name}")
ListenerManagement.update_status(
QuerySet(Document).filter(knowledge_id=self.data.get('knowledge_id')),
TaskType.EMBEDDING,
State.PENDING
)
ListenerManagement.update_status(
QuerySet(Paragraph).filter(knowledge_id=self.data.get('knowledge_id')),
TaskType.EMBEDDING,
State.PENDING
)
ListenerManagement.get_aggregation_document_status_by_knowledge_id(self.data.get('knowledge_id'))()
embedding_model_id = get_embedding_model_id_by_knowledge_id(self.data.get('knowledge_id'))
try:
embedding_by_knowledge.delay(knowledge_id, embedding_model_id)
except AlreadyQueued as e:
raise AppApiException(500, _('Failed to send the vectorization task, please try again later!'))
def generate_related(self, instance: Dict, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)

View File

@ -10,6 +10,7 @@ urlpatterns = [
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>', views.KnowledgeView.Operate.as_view()),
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/sync', views.KnowledgeView.SyncWeb.as_view()),
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/generate_related', views.KnowledgeView.GenerateRelated.as_view()),
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/embedding', views.KnowledgeView.Embedding.as_view()),
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/hit_test', views.KnowledgeView.HitTest.as_view()),
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document', views.DocumentView.as_view()),
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/split', views.DocumentView.Split.as_view()),

View File

@ -8,7 +8,7 @@ from common.auth.authentication import has_permissions
from common.constants.permission_constants import PermissionConstants
from common.result import result
from knowledge.api.knowledge import KnowledgeBaseCreateAPI, KnowledgeWebCreateAPI, KnowledgeTreeReadAPI, \
KnowledgeEditAPI, KnowledgeReadAPI, KnowledgePageAPI, SyncWebAPI, GenerateRelatedAPI, HitTestAPI
KnowledgeEditAPI, KnowledgeReadAPI, KnowledgePageAPI, SyncWebAPI, GenerateRelatedAPI, HitTestAPI, EmbeddingAPI
from knowledge.serializers.knowledge import KnowledgeSerializer
@ -160,6 +160,24 @@ class KnowledgeView(APIView):
}
).hit_test())
class Embedding(APIView):
authentication_classes = [TokenAuth]
@extend_schema(
methods=['PUT'],
summary=_('Re-vectorize'),
description=_('Re-vectorize'),
operation_id=_('Re-vectorize'),
parameters=EmbeddingAPI.get_parameters(),
responses=EmbeddingAPI.get_response(),
tags=[_('Knowledge Base')]
)
@has_permissions(PermissionConstants.KNOWLEDGE_EDIT.get_workspace_permission())
def put(self, request: Request, workspace_id: str, knowledge_id: str):
return result.success(KnowledgeSerializer.Operate(
data={'knowledge_id': knowledge_id, 'workspace_id': workspace_id, 'user_id': request.user.id}
).embedding())
class GenerateRelated(APIView):
authentication_classes = [TokenAuth]