mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 10:02:46 +00:00
feat: add Embedding API for re-vectorization of knowledge objects
This commit is contained in:
parent
b6f5a8378a
commit
54836d549c
|
|
@ -244,3 +244,7 @@ class HitTestAPI(SyncWebAPI):
|
|||
@staticmethod
|
||||
def get_request():
|
||||
return HitTestSerializer
|
||||
|
||||
|
||||
class EmbeddingAPI(SyncWebAPI):
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ from knowledge.task.embedding import embedding_by_knowledge, delete_embedding_by
|
|||
from knowledge.task.generate import generate_related_by_knowledge_id
|
||||
from knowledge.task.sync import sync_web_knowledge, sync_replace_web_knowledge
|
||||
from maxkb.conf import PROJECT_DIR
|
||||
from models_provider.models import Model
|
||||
|
||||
|
||||
class KnowledgeModelSerializer(serializers.ModelSerializer):
|
||||
|
|
@ -80,6 +81,7 @@ class KnowledgeEditRequest(serializers.Serializer):
|
|||
valid_class = knowledge_meta_valid_map.get(knowledge.type)
|
||||
valid_class(data=self.data.get('meta')).is_valid(raise_exception=True)
|
||||
|
||||
|
||||
class HitTestSerializer(serializers.Serializer):
|
||||
query_text = serializers.CharField(required=True, label=_('query text'))
|
||||
top_number = serializers.IntegerField(required=True, max_value=10000, min_value=1, label=_("top number"))
|
||||
|
|
@ -89,6 +91,7 @@ class HitTestSerializer(serializers.Serializer):
|
|||
message=_('The type only supports embedding|keywords|blend'), code=500)
|
||||
])
|
||||
|
||||
|
||||
class KnowledgeSerializer(serializers.Serializer):
|
||||
class Query(serializers.Serializer):
|
||||
workspace_id = serializers.CharField(required=True)
|
||||
|
|
@ -157,6 +160,36 @@ class KnowledgeSerializer(serializers.Serializer):
|
|||
workspace_id = serializers.CharField(required=True, label=_('workspace id'))
|
||||
knowledge_id = serializers.UUIDField(required=True, label=_('knowledge id'))
|
||||
|
||||
@transaction.atomic
|
||||
def embedding(self, with_valid=True):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
knowledge_id = self.data.get('knowledge_id')
|
||||
knowledge = QuerySet(Knowledge).filter(id=knowledge_id).first()
|
||||
embedding_model_id = knowledge.embedding_mode_id
|
||||
knowledge_user_id = knowledge.user_id
|
||||
embedding_model = QuerySet(Model).filter(id=embedding_model_id).first()
|
||||
if embedding_model is None:
|
||||
raise AppApiException(500, _('Model does not exist'))
|
||||
if embedding_model.permission_type == 'PRIVATE' and knowledge_user_id != embedding_model.user_id:
|
||||
raise AppApiException(500, _('No permission to use this model') + f"{embedding_model.name}")
|
||||
ListenerManagement.update_status(
|
||||
QuerySet(Document).filter(knowledge_id=self.data.get('knowledge_id')),
|
||||
TaskType.EMBEDDING,
|
||||
State.PENDING
|
||||
)
|
||||
ListenerManagement.update_status(
|
||||
QuerySet(Paragraph).filter(knowledge_id=self.data.get('knowledge_id')),
|
||||
TaskType.EMBEDDING,
|
||||
State.PENDING
|
||||
)
|
||||
ListenerManagement.get_aggregation_document_status_by_knowledge_id(self.data.get('knowledge_id'))()
|
||||
embedding_model_id = get_embedding_model_id_by_knowledge_id(self.data.get('knowledge_id'))
|
||||
try:
|
||||
embedding_by_knowledge.delay(knowledge_id, embedding_model_id)
|
||||
except AlreadyQueued as e:
|
||||
raise AppApiException(500, _('Failed to send the vectorization task, please try again later!'))
|
||||
|
||||
def generate_related(self, instance: Dict, with_valid=True):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ urlpatterns = [
|
|||
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>', views.KnowledgeView.Operate.as_view()),
|
||||
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/sync', views.KnowledgeView.SyncWeb.as_view()),
|
||||
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/generate_related', views.KnowledgeView.GenerateRelated.as_view()),
|
||||
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/embedding', views.KnowledgeView.Embedding.as_view()),
|
||||
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/hit_test', views.KnowledgeView.HitTest.as_view()),
|
||||
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document', views.DocumentView.as_view()),
|
||||
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/split', views.DocumentView.Split.as_view()),
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from common.auth.authentication import has_permissions
|
|||
from common.constants.permission_constants import PermissionConstants
|
||||
from common.result import result
|
||||
from knowledge.api.knowledge import KnowledgeBaseCreateAPI, KnowledgeWebCreateAPI, KnowledgeTreeReadAPI, \
|
||||
KnowledgeEditAPI, KnowledgeReadAPI, KnowledgePageAPI, SyncWebAPI, GenerateRelatedAPI, HitTestAPI
|
||||
KnowledgeEditAPI, KnowledgeReadAPI, KnowledgePageAPI, SyncWebAPI, GenerateRelatedAPI, HitTestAPI, EmbeddingAPI
|
||||
from knowledge.serializers.knowledge import KnowledgeSerializer
|
||||
|
||||
|
||||
|
|
@ -160,6 +160,24 @@ class KnowledgeView(APIView):
|
|||
}
|
||||
).hit_test())
|
||||
|
||||
class Embedding(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
@extend_schema(
|
||||
methods=['PUT'],
|
||||
summary=_('Re-vectorize'),
|
||||
description=_('Re-vectorize'),
|
||||
operation_id=_('Re-vectorize'),
|
||||
parameters=EmbeddingAPI.get_parameters(),
|
||||
responses=EmbeddingAPI.get_response(),
|
||||
tags=[_('Knowledge Base')]
|
||||
)
|
||||
@has_permissions(PermissionConstants.KNOWLEDGE_EDIT.get_workspace_permission())
|
||||
def put(self, request: Request, workspace_id: str, knowledge_id: str):
|
||||
return result.success(KnowledgeSerializer.Operate(
|
||||
data={'knowledge_id': knowledge_id, 'workspace_id': workspace_id, 'user_id': request.user.id}
|
||||
).embedding())
|
||||
|
||||
class GenerateRelated(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue