diff --git a/apps/knowledge/api/knowledge.py b/apps/knowledge/api/knowledge.py index 1ac00180a..d334ac3b1 100644 --- a/apps/knowledge/api/knowledge.py +++ b/apps/knowledge/api/knowledge.py @@ -244,3 +244,7 @@ class HitTestAPI(SyncWebAPI): @staticmethod def get_request(): return HitTestSerializer + + +class EmbeddingAPI(SyncWebAPI): + pass diff --git a/apps/knowledge/serializers/knowledge.py b/apps/knowledge/serializers/knowledge.py index 3477b506c..7aee5bfde 100644 --- a/apps/knowledge/serializers/knowledge.py +++ b/apps/knowledge/serializers/knowledge.py @@ -31,6 +31,7 @@ from knowledge.task.embedding import embedding_by_knowledge, delete_embedding_by from knowledge.task.generate import generate_related_by_knowledge_id from knowledge.task.sync import sync_web_knowledge, sync_replace_web_knowledge from maxkb.conf import PROJECT_DIR +from models_provider.models import Model class KnowledgeModelSerializer(serializers.ModelSerializer): @@ -80,6 +81,7 @@ class KnowledgeEditRequest(serializers.Serializer): valid_class = knowledge_meta_valid_map.get(knowledge.type) valid_class(data=self.data.get('meta')).is_valid(raise_exception=True) + class HitTestSerializer(serializers.Serializer): query_text = serializers.CharField(required=True, label=_('query text')) top_number = serializers.IntegerField(required=True, max_value=10000, min_value=1, label=_("top number")) @@ -89,6 +91,7 @@ class HitTestSerializer(serializers.Serializer): message=_('The type only supports embedding|keywords|blend'), code=500) ]) + class KnowledgeSerializer(serializers.Serializer): class Query(serializers.Serializer): workspace_id = serializers.CharField(required=True) @@ -157,6 +160,36 @@ class KnowledgeSerializer(serializers.Serializer): workspace_id = serializers.CharField(required=True, label=_('workspace id')) knowledge_id = serializers.UUIDField(required=True, label=_('knowledge id')) + @transaction.atomic + def embedding(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + knowledge_id = self.data.get('knowledge_id') + knowledge = QuerySet(Knowledge).filter(id=knowledge_id).first() + embedding_model_id = knowledge.embedding_mode_id + knowledge_user_id = knowledge.user_id + embedding_model = QuerySet(Model).filter(id=embedding_model_id).first() + if embedding_model is None: + raise AppApiException(500, _('Model does not exist')) + if embedding_model.permission_type == 'PRIVATE' and knowledge_user_id != embedding_model.user_id: + raise AppApiException(500, _('No permission to use this model') + f"{embedding_model.name}") + ListenerManagement.update_status( + QuerySet(Document).filter(knowledge_id=self.data.get('knowledge_id')), + TaskType.EMBEDDING, + State.PENDING + ) + ListenerManagement.update_status( + QuerySet(Paragraph).filter(knowledge_id=self.data.get('knowledge_id')), + TaskType.EMBEDDING, + State.PENDING + ) + ListenerManagement.get_aggregation_document_status_by_knowledge_id(self.data.get('knowledge_id'))() + embedding_model_id = get_embedding_model_id_by_knowledge_id(self.data.get('knowledge_id')) + try: + embedding_by_knowledge.delay(knowledge_id, embedding_model_id) + except AlreadyQueued as e: + raise AppApiException(500, _('Failed to send the vectorization task, please try again later!')) + def generate_related(self, instance: Dict, with_valid=True): if with_valid: self.is_valid(raise_exception=True) diff --git a/apps/knowledge/urls.py b/apps/knowledge/urls.py index 4463b3288..12187e38a 100644 --- a/apps/knowledge/urls.py +++ b/apps/knowledge/urls.py @@ -10,6 +10,7 @@ urlpatterns = [ path('workspace//knowledge/', views.KnowledgeView.Operate.as_view()), path('workspace//knowledge//sync', views.KnowledgeView.SyncWeb.as_view()), path('workspace//knowledge//generate_related', views.KnowledgeView.GenerateRelated.as_view()), + path('workspace//knowledge//embedding', views.KnowledgeView.Embedding.as_view()), path('workspace//knowledge//hit_test', views.KnowledgeView.HitTest.as_view()), path('workspace//knowledge//document', views.DocumentView.as_view()), path('workspace//knowledge//document/split', views.DocumentView.Split.as_view()), diff --git a/apps/knowledge/views/knowledge.py b/apps/knowledge/views/knowledge.py index 7d8cde6a8..64b7f4644 100644 --- a/apps/knowledge/views/knowledge.py +++ b/apps/knowledge/views/knowledge.py @@ -8,7 +8,7 @@ from common.auth.authentication import has_permissions from common.constants.permission_constants import PermissionConstants from common.result import result from knowledge.api.knowledge import KnowledgeBaseCreateAPI, KnowledgeWebCreateAPI, KnowledgeTreeReadAPI, \ - KnowledgeEditAPI, KnowledgeReadAPI, KnowledgePageAPI, SyncWebAPI, GenerateRelatedAPI, HitTestAPI + KnowledgeEditAPI, KnowledgeReadAPI, KnowledgePageAPI, SyncWebAPI, GenerateRelatedAPI, HitTestAPI, EmbeddingAPI from knowledge.serializers.knowledge import KnowledgeSerializer @@ -160,6 +160,24 @@ class KnowledgeView(APIView): } ).hit_test()) + class Embedding(APIView): + authentication_classes = [TokenAuth] + + @extend_schema( + methods=['PUT'], + summary=_('Re-vectorize'), + description=_('Re-vectorize'), + operation_id=_('Re-vectorize'), + parameters=EmbeddingAPI.get_parameters(), + responses=EmbeddingAPI.get_response(), + tags=[_('Knowledge Base')] + ) + @has_permissions(PermissionConstants.KNOWLEDGE_EDIT.get_workspace_permission()) + def put(self, request: Request, workspace_id: str, knowledge_id: str): + return result.success(KnowledgeSerializer.Operate( + data={'knowledge_id': knowledge_id, 'workspace_id': workspace_id, 'user_id': request.user.id} + ).embedding()) + class GenerateRelated(APIView): authentication_classes = [TokenAuth]