mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-30 01:32:49 +00:00
fix: 对话日志存储,存储段落详情
This commit is contained in:
parent
e83afe74c5
commit
9822f82593
|
|
@ -8,10 +8,91 @@
|
|||
"""
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from typing import Type
|
||||
from typing import Type, Dict
|
||||
|
||||
from rest_framework import serializers
|
||||
|
||||
from dataset.models import Paragraph
|
||||
|
||||
|
||||
class ParagraphPipelineModel:
|
||||
|
||||
def __init__(self, _id: str, document_id: str, dataset_id: str, content: str, title: str, status: str,
|
||||
is_active: bool, comprehensive_score: float, similarity: float, dataset_name: str, document_name: str):
|
||||
self.id = _id
|
||||
self.document_id = document_id
|
||||
self.dataset_id = dataset_id
|
||||
self.content = content
|
||||
self.title = title
|
||||
self.status = status,
|
||||
self.is_active = is_active
|
||||
self.comprehensive_score = comprehensive_score
|
||||
self.similarity = similarity
|
||||
self.dataset_name = dataset_name
|
||||
self.document_name = document_name
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
'id': self.id,
|
||||
'document_id': self.document_id,
|
||||
'dataset_id': self.dataset_id,
|
||||
'content': self.content,
|
||||
'title': self.title,
|
||||
'status': self.status,
|
||||
'is_active': self.is_active,
|
||||
'comprehensive_score': self.comprehensive_score,
|
||||
'similarity': self.similarity,
|
||||
'dataset_name': self.dataset_name,
|
||||
'document_name': self.document_name
|
||||
}
|
||||
|
||||
class builder:
|
||||
def __init__(self):
|
||||
self.similarity = None
|
||||
self.paragraph = {}
|
||||
self.comprehensive_score = None
|
||||
self.document_name = None
|
||||
self.dataset_name = None
|
||||
|
||||
def add_paragraph(self, paragraph):
|
||||
if isinstance(paragraph, Paragraph):
|
||||
self.paragraph = {'id': paragraph.id,
|
||||
'document_id': paragraph.document_id,
|
||||
'dataset_id': paragraph.dataset_id,
|
||||
'content': paragraph.content,
|
||||
'title': paragraph.title,
|
||||
'status': paragraph.status,
|
||||
'is_active': paragraph.is_active,
|
||||
}
|
||||
else:
|
||||
self.paragraph = paragraph
|
||||
return self
|
||||
|
||||
def add_dataset_name(self, dataset_name):
|
||||
self.dataset_name = dataset_name
|
||||
return self
|
||||
|
||||
def add_document_name(self, document_name):
|
||||
self.document_name = document_name
|
||||
return self
|
||||
|
||||
def add_comprehensive_score(self, comprehensive_score: float):
|
||||
self.comprehensive_score = comprehensive_score
|
||||
return self
|
||||
|
||||
def add_similarity(self, similarity: float):
|
||||
self.similarity = similarity
|
||||
return self
|
||||
|
||||
def build(self):
|
||||
return ParagraphPipelineModel(str(self.paragraph.get('id')), str(self.paragraph.get('document_id')),
|
||||
str(self.paragraph.get('dataset_id')),
|
||||
self.paragraph.get('content'), self.paragraph.get('title'),
|
||||
self.paragraph.get('status'),
|
||||
self.paragraph.get('is_active'),
|
||||
self.comprehensive_score, self.similarity, self.dataset_name,
|
||||
self.document_name)
|
||||
|
||||
|
||||
class IBaseChatPipelineStep:
|
||||
def __init__(self):
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from langchain.chat_models.base import BaseChatModel
|
|||
from langchain.schema import BaseMessage
|
||||
from rest_framework import serializers
|
||||
|
||||
from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep
|
||||
from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep, ParagraphPipelineModel
|
||||
from application.chat_pipeline.pipeline_manage import PiplineManage
|
||||
from common.field.common import InstanceField
|
||||
from dataset.models import Paragraph
|
||||
|
|
@ -41,7 +41,8 @@ class MessageField(serializers.Field):
|
|||
|
||||
class PostResponseHandler:
|
||||
@abstractmethod
|
||||
def handler(self, chat_id, chat_record_id, paragraph_list: List[Paragraph], problem_text: str, answer_text,
|
||||
def handler(self, chat_id, chat_record_id, paragraph_list: List[ParagraphPipelineModel], problem_text: str,
|
||||
answer_text,
|
||||
manage, step, padding_problem_text: str = None, **kwargs):
|
||||
pass
|
||||
|
||||
|
|
|
|||
|
|
@ -18,15 +18,15 @@ from langchain.chat_models.base import BaseChatModel
|
|||
from langchain.schema import BaseMessage
|
||||
from langchain.schema.messages import BaseMessageChunk, HumanMessage
|
||||
|
||||
from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel
|
||||
from application.chat_pipeline.pipeline_manage import PiplineManage
|
||||
from application.chat_pipeline.step.chat_step.i_chat_step import IChatStep, PostResponseHandler
|
||||
from dataset.models import Paragraph
|
||||
|
||||
|
||||
def event_content(response,
|
||||
chat_id,
|
||||
chat_record_id,
|
||||
paragraph_list: List[Paragraph],
|
||||
paragraph_list: List[ParagraphPipelineModel],
|
||||
post_response_handler: PostResponseHandler,
|
||||
manage,
|
||||
step,
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from typing import Type, List
|
|||
from langchain.schema import BaseMessage
|
||||
from rest_framework import serializers
|
||||
|
||||
from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep
|
||||
from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep, ParagraphPipelineModel
|
||||
from application.chat_pipeline.pipeline_manage import PiplineManage
|
||||
from application.models import ChatRecord
|
||||
from common.field.common import InstanceField
|
||||
|
|
@ -24,7 +24,7 @@ class IGenerateHumanMessageStep(IBaseChatPipelineStep):
|
|||
# 问题
|
||||
problem_text = serializers.CharField(required=True)
|
||||
# 段落列表
|
||||
paragraph_list = serializers.ListField(child=InstanceField(model_type=Paragraph, required=True))
|
||||
paragraph_list = serializers.ListField(child=InstanceField(model_type=ParagraphPipelineModel, required=True))
|
||||
# 历史对答
|
||||
history_chat_record = serializers.ListField(child=InstanceField(model_type=ChatRecord, required=True))
|
||||
# 多轮对话数量
|
||||
|
|
@ -46,7 +46,7 @@ class IGenerateHumanMessageStep(IBaseChatPipelineStep):
|
|||
@abstractmethod
|
||||
def execute(self,
|
||||
problem_text: str,
|
||||
paragraph_list: List[Paragraph],
|
||||
paragraph_list: List[ParagraphPipelineModel],
|
||||
history_chat_record: List[ChatRecord],
|
||||
dialogue_number: int,
|
||||
max_paragraph_char_number: int,
|
||||
|
|
|
|||
|
|
@ -10,17 +10,17 @@ from typing import List
|
|||
|
||||
from langchain.schema import BaseMessage, HumanMessage
|
||||
|
||||
from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel
|
||||
from application.chat_pipeline.step.generate_human_message_step.i_generate_human_message_step import \
|
||||
IGenerateHumanMessageStep
|
||||
from application.models import ChatRecord
|
||||
from common.util.split_model import flat_map
|
||||
from dataset.models import Paragraph
|
||||
|
||||
|
||||
class BaseGenerateHumanMessageStep(IGenerateHumanMessageStep):
|
||||
|
||||
def execute(self, problem_text: str,
|
||||
paragraph_list: List[Paragraph],
|
||||
paragraph_list: List[ParagraphPipelineModel],
|
||||
history_chat_record: List[ChatRecord],
|
||||
dialogue_number: int,
|
||||
max_paragraph_char_number: int,
|
||||
|
|
@ -39,7 +39,7 @@ class BaseGenerateHumanMessageStep(IGenerateHumanMessageStep):
|
|||
def to_human_message(prompt: str,
|
||||
problem: str,
|
||||
max_paragraph_char_number: int,
|
||||
paragraph_list: List[Paragraph]):
|
||||
paragraph_list: List[ParagraphPipelineModel]):
|
||||
if paragraph_list is None or len(paragraph_list) == 0:
|
||||
return HumanMessage(content=problem)
|
||||
temp_data = ""
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from typing import List, Type
|
|||
|
||||
from rest_framework import serializers
|
||||
|
||||
from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep
|
||||
from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep, ParagraphPipelineModel
|
||||
from application.chat_pipeline.pipeline_manage import PiplineManage
|
||||
from dataset.models import Paragraph
|
||||
|
||||
|
|
@ -39,11 +39,12 @@ class ISearchDatasetStep(IBaseChatPipelineStep):
|
|||
def _run(self, manage: PiplineManage):
|
||||
paragraph_list = self.execute(**self.context['step_args'])
|
||||
manage.context['paragraph_list'] = paragraph_list
|
||||
self.context['paragraph_list'] = paragraph_list
|
||||
|
||||
@abstractmethod
|
||||
def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str],
|
||||
exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None,
|
||||
**kwargs) -> List[Paragraph]:
|
||||
**kwargs) -> List[ParagraphPipelineModel]:
|
||||
"""
|
||||
关于 用户和补全问题 说明: 补全问题如果有就使用补全问题去查询 反之就用用户原始问题查询
|
||||
:param similarity: 相关性
|
||||
|
|
|
|||
|
|
@ -6,20 +6,25 @@
|
|||
@date:2024/1/10 10:33
|
||||
@desc:
|
||||
"""
|
||||
from typing import List
|
||||
import os
|
||||
from typing import List, Dict
|
||||
|
||||
from django.db.models import QuerySet
|
||||
|
||||
from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel
|
||||
from application.chat_pipeline.step.search_dataset_step.i_search_dataset_step import ISearchDatasetStep
|
||||
from common.config.embedding_config import VectorStore, EmbeddingModel
|
||||
from common.db.search import native_search
|
||||
from common.util.file_util import get_file_content
|
||||
from dataset.models import Paragraph
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
|
||||
class BaseSearchDatasetStep(ISearchDatasetStep):
|
||||
|
||||
def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str],
|
||||
exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None,
|
||||
**kwargs) -> List[Paragraph]:
|
||||
**kwargs) -> List[ParagraphPipelineModel]:
|
||||
exec_problem_text = padding_problem_text if padding_problem_text is not None else problem_text
|
||||
embedding_model = EmbeddingModel.get_embedding_model()
|
||||
embedding_value = embedding_model.embed_query(exec_problem_text)
|
||||
|
|
@ -28,16 +33,35 @@ class BaseSearchDatasetStep(ISearchDatasetStep):
|
|||
exclude_paragraph_id_list, True, top_n, similarity)
|
||||
if embedding_list is None:
|
||||
return []
|
||||
return self.list_paragraph([row.get('paragraph_id') for row in embedding_list], vector)
|
||||
paragraph_list = self.list_paragraph([row.get('paragraph_id') for row in embedding_list], vector)
|
||||
return [self.reset_paragraph(paragraph, embedding_list) for paragraph in paragraph_list]
|
||||
|
||||
@staticmethod
|
||||
def reset_paragraph(paragraph: Dict, embedding_list: List) -> ParagraphPipelineModel:
|
||||
filter_embedding_list = [embedding for embedding in embedding_list if
|
||||
str(embedding.get('paragraph_id')) == str(paragraph.get('id'))]
|
||||
if filter_embedding_list is not None and len(filter_embedding_list) > 0:
|
||||
find_embedding = filter_embedding_list[-1]
|
||||
return (ParagraphPipelineModel.builder()
|
||||
.add_paragraph(paragraph)
|
||||
.add_similarity(find_embedding.get('similarity'))
|
||||
.add_comprehensive_score(find_embedding.get('comprehensive_score'))
|
||||
.add_dataset_name(paragraph.get('dataset_name'))
|
||||
.add_document_name(paragraph.get('document_name'))
|
||||
.build())
|
||||
|
||||
@staticmethod
|
||||
def list_paragraph(paragraph_id_list: List, vector):
|
||||
if paragraph_id_list is None or len(paragraph_id_list) == 0:
|
||||
return []
|
||||
paragraph_list = QuerySet(Paragraph).filter(id__in=paragraph_id_list)
|
||||
paragraph_list = native_search(QuerySet(Paragraph).filter(id__in=paragraph_id_list),
|
||||
get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", "application", 'sql',
|
||||
'list_dataset_paragraph_by_paragraph_id.sql')),
|
||||
with_table_name=True)
|
||||
# 如果向量库中存在脏数据 直接删除
|
||||
if len(paragraph_list) != len(paragraph_id_list):
|
||||
exist_paragraph_list = [str(row.id) for row in paragraph_list]
|
||||
exist_paragraph_list = [row.get('id') for row in paragraph_list]
|
||||
for paragraph_id in paragraph_id_list:
|
||||
if not exist_paragraph_list.__contains__(paragraph_id):
|
||||
vector.delete_by_paragraph_id(paragraph_id)
|
||||
|
|
@ -48,6 +72,7 @@ class BaseSearchDatasetStep(ISearchDatasetStep):
|
|||
|
||||
return {
|
||||
'step_type': 'search_step',
|
||||
'paragraph_list': [row.to_dict() for row in self.context['paragraph_list']],
|
||||
'run_time': self.context['run_time'],
|
||||
'problem_text': step_args.get(
|
||||
'padding_problem_text') if 'padding_problem_text' in step_args else step_args.get('problem_text'),
|
||||
|
|
|
|||
|
|
@ -0,0 +1,17 @@
|
|||
# Generated by Django 4.1.10 on 2024-01-19 14:02
|
||||
|
||||
from django.db import migrations
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('application', '0005_alter_chatrecord_details'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.RemoveField(
|
||||
model_name='chatrecord',
|
||||
name='paragraph_id_list',
|
||||
),
|
||||
]
|
||||
|
|
@ -89,9 +89,6 @@ class ChatRecord(AppModelMixin):
|
|||
chat = models.ForeignKey(Chat, on_delete=models.CASCADE)
|
||||
vote_status = models.CharField(verbose_name='投票', max_length=10, choices=VoteChoices.choices,
|
||||
default=VoteChoices.UN_VOTE)
|
||||
paragraph_id_list = ArrayField(verbose_name="引用段落id列表",
|
||||
base_field=models.UUIDField(max_length=128, blank=True)
|
||||
, default=list)
|
||||
problem_text = models.CharField(max_length=1024, verbose_name="问题")
|
||||
answer_text = models.CharField(max_length=4096, verbose_name="答案")
|
||||
message_tokens = models.IntegerField(verbose_name="请求token数量", default=0)
|
||||
|
|
|
|||
|
|
@ -108,7 +108,6 @@ def get_post_handler(chat_info: ChatInfo):
|
|||
**kwargs):
|
||||
chat_record = ChatRecord(id=chat_record_id,
|
||||
chat_id=chat_id,
|
||||
paragraph_id_list=[str(p.id) for p in paragraph_list],
|
||||
problem_text=problem_text,
|
||||
answer_text=answer_text,
|
||||
details=manage.get_details(),
|
||||
|
|
|
|||
|
|
@ -192,28 +192,7 @@ class ChatRecordSerializer(serializers.Serializer):
|
|||
chat_record = self.get_chat_record()
|
||||
if chat_record is None:
|
||||
raise AppApiException(500, "对话不存在")
|
||||
dataset_list = []
|
||||
paragraph_list = []
|
||||
if len(chat_record.paragraph_id_list) > 0:
|
||||
paragraph_list = native_search(QuerySet(Paragraph).filter(id__in=chat_record.paragraph_id_list),
|
||||
get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", "application", 'sql',
|
||||
'list_dataset_paragraph_by_paragraph_id.sql')),
|
||||
with_table_name=True)
|
||||
dataset_list = [{'id': dataset_id, 'name': name} for dataset_id, name in reduce(lambda x, y: {**x, **y},
|
||||
[{row.get(
|
||||
'dataset_id'): row.get(
|
||||
"dataset_name")} for
|
||||
row in
|
||||
paragraph_list],
|
||||
{}).items()]
|
||||
|
||||
return {
|
||||
**ChatRecordSerializerModel(chat_record).data,
|
||||
'padding_problem_text': chat_record.details.get('problem_padding').get(
|
||||
'padding_problem_text') if 'problem_padding' in chat_record.details else None,
|
||||
'dataset_list': dataset_list,
|
||||
'paragraph_list': paragraph_list}
|
||||
return ChatRecordSerializer.Query.reset_chat_record(chat_record)
|
||||
|
||||
class Query(serializers.Serializer):
|
||||
application_id = serializers.UUIDField(required=True)
|
||||
|
|
@ -226,37 +205,22 @@ class ChatRecordSerializer(serializers.Serializer):
|
|||
return [ChatRecordSerializerModel(chat_record).data for chat_record in
|
||||
QuerySet(ChatRecord).filter(chat_id=self.data.get('chat_id'))]
|
||||
|
||||
def reset_chat_record_list(self, chat_record_list: List[ChatRecord]):
|
||||
paragraph_id_list = flat_map([chat_record.paragraph_id_list for chat_record in chat_record_list])
|
||||
# 去重
|
||||
paragraph_id_list = list(set(paragraph_id_list))
|
||||
paragraph_list = self.search_paragraph(paragraph_id_list)
|
||||
return [self.reset_chat_record(chat_record, paragraph_list) for chat_record in chat_record_list]
|
||||
|
||||
@staticmethod
|
||||
def search_paragraph(paragraph_id_list: List[str]):
|
||||
def reset_chat_record(chat_record):
|
||||
dataset_list = []
|
||||
paragraph_list = []
|
||||
if len(paragraph_id_list) > 0:
|
||||
paragraph_list = native_search(QuerySet(Paragraph).filter(id__in=paragraph_id_list),
|
||||
get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", "application", 'sql',
|
||||
'list_dataset_paragraph_by_paragraph_id.sql')),
|
||||
with_table_name=True)
|
||||
return paragraph_list
|
||||
if 'search_step' in chat_record.details and chat_record.details.get('search_step').get(
|
||||
'paragraph_list') is not None:
|
||||
paragraph_list = chat_record.details.get('search_step').get(
|
||||
'paragraph_list')
|
||||
dataset_list = [{'id': dataset_id, 'name': name} for dataset_id, name in reduce(lambda x, y: {**x, **y},
|
||||
[{row.get(
|
||||
'dataset_id'): row.get(
|
||||
"dataset_name")} for
|
||||
row in
|
||||
paragraph_list],
|
||||
{}).items()]
|
||||
|
||||
@staticmethod
|
||||
def reset_chat_record(chat_record, all_paragraph_list):
|
||||
paragraph_list = list(
|
||||
filter(
|
||||
lambda paragraph: chat_record.paragraph_id_list.__contains__(uuid.UUID(str(paragraph.get('id')))),
|
||||
all_paragraph_list))
|
||||
dataset_list = [{'id': dataset_id, 'name': name} for dataset_id, name in reduce(lambda x, y: {**x, **y},
|
||||
[{row.get(
|
||||
'dataset_id'): row.get(
|
||||
"dataset_name")} for
|
||||
row in
|
||||
paragraph_list],
|
||||
{}).items()]
|
||||
return {
|
||||
**ChatRecordSerializerModel(chat_record).data,
|
||||
'padding_problem_text': chat_record.details.get('problem_padding').get(
|
||||
|
|
@ -270,9 +234,7 @@ class ChatRecordSerializer(serializers.Serializer):
|
|||
self.is_valid(raise_exception=True)
|
||||
page = page_search(current_page, page_size,
|
||||
QuerySet(ChatRecord).filter(chat_id=self.data.get('chat_id')).order_by("index"),
|
||||
post_records_handler=lambda chat_record: chat_record)
|
||||
records = page.get('records')
|
||||
page['records'] = self.reset_chat_record_list(records)
|
||||
post_records_handler=lambda chat_record: self.reset_chat_record(chat_record))
|
||||
return page
|
||||
|
||||
class Vote(serializers.Serializer):
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
SELECT
|
||||
paragraph.*,
|
||||
dataset."name" AS "dataset_name"
|
||||
dataset."name" AS "dataset_name",
|
||||
"document"."name" AS "document_name"
|
||||
FROM
|
||||
paragraph paragraph
|
||||
LEFT JOIN dataset dataset ON dataset."id" = paragraph.dataset_id
|
||||
LEFT JOIN dataset dataset ON dataset."id" = paragraph.dataset_id
|
||||
LEFT JOIN "document" "document" ON "document"."id" =paragraph.document_id
|
||||
Loading…
Reference in New Issue