refactor: fix typo (#64)

This commit is contained in:
Ikko Eltociear Ashimine 2024-04-15 11:59:50 +09:00 committed by GitHub
parent 142c999b95
commit f19e78b692
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 19 additions and 19 deletions

View File

@ -13,7 +13,7 @@ from typing import List, Type, Dict
from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep
class PiplineManage:
class PipelineManage:
def __init__(self, step_list: List[Type[IBaseChatPipelineStep]]):
# 步骤执行器
self.step_list = [step() for step in step_list]
@ -42,4 +42,4 @@ class PiplineManage:
return self
def build(self):
return PiplineManage(step_list=self.step_list)
return PipelineManage(step_list=self.step_list)

View File

@ -14,7 +14,7 @@ from langchain.schema import BaseMessage
from rest_framework import serializers
from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep, ParagraphPipelineModel
from application.chat_pipeline.pipeline_manage import PiplineManage
from application.chat_pipeline.pipeline_manage import PipelineManage
from common.field.common import InstanceField
from common.util.field_message import ErrMessage
from dataset.models import Paragraph
@ -78,10 +78,10 @@ class IChatStep(IBaseChatPipelineStep):
if not isinstance(message, BaseMessage):
raise Exception("message 类型错误")
def get_step_serializer(self, manage: PiplineManage) -> Type[serializers.Serializer]:
def get_step_serializer(self, manage: PipelineManage) -> Type[serializers.Serializer]:
return self.InstanceSerializer
def _run(self, manage: PiplineManage):
def _run(self, manage: PipelineManage):
chat_result = self.execute(**self.context['step_args'], manage=manage)
manage.context['chat_result'] = chat_result
@ -91,6 +91,6 @@ class IChatStep(IBaseChatPipelineStep):
post_response_handler: PostResponseHandler,
chat_model: BaseChatModel = None,
paragraph_list=None,
manage: PiplineManage = None,
manage: PipelineManage = None,
padding_problem_text: str = None, stream: bool = True, client_id=None, client_type=None, **kwargs):
pass

View File

@ -20,7 +20,7 @@ from langchain.schema import BaseMessage
from langchain.schema.messages import BaseMessageChunk, HumanMessage, AIMessage
from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel
from application.chat_pipeline.pipeline_manage import PiplineManage
from application.chat_pipeline.pipeline_manage import PipelineManage
from application.chat_pipeline.step.chat_step.i_chat_step import IChatStep, PostResponseHandler
from application.models.api_key_model import ApplicationPublicAccessClient
from common.constants.authentication_type import AuthenticationType
@ -84,7 +84,7 @@ class BaseChatStep(IChatStep):
post_response_handler: PostResponseHandler,
chat_model: BaseChatModel = None,
paragraph_list=None,
manage: PiplineManage = None,
manage: = None,
padding_problem_text: str = None,
stream: bool = True,
client_id=None, client_type=None,
@ -125,7 +125,7 @@ class BaseChatStep(IChatStep):
post_response_handler: PostResponseHandler,
chat_model: BaseChatModel = None,
paragraph_list=None,
manage: PiplineManage = None,
manage: = None,
padding_problem_text: str = None,
client_id=None, client_type=None):
# 调用模型
@ -151,7 +151,7 @@ class BaseChatStep(IChatStep):
post_response_handler: PostResponseHandler,
chat_model: BaseChatModel = None,
paragraph_list=None,
manage: PiplineManage = None,
manage: = None,
padding_problem_text: str = None,
client_id=None, client_type=None):
# 调用模型

View File

@ -13,7 +13,7 @@ from langchain.schema import BaseMessage
from rest_framework import serializers
from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep, ParagraphPipelineModel
from application.chat_pipeline.pipeline_manage import PiplineManage
from application.chat_pipeline.pipeline_manage import PipelineManage
from application.models import ChatRecord
from common.field.common import InstanceField
from common.util.field_message import ErrMessage
@ -40,10 +40,10 @@ class IGenerateHumanMessageStep(IBaseChatPipelineStep):
# 补齐问题
padding_problem_text = serializers.CharField(required=False, error_messages=ErrMessage.char("补齐问题"))
def get_step_serializer(self, manage: PiplineManage) -> Type[serializers.Serializer]:
def get_step_serializer(self, manage: PipelineManage) -> Type[serializers.Serializer]:
return self.InstanceSerializer
def _run(self, manage: PiplineManage):
def _run(self, manage: PipelineManage):
message_list = self.execute(**self.context['step_args'])
manage.context['message_list'] = message_list

View File

@ -13,7 +13,7 @@ from langchain.chat_models.base import BaseChatModel
from rest_framework import serializers
from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep
from application.chat_pipeline.pipeline_manage import PiplineManage
from application.chat_pipeline.pipeline_manage import PipelineManage
from application.chat_pipeline.step.chat_step.i_chat_step import ModelField
from application.models import ChatRecord
from common.field.common import InstanceField
@ -30,10 +30,10 @@ class IResetProblemStep(IBaseChatPipelineStep):
# 大语言模型
chat_model = ModelField(error_messages=ErrMessage.base("大语言模型"))
def get_step_serializer(self, manage: PiplineManage) -> Type[serializers.Serializer]:
def get_step_serializer(self, manage: PipelineManage) -> Type[serializers.Serializer]:
return self.InstanceSerializer
def _run(self, manage: PiplineManage):
def _run(self, manage: PipelineManage):
padding_problem = self.execute(**self.context.get('step_args'))
# 用户输入问题
source_problem_text = self.context.get('step_args').get('problem_text')

View File

@ -12,7 +12,7 @@ from typing import List, Type
from rest_framework import serializers
from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep, ParagraphPipelineModel
from application.chat_pipeline.pipeline_manage import PiplineManage
from application.chat_pipeline.pipeline_manage import PipelineManage
from common.util.field_message import ErrMessage
from dataset.models import Paragraph
@ -39,10 +39,10 @@ class ISearchDatasetStep(IBaseChatPipelineStep):
similarity = serializers.FloatField(required=True, max_value=1, min_value=0,
error_messages=ErrMessage.float("引用分段数"))
def get_step_serializer(self, manage: PiplineManage) -> Type[InstanceSerializer]:
def get_step_serializer(self, manage: PipelineManage) -> Type[InstanceSerializer]:
return self.InstanceSerializer
def _run(self, manage: PiplineManage):
def _run(self, manage: PipelineManage):
paragraph_list = self.execute(**self.context['step_args'])
manage.context['paragraph_list'] = paragraph_list
self.context['paragraph_list'] = paragraph_list