diff --git a/apps/application/serializers/chat_serializers.py b/apps/application/serializers/chat_serializers.py index b6e43a412..efba6d302 100644 --- a/apps/application/serializers/chat_serializers.py +++ b/apps/application/serializers/chat_serializers.py @@ -46,6 +46,11 @@ from smartdoc.conf import PROJECT_DIR chat_cache = caches['model_cache'] +class WorkFlowSerializers(serializers.Serializer): + nodes = serializers.ListSerializer(child=serializers.DictField(), error_messages=ErrMessage.uuid("节点")) + edges = serializers.ListSerializer(child=serializers.DictField(), error_messages=ErrMessage.uuid("连线")) + + class ChatSerializers(serializers.Serializer): class Operate(serializers.Serializer): chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("对话id")) @@ -251,6 +256,26 @@ class ChatSerializers(serializers.Serializer): application), timeout=60 * 30) return chat_id + class OpenWorkFlowChat(serializers.Serializer): + work_flow = WorkFlowSerializers(error_messages=ErrMessage.uuid("工作流")) + + def open(self): + self.is_valid(raise_exception=True) + work_flow = self.data.get('work_flow') + chat_id = str(uuid.uuid1()) + application = Application(id=None, dialogue_number=3, model=None, + dataset_setting={}, + model_setting={}, + problem_optimization=None, + type=ApplicationTypeChoices.WORK_FLOW + ) + work_flow_version = WorkFlowVersion(work_flow=work_flow) + chat_cache.set(chat_id, + ChatInfo(chat_id, None, [], + [], + application, work_flow_version), timeout=60 * 30) + return chat_id + class OpenTempChat(serializers.Serializer): user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id")) @@ -351,7 +376,7 @@ class ChatRecordSerializer(serializers.Serializer): chat_info: ChatInfo = chat_cache.get(chat_id) if chat_info is not None: chat_record_list = [chat_record for chat_record in chat_info.chat_record_list if - chat_record.id == uuid.UUID(chat_record_id)] + str(chat_record.id) == str(chat_record_id)] if chat_record_list is not None and len(chat_record_list): return chat_record_list[-1] return QuerySet(ChatRecord).filter(id=chat_record_id, chat_id=chat_id).first() diff --git a/apps/application/swagger_api/chat_api.py b/apps/application/swagger_api/chat_api.py index 9c56cd21e..2ff8f8ac5 100644 --- a/apps/application/swagger_api/chat_api.py +++ b/apps/application/swagger_api/chat_api.py @@ -82,6 +82,17 @@ class ChatApi(ApiMixin): ] + class OpenWorkFlowTemp(ApiMixin): + @staticmethod + def get_request_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=[], + properties={ + 'work_flow': ApplicationApi.WorkFlow.get_request_body_api() + } + ) + class OpenTempChat(ApiMixin): @staticmethod def get_request_body_api(): diff --git a/apps/application/urls.py b/apps/application/urls.py index a39f0d47e..335205d37 100644 --- a/apps/application/urls.py +++ b/apps/application/urls.py @@ -31,6 +31,7 @@ urlpatterns = [ path('application//', views.Application.Page.as_view(), name='application_page'), path('application//chat/open', views.ChatView.Open.as_view(), name='application/open'), path("application/chat/open", views.ChatView.OpenTemp.as_view()), + path("application/chat_workflow/open", views.ChatView.OpenWorkFlowTemp.as_view()), path("application//chat/client//", views.ChatView.ClientChatHistoryPage.as_view()), path("application//chat/client/", diff --git a/apps/application/views/chat_views.py b/apps/application/views/chat_views.py index b7f5968f2..577c08838 100644 --- a/apps/application/views/chat_views.py +++ b/apps/application/views/chat_views.py @@ -64,6 +64,18 @@ class ChatView(APIView): return result.success(ChatSerializers.OpenChat( data={'user_id': request.user.id, 'application_id': application_id}).open()) + class OpenWorkFlowTemp(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['POST'], detail=False) + @swagger_auto_schema(operation_summary="获取工作流临时会话id", + operation_id="获取工作流临时会话id", + request_body=ChatApi.OpenWorkFlowTemp.get_request_body_api(), + tags=["应用/会话"]) + def post(self, request: Request): + return result.success(ChatSerializers.OpenWorkFlowChat( + data={**request.data, 'user_id': request.user.id}).open()) + class OpenTemp(APIView): authentication_classes = [TokenAuth]