diff --git a/apps/common/swagger_api/common_api.py b/apps/common/swagger_api/common_api.py index 3134db0d0..9e7d19762 100644 --- a/apps/common/swagger_api/common_api.py +++ b/apps/common/swagger_api/common_api.py @@ -15,33 +15,21 @@ from django.utils.translation import gettext_lazy as _ class CommonApi: class HitTestApi(ApiMixin): @staticmethod - def get_request_params_api(): - return [ - openapi.Parameter(name='query_text', - in_=openapi.IN_QUERY, - type=openapi.TYPE_STRING, - required=True, - description=_('query text')), - openapi.Parameter(name='top_number', - in_=openapi.IN_QUERY, - type=openapi.TYPE_NUMBER, - default=10, - required=True, - description='topN'), - openapi.Parameter(name='similarity', - in_=openapi.IN_QUERY, - type=openapi.TYPE_NUMBER, - default=0.6, - required=True, - description=_('similarity')), - openapi.Parameter(name='search_mode', - in_=openapi.IN_QUERY, - type=openapi.TYPE_STRING, - default="embedding", - required=True, - description=_('Retrieval pattern embedding|keywords|blend') - ) - ] + def get_request_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['query_text', 'top_number', 'similarity', 'search_mode'], + properties={ + 'query_text': openapi.Schema(type=openapi.TYPE_STRING, title=_('query text'), + description=_('query text')), + 'top_number': openapi.Schema(type=openapi.TYPE_NUMBER, title=_('top number'), + description=_('top number')), + 'similarity': openapi.Schema(type=openapi.TYPE_NUMBER, title=_('similarity'), + description=_('similarity')), + 'search_mode': openapi.Schema(type=openapi.TYPE_STRING, title=_('search mode'), + description=_('search mode')) + } + ) @staticmethod def get_response_body_api(): diff --git a/apps/dataset/views/dataset.py b/apps/dataset/views/dataset.py index ad28bc198..aeb1af289 100644 --- a/apps/dataset/views/dataset.py +++ b/apps/dataset/views/dataset.py @@ -142,7 +142,7 @@ class Dataset(APIView): @action(methods="PUT", detail=False) @swagger_auto_schema(operation_summary=_('Hit test list'), operation_id=_('Hit test list'), - manual_parameters=CommonApi.HitTestApi.get_request_params_api(), + request_body=CommonApi.HitTestApi.get_request_body_api(), responses=result.get_api_array_response(CommonApi.HitTestApi.get_response_body_api()), tags=[_('Knowledge Base')]) @has_permissions(lambda r, keywords: Permission(group=Group.DATASET, operate=Operate.USE,