feat: 添加问题管理相关接口,兼容历史版本

This commit is contained in:
shaohuzhang1 2024-03-11 17:28:05 +08:00
parent 1691e56da5
commit b470b1b6e5
16 changed files with 768 additions and 239 deletions

View File

@ -21,7 +21,7 @@ from common.event.common import poxy, embedding_poxy
from common.util.file_util import get_file_content
from common.util.fork import ForkManage, Fork
from common.util.lock import try_lock, un_lock
from dataset.models import Paragraph, Status, Document
from dataset.models import Paragraph, Status, Document, ProblemParagraphMapping
from embedding.models import SourceType
from smartdoc.conf import PROJECT_DIR
@ -44,6 +44,12 @@ class SyncWebDocumentArgs:
self.handler = handler
class UpdateProblemArgs:
def __init__(self, problem_id: str, problem_content: str):
self.problem_id = problem_id
self.problem_content = problem_content
class ListenerManagement:
embedding_by_problem_signal = signal("embedding_by_problem")
embedding_by_paragraph_signal = signal("embedding_by_paragraph")
@ -59,6 +65,8 @@ class ListenerManagement:
init_embedding_model_signal = signal('init_embedding_model')
sync_web_dataset_signal = signal('sync_web_dataset')
sync_web_document_signal = signal('sync_web_document')
update_problem_signal = signal('update_problem')
delete_embedding_by_source_ids_signal = signal('delete_embedding_by_source_ids')
@staticmethod
def embedding_by_problem(args):
@ -76,8 +84,8 @@ class ListenerManagement:
status = Status.success
try:
data_list = native_search(
{'problem': QuerySet(get_dynamics_model({'problem.paragraph_id': django.db.models.CharField()})).filter(
**{'problem.paragraph_id': paragraph_id}),
{'problem': QuerySet(get_dynamics_model({'paragraph.id': django.db.models.CharField()})).filter(
**{'paragraph.id': paragraph_id}),
'paragraph': QuerySet(Paragraph).filter(id=paragraph_id)},
select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "common", 'sql', 'list_embedding_text.sql')))
@ -104,8 +112,9 @@ class ListenerManagement:
status = Status.success
try:
data_list = native_search(
{'problem': QuerySet(get_dynamics_model({'problem.document_id': django.db.models.CharField()})).filter(
**{'problem.document_id': document_id}),
{'problem': QuerySet(
get_dynamics_model({'paragraph.document_id': django.db.models.CharField()})).filter(
**{'paragraph.document_id': document_id}),
'paragraph': QuerySet(Paragraph).filter(document_id=document_id)},
select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "common", 'sql', 'list_embedding_text.sql')))
@ -188,6 +197,17 @@ class ListenerManagement:
finally:
un_lock('sync_web_dataset' + args.lock_key)
@staticmethod
def update_problem(args: UpdateProblemArgs):
problem_paragraph_mapping_list = QuerySet(ProblemParagraphMapping).filter(problem_id=args.problem_id)
embed_value = VectorStore.get_embedding_vector().embed_query(args.problem_content)
VectorStore.get_embedding_vector().update_by_source_ids([v.id for v in problem_paragraph_mapping_list],
{'embedding': embed_value})
@staticmethod
def delete_embedding_by_source_ids(source_ids: List[str]):
VectorStore.get_embedding_vector().delete_by_source_ids(source_ids, SourceType.PROBLEM)
@staticmethod
@poxy
def init_embedding_model(ags):
@ -225,3 +245,6 @@ class ListenerManagement:
ListenerManagement.sync_web_dataset_signal.connect(self.sync_web_dataset)
# 同步web站点 文档
ListenerManagement.sync_web_document_signal.connect(self.sync_web_document)
# 更新问题向量
ListenerManagement.update_problem_signal.connect(self.update_problem)
ListenerManagement.delete_embedding_by_source_ids_signal.connect(self.delete_embedding_by_source_ids)

View File

@ -1,14 +1,15 @@
SELECT
problem."id" AS "source_id",
problem.document_id AS document_id,
problem.paragraph_id AS paragraph_id,
problem_paragraph_mapping."id" AS "source_id",
paragraph.document_id AS document_id,
paragraph."id" AS paragraph_id,
problem.dataset_id AS dataset_id,
0 AS source_type,
problem."content" AS "text",
paragraph.is_active AS is_active
FROM
problem problem
LEFT JOIN paragraph paragraph ON paragraph."id" = problem.paragraph_id
LEFT JOIN problem_paragraph_mapping problem_paragraph_mapping ON problem_paragraph_mapping.problem_id=problem."id"
LEFT JOIN paragraph paragraph ON paragraph."id" = problem_paragraph_mapping.paragraph_id
${problem}
UNION

View File

@ -0,0 +1,59 @@
# Generated by Django 4.1.10 on 2024-03-08 18:29
from django.db import migrations, models
import django.db.models.deletion
import uuid
from embedding.models import SourceType
def delete_problem_embedding(apps, schema_editor):
Embedding = apps.get_model('embedding', 'Embedding')
Embedding.objects.filter(source_type=SourceType.PROBLEM).delete()
class Migration(migrations.Migration):
dependencies = [
('dataset', '0004_remove_paragraph_hit_num_remove_paragraph_star_num_and_more'),
]
operations = [
migrations.RemoveField(
model_name='problem',
name='document',
),
migrations.RemoveField(
model_name='problem',
name='paragraph',
),
migrations.AddField(
model_name='paragraph',
name='hit_num',
field=models.IntegerField(default=0, verbose_name='命中次数'),
),
migrations.AddField(
model_name='problem',
name='hit_num',
field=models.IntegerField(default=0, verbose_name='命中次数'),
),
migrations.CreateModel(
name='ProblemParagraphMapping',
fields=[
('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')),
('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')),
('id', models.UUIDField(default=uuid.uuid1, editable=False, primary_key=True, serialize=False,
verbose_name='主键id')),
('dataset', models.ForeignKey(db_constraint=False, on_delete=django.db.models.deletion.DO_NOTHING,
to='dataset.dataset')),
('document', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.document')),
('paragraph', models.ForeignKey(db_constraint=False, on_delete=django.db.models.deletion.DO_NOTHING,
to='dataset.paragraph')),
('problem', models.ForeignKey(db_constraint=False, on_delete=django.db.models.deletion.DO_NOTHING,
to='dataset.problem')),
],
options={
'db_table': 'problem_paragraph_mapping',
},
),
migrations.RunPython(delete_problem_embedding)
]

