mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
254 lines
15 KiB
Python
254 lines
15 KiB
Python
# coding=utf-8
|
||
"""
|
||
@project: MaxKB
|
||
@Author:虎虎
|
||
@file: application.py
|
||
@date:2025/5/26 17:03
|
||
@desc:
|
||
"""
|
||
import hashlib
|
||
import re
|
||
from typing import Dict
|
||
|
||
import uuid_utils.compat as uuid
|
||
from django.core import validators
|
||
from django.db import models
|
||
from django.db.models import QuerySet
|
||
from django.utils.translation import gettext_lazy as _
|
||
from rest_framework import serializers
|
||
|
||
from application.models.application import Application, ApplicationTypeChoices, ApplicationKnowledgeMapping
|
||
from application.models.application_access_token import ApplicationAccessToken
|
||
from common.exception.app_exception import AppApiException
|
||
from knowledge.models import Knowledge
|
||
from models_provider.models import Model
|
||
|
||
|
||
class NoReferencesChoices(models.TextChoices):
|
||
"""订单类型"""
|
||
ai_questioning = 'ai_questioning', 'ai回答'
|
||
designated_answer = 'designated_answer', '指定回答'
|
||
|
||
|
||
class NoReferencesSetting(serializers.Serializer):
|
||
status = serializers.ChoiceField(required=True, choices=NoReferencesChoices.choices,
|
||
label=_("No reference status"))
|
||
value = serializers.CharField(required=True, label=_("Prompt word"))
|
||
|
||
|
||
class KnowledgeSettingSerializer(serializers.Serializer):
|
||
top_n = serializers.FloatField(required=True, max_value=10000, min_value=1,
|
||
label=_("Reference segment number"))
|
||
similarity = serializers.FloatField(required=True, max_value=1, min_value=0,
|
||
label=_("Acquaintance"))
|
||
max_paragraph_char_number = serializers.IntegerField(required=True, min_value=500, max_value=100000,
|
||
label=_("Maximum number of quoted characters"))
|
||
search_mode = serializers.CharField(required=True, validators=[
|
||
validators.RegexValidator(regex=re.compile("^embedding|keywords|blend$"),
|
||
message=_("The type only supports embedding|keywords|blend"), code=500)
|
||
], label=_("Retrieval Mode"))
|
||
|
||
no_references_setting = NoReferencesSetting(required=True,
|
||
label=_("Segment settings not referenced"))
|
||
|
||
|
||
class ModelKnowledgeAssociation(serializers.Serializer):
|
||
user_id = serializers.UUIDField(required=True, label=_("User ID"))
|
||
model_id = serializers.CharField(required=False, allow_null=True, allow_blank=True,
|
||
label=_("Model id"))
|
||
Knowledge_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True,
|
||
label=_(
|
||
"Knowledge base id")),
|
||
label=_("Knowledge Base List"))
|
||
|
||
def is_valid(self, *, raise_exception=True):
|
||
super().is_valid(raise_exception=True)
|
||
model_id = self.data.get('model_id')
|
||
user_id = self.data.get('user_id')
|
||
if model_id is not None and len(model_id) > 0:
|
||
if not QuerySet(Model).filter(id=model_id).exists():
|
||
raise AppApiException(500, f'{_("Model does not exist")}【{model_id}】')
|
||
knowledge_id_list = list(set(self.data.get('knowledge_id_list')))
|
||
exist_knowledge_id_list = [str(knowledge.id) for knowledge in
|
||
QuerySet(Knowledge).filter(id__in=knowledge_id_list, user_id=user_id)]
|
||
for knowledge_id in knowledge_id_list:
|
||
if not exist_knowledge_id_list.__contains__(knowledge_id):
|
||
raise AppApiException(500, f'{_("The knowledge base id does not exist")}【{knowledge_id}】')
|
||
|
||
|
||
class ModelSettingSerializer(serializers.Serializer):
|
||
prompt = serializers.CharField(required=False, allow_null=True, allow_blank=True, max_length=102400,
|
||
label=_("Prompt word"))
|
||
system = serializers.CharField(required=False, allow_null=True, allow_blank=True, max_length=102400,
|
||
label=_("Role prompts"))
|
||
no_references_prompt = serializers.CharField(required=True, max_length=102400, allow_null=True, allow_blank=True,
|
||
label=_("No citation segmentation prompt"))
|
||
reasoning_content_enable = serializers.BooleanField(required=False,
|
||
label=_("Thinking process switch"))
|
||
reasoning_content_start = serializers.CharField(required=False, allow_null=True, default="<think>",
|
||
allow_blank=True, max_length=256,
|
||
trim_whitespace=False,
|
||
label=_("The thinking process begins to mark"))
|
||
reasoning_content_end = serializers.CharField(required=False, allow_null=True, allow_blank=True, default="</think>",
|
||
max_length=256,
|
||
trim_whitespace=False,
|
||
label=_("End of thinking process marker"))
|
||
|
||
|
||
class ApplicationCreateSerializer(serializers.Serializer):
|
||
class ApplicationResponse(serializers.ModelSerializer):
|
||
class Meta:
|
||
model = Application
|
||
fields = "__all__"
|
||
|
||
class WorkflowRequest(serializers.Serializer):
|
||
name = serializers.CharField(required=True, max_length=64, min_length=1,
|
||
label=_("Application Name"))
|
||
desc = serializers.CharField(required=False, allow_null=True, allow_blank=True,
|
||
max_length=256, min_length=1,
|
||
label=_("Application Description"))
|
||
work_flow = serializers.DictField(required=True, label=_("Workflow Objects"))
|
||
prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True, max_length=102400,
|
||
label=_("Opening remarks"))
|
||
|
||
@staticmethod
|
||
def to_application_model(user_id: str, workspace_id: str, application: Dict):
|
||
default_workflow = application.get('work_flow')
|
||
for node in default_workflow.get('nodes'):
|
||
if node.get('id') == 'base-node':
|
||
node.get('properties')['node_data']['desc'] = application.get('desc')
|
||
node.get('properties')['node_data']['name'] = application.get('name')
|
||
node.get('properties')['node_data']['prologue'] = application.get('prologue')
|
||
return Application(id=uuid.uuid7(),
|
||
name=application.get('name'),
|
||
desc=application.get('desc'),
|
||
workspace_id=workspace_id,
|
||
prologue="",
|
||
dialogue_number=0,
|
||
user_id=user_id, model_id=None,
|
||
knowledge_setting={},
|
||
model_setting={},
|
||
problem_optimization=False,
|
||
type=ApplicationTypeChoices.WORK_FLOW,
|
||
stt_model_enable=application.get('stt_model_enable', False),
|
||
stt_model_id=application.get('stt_model', None),
|
||
tts_model_id=application.get('tts_model', None),
|
||
tts_model_enable=application.get('tts_model_enable', False),
|
||
tts_model_params_setting=application.get('tts_model_params_setting', {}),
|
||
tts_type=application.get('tts_type', None),
|
||
file_upload_enable=application.get('file_upload_enable', False),
|
||
file_upload_setting=application.get('file_upload_setting', {}),
|
||
work_flow=default_workflow
|
||
)
|
||
|
||
class SimplateRequest(serializers.Serializer):
|
||
name = serializers.CharField(required=True, max_length=64, min_length=1,
|
||
label=_("application name"))
|
||
desc = serializers.CharField(required=False, allow_null=True, allow_blank=True,
|
||
max_length=256, min_length=1,
|
||
label=_("application describe"))
|
||
model_id = serializers.CharField(required=False, allow_null=True, allow_blank=True,
|
||
label=_("Model"))
|
||
dialogue_number = serializers.IntegerField(required=True,
|
||
min_value=0,
|
||
max_value=1024,
|
||
label=_("Historical chat records"))
|
||
prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True, max_length=102400,
|
||
label=_("Opening remarks"))
|
||
knowledge_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True),
|
||
allow_null=True,
|
||
label=_("Related Knowledge Base"))
|
||
# 数据集相关设置
|
||
knowledge_setting = KnowledgeSettingSerializer(required=True)
|
||
# 模型相关设置
|
||
model_setting = ModelSettingSerializer(required=True)
|
||
# 问题补全
|
||
problem_optimization = serializers.BooleanField(required=True,
|
||
label=_("Question completion"))
|
||
problem_optimization_prompt = serializers.CharField(required=False, max_length=102400,
|
||
label=_("Question completion prompt"))
|
||
# 应用类型
|
||
type = serializers.CharField(required=True, label=_("Application Type"),
|
||
validators=[
|
||
validators.RegexValidator(regex=re.compile("^SIMPLE|WORK_FLOW$"),
|
||
message=_(
|
||
"Application type only supports SIMPLE|WORK_FLOW"),
|
||
code=500)
|
||
]
|
||
)
|
||
model_params_setting = serializers.DictField(required=False,
|
||
label=_('Model parameters'))
|
||
|
||
def is_valid(self, *, user_id=None, raise_exception=False):
|
||
super().is_valid(raise_exception=True)
|
||
ModelKnowledgeAssociation(data={'user_id': user_id, 'model_id': self.data.get('model_id'),
|
||
'knowledge_id_list': self.data.get('knowledge_id_list')}).is_valid()
|
||
|
||
@staticmethod
|
||
def to_application_model(user_id: str, application: Dict):
|
||
return Application(id=uuid.uuid1(), name=application.get('name'), desc=application.get('desc'),
|
||
prologue=application.get('prologue'),
|
||
dialogue_number=application.get('dialogue_number', 0),
|
||
user_id=user_id, model_id=application.get('model_id'),
|
||
dataset_setting=application.get('dataset_setting'),
|
||
model_setting=application.get('model_setting'),
|
||
problem_optimization=application.get('problem_optimization'),
|
||
type=ApplicationTypeChoices.SIMPLE,
|
||
model_params_setting=application.get('model_params_setting', {}),
|
||
problem_optimization_prompt=application.get('problem_optimization_prompt', None),
|
||
stt_model_enable=application.get('stt_model_enable', False),
|
||
stt_model_id=application.get('stt_model', None),
|
||
tts_model_id=application.get('tts_model', None),
|
||
tts_model_enable=application.get('tts_model_enable', False),
|
||
tts_model_params_setting=application.get('tts_model_params_setting', {}),
|
||
tts_type=application.get('tts_type', None),
|
||
file_upload_enable=application.get('file_upload_enable', False),
|
||
file_upload_setting=application.get('file_upload_setting', {}),
|
||
work_flow={}
|
||
)
|
||
|
||
|
||
class ApplicationSerializer(serializers.Serializer):
|
||
workspace_id = serializers.CharField(required=True, label=_('workspace id'))
|
||
user_id = serializers.UUIDField(required=True, label=_("User ID"))
|
||
|
||
def insert(self, instance: Dict, with_valid=True):
|
||
application_type = instance.get('type')
|
||
if 'WORK_FLOW' == application_type:
|
||
return self.insert_workflow(instance)
|
||
else:
|
||
return self.insert_simple(instance)
|
||
|
||
def insert_workflow(self, instance: Dict):
|
||
self.is_valid(raise_exception=True)
|
||
user_id = self.data.get('user_id')
|
||
ApplicationCreateSerializer.WorkflowRequest(data=instance).is_valid(raise_exception=True)
|
||
application_model = ApplicationCreateSerializer.WorkflowRequest.to_application_model(user_id, instance)
|
||
application_model.save()
|
||
# 插入认证信息
|
||
ApplicationAccessToken(application_id=application_model.id,
|
||
access_token=hashlib.md5(str(uuid.uuid1()).encode()).hexdigest()[8:24]).save()
|
||
return ApplicationCreateSerializer.ApplicationResponse(application_model).data
|
||
|
||
@staticmethod
|
||
def to_application_knowledge_mapping(application_id: str, dataset_id: str):
|
||
return ApplicationKnowledgeMapping(id=uuid.uuid1(), application_id=application_id, dataset_id=dataset_id)
|
||
|
||
def insert_simple(self, instance: Dict):
|
||
self.is_valid(raise_exception=True)
|
||
user_id = self.data.get('user_id')
|
||
ApplicationCreateSerializer.SimplateRequest(data=instance).is_valid(user_id=user_id, raise_exception=True)
|
||
application_model = ApplicationCreateSerializer.SimplateRequest.to_application_model(user_id, instance)
|
||
dataset_id_list = instance.get('knowledge_id_list', [])
|
||
application_knowledge_mapping_model_list = [
|
||
self.to_application_knowledge_mapping(application_model.id, dataset_id) for
|
||
dataset_id in dataset_id_list]
|
||
# 插入应用
|
||
application_model.save()
|
||
# 插入认证信息
|
||
ApplicationAccessToken(application_id=application_model.id,
|
||
access_token=hashlib.md5(str(uuid.uuid1()).encode()).hexdigest()[8:24]).save()
|
||
# 插入关联数据
|
||
QuerySet(ApplicationKnowledgeMapping).bulk_create(application_knowledge_mapping_model_list)
|
||
return ApplicationCreateSerializer.ApplicationResponse(application_model).data
|