mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
feat: Generate Prompt
This commit is contained in:
parent
ec5c076557
commit
ff6714e505
|
|
@ -34,6 +34,7 @@ urlpatterns = [
|
|||
path('workspace/<str:workspace_id>/application/<str:application_id>/speech_to_text', views.SpeechToText.as_view()),
|
||||
path('workspace/<str:workspace_id>/application/<str:application_id>/play_demo_text', views.PlayDemoText.as_view()),
|
||||
path('workspace/<str:workspace_id>/application/<str:application_id>/mcp_tools', views.McpServers.as_view()),
|
||||
path('workspace/<str:workspace_id>/application/model/<str:model_id>/prompt_generate', views.PromptGenerateView.as_view()),
|
||||
path('chat_message/<str:chat_id>', views.ChatView.as_view()),
|
||||
|
||||
]
|
||||
|
|
|
|||
|
|
@ -17,9 +17,9 @@ from application.api.application_chat import ApplicationChatQueryAPI, Applicatio
|
|||
ApplicationChatExportAPI
|
||||
from application.models import ChatUserType
|
||||
from application.serializers.application_chat import ApplicationChatQuerySerializers
|
||||
from chat.api.chat_api import ChatAPI
|
||||
from chat.api.chat_api import ChatAPI, PromptGenerateAPI
|
||||
from chat.api.chat_authentication_api import ChatOpenAPI
|
||||
from chat.serializers.chat import OpenChatSerializers, ChatSerializers, DebugChatSerializers
|
||||
from chat.serializers.chat import OpenChatSerializers, ChatSerializers, DebugChatSerializers, PromptGenerateSerializer
|
||||
from common.auth import TokenAuth
|
||||
from common.auth.authentication import has_permissions
|
||||
from common.constants.permission_constants import PermissionConstants, RoleConstants, ViewPermission, CompareConstants
|
||||
|
|
@ -144,3 +144,18 @@ class ChatView(APIView):
|
|||
)
|
||||
def post(self, request: Request, chat_id: str):
|
||||
return DebugChatSerializers(data={'chat_id': chat_id}).chat(request.data)
|
||||
|
||||
class PromptGenerateView(APIView):
|
||||
|
||||
@extend_schema(
|
||||
methods=['POST'],
|
||||
description=_("generate prompt"),
|
||||
summary=_("generate prompt"),
|
||||
operation_id=_("generate prompt"), # type: ignore
|
||||
request=PromptGenerateAPI.get_request(),
|
||||
parameters=PromptGenerateAPI.get_parameters(),
|
||||
responses=None,
|
||||
tags=[_('Application')] # type: ignore
|
||||
)
|
||||
def post(self, request: Request, workspace_id: str, model_id:str):
|
||||
return PromptGenerateSerializer(data={'workspace_id': workspace_id, 'model_id': model_id}).generate_prompt(instance=request.data)
|
||||
|
|
@ -10,12 +10,35 @@ from drf_spectacular.types import OpenApiTypes
|
|||
from drf_spectacular.utils import OpenApiParameter
|
||||
|
||||
from application.serializers.application_chat_record import ChatRecordSerializerModel
|
||||
from chat.serializers.chat import ChatMessageSerializers
|
||||
from chat.serializers.chat import ChatMessageSerializers, GeneratePromptSerializers
|
||||
from chat.serializers.chat_record import HistoryChatModel, EditAbstractSerializer
|
||||
from common.mixins.api_mixin import APIMixin
|
||||
from common.result import ResultSerializer, ResultPageSerializer, DefaultResultSerializer
|
||||
|
||||
|
||||
class PromptGenerateAPI(APIMixin):
|
||||
@staticmethod
|
||||
def get_parameters():
|
||||
return [OpenApiParameter(
|
||||
name="workspace_id",
|
||||
description="工作空间id",
|
||||
type=OpenApiTypes.STR,
|
||||
location='path',
|
||||
required=True,
|
||||
),
|
||||
OpenApiParameter(
|
||||
name="model_id",
|
||||
description="模型id",
|
||||
type=OpenApiTypes.STR,
|
||||
location='path',
|
||||
required=True,)
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_request():
|
||||
return GeneratePromptSerializers
|
||||
|
||||
|
||||
class ChatAPI(APIMixin):
|
||||
@staticmethod
|
||||
def get_parameters():
|
||||
|
|
|
|||
|
|
@ -6,13 +6,14 @@
|
|||
@date:2025/6/9 11:23
|
||||
@desc:
|
||||
"""
|
||||
|
||||
import json
|
||||
from gettext import gettext
|
||||
from typing import List, Dict
|
||||
|
||||
import uuid_utils.compat as uuid
|
||||
from django.db.models import QuerySet
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
from rest_framework import serializers
|
||||
|
||||
from application.chat_pipeline.pipeline_manage import PipelineManage
|
||||
|
|
@ -24,6 +25,7 @@ from application.chat_pipeline.step.reset_problem_step.impl.base_reset_problem_s
|
|||
from application.chat_pipeline.step.search_dataset_step.impl.base_search_dataset_step import BaseSearchDatasetStep
|
||||
from application.flow.common import Answer, Workflow
|
||||
from application.flow.i_step_node import WorkFlowPostHandler
|
||||
from application.flow.tools import to_stream_response_simple
|
||||
from application.flow.workflow_manage import WorkflowManage
|
||||
from application.models import Application, ApplicationTypeChoices, ApplicationKnowledgeMapping, \
|
||||
ChatUserType, ApplicationChatUserStats, ApplicationAccessToken, ChatRecord, Chat, ApplicationVersion
|
||||
|
|
@ -37,8 +39,34 @@ from common.handle.impl.response.system_to_response import SystemToResponse
|
|||
from common.utils.common import flat_map
|
||||
from knowledge.models import Document, Paragraph
|
||||
from models_provider.models import Model, Status
|
||||
from models_provider.tools import get_model_instance_by_model_workspace_id
|
||||
|
||||
|
||||
class ChatMessagesSerializers(serializers.Serializer):
|
||||
role = serializers.CharField(required=True, label=_("Role"))
|
||||
content = serializers.CharField(required=True, label=_("Content"))
|
||||
|
||||
|
||||
class GeneratePromptSerializers(serializers.Serializer):
|
||||
prompt = serializers.CharField(required=True, label=_("Prompt template"))
|
||||
messages = serializers.ListSerializer(child=ChatMessagesSerializers(), required=True, label=_("Chat context"))
|
||||
|
||||
def is_valid(self, *, raise_exception=False):
|
||||
super().is_valid(raise_exception=True)
|
||||
messages = self.data.get("messages")
|
||||
|
||||
if len(messages) > 30:
|
||||
raise AppApiException(400, _("Too many messages"))
|
||||
|
||||
for index in range(len(messages)):
|
||||
role = messages[index].get('role')
|
||||
if role == 'ai' and index % 2 != 1:
|
||||
raise AppApiException(400, _("Authentication failed. Please verify that the parameters are correct."))
|
||||
if role == 'user' and index % 2 != 0:
|
||||
raise AppApiException(400, _("Authentication failed. Please verify that the parameters are correct."))
|
||||
if role not in ['user', 'ai']:
|
||||
raise AppApiException(400, _("Authentication failed. Please verify that the parameters are correct."))
|
||||
|
||||
class ChatMessageSerializers(serializers.Serializer):
|
||||
message = serializers.CharField(required=True, label=_("User Questions"))
|
||||
stream = serializers.BooleanField(required=True,
|
||||
|
|
@ -113,6 +141,37 @@ class DebugChatSerializers(serializers.Serializer):
|
|||
}).chat(instance, base_to_response)
|
||||
|
||||
|
||||
class PromptGenerateSerializer(serializers.Serializer):
|
||||
workspace_id = serializers.CharField(required=False, label=_('Workspace ID'))
|
||||
model_id = serializers.CharField(required=False, allow_blank=True, allow_null=True, label=_("Model"))
|
||||
|
||||
def generate_prompt(self, instance: dict, with_valid=True):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
GeneratePromptSerializers(data=instance).is_valid(raise_exception=True)
|
||||
workspace_id = self.data.get('workspace_id')
|
||||
model_id = self.data.get('model_id')
|
||||
prompt = instance.get('prompt')
|
||||
messages = instance.get('messages')
|
||||
|
||||
message = messages[-1]['content']
|
||||
q = prompt.replace("{userInput}", message)
|
||||
messages[-1]['content'] = q
|
||||
|
||||
model_exist = QuerySet(Model).filter(workspace_id=workspace_id, id=model_id).exists()
|
||||
if not model_exist:
|
||||
raise Exception(_("model does not exists"))
|
||||
|
||||
def process():
|
||||
model = get_model_instance_by_model_workspace_id(model_id=model_id, workspace_id=workspace_id)
|
||||
|
||||
for r in model.stream([HumanMessage(content=m.get('content')) if m.get('role') == 'user' else AIMessage(
|
||||
content=m.get('content')) for m in messages]):
|
||||
yield 'data: ' + json.dumps({'content': r.content}) + '\n\n'
|
||||
|
||||
return to_stream_response_simple(process())
|
||||
|
||||
|
||||
class OpenAIMessage(serializers.Serializer):
|
||||
content = serializers.CharField(required=True, label=_('content'))
|
||||
role = serializers.CharField(required=True, label=_('Role'))
|
||||
|
|
|
|||
|
|
@ -8672,4 +8672,16 @@ msgid "System resources authorization"
|
|||
msgstr ""
|
||||
|
||||
msgid "This folder contains resources that you dont have permission"
|
||||
msgstr ""
|
||||
|
||||
msgid "Authentication failed. Please verify that the parameters are correct"
|
||||
msgstr ""
|
||||
|
||||
msgid "Chat context"
|
||||
msgstr ""
|
||||
|
||||
msgid "Prompt template"
|
||||
msgstr ""
|
||||
|
||||
msgid "generate prompt"
|
||||
msgstr ""
|
||||
|
|
@ -8799,3 +8799,15 @@ msgstr "系统资源授权"
|
|||
|
||||
msgid "This folder contains resources that you dont have permission"
|
||||
msgstr "此文件夹包含您没有权限的资源"
|
||||
|
||||
msgid "Authentication failed. Please verify that the parameters are correct"
|
||||
msgstr "认证失败,请检查参数是否正确"
|
||||
|
||||
msgid "Chat context"
|
||||
msgstr "聊天上下文"
|
||||
|
||||
msgid "Prompt template"
|
||||
msgstr "提示词模板"
|
||||
|
||||
msgid "generate prompt"
|
||||
msgstr "生成提示词"
|
||||
|
|
@ -8798,4 +8798,16 @@ msgid "System resources authorization"
|
|||
msgstr "系統資源授權"
|
||||
|
||||
msgid "This folder contains resources that you dont have permission"
|
||||
msgstr "此資料夾包含您沒有許可權的資源"
|
||||
msgstr "此資料夾包含您沒有許可權的資源"
|
||||
|
||||
msgid "Authentication failed. Please verify that the parameters are correct"
|
||||
msgstr "認證失敗,請檢查參數是否正確"
|
||||
|
||||
msgid "Chat context"
|
||||
msgstr "聊天上下文"
|
||||
|
||||
msgid "Prompt template"
|
||||
msgstr "提示詞範本"
|
||||
|
||||
msgid "generate prompt"
|
||||
msgstr "生成提示詞"
|
||||
Loading…
Reference in New Issue