View File

@ -76,6 +76,7 @@ class Paragraph(AppModelMixin):
title = models.CharField(max_length=256, verbose_name="标题", default="")
status = models.CharField(verbose_name='状态', max_length=1, choices=Status.choices,
default=Status.embedding)
hit_num = models.IntegerField(verbose_name="命中次数", default=0)
is_active = models.BooleanField(default=True)
class Meta:
@ -87,10 +88,20 @@ class Problem(AppModelMixin):
问题表
"""
id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id")
document = models.ForeignKey(Document, on_delete=models.DO_NOTHING, db_constraint=False)
dataset = models.ForeignKey(DataSet, on_delete=models.DO_NOTHING, db_constraint=False)
paragraph = models.ForeignKey(Paragraph, on_delete=models.DO_NOTHING, db_constraint=False)
content = models.CharField(max_length=256, verbose_name="问题内容")
hit_num = models.IntegerField(verbose_name="命中次数", default=0)
class Meta:
db_table = "problem"
class ProblemParagraphMapping(AppModelMixin):
id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id")
dataset = models.ForeignKey(DataSet, on_delete=models.DO_NOTHING, db_constraint=False)
document = models.ForeignKey(Document, on_delete=models.DO_NOTHING)
problem = models.ForeignKey(Problem, on_delete=models.DO_NOTHING, db_constraint=False)
paragraph = models.ForeignKey(Paragraph, on_delete=models.DO_NOTHING, db_constraint=False)
class Meta:
db_table = "problem_paragraph_mapping"

View File

@ -34,7 +34,7 @@ from common.util.field_message import ErrMessage
from common.util.file_util import get_file_content
from common.util.fork import ChildLink, Fork
from common.util.split_model import get_split_model
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, ProblemParagraphMapping
from dataset.serializers.common_serializers import list_paragraph, MetaSerializer
from dataset.serializers.document_serializers import DocumentSerializers, DocumentInstanceSerializer
from setting.models import AuthOperate
@ -303,6 +303,7 @@ class DataSetSerializers(serializers.ModelSerializer):
document_model_list = []
paragraph_model_list = []
problem_model_list = []
problem_paragraph_mapping_list = []
# 插入文档
for document in instance.get('documents') if 'documents' in instance else []:
document_paragraph_dict_model = DocumentSerializers.Create.get_document_paragraph_model(dataset_id,
@ -312,6 +313,8 @@ class DataSetSerializers(serializers.ModelSerializer):
paragraph_model_list.append(paragraph)
for problem in document_paragraph_dict_model.get('problem_model_list'):
problem_model_list.append(problem)
for problem_paragraph_mapping in document_paragraph_dict_model.get('problem_paragraph_mapping_list'):
problem_paragraph_mapping_list.append(problem_paragraph_mapping)
# 插入知识库
dataset.save()
@ -321,6 +324,9 @@ class DataSetSerializers(serializers.ModelSerializer):
QuerySet(Paragraph).bulk_create(paragraph_model_list) if len(paragraph_model_list) > 0 else None
# 批量插入问题
QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None
# 批量插入关联问题
QuerySet(ProblemParagraphMapping).bulk_create(problem_paragraph_mapping_list) if len(
problem_paragraph_mapping_list) > 0 else None
# 响应数据
return {**DataSetSerializers(dataset).data,

View File

@ -28,7 +28,7 @@ from common.util.field_message import ErrMessage
from common.util.file_util import get_file_content
from common.util.fork import Fork
from common.util.split_model import SplitModel, get_split_model
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, Status
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, Status, ProblemParagraphMapping
from dataset.serializers.common_serializers import BatchSerializer, MetaSerializer
from dataset.serializers.paragraph_serializers import ParagraphSerializers, ParagraphInstanceSerializer
from smartdoc.conf import PROJECT_DIR
@ -179,7 +179,7 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
# 删除段落
QuerySet(model=Paragraph).filter(document_id=document_id).delete()
# 删除问题
QuerySet(model=Problem).filter(document_id=document_id).delete()
QuerySet(model=ProblemParagraphMapping).filter(document_id=document_id).delete()
# 删除向量库
ListenerManagement.delete_embedding_by_document_signal.send(document_id)
paragraphs = get_split_model('web.md').parse(result.content)
@ -191,10 +191,14 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
paragraph_model_list = document_paragraph_model.get('paragraph_model_list')
problem_model_list = document_paragraph_model.get('problem_model_list')
problem_paragraph_mapping_list = document_paragraph_model.get('problem_paragraph_mapping_list')
# 批量插入段落
QuerySet(Paragraph).bulk_create(paragraph_model_list) if len(paragraph_model_list) > 0 else None
# 批量插入问题
QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None
# 插入关联问题
QuerySet(ProblemParagraphMapping).bulk_create(problem_paragraph_mapping_list) if len(
problem_paragraph_mapping_list) > 0 else None
# 向量化
if with_embedding:
ListenerManagement.embedding_by_document_signal.send(document_id)
@ -273,7 +277,7 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
# 删除段落
QuerySet(model=Paragraph).filter(document_id=document_id).delete()
# 删除问题
QuerySet(model=Problem).filter(document_id=document_id).delete()
QuerySet(model=ProblemParagraphMapping).filter(document_id=document_id).delete()
# 删除向量库
ListenerManagement.delete_embedding_by_document_signal.send(document_id)
return True
@ -344,12 +348,17 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
document_model = document_paragraph_model.get('document')
paragraph_model_list = document_paragraph_model.get('paragraph_model_list')
problem_model_list = document_paragraph_model.get('problem_model_list')
problem_paragraph_mapping_list = document_paragraph_model.get('problem_paragraph_mapping_list')
# 插入文档
document_model.save()
# 批量插入段落
QuerySet(Paragraph).bulk_create(paragraph_model_list) if len(paragraph_model_list) > 0 else None
# 批量插入问题
QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None
# 批量插入关联问题
QuerySet(ProblemParagraphMapping).bulk_create(problem_paragraph_mapping_list) if len(
problem_paragraph_mapping_list) > 0 else None
document_id = str(document_model.id)
return DocumentSerializers.Operate(
data={'dataset_id': dataset_id, 'document_id': document_id}).one(
@ -396,14 +405,18 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
paragraph_model_list = []
problem_model_list = []
problem_paragraph_mapping_list = []
for paragraphs in paragraph_model_dict_list:
paragraph = paragraphs.get('paragraph')
for problem_model in paragraphs.get('problem_model_list'):
problem_model_list.append(problem_model)
for problem_paragraph_mapping in paragraphs.get('problem_paragraph_mapping_list'):
problem_paragraph_mapping_list.append(problem_paragraph_mapping)
paragraph_model_list.append(paragraph)
return {'document': document_model, 'paragraph_model_list': paragraph_model_list,
'problem_model_list': problem_model_list}
'problem_model_list': problem_model_list,
'problem_paragraph_mapping_list': problem_paragraph_mapping_list}
@staticmethod
def get_document_paragraph_model(dataset_id, instance: Dict):
@ -523,6 +536,7 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
document_model_list = []
paragraph_model_list = []
problem_model_list = []
problem_paragraph_mapping_list = []
# 插入文档
for document in instance_list:
document_paragraph_dict_model = DocumentSerializers.Create.get_document_paragraph_model(dataset_id,
@ -532,6 +546,8 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
paragraph_model_list.append(paragraph)
for problem in document_paragraph_dict_model.get('problem_model_list'):
problem_model_list.append(problem)
for problem_paragraph_mapping in document_paragraph_dict_model.get('problem_paragraph_mapping_list'):
problem_paragraph_mapping_list.append(problem_paragraph_mapping)
# 插入文档
QuerySet(Document).bulk_create(document_model_list) if len(document_model_list) > 0 else None
@ -539,6 +555,9 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
QuerySet(Paragraph).bulk_create(paragraph_model_list) if len(paragraph_model_list) > 0 else None
# 批量插入问题
QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None
# 批量插入关联问题
QuerySet(ProblemParagraphMapping).bulk_create(problem_paragraph_mapping_list) if len(
problem_paragraph_mapping_list) > 0 else None
# 查询文档
query_set = QuerySet(model=Document)
query_set = query_set.filter(**{'id__in': [d.id for d in document_model_list]})

View File

@ -20,9 +20,10 @@ from common.exception.app_exception import AppApiException
from common.mixins.api_mixin import ApiMixin
from common.util.common import post
from common.util.field_message import ErrMessage
from dataset.models import Paragraph, Problem, Document
from dataset.models import Paragraph, Problem, Document, ProblemParagraphMapping
from dataset.serializers.common_serializers import update_document_char_length
from dataset.serializers.problem_serializers import ProblemInstanceSerializer, ProblemSerializer
from dataset.serializers.problem_serializers import ProblemInstanceSerializer, ProblemSerializer, ProblemSerializers
from embedding.models import SourceType
class ParagraphSerializer(serializers.ModelSerializer):
@ -84,6 +85,193 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
content = serializers.CharField(required=True, max_length=4096, error_messages=ErrMessage.char(
"分段内容"))
class Problem(ApiMixin, serializers.Serializer):
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id"))
document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id"))
paragraph_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("段落id"))
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
if not QuerySet(Paragraph).filter(id=self.data.get('paragraph_id')).exists():
raise AppApiException(500, "段落id不存在")
def list(self, with_valid=False):
"""
获取问题列表
:param with_valid: 是否校验
:return: 问题列表
"""
if with_valid:
self.is_valid(raise_exception=True)
problem_paragraph_mapping = QuerySet(ProblemParagraphMapping).filter(dataset_id=self.data.get("dataset_id"),
paragraph_id=self.data.get(
'paragraph_id'))
return [ProblemSerializer(row).data for row in
QuerySet(Problem).filter(id__in=[row.problem_id for row in problem_paragraph_mapping])]
@transaction.atomic
def save(self, instance: Dict, with_valid=True, with_embedding=True):
if with_valid:
self.is_valid()
ProblemInstanceSerializer(data=instance).is_valid(raise_exception=True)
problem = QuerySet(Problem).filter(dataset_id=self.data.get('dataset_id'),
content=instance.get('content')).first()
if problem is None:
problem = Problem(id=uuid.uuid1(), dataset_id=self.data.get('dataset_id'),
content=instance.get('content'))
problem.save()
if QuerySet(ProblemParagraphMapping).filter(dataset_id=self.data.get('dataset_id'), problem_id=problem.id,
paragraph_id=self.data.get('paragraph_id')).exists():
raise AppApiException(500, "已经关联,请勿重复关联")
problem_paragraph_mapping = ProblemParagraphMapping(id=uuid.uuid1(),
problem_id=problem.id,
document_id=self.data.get('document_id'),
paragraph_id=self.data.get('paragraph_id'),
dataset_id=self.data.get('dataset_id'))
problem_paragraph_mapping.save()
if with_embedding:
ListenerManagement.embedding_by_problem_signal.send({'text': problem.content,
'is_active': True,
'source_type': SourceType.PROBLEM,
'source_id': problem_paragraph_mapping.id,
'document_id': self.data.get('document_id'),
'paragraph_id': self.data.get('paragraph_id'),
'dataset_id': self.data.get('dataset_id'),
})
return ProblemSerializers.Operate(
data={'dataset_id': self.data.get('dataset_id'),
'problem_id': problem.id}).one(with_valid=True)
@staticmethod
def get_request_params_api():
return [openapi.Parameter(name='dataset_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='知识库id'),
openapi.Parameter(name='document_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='文档id'),
openapi.Parameter(name='paragraph_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='段落id')]
@staticmethod
def get_request_body_api():
return openapi.Schema(type=openapi.TYPE_OBJECT,
required=["content"],
properties={
'content': openapi.Schema(
type=openapi.TYPE_STRING, title="内容")
})
@staticmethod
def get_response_body_api():
return openapi.Schema(
type=openapi.TYPE_OBJECT,
required=['id', 'content', 'hit_num', 'dataset_id', 'create_time', 'update_time'],
properties={
'id': openapi.Schema(type=openapi.TYPE_STRING, title="id",
description="id", default="xx"),
'content': openapi.Schema(type=openapi.TYPE_STRING, title="问题内容",
description="问题内容", default='问题内容'),
'hit_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="命中数量", description="命中数量",
default=1),
'dataset_id': openapi.Schema(type=openapi.TYPE_STRING, title="知识库id",
description="知识库id", default='xxx'),
'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间",
description="修改时间",
default="1970-01-01 00:00:00"),
'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间",
description="创建时间",
default="1970-01-01 00:00:00"
)
}
)
class Association(ApiMixin, serializers.Serializer):
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id"))
problem_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("问题id"))
document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id"))
paragraph_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("段落id"))
def is_valid(self, *, raise_exception=True):
super().is_valid(raise_exception=True)
dataset_id = self.data.get('dataset_id')
paragraph_id = self.data.get('paragraph_id')
problem_id = self.data.get("problem_id")
if not QuerySet(Paragraph).filter(dataset_id=dataset_id, id=paragraph_id).exists():
raise AppApiException(500, "段落不存在")
if not QuerySet(Problem).filter(dataset_id=dataset_id, id=problem_id).exists():
raise AppApiException(500, "问题不存在")
def association(self, with_valid=True, with_embedding=True):
if with_valid:
self.is_valid(raise_exception=True)
problem = QuerySet(Problem).filter(id=self.data.get("problem_id"))
problem_paragraph_mapping = ProblemParagraphMapping(id=uuid.uuid1(),
document_id=self.data.get('document_id'),
paragraph_id=self.data.get('paragraph_id'),
dataset_id=self.data.get('dataset_id'),
problem_id=problem.id)
problem_paragraph_mapping.save()
if with_embedding:
ListenerManagement.embedding_by_problem_signal.send({'text': problem.content,
'is_active': True,
'source_type': SourceType.PROBLEM,
'source_id': problem_paragraph_mapping.id,
'document_id': self.data.get('document_id'),
'paragraph_id': self.data.get('paragraph_id'),
'dataset_id': self.data.get('dataset_id'),
})
def un_association(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
problem_paragraph_mapping = QuerySet(ProblemParagraphMapping).filter(
paragraph_id=self.data.get('paragraph_id'),
dataset_id=self.data.get('dataset_id'),
problem_id=self.data.get(
'problem_id')).first()
problem_paragraph_mapping_id = problem_paragraph_mapping.id
problem_paragraph_mapping.delete()
ListenerManagement.delete_embedding_by_source_signal.send(problem_paragraph_mapping_id)
return True
@staticmethod
def get_request_params_api():
return [openapi.Parameter(name='dataset_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='知识库id'),
openapi.Parameter(name='document_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='文档id')
, openapi.Parameter(name='paragraph_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='段落id'),
openapi.Parameter(name='problem_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='问题id')
]
class Operate(ApiMixin, serializers.Serializer):
# 段落id
paragraph_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char(
@ -158,8 +346,13 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
return self.one(), instance
def get_problem_list(self):
return [ProblemSerializer(problem).data for problem in
QuerySet(Problem).filter(paragraph_id=self.data.get("paragraph_id"))]
ProblemParagraphMapping(ProblemParagraphMapping)
problem_paragraph_mapping = QuerySet(ProblemParagraphMapping).filter(
paragraph_id=self.data.get("paragraph_id"))
if len(problem_paragraph_mapping) > 0:
return [ProblemSerializer(problem).data for problem in
QuerySet(Problem).filter(id__in=[ppm.problem_id for ppm in problem_paragraph_mapping])]
return []
def one(self, with_valid=False):
if with_valid:
@ -172,7 +365,7 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
self.is_valid(raise_exception=True)
paragraph_id = self.data.get('paragraph_id')
QuerySet(Paragraph).filter(id=paragraph_id).delete()
QuerySet(Problem).filter(paragraph_id=paragraph_id).delete()
QuerySet(ProblemParagraphMapping).filter(paragraph_id=paragraph_id).delete()
ListenerManagement.delete_embedding_by_paragraph_signal.send(paragraph_id)
@staticmethod
@ -210,10 +403,14 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
paragraph_problem_model = self.get_paragraph_problem_model(dataset_id, document_id, instance)
paragraph = paragraph_problem_model.get('paragraph')
problem_model_list = paragraph_problem_model.get('problem_model_list')
problem_paragraph_mapping_list = paragraph_problem_model.get('problem_paragraph_mapping_list')
# 插入段落
paragraph_problem_model.get('paragraph').save()
# 插入問題
QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None
# 插入问题关联关系
QuerySet(ProblemParagraphMapping).bulk_create(problem_paragraph_mapping_list) if len(
problem_paragraph_mapping_list) > 0 else None
# 修改长度
update_document_char_length(document_id)
if with_embedding:
@ -229,12 +426,35 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
content=instance.get("content"),
dataset_id=dataset_id,
title=instance.get("title") if 'title' in instance else '')
problem_list = instance.get('problem_list')
exists_problem_list = []
if 'problem_list' in instance and len(problem_list) > 0:
exists_problem_list = QuerySet(Problem).filter(dataset_id=dataset_id,
content__in=[p.get('content') for p in
problem_list]).all()
problem_model_list = [Problem(id=uuid.uuid1(), content=problem.get('content'), paragraph_id=paragraph.id,
document_id=document_id, dataset_id=dataset_id) for problem in (
instance.get('problem_list') if 'problem_list' in instance else [])]
problem_model_list = [
ParagraphSerializers.Create.or_get(exists_problem_list, problem.get('content'), dataset_id) for
problem in (
instance.get('problem_list') if 'problem_list' in instance else [])]
return {'paragraph': paragraph, 'problem_model_list': problem_model_list}
problem_paragraph_mapping_list = [
ProblemParagraphMapping(id=uuid.uuid1(), document_id=document_id, problem_id=problem_model.id,
paragraph_id=paragraph.id,
dataset_id=dataset_id) for
problem_model in problem_model_list]
return {'paragraph': paragraph,
'problem_model_list': [problem_model for problem_model in problem_model_list if
not list(exists_problem_list).__contains__(problem_model)],
'problem_paragraph_mapping_list': problem_paragraph_mapping_list}
@staticmethod
def or_get(exists_problem_list, content, dataset_id):
exists = [row for row in exists_problem_list if row.content == content]
if len(exists) > 0:
return exists[0]
else:
return Problem(id=uuid.uuid1(), content=content, dataset_id=dataset_id)
@staticmethod
def get_request_body_api():

View File

@ -6,6 +6,7 @@
@date2023/10/23 13:55
@desc:
"""
import os
import uuid
from typing import Dict
@ -14,19 +15,19 @@ from django.db.models import QuerySet
from drf_yasg import openapi
from rest_framework import serializers
from common.event.listener_manage import ListenerManagement
from common.exception.app_exception import AppApiException
from common.db.search import native_search, native_page_search
from common.event import ListenerManagement, UpdateProblemArgs
from common.mixins.api_mixin import ApiMixin
from common.util.field_message import ErrMessage
from dataset.models import Problem, Paragraph
from embedding.models import SourceType
from embedding.vector.pg_vector import PGVector
from common.util.file_util import get_file_content
from dataset.models import Problem, Paragraph, ProblemParagraphMapping
from smartdoc.conf import PROJECT_DIR
class ProblemSerializer(serializers.ModelSerializer):
class Meta:
model = Problem
fields = ['id', 'content', 'dataset_id', 'document_id',
fields = ['id', 'content', 'dataset_id',
'create_time', 'update_time']
@ -49,186 +50,92 @@ class ProblemInstanceSerializer(ApiMixin, serializers.Serializer):
class ProblemSerializers(ApiMixin, serializers.Serializer):
class Create(ApiMixin, serializers.Serializer):
class Create(serializers.Serializer):
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id"))
problem_list = serializers.ListField(required=True, error_messages=ErrMessage.list("问题列表"),
child=serializers.CharField(required=True,
error_messages=ErrMessage.char("问题")))
document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id"))
paragraph_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("段落id"))
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
if not QuerySet(Paragraph).filter(id=self.data.get('paragraph_id'),
document_id=self.data.get('document_id'),
dataset_id=self.data.get('dataset_id')).exists():
raise AppApiException(500, "段落id不正确")
@transaction.atomic
def save(self, instance: Dict, with_valid=True, with_embedding=True):
if with_valid:
self.is_valid()
ProblemInstanceSerializer(data=instance).is_valid(raise_exception=True)
problem = Problem(id=uuid.uuid1(), paragraph_id=self.data.get('paragraph_id'),
document_id=self.data.get('document_id'), dataset_id=self.data.get('dataset_id'),
content=instance.get('content'))
problem.save()
if with_embedding:
ListenerManagement.embedding_by_problem_signal.send({'text': problem.content,
'is_active': True,
'source_type': SourceType.PROBLEM,
'source_id': problem.id,
'document_id': self.data.get('document_id'),
'paragraph_id': self.data.get('paragraph_id'),
'dataset_id': self.data.get('dataset_id'),
})
return ProblemSerializers.Operate(
data={'dataset_id': self.data.get('dataset_id'), 'document_id': self.data.get('document_id'),
'paragraph_id': self.data.get('paragraph_id'), 'problem_id': problem.id}).one(with_valid=True)
@staticmethod
def get_request_body_api():
return ProblemInstanceSerializer.get_request_body_api()
@staticmethod
def get_request_params_api():
return [openapi.Parameter(name='dataset_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='知识库id'),
openapi.Parameter(name='document_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='文档id'),
openapi.Parameter(name='paragraph_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='段落id')]
class Query(ApiMixin, serializers.Serializer):
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id"))
document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id"))
paragraph_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("段落id"))
def is_valid(self, *, raise_exception=True):
super().is_valid(raise_exception=True)
if not QuerySet(Paragraph).filter(id=self.data.get('paragraph_id')).exists():
raise AppApiException(500, "段落id不存在")
def get_query_set(self):
dataset_id = self.data.get('dataset_id')
document_id = self.data.get('document_id')
paragraph_id = self.data.get("paragraph_id")
return QuerySet(Problem).filter(
**{'paragraph_id': paragraph_id, 'dataset_id': dataset_id, 'document_id': document_id})
def list(self, with_valid=False):
"""
获取问题列表
:param with_valid: 是否校验
:return: 问题列表
"""
def batch(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
query_set = self.get_query_set()
return [ProblemSerializer(p).data for p in query_set]
problem_list = self.data.get('problem_list')
dataset_id = self.data.get('dataset_id')
exists_problem_content_list = [problem.content for problem in
QuerySet(Problem).filter(dataset_id=dataset_id,
content__in=problem_list)]
problem_instance_list = [Problem(id=uuid.uuid1(), dataset_id=dataset_id, content=problem_content) for
problem_content in
self.data.get('problem_list') if
(not exists_problem_content_list.__contains__(problem_content) if
len(exists_problem_content_list) > 0 else True)]
@staticmethod
def get_request_params_api():
return [openapi.Parameter(name='dataset_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='知识库id'),
openapi.Parameter(name='document_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='文档id')
, openapi.Parameter(name='paragraph_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='段落id')]
QuerySet(Problem).bulk_create(problem_instance_list) if len(problem_instance_list) > 0 else None
return [ProblemSerializer(problem_instance).data for problem_instance in problem_instance_list]
class Operate(ApiMixin, serializers.Serializer):
class Query(serializers.Serializer):
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id"))
content = serializers.CharField(required=False, error_messages=ErrMessage.char("问题"))
document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id"))
def get_query_set(self):
query_set = QuerySet(model=Problem)
query_set = query_set.filter(
**{'dataset_id': self.data.get('dataset_id')})
if 'content' in self.data:
query_set = query_set.filter(**{'content__contains': self.data.get('content')})
return query_set
paragraph_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("段落id"))
def list(self):
query_set = self.get_query_set()
return native_search(query_set, select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_problem.sql')))
def page(self, current_page, page_size):
query_set = self.get_query_set()
return native_page_search(current_page, page_size, query_set, select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_problem.sql')))
class Operate(serializers.Serializer):
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id"))
problem_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("问题id"))
def delete(self, with_valid=False):
def list_paragraph(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
QuerySet(Problem).filter(**{'id': self.data.get('problem_id')}).delete()
PGVector().delete_by_source_id(self.data.get('problem_id'), SourceType.PROBLEM)
ListenerManagement.delete_embedding_by_source_signal.send(self.data.get('problem_id'))
return True
problem_paragraph_mapping = QuerySet(ProblemParagraphMapping).filter(dataset_id=self.data.get("dataset_id"),
problem_id=self.data.get("problem_id"))
return native_search(
QuerySet(Paragraph).filter(id__in=[row.paragraph_id for row in problem_paragraph_mapping]),
select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_paragraph.sql')))
def one(self, with_valid=False):
def one(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
return ProblemInstanceSerializer(QuerySet(Problem).get(**{'id': self.data.get('problem_id')})).data
@staticmethod
def get_request_params_api():
return [openapi.Parameter(name='dataset_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='知识库id'),
openapi.Parameter(name='document_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='文档id')
, openapi.Parameter(name='paragraph_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='段落id'),
openapi.Parameter(name='problem_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='问题id')
]
@transaction.atomic
def delete(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
problem_paragraph_mapping_list = QuerySet(ProblemParagraphMapping).filter(
dataset_id=self.data.get('dataset_id'),
problem_id=self.data.get('problem_id'))
source_ids = [row.id for row in problem_paragraph_mapping_list]
QuerySet(Problem).filter(id=self.data.get('problem_id')).delete()
ListenerManagement.delete_embedding_by_source_ids_signal.send(source_ids)
return True
@staticmethod
def get_response_body_api():
return openapi.Schema(
type=openapi.TYPE_OBJECT,
required=['id', 'content', 'hit_num', 'star_num', 'trample_num', 'dataset_id',
'document_id',
'create_time', 'update_time'],
properties={
'id': openapi.Schema(type=openapi.TYPE_STRING, title="id",
description="id", default="xx"),
'content': openapi.Schema(type=openapi.TYPE_STRING, title="问题内容",
description="问题内容", default='问题内容'),
'hit_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="命中数量", description="命中数量",
default=1),
'star_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="点赞数量",
description="点赞数量", default=1),
'trample_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="点踩数量",
description="点踩数", default=1),
'document_id': openapi.Schema(type=openapi.TYPE_STRING, title="文档id",
description="文档id", default='xxx'),
'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间",
description="修改时间",
default="1970-01-01 00:00:00"),
'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间",
description="创建时间",
default="1970-01-01 00:00:00"
)
}
)
@transaction.atomic
def edit(self, instance: Dict, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
problem_id = self.data.get('problem_id')
dataset_id = self.data.get('dataset_id')
content = instance.get('content')
problem = QuerySet(Problem).filter(id=problem_id,
dataset_id=dataset_id).first()
problem.content = content
problem.save()
ListenerManagement.update_problem_signal.send(UpdateProblemArgs(problem_id, content))

View File

@ -1,10 +1,5 @@
SELECT
problem."id",
problem."content",
problem_paragraph_mapping.hit_num,
problem_paragraph_mapping.star_num,
problem_paragraph_mapping.trample_num,
problem_paragraph_mapping.paragraph_id
problem.*,
(SELECT "count"("id") FROM "problem_paragraph_mapping" WHERE problem_id="problem"."id") as "paragraph_count"
FROM
problem problem
LEFT JOIN problem_paragraph_mapping problem_paragraph_mapping ON problem."id" = problem_paragraph_mapping.problem_id

View File

@ -0,0 +1,127 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file problem_api.py
@date2024/3/11 10:49
@desc:
"""
from drf_yasg import openapi
from common.mixins.api_mixin import ApiMixin
class ProblemApi(ApiMixin):
@staticmethod
def get_response_body_api():
return openapi.Schema(
type=openapi.TYPE_OBJECT,
required=['id', 'content', 'hit_num', 'dataset_id', 'create_time', 'update_time'],
properties={
'id': openapi.Schema(type=openapi.TYPE_STRING, title="id",
description="id", default="xx"),
'content': openapi.Schema(type=openapi.TYPE_STRING, title="问题内容",
description="问题内容", default='问题内容'),
'hit_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="命中数量", description="命中数量",
default=1),
'dataset_id': openapi.Schema(type=openapi.TYPE_STRING, title="知识库id",
description="知识库id", default='xxx'),
'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间",
description="修改时间",
default="1970-01-01 00:00:00"),
'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间",
description="创建时间",
default="1970-01-01 00:00:00"
)
}
)
class Operate(ApiMixin):
@staticmethod
def get_request_params_api():
return [openapi.Parameter(name='dataset_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='知识库id'),
openapi.Parameter(name='problem_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='问题id')]
@staticmethod
def get_request_body_api():
return openapi.Schema(
type=openapi.TYPE_OBJECT,
required=['content'],
properties={
'content': openapi.Schema(type=openapi.TYPE_STRING, title="问题内容",
description="问题内容"),
}
)
class Paragraph(ApiMixin):
@staticmethod
def get_request_params_api():
return ProblemApi.Operate.get_request_params_api()
@staticmethod
def get_response_body_api():
return openapi.Schema(
type=openapi.TYPE_OBJECT,
required=['content'],
properties={
'content': openapi.Schema(type=openapi.TYPE_STRING, max_length=4096, title="分段内容",
description="分段内容"),
'title': openapi.Schema(type=openapi.TYPE_STRING, max_length=256, title="分段标题",
description="分段标题"),
'is_active': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否可用", description="是否可用"),
'hit_num': openapi.Schema(type=openapi.TYPE_NUMBER, title="命中次数", description="命中次数"),
'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间",
description="修改时间",
default="1970-01-01 00:00:00"),
'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间",
description="创建时间",
default="1970-01-01 00:00:00"
),
}
)
class Query(ApiMixin):
@staticmethod
def get_request_params_api():
return [openapi.Parameter(name='dataset_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='知识库id'),
openapi.Parameter(name='content',
in_=openapi.IN_QUERY,
type=openapi.TYPE_STRING,
required=False,
description='问题')]
class BatchCreate(ApiMixin):
@staticmethod
def get_request_body_api():
return openapi.Schema(type=openapi.TYPE_ARRAY,
items=ProblemApi.Create.get_request_body_api())
@staticmethod
def get_request_params_api():
return ProblemApi.Create.get_request_params_api()
class Create(ApiMixin):
@staticmethod
def get_request_body_api():
return openapi.Schema(type=openapi.TYPE_STRING, description="问题文本")
@staticmethod
def get_request_params_api():
return [openapi.Parameter(name='dataset_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='知识库id')]

View File

@ -28,7 +28,16 @@ urlpatterns = [
path('dataset/<str:dataset_id>/document/<str:document_id>/paragraph/<paragraph_id>',
views.Paragraph.Operate.as_view()),
path('dataset/<str:dataset_id>/document/<str:document_id>/paragraph/<str:paragraph_id>/problem',
views.Problem.as_view()),
path('dataset/<str:dataset_id>/document/<str:document_id>/paragraph/<str:paragraph_id>/problem/<str:problem_id>',
views.Problem.Operate.as_view())
views.Paragraph.Problem.as_view()),
path(
'dataset/<str:dataset_id>/document/<str:document_id>/paragraph/<str:paragraph_id>/problem/<str:problem_id>/un_association',
views.Paragraph.Problem.UnAssociation.as_view()),
path(
'dataset/<str:dataset_id>/document/<str:document_id>/paragraph/<str:paragraph_id>/problem/<str:problem_id>/association',
views.Paragraph.Problem.Association.as_view()),
path('dataset/<str:dataset_id>/problem', views.Problem.as_view()),
path('dataset/<str:dataset_id>/problem/<int:current_page>/<int:page_size>', views.Problem.Page.as_view()),
path('dataset/<str:dataset_id>/problem/<str:problem_id>', views.Problem.Operate.as_view()),
path('dataset/<str:dataset_id>/problem/<str:problem_id>/paragraph', views.Problem.Paragraph.as_view()),
]

View File

@ -52,6 +52,73 @@ class Paragraph(APIView):
return result.success(
ParagraphSerializers.Create(data={'dataset_id': dataset_id, 'document_id': document_id}).save(request.data))
class Problem(APIView):
authentication_classes = [TokenAuth]
@action(methods=['POST'], detail=False)
@swagger_auto_schema(operation_summary="添加关联问题",
operation_id="添加段落关联问题",
manual_parameters=ParagraphSerializers.Problem.get_request_params_api(),
request_body=ParagraphSerializers.Problem.get_request_body_api(),
responses=result.get_api_response(ParagraphSerializers.Problem.get_response_body_api()),
tags=["知识库/文档/段落"])
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
dynamic_tag=k.get('dataset_id')))
def post(self, request: Request, dataset_id: str, document_id: str, paragraph_id: str):
return result.success(ParagraphSerializers.Problem(
data={"dataset_id": dataset_id, 'document_id': document_id, 'paragraph_id': paragraph_id}).save(
request.data, with_valid=True))
@action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="获取段落问题列表",
operation_id="获取段落问题列表",
manual_parameters=ParagraphSerializers.Problem.get_request_params_api(),
responses=result.get_api_array_response(
ParagraphSerializers.Problem.get_response_body_api()),
tags=["知识库/文档/段落"])
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE,
dynamic_tag=k.get('dataset_id')))
def get(self, request: Request, dataset_id: str, document_id: str, paragraph_id: str):
return result.success(ParagraphSerializers.Problem(
data={"dataset_id": dataset_id, 'document_id': document_id, 'paragraph_id': paragraph_id}).list(
with_valid=True))
class UnAssociation(APIView):
authentication_classes = [TokenAuth]
@action(methods=['PUT'], detail=False)
@swagger_auto_schema(operation_summary="解除关联问题",
operation_id="解除关联问题",
manual_parameters=ParagraphSerializers.Association.get_request_params_api(),
responses=result.get_default_response(),
tags=["知识库/文档/段落"])
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
dynamic_tag=k.get('dataset_id')))
def put(self, request: Request, dataset_id: str, document_id: str, paragraph_id: str, problem_id: str):
return result.success(ParagraphSerializers.Association(
data={'dataset_id': dataset_id, 'document_id': document_id, 'paragraph_id': paragraph_id,
'problem_id': problem_id}).un_association())
class Association(APIView):
authentication_classes = [TokenAuth]
@action(methods=['PUT'], detail=False)
@swagger_auto_schema(operation_summary="关联问题",
operation_id="关联问题",
manual_parameters=ParagraphSerializers.Association.get_request_params_api(),
responses=result.get_default_response(),
tags=["知识库/文档/段落"])
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
dynamic_tag=k.get('dataset_id')))
def put(self, request: Request, dataset_id: str, document_id: str, paragraph_id: str, problem_id: str):
return result.success(ParagraphSerializers.Association(
data={'dataset_id': dataset_id, 'document_id': document_id, 'paragraph_id': paragraph_id,
'problem_id': problem_id}).association())
class Operate(APIView):
authentication_classes = [TokenAuth]
@ -61,7 +128,7 @@ class Paragraph(APIView):
manual_parameters=ParagraphSerializers.Operate.get_request_params_api(),
request_body=ParagraphSerializers.Operate.get_request_body_api(),
responses=result.get_api_response(ParagraphSerializers.Operate.get_response_body_api())
,tags=["知识库/文档/段落"])
, tags=["知识库/文档/段落"])
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
dynamic_tag=k.get('dataset_id')))

View File

@ -8,61 +8,115 @@
"""
from drf_yasg.utils import swagger_auto_schema
from rest_framework.decorators import action
from rest_framework.request import Request
from rest_framework.views import APIView
from rest_framework.views import Request
from common.auth import TokenAuth, has_permissions
from common.constants.permission_constants import Permission, Group, Operate
from common.response import result
from common.util.common import query_params_to_single_dict
from dataset.serializers.problem_serializers import ProblemSerializers
from dataset.swagger_api.problem_api import ProblemApi
class Problem(APIView):
authentication_classes = [TokenAuth]
@action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="问题列表",
operation_id="问题列表",
manual_parameters=ProblemApi.Query.get_request_params_api(),
responses=result.get_api_array_response(ProblemApi.get_response_body_api()),
tags=["知识库/文档/段落/问题"]
)
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE,
dynamic_tag=k.get('dataset_id')))
def get(self, request: Request, dataset_id: str):
q = ProblemSerializers.Query(
data={**query_params_to_single_dict(request.query_params), 'dataset_id': dataset_id})
q.is_valid(raise_exception=True)
return result.success(q.list())
@action(methods=['POST'], detail=False)
@swagger_auto_schema(operation_summary="添加关联问题",
operation_id="添加段落关联问题",
manual_parameters=ProblemSerializers.Create.get_request_params_api(),
request_body=ProblemSerializers.Create.get_request_body_api(),
responses=result.get_api_response(ProblemSerializers.Operate.get_response_body_api()),
@swagger_auto_schema(operation_summary="创建问题",
operation_id="创建问题",
manual_parameters=ProblemApi.BatchCreate.get_request_params_api(),
request_body=ProblemApi.BatchCreate.get_request_body_api(),
responses=result.get_api_response(ProblemApi.Query.get_response_body_api()),
tags=["知识库/文档/段落/问题"])
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
dynamic_tag=k.get('dataset_id')))
def post(self, request: Request, dataset_id: str, document_id: str, paragraph_id: str):
return result.success(ProblemSerializers.Create(
data={"dataset_id": dataset_id, 'document_id': document_id, 'paragraph_id': paragraph_id}).save(
request.data, with_valid=True))
def post(self, request: Request, dataset_id: str):
return result.success(
ProblemSerializers.Create(
data={'dataset_id': dataset_id, 'problem_list': request.query_params.get('problem_list')}).save())
@action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="获取段落问题列表",
operation_id="获取段落问题列表",
manual_parameters=ProblemSerializers.Query.get_request_params_api(),
responses=result.get_api_array_response(ProblemSerializers.Operate.get_response_body_api()),
tags=["知识库/文档/段落/问题"])
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE,
dynamic_tag=k.get('dataset_id')))
def get(self, request: Request, dataset_id: str, document_id: str, paragraph_id: str):
return result.success(ProblemSerializers.Query(
data={"dataset_id": dataset_id, 'document_id': document_id, 'paragraph_id': paragraph_id}).list(
with_valid=True))
class Paragraph(APIView):
authentication_classes = [TokenAuth]
@action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="获取关联段落列表",
operation_id="获取关联段落列表",
manual_parameters=ProblemApi.Paragraph.get_request_params_api(),
responses=result.get_api_array_response(ProblemApi.Paragraph.get_response_body_api()),
tags=["知识库/文档/段落/问题"])
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE,
dynamic_tag=k.get('dataset_id')))
def get(self, request: Request, dataset_id: str, problem_id: str):
return result.success(ProblemSerializers.Operate(
data={**query_params_to_single_dict(request.query_params), 'dataset_id': dataset_id,
'problem_id': problem_id}).list_paragraph())
class Operate(APIView):
authentication_classes = [TokenAuth]
@action(methods=['DELETE'], detail=False)
@swagger_auto_schema(operation_summary="删除段落问题",
operation_id="删除段落问题",
manual_parameters=ProblemSerializers.Query.get_request_params_api(),
@swagger_auto_schema(operation_summary="删除问题",
operation_id="删除问题",
manual_parameters=ProblemApi.Operate.get_request_params_api(),
responses=result.get_default_response(),
tags=["知识库/文档/段落/问题"])
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
dynamic_tag=k.get('dataset_id')))
def delete(self, request: Request, dataset_id: str, document_id: str, paragraph_id: str, problem_id: str):
o = ProblemSerializers.Operate(
data={'dataset_id': dataset_id, 'document_id': document_id, 'paragraph_id': paragraph_id,
'problem_id': problem_id})
return result.success(o.delete(with_valid=True))
def delete(self, request: Request, dataset_id: str, problem_id: str):
return result.success(ProblemSerializers.Operate(
data={**query_params_to_single_dict(request.query_params), 'dataset_id': dataset_id,
'problem_id': problem_id}).delete())
@action(methods=['PUT'], detail=False)
@swagger_auto_schema(operation_summary="修改问题",
operation_id="修改问题",
manual_parameters=ProblemApi.Operate.get_request_params_api(),
request_body=ProblemApi.Operate.get_request_body_api(),
responses=result.get_api_response(ProblemApi.get_response_body_api()),
tags=["知识库/文档/段落/问题"])
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
dynamic_tag=k.get('dataset_id')))
def put(self, request: Request, dataset_id: str, problem_id: str):
return result.success(ProblemSerializers.Operate(
data={**query_params_to_single_dict(request.query_params), 'dataset_id': dataset_id,
'problem_id': problem_id}).edit(request.data))
class Page(APIView):
authentication_classes = [TokenAuth]
@action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="分页获取问题列表",
operation_id="分页获取问题列表",
manual_parameters=result.get_page_request_params(
ProblemApi.Query.get_request_params_api()),
responses=result.get_page_api_response(ProblemApi.get_response_body_api()),
tags=["知识库/文档/段落/问题"])
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE,
dynamic_tag=k.get('dataset_id')))
def get(self, request: Request, dataset_id: str, current_page, page_size):
d = ProblemSerializers.Query(
data={**query_params_to_single_dict(request.query_params), 'dataset_id': dataset_id})
d.is_valid(raise_exception=True)
return result.success(d.page(current_page, page_size))

