diff --git a/apps/application/urls.py b/apps/application/urls.py index b49e9dca4..fb6a8ac3d 100644 --- a/apps/application/urls.py +++ b/apps/application/urls.py @@ -34,7 +34,7 @@ urlpatterns = [ path('workspace//application//speech_to_text', views.SpeechToText.as_view()), path('workspace//application//play_demo_text', views.PlayDemoText.as_view()), path('workspace//application//mcp_tools', views.McpServers.as_view()), - path('workspace//application/model//prompt_generate', views.PromptGenerateView.as_view()), + path('workspace//application//model//prompt_generate', views.PromptGenerateView.as_view()), path('chat_message/', views.ChatView.as_view()), ] diff --git a/apps/application/views/application_chat.py b/apps/application/views/application_chat.py index c05bb65f3..c9175bebb 100644 --- a/apps/application/views/application_chat.py +++ b/apps/application/views/application_chat.py @@ -7,6 +7,7 @@ @desc: """ import uuid_utils.compat as uuid +from django.db.models import QuerySet from django.utils.translation import gettext_lazy as _ from drf_spectacular.utils import extend_schema @@ -15,7 +16,7 @@ from rest_framework.views import APIView from application.api.application_chat import ApplicationChatQueryAPI, ApplicationChatQueryPageAPI, \ ApplicationChatExportAPI -from application.models import ChatUserType +from application.models import ChatUserType, Application from application.serializers.application_chat import ApplicationChatQuerySerializers from chat.api.chat_api import ChatAPI, PromptGenerateAPI from chat.api.chat_authentication_api import ChatOpenAPI @@ -23,9 +24,17 @@ from chat.serializers.chat import OpenChatSerializers, ChatSerializers, DebugCha from common.auth import TokenAuth from common.auth.authentication import has_permissions from common.constants.permission_constants import PermissionConstants, RoleConstants, ViewPermission, CompareConstants +from common.log.log import log from common.result import result from common.utils.common import query_params_to_single_dict +def get_application_operation_object(application_id): + application_model = QuerySet(model=Application).filter(id=application_id).first() + if application_model is not None: + return { + 'name': application_model.name + } + return {} class ApplicationChat(APIView): authentication_classes = [TokenAuth] @@ -146,6 +155,7 @@ class ChatView(APIView): return DebugChatSerializers(data={'chat_id': chat_id}).chat(request.data) class PromptGenerateView(APIView): + authentication_classes = [TokenAuth] @extend_schema( methods=['POST'], @@ -157,5 +167,13 @@ class PromptGenerateView(APIView): responses=None, tags=[_('Application')] # type: ignore ) - def post(self, request: Request, workspace_id: str, model_id:str): - return PromptGenerateSerializer(data={'workspace_id': workspace_id, 'model_id': model_id}).generate_prompt(instance=request.data) \ No newline at end of file + @has_permissions(PermissionConstants.APPLICATION_EDIT.get_workspace_application_permission(), + PermissionConstants.APPLICATION_EDIT.get_workspace_permission_workspace_manage_role(), + ViewPermission([RoleConstants.USER.get_workspace_role()], + [PermissionConstants.APPLICATION.get_workspace_application_permission()], + CompareConstants.AND), + RoleConstants.WORKSPACE_MANAGE.get_workspace_role()) + @log(menu='Application', operate='Generate prompt', + get_operation_object=lambda r, k: get_application_operation_object(k.get('application_id'))) + def post(self, request: Request, workspace_id: str, model_id:str, application_id: str): + return PromptGenerateSerializer(data={'workspace_id': workspace_id, 'model_id': model_id, 'application_id': application_id}).generate_prompt(instance=request.data) \ No newline at end of file diff --git a/apps/chat/api/chat_api.py b/apps/chat/api/chat_api.py index b8b1f0e9f..c3cb5c028 100644 --- a/apps/chat/api/chat_api.py +++ b/apps/chat/api/chat_api.py @@ -31,7 +31,15 @@ class PromptGenerateAPI(APIMixin): description="模型id", type=OpenApiTypes.STR, location='path', - required=True,) + required=True, + ), + OpenApiParameter( + name="application_id", + description="应用id", + type=OpenApiTypes.STR, + location='path', + required=True, + ), ] @staticmethod diff --git a/apps/chat/serializers/chat.py b/apps/chat/serializers/chat.py index 47c658ee9..e61b15e88 100644 --- a/apps/chat/serializers/chat.py +++ b/apps/chat/serializers/chat.py @@ -144,6 +144,7 @@ class DebugChatSerializers(serializers.Serializer): class PromptGenerateSerializer(serializers.Serializer): workspace_id = serializers.CharField(required=False, label=_('Workspace ID')) model_id = serializers.CharField(required=False, allow_blank=True, allow_null=True, label=_("Model")) + application_id = serializers.CharField(required=False, allow_blank=True, allow_null=True, label=_("Application")) def generate_prompt(self, instance: dict, with_valid=True): if with_valid: