mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
--story=1016652 --user=王孝刚 【南区】应用对话回复支持显示引用分段来知识库的文档来源等信息 https://www.tapd.cn/57709429/s/1616642
158 lines
5.4 KiB
Python
158 lines
5.4 KiB
Python
# coding=utf-8
|
||
"""
|
||
@project: maxkb
|
||
@Author:虎
|
||
@file: I_base_chat_pipeline.py
|
||
@date:2024/1/9 17:25
|
||
@desc:
|
||
"""
|
||
import time
|
||
from abc import abstractmethod
|
||
from typing import Type
|
||
|
||
from rest_framework import serializers
|
||
|
||
from dataset.models import Paragraph
|
||
|
||
|
||
class ParagraphPipelineModel:
|
||
|
||
def __init__(self, _id: str, document_id: str, dataset_id: str, content: str, title: str, status: str,
|
||
is_active: bool, comprehensive_score: float, similarity: float, dataset_name: str, document_name: str,
|
||
hit_handling_method: str, directly_return_similarity: float, meta: dict = None):
|
||
self.id = _id
|
||
self.document_id = document_id
|
||
self.dataset_id = dataset_id
|
||
self.content = content
|
||
self.title = title
|
||
self.status = status,
|
||
self.is_active = is_active
|
||
self.comprehensive_score = comprehensive_score
|
||
self.similarity = similarity
|
||
self.dataset_name = dataset_name
|
||
self.document_name = document_name
|
||
self.hit_handling_method = hit_handling_method
|
||
self.directly_return_similarity = directly_return_similarity
|
||
self.meta = meta
|
||
|
||
def to_dict(self):
|
||
return {
|
||
'id': self.id,
|
||
'document_id': self.document_id,
|
||
'dataset_id': self.dataset_id,
|
||
'content': self.content,
|
||
'title': self.title,
|
||
'status': self.status,
|
||
'is_active': self.is_active,
|
||
'comprehensive_score': self.comprehensive_score,
|
||
'similarity': self.similarity,
|
||
'dataset_name': self.dataset_name,
|
||
'document_name': self.document_name,
|
||
'meta': self.meta,
|
||
}
|
||
|
||
class builder:
|
||
def __init__(self):
|
||
self.similarity = None
|
||
self.paragraph = {}
|
||
self.comprehensive_score = None
|
||
self.document_name = None
|
||
self.dataset_name = None
|
||
self.hit_handling_method = None
|
||
self.directly_return_similarity = 0.9
|
||
self.meta = {}
|
||
|
||
def add_paragraph(self, paragraph):
|
||
if isinstance(paragraph, Paragraph):
|
||
self.paragraph = {'id': paragraph.id,
|
||
'document_id': paragraph.document_id,
|
||
'dataset_id': paragraph.dataset_id,
|
||
'content': paragraph.content,
|
||
'title': paragraph.title,
|
||
'status': paragraph.status,
|
||
'is_active': paragraph.is_active,
|
||
}
|
||
else:
|
||
self.paragraph = paragraph
|
||
return self
|
||
|
||
def add_dataset_name(self, dataset_name):
|
||
self.dataset_name = dataset_name
|
||
return self
|
||
|
||
def add_document_name(self, document_name):
|
||
self.document_name = document_name
|
||
return self
|
||
|
||
def add_hit_handling_method(self, hit_handling_method):
|
||
self.hit_handling_method = hit_handling_method
|
||
return self
|
||
|
||
def add_directly_return_similarity(self, directly_return_similarity):
|
||
self.directly_return_similarity = directly_return_similarity
|
||
return self
|
||
|
||
def add_comprehensive_score(self, comprehensive_score: float):
|
||
self.comprehensive_score = comprehensive_score
|
||
return self
|
||
|
||
def add_similarity(self, similarity: float):
|
||
self.similarity = similarity
|
||
return self
|
||
|
||
def add_meta(self, meta: dict):
|
||
self.meta = meta
|
||
return self
|
||
|
||
def build(self):
|
||
return ParagraphPipelineModel(str(self.paragraph.get('id')), str(self.paragraph.get('document_id')),
|
||
str(self.paragraph.get('dataset_id')),
|
||
self.paragraph.get('content'), self.paragraph.get('title'),
|
||
self.paragraph.get('status'),
|
||
self.paragraph.get('is_active'),
|
||
self.comprehensive_score, self.similarity, self.dataset_name,
|
||
self.document_name, self.hit_handling_method, self.directly_return_similarity,
|
||
self.meta)
|
||
|
||
|
||
class IBaseChatPipelineStep:
|
||
def __init__(self):
|
||
# 当前步骤上下文,用于存储当前步骤信息
|
||
self.context = {}
|
||
|
||
@abstractmethod
|
||
def get_step_serializer(self, manage) -> Type[serializers.Serializer]:
|
||
pass
|
||
|
||
def valid_args(self, manage):
|
||
step_serializer_clazz = self.get_step_serializer(manage)
|
||
step_serializer = step_serializer_clazz(data=manage.context)
|
||
step_serializer.is_valid(raise_exception=True)
|
||
self.context['step_args'] = step_serializer.data
|
||
|
||
def run(self, manage):
|
||
"""
|
||
|
||
:param manage: 步骤管理器
|
||
:return: 执行结果
|
||
"""
|
||
start_time = time.time()
|
||
self.context['start_time'] = start_time
|
||
# 校验参数,
|
||
self.valid_args(manage)
|
||
self._run(manage)
|
||
self.context['run_time'] = time.time() - start_time
|
||
|
||
def _run(self, manage):
|
||
pass
|
||
|
||
def execute(self, **kwargs):
|
||
pass
|
||
|
||
def get_details(self, manage, **kwargs):
|
||
"""
|
||
运行详情
|
||
:return: 步骤详情
|
||
"""
|
||
return None
|