From a106ff09d011ec344938e74ca2a71dfa23d82961 Mon Sep 17 00:00:00 2001 From: shaohuzhang1 Date: Wed, 6 Mar 2024 18:11:11 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E6=89=93=E5=BC=80=E4=B8=B4=E6=97=B6?= =?UTF-8?q?=E4=BC=9A=E8=AF=9D=E6=97=B6,=E9=9C=80=E8=A6=81=E5=8C=BA?= =?UTF-8?q?=E5=88=86=E5=BA=94=E7=94=A8=E7=94=A8=E6=88=B7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../serializers/chat_serializers.py | 18 +++++++++++++++--- apps/application/swagger_api/chat_api.py | 2 ++ 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/apps/application/serializers/chat_serializers.py b/apps/application/serializers/chat_serializers.py index 3c30da72c..d48e4c899 100644 --- a/apps/application/serializers/chat_serializers.py +++ b/apps/application/serializers/chat_serializers.py @@ -158,6 +158,8 @@ class ChatSerializers(serializers.Serializer): class OpenTempChat(serializers.Serializer): user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id")) + id = serializers.UUIDField(required=False, allow_null=True, + error_messages=ErrMessage.uuid("应用id")) model_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("模型id")) multiple_rounds_dialogue = serializers.BooleanField(required=True, @@ -174,14 +176,24 @@ class ChatSerializers(serializers.Serializer): def is_valid(self, *, raise_exception=False): super().is_valid(raise_exception=True) + user_id = self.get_user_id() ModelDatasetAssociation( - data={'user_id': self.data.get('user_id'), 'model_id': self.data.get('model_id'), + data={'user_id': user_id, 'model_id': self.data.get('model_id'), 'dataset_id_list': self.data.get('dataset_id_list')}).is_valid() + return user_id + + def get_user_id(self): + if 'id' in self.data and self.data.get('id') is not None: + application = QuerySet(Application).filter(id=self.data.get('id')).first() + if application is None: + raise AppApiException(500, "应用不存在") + return application.user_id + return self.data.get('user_id') def open(self): - self.is_valid(raise_exception=True) + user_id = self.is_valid(raise_exception=True) chat_id = str(uuid.uuid1()) - model = QuerySet(Model).filter(user_id=self.data.get('user_id'), id=self.data.get('model_id')).first() + model = QuerySet(Model).filter(user_id=user_id, id=self.data.get('model_id')).first() if model is None: raise AppApiException(500, "模型不存在") dataset_id_list = self.data.get('dataset_id_list') diff --git a/apps/application/swagger_api/chat_api.py b/apps/application/swagger_api/chat_api.py index 8ae2b64cd..29f60c3d0 100644 --- a/apps/application/swagger_api/chat_api.py +++ b/apps/application/swagger_api/chat_api.py @@ -79,6 +79,8 @@ class ChatApi(ApiMixin): required=['model_id', 'multiple_rounds_dialogue', 'dataset_setting', 'model_setting', 'problem_optimization'], properties={ + 'id': openapi.Schema(type=openapi.TYPE_STRING, title="应用id", + description="应用id,修改的时候传,创建的时候不传"), 'model_id': openapi.Schema(type=openapi.TYPE_STRING, title="模型id", description="模型id"), 'dataset_id_list': openapi.Schema(type=openapi.TYPE_ARRAY, items=openapi.Schema(type=openapi.TYPE_STRING),