fix: 对话日志存储,存储段落详情

This commit is contained in:
shaohuzhang1 2024-01-19 16:07:12 +08:00
parent e83afe74c5
commit 9822f82593
12 changed files with 162 additions and 77 deletions

View File

@ -8,10 +8,91 @@
"""
import time
from abc import abstractmethod
from typing import Type
from typing import Type, Dict
from rest_framework import serializers
from dataset.models import Paragraph
class ParagraphPipelineModel:
def __init__(self, _id: str, document_id: str, dataset_id: str, content: str, title: str, status: str,
is_active: bool, comprehensive_score: float, similarity: float, dataset_name: str, document_name: str):
self.id = _id
self.document_id = document_id
self.dataset_id = dataset_id
self.content = content
self.title = title
self.status = status,
self.is_active = is_active
self.comprehensive_score = comprehensive_score
self.similarity = similarity
self.dataset_name = dataset_name
self.document_name = document_name
def to_dict(self):
return {
'id': self.id,
'document_id': self.document_id,
'dataset_id': self.dataset_id,
'content': self.content,
'title': self.title,
'status': self.status,
'is_active': self.is_active,
'comprehensive_score': self.comprehensive_score,
'similarity': self.similarity,
'dataset_name': self.dataset_name,
'document_name': self.document_name
}
class builder:
def __init__(self):
self.similarity = None
self.paragraph = {}
self.comprehensive_score = None
self.document_name = None
self.dataset_name = None
def add_paragraph(self, paragraph):
if isinstance(paragraph, Paragraph):
self.paragraph = {'id': paragraph.id,
'document_id': paragraph.document_id,
'dataset_id': paragraph.dataset_id,
'content': paragraph.content,
'title': paragraph.title,
'status': paragraph.status,
'is_active': paragraph.is_active,
}
else:
self.paragraph = paragraph
return self
def add_dataset_name(self, dataset_name):
self.dataset_name = dataset_name
return self
def add_document_name(self, document_name):
self.document_name = document_name
return self
def add_comprehensive_score(self, comprehensive_score: float):
self.comprehensive_score = comprehensive_score
return self
def add_similarity(self, similarity: float):
self.similarity = similarity
return self
def build(self):
return ParagraphPipelineModel(str(self.paragraph.get('id')), str(self.paragraph.get('document_id')),
str(self.paragraph.get('dataset_id')),
self.paragraph.get('content'), self.paragraph.get('title'),
self.paragraph.get('status'),
self.paragraph.get('is_active'),
self.comprehensive_score, self.similarity, self.dataset_name,
self.document_name)
class IBaseChatPipelineStep:
def __init__(self):

View File

@ -13,7 +13,7 @@ from langchain.chat_models.base import BaseChatModel
from langchain.schema import BaseMessage
from rest_framework import serializers
from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep
from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep, ParagraphPipelineModel
from application.chat_pipeline.pipeline_manage import PiplineManage
from common.field.common import InstanceField
from dataset.models import Paragraph
@ -41,7 +41,8 @@ class MessageField(serializers.Field):
class PostResponseHandler:
@abstractmethod
def handler(self, chat_id, chat_record_id, paragraph_list: List[Paragraph], problem_text: str, answer_text,
def handler(self, chat_id, chat_record_id, paragraph_list: List[ParagraphPipelineModel], problem_text: str,
answer_text,
manage, step, padding_problem_text: str = None, **kwargs):
pass

View File

@ -18,15 +18,15 @@ from langchain.chat_models.base import BaseChatModel
from langchain.schema import BaseMessage
from langchain.schema.messages import BaseMessageChunk, HumanMessage
from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel
from application.chat_pipeline.pipeline_manage import PiplineManage
from application.chat_pipeline.step.chat_step.i_chat_step import IChatStep, PostResponseHandler
from dataset.models import Paragraph
def event_content(response,
chat_id,
chat_record_id,
paragraph_list: List[Paragraph],
paragraph_list: List[ParagraphPipelineModel],
post_response_handler: PostResponseHandler,
manage,
step,

View File

@ -12,7 +12,7 @@ from typing import Type, List
from langchain.schema import BaseMessage
from rest_framework import serializers
from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep
from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep, ParagraphPipelineModel
from application.chat_pipeline.pipeline_manage import PiplineManage
from application.models import ChatRecord
from common.field.common import InstanceField
@ -24,7 +24,7 @@ class IGenerateHumanMessageStep(IBaseChatPipelineStep):
# 问题
problem_text = serializers.CharField(required=True)
# 段落列表
paragraph_list = serializers.ListField(child=InstanceField(model_type=Paragraph, required=True))
paragraph_list = serializers.ListField(child=InstanceField(model_type=ParagraphPipelineModel, required=True))
# 历史对答
history_chat_record = serializers.ListField(child=InstanceField(model_type=ChatRecord, required=True))
# 多轮对话数量
@ -46,7 +46,7 @@ class IGenerateHumanMessageStep(IBaseChatPipelineStep):
@abstractmethod
def execute(self,
problem_text: str,
paragraph_list: List[Paragraph],
paragraph_list: List[ParagraphPipelineModel],
history_chat_record: List[ChatRecord],
dialogue_number: int,
max_paragraph_char_number: int,

View File

@ -10,17 +10,17 @@ from typing import List
from langchain.schema import BaseMessage, HumanMessage
from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel
from application.chat_pipeline.step.generate_human_message_step.i_generate_human_message_step import \
IGenerateHumanMessageStep
from application.models import ChatRecord
from common.util.split_model import flat_map
from dataset.models import Paragraph
class BaseGenerateHumanMessageStep(IGenerateHumanMessageStep):
def execute(self, problem_text: str,
paragraph_list: List[Paragraph],
paragraph_list: List[ParagraphPipelineModel],
history_chat_record: List[ChatRecord],
dialogue_number: int,
max_paragraph_char_number: int,
@ -39,7 +39,7 @@ class BaseGenerateHumanMessageStep(IGenerateHumanMessageStep):
def to_human_message(prompt: str,
problem: str,
max_paragraph_char_number: int,
paragraph_list: List[Paragraph]):
paragraph_list: List[ParagraphPipelineModel]):
if paragraph_list is None or len(paragraph_list) == 0:
return HumanMessage(content=problem)
temp_data = ""

View File

@ -11,7 +11,7 @@ from typing import List, Type
from rest_framework import serializers
from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep
from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep, ParagraphPipelineModel
from application.chat_pipeline.pipeline_manage import PiplineManage
from dataset.models import Paragraph
@ -39,11 +39,12 @@ class ISearchDatasetStep(IBaseChatPipelineStep):
def _run(self, manage: PiplineManage):
paragraph_list = self.execute(**self.context['step_args'])
manage.context['paragraph_list'] = paragraph_list
self.context['paragraph_list'] = paragraph_list
@abstractmethod
def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str],
exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None,
**kwargs) -> List[Paragraph]:
**kwargs) -> List[ParagraphPipelineModel]:
"""
关于 用户和补全问题 说明: 补全问题如果有就使用补全问题去查询 反之就用用户原始问题查询
:param similarity: 相关性