View File

@ -131,6 +131,18 @@ class BaseVectorStore(ABC):
def update_by_source_id(self, source_id: str, instance: Dict):
pass
@abstractmethod
def update_by_source_ids(self, source_ids: List[str], instance: Dict):
pass
@abstractmethod
def embed_documents(self, text_list: List[str]):
pass
@abstractmethod
def embed_query(self, text: str):
pass
@abstractmethod
def delete_by_dataset_id(self, dataset_id: str):
pass
@ -147,6 +159,10 @@ class BaseVectorStore(ABC):
def delete_by_source_id(self, source_id: str, source_type: str):
pass
@abstractmethod
def delete_by_source_ids(self, source_ids: List[str], source_type: str):
pass
@abstractmethod
def delete_by_paragraph_id(self, paragraph_id: str):
pass

View File

@ -14,6 +14,7 @@ from typing import Dict, List
from django.db.models import QuerySet
from langchain_community.embeddings import HuggingFaceEmbeddings
from common.config.embedding_config import EmbeddingModel
from common.db.search import native_search, generate_sql_by_query_dict
from common.db.sql_execute import select_one, select_list
from common.util.file_util import get_file_content
@ -24,6 +25,20 @@ from smartdoc.conf import PROJECT_DIR
class PGVector(BaseVectorStore):
def delete_by_source_ids(self, source_ids: List[str], source_type: str):
QuerySet(Embedding).filter(source_id__in=source_ids, source_type=source_type).delete()
def update_by_source_ids(self, source_ids: List[str], instance: Dict):
QuerySet(Embedding).filter(source_id__in=source_ids).update(**instance)
def embed_documents(self, text_list: List[str]):
embedding = EmbeddingModel.get_embedding_model()
return embedding.embed_documents(text_list)
def embed_query(self, text: str):
embedding = EmbeddingModel.get_embedding_model()
return embedding.embed_query(text)
def vector_is_create(self) -> bool:
# 项目启动默认是创建好的 不需要再创建
return True

View File

@ -137,7 +137,7 @@ const postProblem: (
)
}
/**
*
*
* @param dataset_id, document_id, paragraph_id,problem_id
*/
const delProblem: (
@ -146,8 +146,8 @@ const delProblem: (
paragraph_id: string,
problem_id: string
) => Promise<Result<boolean>> = (dataset_id, document_id, paragraph_id, problem_id) => {
return del(
`${prefix}/${dataset_id}/document/${document_id}/paragraph/${paragraph_id}/problem/${problem_id}`
return put(
`${prefix}/${dataset_id}/document/${document_id}/paragraph/${paragraph_id}/problem/${problem_id}/un_association`
)
}