mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-28 05:42:51 +00:00
feat: 工作流
This commit is contained in:
parent
9eaa8a3963
commit
9dbf5f1cd0
|
|
@ -53,11 +53,12 @@ class Flow:
|
|||
|
||||
|
||||
class WorkflowManage:
|
||||
def __init__(self, flow: Flow, params):
|
||||
def __init__(self, flow: Flow, params, work_flow_post_handler: WorkFlowPostHandler):
|
||||
self.params = params
|
||||
self.flow = flow
|
||||
self.context = {}
|
||||
self.node_context = []
|
||||
self.work_flow_post_handler = work_flow_post_handler
|
||||
self.current_node = None
|
||||
self.current_result = None
|
||||
|
||||
|
|
@ -74,12 +75,8 @@ class WorkflowManage:
|
|||
else:
|
||||
r = self.current_result.to_response(self.params['chat_id'], self.params['chat_record_id'],
|
||||
self.current_node, self,
|
||||
WorkFlowPostHandler(client_id=self.params['client_id'],
|
||||
chat_info=None,
|
||||
client_type='APPLICATION_ACCESS_TOKEN'))
|
||||
for row in r:
|
||||
print(row)
|
||||
print(self)
|
||||
self.work_flow_post_handler)
|
||||
return r
|
||||
|
||||
def has_next_node(self, node_result: NodeResult | None):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -18,6 +18,12 @@ from setting.models.model_management import Model
|
|||
from users.models import User
|
||||
|
||||
|
||||
class ApplicationTypeChoices(models.TextChoices):
|
||||
"""订单类型"""
|
||||
SIMPLE = 'SIMPLE', '简易'
|
||||
WORK_FLOW = 'WORK_FLOW', '工作流'
|
||||
|
||||
|
||||
def get_dataset_setting_dict():
|
||||
return {'top_n': 3, 'similarity': 0.6, 'max_paragraph_char_number': 5000, 'search_mode': 'embedding',
|
||||
'no_references_setting': {
|
||||
|
|
@ -42,6 +48,9 @@ class Application(AppModelMixin):
|
|||
model_setting = models.JSONField(verbose_name="模型参数相关设置", default=get_model_setting_dict)
|
||||
problem_optimization = models.BooleanField(verbose_name="问题优化", default=False)
|
||||
icon = models.CharField(max_length=256, verbose_name="应用icon", default="/ui/favicon.ico")
|
||||
work_flow = models.JSONField(verbose_name="工作流数据", default={})
|
||||
type = models.CharField(verbose_name="应用类型", choices=ApplicationTypeChoices,
|
||||
default=ApplicationTypeChoices.SIMPLE)
|
||||
|
||||
@staticmethod
|
||||
def get_default_model_prompt():
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@
|
|||
@desc:
|
||||
"""
|
||||
import json
|
||||
import uuid
|
||||
from typing import List
|
||||
from uuid import UUID
|
||||
|
||||
|
|
@ -22,7 +23,9 @@ from application.chat_pipeline.step.generate_human_message_step.impl.base_genera
|
|||
BaseGenerateHumanMessageStep
|
||||
from application.chat_pipeline.step.reset_problem_step.impl.base_reset_problem_step import BaseResetProblemStep
|
||||
from application.chat_pipeline.step.search_dataset_step.impl.base_search_dataset_step import BaseSearchDatasetStep
|
||||
from application.models import ChatRecord, Chat, Application, ApplicationDatasetMapping
|
||||
from application.flow.i_step_node import WorkFlowPostHandler
|
||||
from application.flow.workflow_manage import WorkflowManage, Flow
|
||||
from application.models import ChatRecord, Chat, Application, ApplicationDatasetMapping, ApplicationTypeChoices
|
||||
from application.models.api_key_model import ApplicationPublicAccessClient, ApplicationAccessToken
|
||||
from common.constants.authentication_type import AuthenticationType
|
||||
from common.exception.app_exception import AppApiException, AppChatNumOutOfBoundsFailed
|
||||
|
|
@ -146,8 +149,17 @@ class ChatMessageSerializer(serializers.Serializer):
|
|||
client_id = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端id"))
|
||||
client_type = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端类型"))
|
||||
|
||||
def is_valid(self, *, raise_exception=False):
|
||||
super().is_valid(raise_exception=True)
|
||||
def is_valid_application_workflow(self, *, raise_exception=False):
|
||||
self.is_valid_intraday_access_num()
|
||||
chat_id = self.data.get('chat_id')
|
||||
chat_info: ChatInfo = chat_cache.get(chat_id)
|
||||
if chat_info is None:
|
||||
chat_info = self.re_open_chat(chat_id)
|
||||
chat_cache.set(chat_id,
|
||||
chat_info, timeout=60 * 30)
|
||||
return chat_info
|
||||
|
||||
def is_valid_intraday_access_num(self):
|
||||
if self.data.get('client_type') == AuthenticationType.APPLICATION_ACCESS_TOKEN.value:
|
||||
access_client = QuerySet(ApplicationPublicAccessClient).filter(id=self.data.get('client_id')).first()
|
||||
if access_client is None:
|
||||
|
|
@ -161,6 +173,9 @@ class ChatMessageSerializer(serializers.Serializer):
|
|||
application_id=self.data.get('application_id')).first()
|
||||
if application_access_token.access_num <= access_client.intraday_access_num:
|
||||
raise AppChatNumOutOfBoundsFailed(1002, "访问次数超过今日访问量")
|
||||
|
||||
def is_valid_application_simple(self, *, raise_exception=False):
|
||||
self.is_valid_intraday_access_num()
|
||||
chat_id = self.data.get('chat_id')
|
||||
chat_info: ChatInfo = chat_cache.get(chat_id)
|
||||
if chat_info is None:
|
||||
|
|
@ -179,8 +194,7 @@ class ChatMessageSerializer(serializers.Serializer):
|
|||
raise AppApiException(500, "模型正在下载中,请稍后再发起对话")
|
||||
return chat_info
|
||||
|
||||
def chat(self):
|
||||
chat_info = self.is_valid(raise_exception=True)
|
||||
def chat_simple(self, chat_info):
|
||||
message = self.data.get('message')
|
||||
re_chat = self.data.get('re_chat')
|
||||
stream = self.data.get('stream')
|
||||
|
|
@ -211,6 +225,31 @@ class ChatMessageSerializer(serializers.Serializer):
|
|||
pipeline_message.run(params)
|
||||
return pipeline_message.context['chat_result']
|
||||
|
||||
def chat_work_flow(self, chat_info: ChatInfo):
|
||||
message = self.data.get('message')
|
||||
re_chat = self.data.get('re_chat')
|
||||
stream = self.data.get('stream')
|
||||
client_id = self.data.get('client_id')
|
||||
client_type = self.data.get('client_type')
|
||||
work_flow_manage = WorkflowManage(Flow.new_instance(json.loads(chat_info.application.work_flow)),
|
||||
{'history_chat_record': chat_info.chat_record_list, 'question': message,
|
||||
'chat_id': chat_info.chat_id, 'chat_record_id': str(uuid.uuid1()),
|
||||
'stream': stream,
|
||||
're_chat': re_chat}, WorkFlowPostHandler(chat_info, client_id, client_type))
|
||||
r = work_flow_manage.run()
|
||||
return r
|
||||
|
||||
def chat(self):
|
||||
super().is_valid(raise_exception=True)
|
||||
application = QuerySet(Application).filter(self.data.get('application_id'))
|
||||
if application.type == ApplicationTypeChoices.SIMPLE:
|
||||
chat_info = self.is_valid_application_simple(raise_exception=True)
|
||||
return self.chat_simple(chat_info)
|
||||
|
||||
else:
|
||||
chat_info = self.is_valid_application_workflow(raise_exception=True)
|
||||
return self.chat_work_flow(chat_info)
|
||||
|
||||
@staticmethod
|
||||
def re_open_chat(chat_id: str):
|
||||
chat = QuerySet(Chat).filter(id=chat_id).first()
|
||||
|
|
|
|||
Loading…
Reference in New Issue