diff --git a/apps/application/flow/workflow_manage.py b/apps/application/flow/workflow_manage.py index 314cfc787..88957d040 100644 --- a/apps/application/flow/workflow_manage.py +++ b/apps/application/flow/workflow_manage.py @@ -53,11 +53,12 @@ class Flow: class WorkflowManage: - def __init__(self, flow: Flow, params): + def __init__(self, flow: Flow, params, work_flow_post_handler: WorkFlowPostHandler): self.params = params self.flow = flow self.context = {} self.node_context = [] + self.work_flow_post_handler = work_flow_post_handler self.current_node = None self.current_result = None @@ -74,12 +75,8 @@ class WorkflowManage: else: r = self.current_result.to_response(self.params['chat_id'], self.params['chat_record_id'], self.current_node, self, - WorkFlowPostHandler(client_id=self.params['client_id'], - chat_info=None, - client_type='APPLICATION_ACCESS_TOKEN')) - for row in r: - print(row) - print(self) + self.work_flow_post_handler) + return r def has_next_node(self, node_result: NodeResult | None): """ diff --git a/apps/application/models/application.py b/apps/application/models/application.py index 073e980c7..dd8967718 100644 --- a/apps/application/models/application.py +++ b/apps/application/models/application.py @@ -18,6 +18,12 @@ from setting.models.model_management import Model from users.models import User +class ApplicationTypeChoices(models.TextChoices): + """订单类型""" + SIMPLE = 'SIMPLE', '简易' + WORK_FLOW = 'WORK_FLOW', '工作流' + + def get_dataset_setting_dict(): return {'top_n': 3, 'similarity': 0.6, 'max_paragraph_char_number': 5000, 'search_mode': 'embedding', 'no_references_setting': { @@ -42,6 +48,9 @@ class Application(AppModelMixin): model_setting = models.JSONField(verbose_name="模型参数相关设置", default=get_model_setting_dict) problem_optimization = models.BooleanField(verbose_name="问题优化", default=False) icon = models.CharField(max_length=256, verbose_name="应用icon", default="/ui/favicon.ico") + work_flow = models.JSONField(verbose_name="工作流数据", default={}) + type = models.CharField(verbose_name="应用类型", choices=ApplicationTypeChoices, + default=ApplicationTypeChoices.SIMPLE) @staticmethod def get_default_model_prompt(): diff --git a/apps/application/serializers/chat_message_serializers.py b/apps/application/serializers/chat_message_serializers.py index cc658f71d..eb3f0bac4 100644 --- a/apps/application/serializers/chat_message_serializers.py +++ b/apps/application/serializers/chat_message_serializers.py @@ -7,6 +7,7 @@ @desc: """ import json +import uuid from typing import List from uuid import UUID @@ -22,7 +23,9 @@ from application.chat_pipeline.step.generate_human_message_step.impl.base_genera BaseGenerateHumanMessageStep from application.chat_pipeline.step.reset_problem_step.impl.base_reset_problem_step import BaseResetProblemStep from application.chat_pipeline.step.search_dataset_step.impl.base_search_dataset_step import BaseSearchDatasetStep -from application.models import ChatRecord, Chat, Application, ApplicationDatasetMapping +from application.flow.i_step_node import WorkFlowPostHandler +from application.flow.workflow_manage import WorkflowManage, Flow +from application.models import ChatRecord, Chat, Application, ApplicationDatasetMapping, ApplicationTypeChoices from application.models.api_key_model import ApplicationPublicAccessClient, ApplicationAccessToken from common.constants.authentication_type import AuthenticationType from common.exception.app_exception import AppApiException, AppChatNumOutOfBoundsFailed @@ -146,8 +149,17 @@ class ChatMessageSerializer(serializers.Serializer): client_id = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端id")) client_type = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端类型")) - def is_valid(self, *, raise_exception=False): - super().is_valid(raise_exception=True) + def is_valid_application_workflow(self, *, raise_exception=False): + self.is_valid_intraday_access_num() + chat_id = self.data.get('chat_id') + chat_info: ChatInfo = chat_cache.get(chat_id) + if chat_info is None: + chat_info = self.re_open_chat(chat_id) + chat_cache.set(chat_id, + chat_info, timeout=60 * 30) + return chat_info + + def is_valid_intraday_access_num(self): if self.data.get('client_type') == AuthenticationType.APPLICATION_ACCESS_TOKEN.value: access_client = QuerySet(ApplicationPublicAccessClient).filter(id=self.data.get('client_id')).first() if access_client is None: @@ -161,6 +173,9 @@ class ChatMessageSerializer(serializers.Serializer): application_id=self.data.get('application_id')).first() if application_access_token.access_num <= access_client.intraday_access_num: raise AppChatNumOutOfBoundsFailed(1002, "访问次数超过今日访问量") + + def is_valid_application_simple(self, *, raise_exception=False): + self.is_valid_intraday_access_num() chat_id = self.data.get('chat_id') chat_info: ChatInfo = chat_cache.get(chat_id) if chat_info is None: @@ -179,8 +194,7 @@ class ChatMessageSerializer(serializers.Serializer): raise AppApiException(500, "模型正在下载中,请稍后再发起对话") return chat_info - def chat(self): - chat_info = self.is_valid(raise_exception=True) + def chat_simple(self, chat_info): message = self.data.get('message') re_chat = self.data.get('re_chat') stream = self.data.get('stream') @@ -211,6 +225,31 @@ class ChatMessageSerializer(serializers.Serializer): pipeline_message.run(params) return pipeline_message.context['chat_result'] + def chat_work_flow(self, chat_info: ChatInfo): + message = self.data.get('message') + re_chat = self.data.get('re_chat') + stream = self.data.get('stream') + client_id = self.data.get('client_id') + client_type = self.data.get('client_type') + work_flow_manage = WorkflowManage(Flow.new_instance(json.loads(chat_info.application.work_flow)), + {'history_chat_record': chat_info.chat_record_list, 'question': message, + 'chat_id': chat_info.chat_id, 'chat_record_id': str(uuid.uuid1()), + 'stream': stream, + 're_chat': re_chat}, WorkFlowPostHandler(chat_info, client_id, client_type)) + r = work_flow_manage.run() + return r + + def chat(self): + super().is_valid(raise_exception=True) + application = QuerySet(Application).filter(self.data.get('application_id')) + if application.type == ApplicationTypeChoices.SIMPLE: + chat_info = self.is_valid_application_simple(raise_exception=True) + return self.chat_simple(chat_info) + + else: + chat_info = self.is_valid_application_workflow(raise_exception=True) + return self.chat_work_flow(chat_info) + @staticmethod def re_open_chat(chat_id: str): chat = QuerySet(Chat).filter(id=chat_id).first()