diff --git a/apps/application/chat_pipeline/pipeline_manage.py b/apps/application/chat_pipeline/pipeline_manage.py index 2b94290d2..37d7736b5 100644 --- a/apps/application/chat_pipeline/pipeline_manage.py +++ b/apps/application/chat_pipeline/pipeline_manage.py @@ -13,7 +13,7 @@ from typing import List, Type, Dict from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep -class PiplineManage: +class PipelineManage: def __init__(self, step_list: List[Type[IBaseChatPipelineStep]]): # 步骤执行器 self.step_list = [step() for step in step_list] @@ -42,4 +42,4 @@ class PiplineManage: return self def build(self): - return PiplineManage(step_list=self.step_list) + return PipelineManage(step_list=self.step_list) diff --git a/apps/application/chat_pipeline/step/chat_step/i_chat_step.py b/apps/application/chat_pipeline/step/chat_step/i_chat_step.py index fe4bc1dfd..b3437390b 100644 --- a/apps/application/chat_pipeline/step/chat_step/i_chat_step.py +++ b/apps/application/chat_pipeline/step/chat_step/i_chat_step.py @@ -14,7 +14,7 @@ from langchain.schema import BaseMessage from rest_framework import serializers from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep, ParagraphPipelineModel -from application.chat_pipeline.pipeline_manage import PiplineManage +from application.chat_pipeline.pipeline_manage import PipelineManage from common.field.common import InstanceField from common.util.field_message import ErrMessage from dataset.models import Paragraph @@ -78,10 +78,10 @@ class IChatStep(IBaseChatPipelineStep): if not isinstance(message, BaseMessage): raise Exception("message 类型错误") - def get_step_serializer(self, manage: PiplineManage) -> Type[serializers.Serializer]: + def get_step_serializer(self, manage: PipelineManage) -> Type[serializers.Serializer]: return self.InstanceSerializer - def _run(self, manage: PiplineManage): + def _run(self, manage: PipelineManage): chat_result = self.execute(**self.context['step_args'], manage=manage) manage.context['chat_result'] = chat_result @@ -91,6 +91,6 @@ class IChatStep(IBaseChatPipelineStep): post_response_handler: PostResponseHandler, chat_model: BaseChatModel = None, paragraph_list=None, - manage: PiplineManage = None, + manage: PipelineManage = None, padding_problem_text: str = None, stream: bool = True, client_id=None, client_type=None, **kwargs): pass diff --git a/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py b/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py index 808b43de3..4c8e27c91 100644 --- a/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py +++ b/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py @@ -20,7 +20,7 @@ from langchain.schema import BaseMessage from langchain.schema.messages import BaseMessageChunk, HumanMessage, AIMessage from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel -from application.chat_pipeline.pipeline_manage import PiplineManage +from application.chat_pipeline.pipeline_manage import PipelineManage from application.chat_pipeline.step.chat_step.i_chat_step import IChatStep, PostResponseHandler from application.models.api_key_model import ApplicationPublicAccessClient from common.constants.authentication_type import AuthenticationType @@ -84,7 +84,7 @@ class BaseChatStep(IChatStep): post_response_handler: PostResponseHandler, chat_model: BaseChatModel = None, paragraph_list=None, - manage: PiplineManage = None, + manage: = None, padding_problem_text: str = None, stream: bool = True, client_id=None, client_type=None, @@ -125,7 +125,7 @@ class BaseChatStep(IChatStep): post_response_handler: PostResponseHandler, chat_model: BaseChatModel = None, paragraph_list=None, - manage: PiplineManage = None, + manage: = None, padding_problem_text: str = None, client_id=None, client_type=None): # 调用模型 @@ -151,7 +151,7 @@ class BaseChatStep(IChatStep): post_response_handler: PostResponseHandler, chat_model: BaseChatModel = None, paragraph_list=None, - manage: PiplineManage = None, + manage: = None, padding_problem_text: str = None, client_id=None, client_type=None): # 调用模型 diff --git a/apps/application/chat_pipeline/step/generate_human_message_step/i_generate_human_message_step.py b/apps/application/chat_pipeline/step/generate_human_message_step/i_generate_human_message_step.py index 01ce57ea0..350151659 100644 --- a/apps/application/chat_pipeline/step/generate_human_message_step/i_generate_human_message_step.py +++ b/apps/application/chat_pipeline/step/generate_human_message_step/i_generate_human_message_step.py @@ -13,7 +13,7 @@ from langchain.schema import BaseMessage from rest_framework import serializers from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep, ParagraphPipelineModel -from application.chat_pipeline.pipeline_manage import PiplineManage +from application.chat_pipeline.pipeline_manage import PipelineManage from application.models import ChatRecord from common.field.common import InstanceField from common.util.field_message import ErrMessage @@ -40,10 +40,10 @@ class IGenerateHumanMessageStep(IBaseChatPipelineStep): # 补齐问题 padding_problem_text = serializers.CharField(required=False, error_messages=ErrMessage.char("补齐问题")) - def get_step_serializer(self, manage: PiplineManage) -> Type[serializers.Serializer]: + def get_step_serializer(self, manage: PipelineManage) -> Type[serializers.Serializer]: return self.InstanceSerializer - def _run(self, manage: PiplineManage): + def _run(self, manage: PipelineManage): message_list = self.execute(**self.context['step_args']) manage.context['message_list'] = message_list diff --git a/apps/application/chat_pipeline/step/reset_problem_step/i_reset_problem_step.py b/apps/application/chat_pipeline/step/reset_problem_step/i_reset_problem_step.py index 4ff578666..930fd2482 100644 --- a/apps/application/chat_pipeline/step/reset_problem_step/i_reset_problem_step.py +++ b/apps/application/chat_pipeline/step/reset_problem_step/i_reset_problem_step.py @@ -13,7 +13,7 @@ from langchain.chat_models.base import BaseChatModel from rest_framework import serializers from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep -from application.chat_pipeline.pipeline_manage import PiplineManage +from application.chat_pipeline.pipeline_manage import PipelineManage from application.chat_pipeline.step.chat_step.i_chat_step import ModelField from application.models import ChatRecord from common.field.common import InstanceField @@ -30,10 +30,10 @@ class IResetProblemStep(IBaseChatPipelineStep): # 大语言模型 chat_model = ModelField(error_messages=ErrMessage.base("大语言模型")) - def get_step_serializer(self, manage: PiplineManage) -> Type[serializers.Serializer]: + def get_step_serializer(self, manage: PipelineManage) -> Type[serializers.Serializer]: return self.InstanceSerializer - def _run(self, manage: PiplineManage): + def _run(self, manage: PipelineManage): padding_problem = self.execute(**self.context.get('step_args')) # 用户输入问题 source_problem_text = self.context.get('step_args').get('problem_text') diff --git a/apps/application/chat_pipeline/step/search_dataset_step/i_search_dataset_step.py b/apps/application/chat_pipeline/step/search_dataset_step/i_search_dataset_step.py index 3731a6aef..f4e9296af 100644 --- a/apps/application/chat_pipeline/step/search_dataset_step/i_search_dataset_step.py +++ b/apps/application/chat_pipeline/step/search_dataset_step/i_search_dataset_step.py @@ -12,7 +12,7 @@ from typing import List, Type from rest_framework import serializers from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep, ParagraphPipelineModel -from application.chat_pipeline.pipeline_manage import PiplineManage +from application.chat_pipeline.pipeline_manage import PipelineManage from common.util.field_message import ErrMessage from dataset.models import Paragraph @@ -39,10 +39,10 @@ class ISearchDatasetStep(IBaseChatPipelineStep): similarity = serializers.FloatField(required=True, max_value=1, min_value=0, error_messages=ErrMessage.float("引用分段数")) - def get_step_serializer(self, manage: PiplineManage) -> Type[InstanceSerializer]: + def get_step_serializer(self, manage: PipelineManage) -> Type[InstanceSerializer]: return self.InstanceSerializer - def _run(self, manage: PiplineManage): + def _run(self, manage: PipelineManage): paragraph_list = self.execute(**self.context['step_args']) manage.context['paragraph_list'] = paragraph_list self.context['paragraph_list'] = paragraph_list