View File

@ -6,20 +6,25 @@
@date2024/1/10 10:33
@desc:
"""
from typing import List
import os
from typing import List, Dict
from django.db.models import QuerySet
from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel
from application.chat_pipeline.step.search_dataset_step.i_search_dataset_step import ISearchDatasetStep
from common.config.embedding_config import VectorStore, EmbeddingModel
from common.db.search import native_search
from common.util.file_util import get_file_content
from dataset.models import Paragraph
from smartdoc.conf import PROJECT_DIR
class BaseSearchDatasetStep(ISearchDatasetStep):
def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str],
exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None,
**kwargs) -> List[Paragraph]:
**kwargs) -> List[ParagraphPipelineModel]:
exec_problem_text = padding_problem_text if padding_problem_text is not None else problem_text
embedding_model = EmbeddingModel.get_embedding_model()
embedding_value = embedding_model.embed_query(exec_problem_text)
@ -28,16 +33,35 @@ class BaseSearchDatasetStep(ISearchDatasetStep):
exclude_paragraph_id_list, True, top_n, similarity)
if embedding_list is None:
return []
return self.list_paragraph([row.get('paragraph_id') for row in embedding_list], vector)
paragraph_list = self.list_paragraph([row.get('paragraph_id') for row in embedding_list], vector)
return [self.reset_paragraph(paragraph, embedding_list) for paragraph in paragraph_list]
@staticmethod
def reset_paragraph(paragraph: Dict, embedding_list: List) -> ParagraphPipelineModel:
filter_embedding_list = [embedding for embedding in embedding_list if
str(embedding.get('paragraph_id')) == str(paragraph.get('id'))]
if filter_embedding_list is not None and len(filter_embedding_list) > 0:
find_embedding = filter_embedding_list[-1]
return (ParagraphPipelineModel.builder()
.add_paragraph(paragraph)
.add_similarity(find_embedding.get('similarity'))
.add_comprehensive_score(find_embedding.get('comprehensive_score'))
.add_dataset_name(paragraph.get('dataset_name'))
.add_document_name(paragraph.get('document_name'))
.build())
@staticmethod
def list_paragraph(paragraph_id_list: List, vector):
if paragraph_id_list is None or len(paragraph_id_list) == 0:
return []
paragraph_list = QuerySet(Paragraph).filter(id__in=paragraph_id_list)
paragraph_list = native_search(QuerySet(Paragraph).filter(id__in=paragraph_id_list),
get_file_content(
os.path.join(PROJECT_DIR, "apps", "application", 'sql',
'list_dataset_paragraph_by_paragraph_id.sql')),
with_table_name=True)
# 如果向量库中存在脏数据 直接删除
if len(paragraph_list) != len(paragraph_id_list):
exist_paragraph_list = [str(row.id) for row in paragraph_list]
exist_paragraph_list = [row.get('id') for row in paragraph_list]
for paragraph_id in paragraph_id_list:
if not exist_paragraph_list.__contains__(paragraph_id):
vector.delete_by_paragraph_id(paragraph_id)
@ -48,6 +72,7 @@ class BaseSearchDatasetStep(ISearchDatasetStep):
return {
'step_type': 'search_step',
'paragraph_list': [row.to_dict() for row in self.context['paragraph_list']],
'run_time': self.context['run_time'],
'problem_text': step_args.get(
'padding_problem_text') if 'padding_problem_text' in step_args else step_args.get('problem_text'),

View File

@ -0,0 +1,17 @@
# Generated by Django 4.1.10 on 2024-01-19 14:02
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
('application', '0005_alter_chatrecord_details'),
]
operations = [
migrations.RemoveField(
model_name='chatrecord',
name='paragraph_id_list',
),
]

View File

@ -89,9 +89,6 @@ class ChatRecord(AppModelMixin):
chat = models.ForeignKey(Chat, on_delete=models.CASCADE)
vote_status = models.CharField(verbose_name='投票', max_length=10, choices=VoteChoices.choices,
default=VoteChoices.UN_VOTE)
paragraph_id_list = ArrayField(verbose_name="引用段落id列表",
base_field=models.UUIDField(max_length=128, blank=True)
, default=list)
problem_text = models.CharField(max_length=1024, verbose_name="问题")
answer_text = models.CharField(max_length=4096, verbose_name="答案")
message_tokens = models.IntegerField(verbose_name="请求token数量", default=0)

View File

@ -108,7 +108,6 @@ def get_post_handler(chat_info: ChatInfo):
**kwargs):
chat_record = ChatRecord(id=chat_record_id,
chat_id=chat_id,
paragraph_id_list=[str(p.id) for p in paragraph_list],
problem_text=problem_text,
answer_text=answer_text,
details=manage.get_details(),

View File

@ -192,28 +192,7 @@ class ChatRecordSerializer(serializers.Serializer):
chat_record = self.get_chat_record()
if chat_record is None:
raise AppApiException(500, "对话不存在")
dataset_list = []
paragraph_list = []
if len(chat_record.paragraph_id_list) > 0:
paragraph_list = native_search(QuerySet(Paragraph).filter(id__in=chat_record.paragraph_id_list),
get_file_content(
os.path.join(PROJECT_DIR, "apps", "application", 'sql',
'list_dataset_paragraph_by_paragraph_id.sql')),
with_table_name=True)
dataset_list = [{'id': dataset_id, 'name': name} for dataset_id, name in reduce(lambda x, y: {**x, **y},
[{row.get(
'dataset_id'): row.get(
"dataset_name")} for
row in
paragraph_list],
{}).items()]
return {
**ChatRecordSerializerModel(chat_record).data,
'padding_problem_text': chat_record.details.get('problem_padding').get(
'padding_problem_text') if 'problem_padding' in chat_record.details else None,
'dataset_list': dataset_list,
'paragraph_list': paragraph_list}
return ChatRecordSerializer.Query.reset_chat_record(chat_record)
class Query(serializers.Serializer):
application_id = serializers.UUIDField(required=True)
@ -226,37 +205,22 @@ class ChatRecordSerializer(serializers.Serializer):
return [ChatRecordSerializerModel(chat_record).data for chat_record in
QuerySet(ChatRecord).filter(chat_id=self.data.get('chat_id'))]
def reset_chat_record_list(self, chat_record_list: List[ChatRecord]):
paragraph_id_list = flat_map([chat_record.paragraph_id_list for chat_record in chat_record_list])
# 去重
paragraph_id_list = list(set(paragraph_id_list))
paragraph_list = self.search_paragraph(paragraph_id_list)
return [self.reset_chat_record(chat_record, paragraph_list) for chat_record in chat_record_list]
@staticmethod
def search_paragraph(paragraph_id_list: List[str]):
def reset_chat_record(chat_record):
dataset_list = []
paragraph_list = []
if len(paragraph_id_list) > 0:
paragraph_list = native_search(QuerySet(Paragraph).filter(id__in=paragraph_id_list),
get_file_content(
os.path.join(PROJECT_DIR, "apps", "application", 'sql',
'list_dataset_paragraph_by_paragraph_id.sql')),
with_table_name=True)
return paragraph_list
if 'search_step' in chat_record.details and chat_record.details.get('search_step').get(
'paragraph_list') is not None:
paragraph_list = chat_record.details.get('search_step').get(
'paragraph_list')
dataset_list = [{'id': dataset_id, 'name': name} for dataset_id, name in reduce(lambda x, y: {**x, **y},
[{row.get(
'dataset_id'): row.get(
"dataset_name")} for
row in
paragraph_list],
{}).items()]
@staticmethod
def reset_chat_record(chat_record, all_paragraph_list):
paragraph_list = list(
filter(
lambda paragraph: chat_record.paragraph_id_list.__contains__(uuid.UUID(str(paragraph.get('id')))),
all_paragraph_list))
dataset_list = [{'id': dataset_id, 'name': name} for dataset_id, name in reduce(lambda x, y: {**x, **y},
[{row.get(
'dataset_id'): row.get(
"dataset_name")} for
row in
paragraph_list],
{}).items()]
return {
**ChatRecordSerializerModel(chat_record).data,
'padding_problem_text': chat_record.details.get('problem_padding').get(
@ -270,9 +234,7 @@ class ChatRecordSerializer(serializers.Serializer):
self.is_valid(raise_exception=True)
page = page_search(current_page, page_size,
QuerySet(ChatRecord).filter(chat_id=self.data.get('chat_id')).order_by("index"),
post_records_handler=lambda chat_record: chat_record)
records = page.get('records')
page['records'] = self.reset_chat_record_list(records)
post_records_handler=lambda chat_record: self.reset_chat_record(chat_record))
return page
class Vote(serializers.Serializer):

View File

@ -1,6 +1,8 @@
SELECT
paragraph.*,
dataset."name" AS "dataset_name"
dataset."name" AS "dataset_name",
"document"."name" AS "document_name"
FROM
paragraph paragraph
LEFT JOIN dataset dataset ON dataset."id" = paragraph.dataset_id
LEFT JOIN dataset dataset ON dataset."id" = paragraph.dataset_id
LEFT JOIN "document" "document" ON "document"."id" =paragraph.document_id