diff --git a/apps/knowledge/api/knowledge.py b/apps/knowledge/api/knowledge.py index 077595a4f..1ac00180a 100644 --- a/apps/knowledge/api/knowledge.py +++ b/apps/knowledge/api/knowledge.py @@ -5,7 +5,7 @@ from common.mixins.api_mixin import APIMixin from common.result import ResultSerializer, DefaultResultSerializer from knowledge.serializers.common import GenerateRelatedSerializer from knowledge.serializers.knowledge import KnowledgeBaseCreateRequest, KnowledgeModelSerializer, KnowledgeEditRequest, \ - KnowledgeWebCreateRequest + KnowledgeWebCreateRequest, HitTestSerializer class KnowledgeCreateResponse(ResultSerializer): @@ -238,3 +238,9 @@ class GenerateRelatedAPI(SyncWebAPI): @staticmethod def get_request(): return GenerateRelatedSerializer + + +class HitTestAPI(SyncWebAPI): + @staticmethod + def get_request(): + return HitTestSerializer diff --git a/apps/knowledge/serializers/knowledge.py b/apps/knowledge/serializers/knowledge.py index 630ba1eb7..3477b506c 100644 --- a/apps/knowledge/serializers/knowledge.py +++ b/apps/knowledge/serializers/knowledge.py @@ -14,6 +14,7 @@ from django.db.models.functions import Reverse, Substr from django.utils.translation import gettext_lazy as _ from rest_framework import serializers +from common.config.embedding_config import VectorStore from common.db.search import native_search, get_dynamics_model, native_page_search from common.db.sql_execute import select_list from common.event import ListenerManagement @@ -22,9 +23,9 @@ from common.utils.common import valid_license, post, get_file_content from common.utils.fork import Fork, ChildLink from common.utils.split_model import get_split_model from knowledge.models import Knowledge, KnowledgeScope, KnowledgeType, Document, Paragraph, Problem, \ - ProblemParagraphMapping, ApplicationKnowledgeMapping, TaskType, State + ProblemParagraphMapping, ApplicationKnowledgeMapping, TaskType, State, SearchMode from knowledge.serializers.common import ProblemParagraphManage, get_embedding_model_id_by_knowledge_id, MetaSerializer, \ - GenerateRelatedSerializer + GenerateRelatedSerializer, get_embedding_model_by_knowledge_id, list_paragraph from knowledge.serializers.document import DocumentSerializers from knowledge.task.embedding import embedding_by_knowledge, delete_embedding_by_knowledge from knowledge.task.generate import generate_related_by_knowledge_id @@ -79,6 +80,14 @@ class KnowledgeEditRequest(serializers.Serializer): valid_class = knowledge_meta_valid_map.get(knowledge.type) valid_class(data=self.data.get('meta')).is_valid(raise_exception=True) +class HitTestSerializer(serializers.Serializer): + query_text = serializers.CharField(required=True, label=_('query text')) + top_number = serializers.IntegerField(required=True, max_value=10000, min_value=1, label=_("top number")) + similarity = serializers.FloatField(required=True, max_value=2, min_value=0, label=_('similarity')) + search_mode = serializers.CharField(required=True, label=_('search mode'), validators=[ + validators.RegexValidator(regex=re.compile("^embedding|keywords|blend$"), + message=_('The type only supports embedding|keywords|blend'), code=500) + ]) class KnowledgeSerializer(serializers.Serializer): class Query(serializers.Serializer): @@ -152,7 +161,7 @@ class KnowledgeSerializer(serializers.Serializer): if with_valid: self.is_valid(raise_exception=True) GenerateRelatedSerializer(data=instance).is_valid(raise_exception=True) - knowledge_id = self.data.get('id') + knowledge_id = self.data.get('knowledge_id') model_id = instance.get("model_id") prompt = instance.get("prompt") state_list = instance.get('state_list') @@ -382,7 +391,8 @@ class KnowledgeSerializer(serializers.Serializer): return {**KnowledgeModelSerializer(knowledge).data, 'document_list': []} class SyncWeb(serializers.Serializer): - id = serializers.CharField(required=True, label=_('knowledge id')) + workspace_id = serializers.CharField(required=True, label=_('workspace id')) + knowledge_id = serializers.CharField(required=True, label=_('knowledge id')) user_id = serializers.UUIDField(required=False, label=_('user id')) sync_type = serializers.CharField(required=True, label=_('sync type'), validators=[ validators.RegexValidator(regex=re.compile("^replace|complete$"), @@ -390,7 +400,7 @@ class KnowledgeSerializer(serializers.Serializer): def is_valid(self, *, raise_exception=False): super().is_valid(raise_exception=True) - first = QuerySet(Knowledge).filter(id=self.data.get("id")).first() + first = QuerySet(Knowledge).filter(id=self.data.get("knowledge_id")).first() if first is None: raise AppApiException(300, _('id does not exist')) if first.type != KnowledgeType.WEB: @@ -400,7 +410,7 @@ class KnowledgeSerializer(serializers.Serializer): if with_valid: self.is_valid(raise_exception=True) sync_type = self.data.get('sync_type') - knowledge_id = self.data.get('id') + knowledge_id = self.data.get('knowledge_id') knowledge = QuerySet(Knowledge).get(id=knowledge_id) self.__getattribute__(sync_type + '_sync')(knowledge) return True @@ -454,6 +464,52 @@ class KnowledgeSerializer(serializers.Serializer): # 删除段落 QuerySet(Paragraph).filter(knowledge=knowledge).delete() # 删除向量 - delete_embedding_by_knowledge(self.data.get('id')) + delete_embedding_by_knowledge(self.data.get('knowledge_id')) # 同步 self.replace_sync(knowledge) + + class HitTest(serializers.Serializer): + workspace_id = serializers.CharField(required=True, label=_('workspace id')) + knowledge_id = serializers.UUIDField(required=True, label=_("id")) + user_id = serializers.UUIDField(required=False, label=_('user id')) + query_text = serializers.CharField(required=True, label=_('query text')) + top_number = serializers.IntegerField(required=True, max_value=10000, min_value=1, label=_("top number")) + similarity = serializers.FloatField(required=True, max_value=2, min_value=0, label=_('similarity')) + search_mode = serializers.CharField(required=True, label=_('search mode'), validators=[ + validators.RegexValidator(regex=re.compile("^embedding|keywords|blend$"), + message=_('The type only supports embedding|keywords|blend'), code=500) + ]) + + def is_valid(self, *, raise_exception=True): + super().is_valid(raise_exception=True) + if not QuerySet(Knowledge).filter(id=self.data.get("knowledge_id")).exists(): + raise AppApiException(300, _('id does not exist')) + + def hit_test(self): + self.is_valid() + vector = VectorStore.get_embedding_vector() + exclude_document_id_list = [ + str( + document.id + ) for document in QuerySet(Document).filter(knowledge_id=self.data.get('knowledge_id'), is_active=False) + ] + model = get_embedding_model_by_knowledge_id(self.data.get('knowledge_id')) + # 向量库检索 + hit_list = vector.hit_test( + self.data.get('query_text'), + [self.data.get('knowledge_id')], + exclude_document_id_list, + self.data.get('top_number'), + self.data.get('similarity'), + SearchMode(self.data.get('search_mode')), + model + ) + hit_dict = reduce(lambda x, y: {**x, **y}, [{hit.get('paragraph_id'): hit} for hit in hit_list], {}) + p_list = list_paragraph([h.get('paragraph_id') for h in hit_list]) + return [ + { + **p, + 'similarity': hit_dict.get(p.get('id')).get('similarity'), + 'comprehensive_score': hit_dict.get(p.get('id')).get('comprehensive_score') + } for p in p_list + ] diff --git a/apps/knowledge/urls.py b/apps/knowledge/urls.py index 09e87177a..4463b3288 100644 --- a/apps/knowledge/urls.py +++ b/apps/knowledge/urls.py @@ -10,6 +10,7 @@ urlpatterns = [ path('workspace//knowledge/', views.KnowledgeView.Operate.as_view()), path('workspace//knowledge//sync', views.KnowledgeView.SyncWeb.as_view()), path('workspace//knowledge//generate_related', views.KnowledgeView.GenerateRelated.as_view()), + path('workspace//knowledge//hit_test', views.KnowledgeView.HitTest.as_view()), path('workspace//knowledge//document', views.DocumentView.as_view()), path('workspace//knowledge//document/split', views.DocumentView.Split.as_view()), path('workspace//knowledge//document/split_pattern', views.DocumentView.SplitPattern.as_view()), diff --git a/apps/knowledge/views/knowledge.py b/apps/knowledge/views/knowledge.py index f53473518..7d8cde6a8 100644 --- a/apps/knowledge/views/knowledge.py +++ b/apps/knowledge/views/knowledge.py @@ -8,7 +8,7 @@ from common.auth.authentication import has_permissions from common.constants.permission_constants import PermissionConstants from common.result import result from knowledge.api.knowledge import KnowledgeBaseCreateAPI, KnowledgeWebCreateAPI, KnowledgeTreeReadAPI, \ - KnowledgeEditAPI, KnowledgeReadAPI, KnowledgePageAPI, SyncWebAPI, GenerateRelatedAPI + KnowledgeEditAPI, KnowledgeReadAPI, KnowledgePageAPI, SyncWebAPI, GenerateRelatedAPI, HitTestAPI from knowledge.serializers.knowledge import KnowledgeSerializer @@ -128,11 +128,38 @@ class KnowledgeView(APIView): data={ 'workspace_id': workspace_id, 'sync_type': request.query_params.get('sync_type'), - 'id': knowledge_id, + 'knowledge_id': knowledge_id, 'user_id': str(request.user.id) } ).sync()) + class HitTest(APIView): + authentication_classes = [TokenAuth] + + @extend_schema( + methods=['PUT'], + summary=_('Hit test list'), + description=_('Hit test list'), + operation_id=_('Hit test list'), + parameters=HitTestAPI.get_parameters(), + request=HitTestAPI.get_request(), + responses=HitTestAPI.get_response(), + tags=[_('Knowledge Base')] + ) + @has_permissions(PermissionConstants.KNOWLEDGE_EDIT.get_workspace_permission()) + def put(self, request: Request, workspace_id: str, knowledge_id: str): + return result.success(KnowledgeSerializer.HitTest( + data={ + 'workspace_id': workspace_id, + 'knowledge_id': knowledge_id, + 'user_id': request.user.id, + "query_text": request.query_params.get("query_text"), + "top_number": request.query_params.get("top_number"), + 'similarity': request.query_params.get('similarity'), + 'search_mode': request.query_params.get('search_mode') + } + ).hit_test()) + class GenerateRelated(APIView): authentication_classes = [TokenAuth]