mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
refactor: fix typo (#64)
This commit is contained in:
parent
142c999b95
commit
f19e78b692
|
|
@ -13,7 +13,7 @@ from typing import List, Type, Dict
|
|||
from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep
|
||||
|
||||
|
||||
class PiplineManage:
|
||||
class PipelineManage:
|
||||
def __init__(self, step_list: List[Type[IBaseChatPipelineStep]]):
|
||||
# 步骤执行器
|
||||
self.step_list = [step() for step in step_list]
|
||||
|
|
@ -42,4 +42,4 @@ class PiplineManage:
|
|||
return self
|
||||
|
||||
def build(self):
|
||||
return PiplineManage(step_list=self.step_list)
|
||||
return PipelineManage(step_list=self.step_list)
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from langchain.schema import BaseMessage
|
|||
from rest_framework import serializers
|
||||
|
||||
from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep, ParagraphPipelineModel
|
||||
from application.chat_pipeline.pipeline_manage import PiplineManage
|
||||
from application.chat_pipeline.pipeline_manage import PipelineManage
|
||||
from common.field.common import InstanceField
|
||||
from common.util.field_message import ErrMessage
|
||||
from dataset.models import Paragraph
|
||||
|
|
@ -78,10 +78,10 @@ class IChatStep(IBaseChatPipelineStep):
|
|||
if not isinstance(message, BaseMessage):
|
||||
raise Exception("message 类型错误")
|
||||
|
||||
def get_step_serializer(self, manage: PiplineManage) -> Type[serializers.Serializer]:
|
||||
def get_step_serializer(self, manage: PipelineManage) -> Type[serializers.Serializer]:
|
||||
return self.InstanceSerializer
|
||||
|
||||
def _run(self, manage: PiplineManage):
|
||||
def _run(self, manage: PipelineManage):
|
||||
chat_result = self.execute(**self.context['step_args'], manage=manage)
|
||||
manage.context['chat_result'] = chat_result
|
||||
|
||||
|
|
@ -91,6 +91,6 @@ class IChatStep(IBaseChatPipelineStep):
|
|||
post_response_handler: PostResponseHandler,
|
||||
chat_model: BaseChatModel = None,
|
||||
paragraph_list=None,
|
||||
manage: PiplineManage = None,
|
||||
manage: PipelineManage = None,
|
||||
padding_problem_text: str = None, stream: bool = True, client_id=None, client_type=None, **kwargs):
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ from langchain.schema import BaseMessage
|
|||
from langchain.schema.messages import BaseMessageChunk, HumanMessage, AIMessage
|
||||
|
||||
from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel
|
||||
from application.chat_pipeline.pipeline_manage import PiplineManage
|
||||
from application.chat_pipeline.pipeline_manage import PipelineManage
|
||||
from application.chat_pipeline.step.chat_step.i_chat_step import IChatStep, PostResponseHandler
|
||||
from application.models.api_key_model import ApplicationPublicAccessClient
|
||||
from common.constants.authentication_type import AuthenticationType
|
||||
|
|
@ -84,7 +84,7 @@ class BaseChatStep(IChatStep):
|
|||
post_response_handler: PostResponseHandler,
|
||||
chat_model: BaseChatModel = None,
|
||||
paragraph_list=None,
|
||||
manage: PiplineManage = None,
|
||||
manage: = None,
|
||||
padding_problem_text: str = None,
|
||||
stream: bool = True,
|
||||
client_id=None, client_type=None,
|
||||
|
|
@ -125,7 +125,7 @@ class BaseChatStep(IChatStep):
|
|||
post_response_handler: PostResponseHandler,
|
||||
chat_model: BaseChatModel = None,
|
||||
paragraph_list=None,
|
||||
manage: PiplineManage = None,
|
||||
manage: = None,
|
||||
padding_problem_text: str = None,
|
||||
client_id=None, client_type=None):
|
||||
# 调用模型
|
||||
|
|
@ -151,7 +151,7 @@ class BaseChatStep(IChatStep):
|
|||
post_response_handler: PostResponseHandler,
|
||||
chat_model: BaseChatModel = None,
|
||||
paragraph_list=None,
|
||||
manage: PiplineManage = None,
|
||||
manage: = None,
|
||||
padding_problem_text: str = None,
|
||||
client_id=None, client_type=None):
|
||||
# 调用模型
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from langchain.schema import BaseMessage
|
|||
from rest_framework import serializers
|
||||
|
||||
from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep, ParagraphPipelineModel
|
||||
from application.chat_pipeline.pipeline_manage import PiplineManage
|
||||
from application.chat_pipeline.pipeline_manage import PipelineManage
|
||||
from application.models import ChatRecord
|
||||
from common.field.common import InstanceField
|
||||
from common.util.field_message import ErrMessage
|
||||
|
|
@ -40,10 +40,10 @@ class IGenerateHumanMessageStep(IBaseChatPipelineStep):
|
|||
# 补齐问题
|
||||
padding_problem_text = serializers.CharField(required=False, error_messages=ErrMessage.char("补齐问题"))
|
||||
|
||||
def get_step_serializer(self, manage: PiplineManage) -> Type[serializers.Serializer]:
|
||||
def get_step_serializer(self, manage: PipelineManage) -> Type[serializers.Serializer]:
|
||||
return self.InstanceSerializer
|
||||
|
||||
def _run(self, manage: PiplineManage):
|
||||
def _run(self, manage: PipelineManage):
|
||||
message_list = self.execute(**self.context['step_args'])
|
||||
manage.context['message_list'] = message_list
|
||||
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from langchain.chat_models.base import BaseChatModel
|
|||
from rest_framework import serializers
|
||||
|
||||
from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep
|
||||
from application.chat_pipeline.pipeline_manage import PiplineManage
|
||||
from application.chat_pipeline.pipeline_manage import PipelineManage
|
||||
from application.chat_pipeline.step.chat_step.i_chat_step import ModelField
|
||||
from application.models import ChatRecord
|
||||
from common.field.common import InstanceField
|
||||
|
|
@ -30,10 +30,10 @@ class IResetProblemStep(IBaseChatPipelineStep):
|
|||
# 大语言模型
|
||||
chat_model = ModelField(error_messages=ErrMessage.base("大语言模型"))
|
||||
|
||||
def get_step_serializer(self, manage: PiplineManage) -> Type[serializers.Serializer]:
|
||||
def get_step_serializer(self, manage: PipelineManage) -> Type[serializers.Serializer]:
|
||||
return self.InstanceSerializer
|
||||
|
||||
def _run(self, manage: PiplineManage):
|
||||
def _run(self, manage: PipelineManage):
|
||||
padding_problem = self.execute(**self.context.get('step_args'))
|
||||
# 用户输入问题
|
||||
source_problem_text = self.context.get('step_args').get('problem_text')
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from typing import List, Type
|
|||
from rest_framework import serializers
|
||||
|
||||
from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep, ParagraphPipelineModel
|
||||
from application.chat_pipeline.pipeline_manage import PiplineManage
|
||||
from application.chat_pipeline.pipeline_manage import PipelineManage
|
||||
from common.util.field_message import ErrMessage
|
||||
from dataset.models import Paragraph
|
||||
|
||||
|
|
@ -39,10 +39,10 @@ class ISearchDatasetStep(IBaseChatPipelineStep):
|
|||
similarity = serializers.FloatField(required=True, max_value=1, min_value=0,
|
||||
error_messages=ErrMessage.float("引用分段数"))
|
||||
|
||||
def get_step_serializer(self, manage: PiplineManage) -> Type[InstanceSerializer]:
|
||||
def get_step_serializer(self, manage: PipelineManage) -> Type[InstanceSerializer]:
|
||||
return self.InstanceSerializer
|
||||
|
||||
def _run(self, manage: PiplineManage):
|
||||
def _run(self, manage: PipelineManage):
|
||||
paragraph_list = self.execute(**self.context['step_args'])
|
||||
manage.context['paragraph_list'] = paragraph_list
|
||||
self.context['paragraph_list'] = paragraph_list
|
||||
|
|
|
|||
Loading…
Reference in New Issue