feat: 工作流

This commit is contained in:
shaohuzhang1 2024-06-14 19:18:07 +08:00
parent 9eaa8a3963
commit 9dbf5f1cd0
3 changed files with 57 additions and 12 deletions

View File

@ -53,11 +53,12 @@ class Flow:
class WorkflowManage:
def __init__(self, flow: Flow, params):
def __init__(self, flow: Flow, params, work_flow_post_handler: WorkFlowPostHandler):
self.params = params
self.flow = flow
self.context = {}
self.node_context = []
self.work_flow_post_handler = work_flow_post_handler
self.current_node = None
self.current_result = None
@ -74,12 +75,8 @@ class WorkflowManage:
else:
r = self.current_result.to_response(self.params['chat_id'], self.params['chat_record_id'],
self.current_node, self,
WorkFlowPostHandler(client_id=self.params['client_id'],
chat_info=None,
client_type='APPLICATION_ACCESS_TOKEN'))
for row in r:
print(row)
print(self)
self.work_flow_post_handler)
return r
def has_next_node(self, node_result: NodeResult | None):
"""

View File

@ -18,6 +18,12 @@ from setting.models.model_management import Model
from users.models import User
class ApplicationTypeChoices(models.TextChoices):
"""订单类型"""
SIMPLE = 'SIMPLE', '简易'
WORK_FLOW = 'WORK_FLOW', '工作流'
def get_dataset_setting_dict():
return {'top_n': 3, 'similarity': 0.6, 'max_paragraph_char_number': 5000, 'search_mode': 'embedding',
'no_references_setting': {
@ -42,6 +48,9 @@ class Application(AppModelMixin):
model_setting = models.JSONField(verbose_name="模型参数相关设置", default=get_model_setting_dict)
problem_optimization = models.BooleanField(verbose_name="问题优化", default=False)
icon = models.CharField(max_length=256, verbose_name="应用icon", default="/ui/favicon.ico")
work_flow = models.JSONField(verbose_name="工作流数据", default={})
type = models.CharField(verbose_name="应用类型", choices=ApplicationTypeChoices,
default=ApplicationTypeChoices.SIMPLE)
@staticmethod
def get_default_model_prompt():

View File

@ -7,6 +7,7 @@
@desc:
"""
import json
import uuid
from typing import List
from uuid import UUID
@ -22,7 +23,9 @@ from application.chat_pipeline.step.generate_human_message_step.impl.base_genera
BaseGenerateHumanMessageStep
from application.chat_pipeline.step.reset_problem_step.impl.base_reset_problem_step import BaseResetProblemStep
from application.chat_pipeline.step.search_dataset_step.impl.base_search_dataset_step import BaseSearchDatasetStep
from application.models import ChatRecord, Chat, Application, ApplicationDatasetMapping
from application.flow.i_step_node import WorkFlowPostHandler
from application.flow.workflow_manage import WorkflowManage, Flow
from application.models import ChatRecord, Chat, Application, ApplicationDatasetMapping, ApplicationTypeChoices
from application.models.api_key_model import ApplicationPublicAccessClient, ApplicationAccessToken
from common.constants.authentication_type import AuthenticationType
from common.exception.app_exception import AppApiException, AppChatNumOutOfBoundsFailed
@ -146,8 +149,17 @@ class ChatMessageSerializer(serializers.Serializer):
client_id = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端id"))
client_type = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端类型"))
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
def is_valid_application_workflow(self, *, raise_exception=False):
self.is_valid_intraday_access_num()
chat_id = self.data.get('chat_id')
chat_info: ChatInfo = chat_cache.get(chat_id)
if chat_info is None:
chat_info = self.re_open_chat(chat_id)
chat_cache.set(chat_id,
chat_info, timeout=60 * 30)
return chat_info
def is_valid_intraday_access_num(self):
if self.data.get('client_type') == AuthenticationType.APPLICATION_ACCESS_TOKEN.value:
access_client = QuerySet(ApplicationPublicAccessClient).filter(id=self.data.get('client_id')).first()
if access_client is None:
@ -161,6 +173,9 @@ class ChatMessageSerializer(serializers.Serializer):
application_id=self.data.get('application_id')).first()
if application_access_token.access_num <= access_client.intraday_access_num:
raise AppChatNumOutOfBoundsFailed(1002, "访问次数超过今日访问量")
def is_valid_application_simple(self, *, raise_exception=False):
self.is_valid_intraday_access_num()
chat_id = self.data.get('chat_id')
chat_info: ChatInfo = chat_cache.get(chat_id)
if chat_info is None:
@ -179,8 +194,7 @@ class ChatMessageSerializer(serializers.Serializer):
raise AppApiException(500, "模型正在下载中,请稍后再发起对话")
return chat_info
def chat(self):
chat_info = self.is_valid(raise_exception=True)
def chat_simple(self, chat_info):
message = self.data.get('message')
re_chat = self.data.get('re_chat')
stream = self.data.get('stream')
@ -211,6 +225,31 @@ class ChatMessageSerializer(serializers.Serializer):
pipeline_message.run(params)
return pipeline_message.context['chat_result']
def chat_work_flow(self, chat_info: ChatInfo):
message = self.data.get('message')
re_chat = self.data.get('re_chat')
stream = self.data.get('stream')
client_id = self.data.get('client_id')
client_type = self.data.get('client_type')
work_flow_manage = WorkflowManage(Flow.new_instance(json.loads(chat_info.application.work_flow)),
{'history_chat_record': chat_info.chat_record_list, 'question': message,
'chat_id': chat_info.chat_id, 'chat_record_id': str(uuid.uuid1()),
'stream': stream,
're_chat': re_chat}, WorkFlowPostHandler(chat_info, client_id, client_type))
r = work_flow_manage.run()
return r
def chat(self):
super().is_valid(raise_exception=True)
application = QuerySet(Application).filter(self.data.get('application_id'))
if application.type == ApplicationTypeChoices.SIMPLE:
chat_info = self.is_valid_application_simple(raise_exception=True)
return self.chat_simple(chat_info)
else:
chat_info = self.is_valid_application_workflow(raise_exception=True)
return self.chat_work_flow(chat_info)
@staticmethod
def re_open_chat(chat_id: str):
chat = QuerySet(Chat).filter(id=chat_id).first()