diff --git a/apps/application/urls.py b/apps/application/urls.py index 6b9f58b37..b49e9dca4 100644 --- a/apps/application/urls.py +++ b/apps/application/urls.py @@ -34,6 +34,7 @@ urlpatterns = [ path('workspace//application//speech_to_text', views.SpeechToText.as_view()), path('workspace//application//play_demo_text', views.PlayDemoText.as_view()), path('workspace//application//mcp_tools', views.McpServers.as_view()), + path('workspace//application/model//prompt_generate', views.PromptGenerateView.as_view()), path('chat_message/', views.ChatView.as_view()), ] diff --git a/apps/application/views/application_chat.py b/apps/application/views/application_chat.py index 652f97f6e..c05bb65f3 100644 --- a/apps/application/views/application_chat.py +++ b/apps/application/views/application_chat.py @@ -17,9 +17,9 @@ from application.api.application_chat import ApplicationChatQueryAPI, Applicatio ApplicationChatExportAPI from application.models import ChatUserType from application.serializers.application_chat import ApplicationChatQuerySerializers -from chat.api.chat_api import ChatAPI +from chat.api.chat_api import ChatAPI, PromptGenerateAPI from chat.api.chat_authentication_api import ChatOpenAPI -from chat.serializers.chat import OpenChatSerializers, ChatSerializers, DebugChatSerializers +from chat.serializers.chat import OpenChatSerializers, ChatSerializers, DebugChatSerializers, PromptGenerateSerializer from common.auth import TokenAuth from common.auth.authentication import has_permissions from common.constants.permission_constants import PermissionConstants, RoleConstants, ViewPermission, CompareConstants @@ -144,3 +144,18 @@ class ChatView(APIView): ) def post(self, request: Request, chat_id: str): return DebugChatSerializers(data={'chat_id': chat_id}).chat(request.data) + +class PromptGenerateView(APIView): + + @extend_schema( + methods=['POST'], + description=_("generate prompt"), + summary=_("generate prompt"), + operation_id=_("generate prompt"), # type: ignore + request=PromptGenerateAPI.get_request(), + parameters=PromptGenerateAPI.get_parameters(), + responses=None, + tags=[_('Application')] # type: ignore + ) + def post(self, request: Request, workspace_id: str, model_id:str): + return PromptGenerateSerializer(data={'workspace_id': workspace_id, 'model_id': model_id}).generate_prompt(instance=request.data) \ No newline at end of file diff --git a/apps/chat/api/chat_api.py b/apps/chat/api/chat_api.py index 974846045..b8b1f0e9f 100644 --- a/apps/chat/api/chat_api.py +++ b/apps/chat/api/chat_api.py @@ -10,12 +10,35 @@ from drf_spectacular.types import OpenApiTypes from drf_spectacular.utils import OpenApiParameter from application.serializers.application_chat_record import ChatRecordSerializerModel -from chat.serializers.chat import ChatMessageSerializers +from chat.serializers.chat import ChatMessageSerializers, GeneratePromptSerializers from chat.serializers.chat_record import HistoryChatModel, EditAbstractSerializer from common.mixins.api_mixin import APIMixin from common.result import ResultSerializer, ResultPageSerializer, DefaultResultSerializer +class PromptGenerateAPI(APIMixin): + @staticmethod + def get_parameters(): + return [OpenApiParameter( + name="workspace_id", + description="工作空间id", + type=OpenApiTypes.STR, + location='path', + required=True, + ), + OpenApiParameter( + name="model_id", + description="模型id", + type=OpenApiTypes.STR, + location='path', + required=True,) + ] + + @staticmethod + def get_request(): + return GeneratePromptSerializers + + class ChatAPI(APIMixin): @staticmethod def get_parameters(): diff --git a/apps/chat/serializers/chat.py b/apps/chat/serializers/chat.py index abbe6e527..1a703a41c 100644 --- a/apps/chat/serializers/chat.py +++ b/apps/chat/serializers/chat.py @@ -6,13 +6,14 @@ @date:2025/6/9 11:23 @desc: """ - +import json from gettext import gettext from typing import List, Dict import uuid_utils.compat as uuid from django.db.models import QuerySet from django.utils.translation import gettext_lazy as _ +from langchain_core.messages import HumanMessage, AIMessage from rest_framework import serializers from application.chat_pipeline.pipeline_manage import PipelineManage @@ -24,6 +25,7 @@ from application.chat_pipeline.step.reset_problem_step.impl.base_reset_problem_s from application.chat_pipeline.step.search_dataset_step.impl.base_search_dataset_step import BaseSearchDatasetStep from application.flow.common import Answer, Workflow from application.flow.i_step_node import WorkFlowPostHandler +from application.flow.tools import to_stream_response_simple from application.flow.workflow_manage import WorkflowManage from application.models import Application, ApplicationTypeChoices, ApplicationKnowledgeMapping, \ ChatUserType, ApplicationChatUserStats, ApplicationAccessToken, ChatRecord, Chat, ApplicationVersion @@ -37,8 +39,34 @@ from common.handle.impl.response.system_to_response import SystemToResponse from common.utils.common import flat_map from knowledge.models import Document, Paragraph from models_provider.models import Model, Status +from models_provider.tools import get_model_instance_by_model_workspace_id +class ChatMessagesSerializers(serializers.Serializer): + role = serializers.CharField(required=True, label=_("Role")) + content = serializers.CharField(required=True, label=_("Content")) + + +class GeneratePromptSerializers(serializers.Serializer): + prompt = serializers.CharField(required=True, label=_("Prompt template")) + messages = serializers.ListSerializer(child=ChatMessagesSerializers(), required=True, label=_("Chat context")) + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + messages = self.data.get("messages") + + if len(messages) > 30: + raise AppApiException(400, _("Too many messages")) + + for index in range(len(messages)): + role = messages[index].get('role') + if role == 'ai' and index % 2 != 1: + raise AppApiException(400, _("Authentication failed. Please verify that the parameters are correct.")) + if role == 'user' and index % 2 != 0: + raise AppApiException(400, _("Authentication failed. Please verify that the parameters are correct.")) + if role not in ['user', 'ai']: + raise AppApiException(400, _("Authentication failed. Please verify that the parameters are correct.")) + class ChatMessageSerializers(serializers.Serializer): message = serializers.CharField(required=True, label=_("User Questions")) stream = serializers.BooleanField(required=True, @@ -113,6 +141,37 @@ class DebugChatSerializers(serializers.Serializer): }).chat(instance, base_to_response) +class PromptGenerateSerializer(serializers.Serializer): + workspace_id = serializers.CharField(required=False, label=_('Workspace ID')) + model_id = serializers.CharField(required=False, allow_blank=True, allow_null=True, label=_("Model")) + + def generate_prompt(self, instance: dict, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + GeneratePromptSerializers(data=instance).is_valid(raise_exception=True) + workspace_id = self.data.get('workspace_id') + model_id = self.data.get('model_id') + prompt = instance.get('prompt') + messages = instance.get('messages') + + message = messages[-1]['content'] + q = prompt.replace("{userInput}", message) + messages[-1]['content'] = q + + model_exist = QuerySet(Model).filter(workspace_id=workspace_id, id=model_id).exists() + if not model_exist: + raise Exception(_("model does not exists")) + + def process(): + model = get_model_instance_by_model_workspace_id(model_id=model_id, workspace_id=workspace_id) + + for r in model.stream([HumanMessage(content=m.get('content')) if m.get('role') == 'user' else AIMessage( + content=m.get('content')) for m in messages]): + yield 'data: ' + json.dumps({'content': r.content}) + '\n\n' + + return to_stream_response_simple(process()) + + class OpenAIMessage(serializers.Serializer): content = serializers.CharField(required=True, label=_('content')) role = serializers.CharField(required=True, label=_('Role')) diff --git a/apps/locales/en_US/LC_MESSAGES/django.po b/apps/locales/en_US/LC_MESSAGES/django.po index e642aa167..2d8c7ee74 100644 --- a/apps/locales/en_US/LC_MESSAGES/django.po +++ b/apps/locales/en_US/LC_MESSAGES/django.po @@ -8672,4 +8672,16 @@ msgid "System resources authorization" msgstr "" msgid "This folder contains resources that you dont have permission" +msgstr "" + +msgid "Authentication failed. Please verify that the parameters are correct" +msgstr "" + +msgid "Chat context" +msgstr "" + +msgid "Prompt template" +msgstr "" + +msgid "generate prompt" msgstr "" \ No newline at end of file diff --git a/apps/locales/zh_CN/LC_MESSAGES/django.po b/apps/locales/zh_CN/LC_MESSAGES/django.po index 752687ea3..a01eb7ca6 100644 --- a/apps/locales/zh_CN/LC_MESSAGES/django.po +++ b/apps/locales/zh_CN/LC_MESSAGES/django.po @@ -8799,3 +8799,15 @@ msgstr "系统资源授权" msgid "This folder contains resources that you dont have permission" msgstr "此文件夹包含您没有权限的资源" + +msgid "Authentication failed. Please verify that the parameters are correct" +msgstr "认证失败,请检查参数是否正确" + +msgid "Chat context" +msgstr "聊天上下文" + +msgid "Prompt template" +msgstr "提示词模板" + +msgid "generate prompt" +msgstr "生成提示词" \ No newline at end of file diff --git a/apps/locales/zh_Hant/LC_MESSAGES/django.po b/apps/locales/zh_Hant/LC_MESSAGES/django.po index acedc0a24..893d7e4ad 100644 --- a/apps/locales/zh_Hant/LC_MESSAGES/django.po +++ b/apps/locales/zh_Hant/LC_MESSAGES/django.po @@ -8798,4 +8798,16 @@ msgid "System resources authorization" msgstr "系統資源授權" msgid "This folder contains resources that you dont have permission" -msgstr "此資料夾包含您沒有許可權的資源" \ No newline at end of file +msgstr "此資料夾包含您沒有許可權的資源" + +msgid "Authentication failed. Please verify that the parameters are correct" +msgstr "認證失敗,請檢查參數是否正確" + +msgid "Chat context" +msgstr "聊天上下文" + +msgid "Prompt template" +msgstr "提示詞範本" + +msgid "generate prompt" +msgstr "生成提示詞" \ No newline at end of file