mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
feat: 添加问题管理相关接口,兼容历史版本
This commit is contained in:
parent
1691e56da5
commit
b470b1b6e5
|
|
@ -21,7 +21,7 @@ from common.event.common import poxy, embedding_poxy
|
|||
from common.util.file_util import get_file_content
|
||||
from common.util.fork import ForkManage, Fork
|
||||
from common.util.lock import try_lock, un_lock
|
||||
from dataset.models import Paragraph, Status, Document
|
||||
from dataset.models import Paragraph, Status, Document, ProblemParagraphMapping
|
||||
from embedding.models import SourceType
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
|
|
@ -44,6 +44,12 @@ class SyncWebDocumentArgs:
|
|||
self.handler = handler
|
||||
|
||||
|
||||
class UpdateProblemArgs:
|
||||
def __init__(self, problem_id: str, problem_content: str):
|
||||
self.problem_id = problem_id
|
||||
self.problem_content = problem_content
|
||||
|
||||
|
||||
class ListenerManagement:
|
||||
embedding_by_problem_signal = signal("embedding_by_problem")
|
||||
embedding_by_paragraph_signal = signal("embedding_by_paragraph")
|
||||
|
|
@ -59,6 +65,8 @@ class ListenerManagement:
|
|||
init_embedding_model_signal = signal('init_embedding_model')
|
||||
sync_web_dataset_signal = signal('sync_web_dataset')
|
||||
sync_web_document_signal = signal('sync_web_document')
|
||||
update_problem_signal = signal('update_problem')
|
||||
delete_embedding_by_source_ids_signal = signal('delete_embedding_by_source_ids')
|
||||
|
||||
@staticmethod
|
||||
def embedding_by_problem(args):
|
||||
|
|
@ -76,8 +84,8 @@ class ListenerManagement:
|
|||
status = Status.success
|
||||
try:
|
||||
data_list = native_search(
|
||||
{'problem': QuerySet(get_dynamics_model({'problem.paragraph_id': django.db.models.CharField()})).filter(
|
||||
**{'problem.paragraph_id': paragraph_id}),
|
||||
{'problem': QuerySet(get_dynamics_model({'paragraph.id': django.db.models.CharField()})).filter(
|
||||
**{'paragraph.id': paragraph_id}),
|
||||
'paragraph': QuerySet(Paragraph).filter(id=paragraph_id)},
|
||||
select_string=get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", "common", 'sql', 'list_embedding_text.sql')))
|
||||
|
|
@ -104,8 +112,9 @@ class ListenerManagement:
|
|||
status = Status.success
|
||||
try:
|
||||
data_list = native_search(
|
||||
{'problem': QuerySet(get_dynamics_model({'problem.document_id': django.db.models.CharField()})).filter(
|
||||
**{'problem.document_id': document_id}),
|
||||
{'problem': QuerySet(
|
||||
get_dynamics_model({'paragraph.document_id': django.db.models.CharField()})).filter(
|
||||
**{'paragraph.document_id': document_id}),
|
||||
'paragraph': QuerySet(Paragraph).filter(document_id=document_id)},
|
||||
select_string=get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", "common", 'sql', 'list_embedding_text.sql')))
|
||||
|
|
@ -188,6 +197,17 @@ class ListenerManagement:
|
|||
finally:
|
||||
un_lock('sync_web_dataset' + args.lock_key)
|
||||
|
||||
@staticmethod
|
||||
def update_problem(args: UpdateProblemArgs):
|
||||
problem_paragraph_mapping_list = QuerySet(ProblemParagraphMapping).filter(problem_id=args.problem_id)
|
||||
embed_value = VectorStore.get_embedding_vector().embed_query(args.problem_content)
|
||||
VectorStore.get_embedding_vector().update_by_source_ids([v.id for v in problem_paragraph_mapping_list],
|
||||
{'embedding': embed_value})
|
||||
|
||||
@staticmethod
|
||||
def delete_embedding_by_source_ids(source_ids: List[str]):
|
||||
VectorStore.get_embedding_vector().delete_by_source_ids(source_ids, SourceType.PROBLEM)
|
||||
|
||||
@staticmethod
|
||||
@poxy
|
||||
def init_embedding_model(ags):
|
||||
|
|
@ -225,3 +245,6 @@ class ListenerManagement:
|
|||
ListenerManagement.sync_web_dataset_signal.connect(self.sync_web_dataset)
|
||||
# 同步web站点 文档
|
||||
ListenerManagement.sync_web_document_signal.connect(self.sync_web_document)
|
||||
# 更新问题向量
|
||||
ListenerManagement.update_problem_signal.connect(self.update_problem)
|
||||
ListenerManagement.delete_embedding_by_source_ids_signal.connect(self.delete_embedding_by_source_ids)
|
||||
|
|
|
|||
|
|
@ -1,14 +1,15 @@
|
|||
SELECT
|
||||
problem."id" AS "source_id",
|
||||
problem.document_id AS document_id,
|
||||
problem.paragraph_id AS paragraph_id,
|
||||
problem_paragraph_mapping."id" AS "source_id",
|
||||
paragraph.document_id AS document_id,
|
||||
paragraph."id" AS paragraph_id,
|
||||
problem.dataset_id AS dataset_id,
|
||||
0 AS source_type,
|
||||
problem."content" AS "text",
|
||||
paragraph.is_active AS is_active
|
||||
FROM
|
||||
problem problem
|
||||
LEFT JOIN paragraph paragraph ON paragraph."id" = problem.paragraph_id
|
||||
LEFT JOIN problem_paragraph_mapping problem_paragraph_mapping ON problem_paragraph_mapping.problem_id=problem."id"
|
||||
LEFT JOIN paragraph paragraph ON paragraph."id" = problem_paragraph_mapping.paragraph_id
|
||||
${problem}
|
||||
|
||||
UNION
|
||||
|
|
|
|||
|
|
@ -0,0 +1,59 @@
|
|||
# Generated by Django 4.1.10 on 2024-03-08 18:29
|
||||
|
||||
from django.db import migrations, models
|
||||
import django.db.models.deletion
|
||||
import uuid
|
||||
|
||||
from embedding.models import SourceType
|
||||
|
||||
|
||||
def delete_problem_embedding(apps, schema_editor):
|
||||
Embedding = apps.get_model('embedding', 'Embedding')
|
||||
Embedding.objects.filter(source_type=SourceType.PROBLEM).delete()
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
('dataset', '0004_remove_paragraph_hit_num_remove_paragraph_star_num_and_more'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.RemoveField(
|
||||
model_name='problem',
|
||||
name='document',
|
||||
),
|
||||
migrations.RemoveField(
|
||||
model_name='problem',
|
||||
name='paragraph',
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name='paragraph',
|
||||
name='hit_num',
|
||||
field=models.IntegerField(default=0, verbose_name='命中次数'),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name='problem',
|
||||
name='hit_num',
|
||||
field=models.IntegerField(default=0, verbose_name='命中次数'),
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name='ProblemParagraphMapping',
|
||||
fields=[
|
||||
('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')),
|
||||
('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')),
|
||||
('id', models.UUIDField(default=uuid.uuid1, editable=False, primary_key=True, serialize=False,
|
||||
verbose_name='主键id')),
|
||||
('dataset', models.ForeignKey(db_constraint=False, on_delete=django.db.models.deletion.DO_NOTHING,
|
||||
to='dataset.dataset')),
|
||||
('document', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.document')),
|
||||
('paragraph', models.ForeignKey(db_constraint=False, on_delete=django.db.models.deletion.DO_NOTHING,
|
||||
to='dataset.paragraph')),
|
||||
('problem', models.ForeignKey(db_constraint=False, on_delete=django.db.models.deletion.DO_NOTHING,
|
||||
to='dataset.problem')),
|
||||
],
|
||||
options={
|
||||
'db_table': 'problem_paragraph_mapping',
|
||||
},
|
||||
),
|
||||
migrations.RunPython(delete_problem_embedding)
|
||||
]
|
||||
|
|
@ -76,6 +76,7 @@ class Paragraph(AppModelMixin):
|
|||
title = models.CharField(max_length=256, verbose_name="标题", default="")
|
||||
status = models.CharField(verbose_name='状态', max_length=1, choices=Status.choices,
|
||||
default=Status.embedding)
|
||||
hit_num = models.IntegerField(verbose_name="命中次数", default=0)
|
||||
is_active = models.BooleanField(default=True)
|
||||
|
||||
class Meta:
|
||||
|
|
@ -87,10 +88,20 @@ class Problem(AppModelMixin):
|
|||
问题表
|
||||
"""
|
||||
id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id")
|
||||
document = models.ForeignKey(Document, on_delete=models.DO_NOTHING, db_constraint=False)
|
||||
dataset = models.ForeignKey(DataSet, on_delete=models.DO_NOTHING, db_constraint=False)
|
||||
paragraph = models.ForeignKey(Paragraph, on_delete=models.DO_NOTHING, db_constraint=False)
|
||||
content = models.CharField(max_length=256, verbose_name="问题内容")
|
||||
hit_num = models.IntegerField(verbose_name="命中次数", default=0)
|
||||
|
||||
class Meta:
|
||||
db_table = "problem"
|
||||
|
||||
|
||||
class ProblemParagraphMapping(AppModelMixin):
|
||||
id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id")
|
||||
dataset = models.ForeignKey(DataSet, on_delete=models.DO_NOTHING, db_constraint=False)
|
||||
document = models.ForeignKey(Document, on_delete=models.DO_NOTHING)
|
||||
problem = models.ForeignKey(Problem, on_delete=models.DO_NOTHING, db_constraint=False)
|
||||
paragraph = models.ForeignKey(Paragraph, on_delete=models.DO_NOTHING, db_constraint=False)
|
||||
|
||||
class Meta:
|
||||
db_table = "problem_paragraph_mapping"
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ from common.util.field_message import ErrMessage
|
|||
from common.util.file_util import get_file_content
|
||||
from common.util.fork import ChildLink, Fork
|
||||
from common.util.split_model import get_split_model
|
||||
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type
|
||||
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, ProblemParagraphMapping
|
||||
from dataset.serializers.common_serializers import list_paragraph, MetaSerializer
|
||||
from dataset.serializers.document_serializers import DocumentSerializers, DocumentInstanceSerializer
|
||||
from setting.models import AuthOperate
|
||||
|
|
@ -303,6 +303,7 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||
document_model_list = []
|
||||
paragraph_model_list = []
|
||||
problem_model_list = []
|
||||
problem_paragraph_mapping_list = []
|
||||
# 插入文档
|
||||
for document in instance.get('documents') if 'documents' in instance else []:
|
||||
document_paragraph_dict_model = DocumentSerializers.Create.get_document_paragraph_model(dataset_id,
|
||||
|
|
@ -312,6 +313,8 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||
paragraph_model_list.append(paragraph)
|
||||
for problem in document_paragraph_dict_model.get('problem_model_list'):
|
||||
problem_model_list.append(problem)
|
||||
for problem_paragraph_mapping in document_paragraph_dict_model.get('problem_paragraph_mapping_list'):
|
||||
problem_paragraph_mapping_list.append(problem_paragraph_mapping)
|
||||
|
||||
# 插入知识库
|
||||
dataset.save()
|
||||
|
|
@ -321,6 +324,9 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||
QuerySet(Paragraph).bulk_create(paragraph_model_list) if len(paragraph_model_list) > 0 else None
|
||||
# 批量插入问题
|
||||
QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None
|
||||
# 批量插入关联问题
|
||||
QuerySet(ProblemParagraphMapping).bulk_create(problem_paragraph_mapping_list) if len(
|
||||
problem_paragraph_mapping_list) > 0 else None
|
||||
|
||||
# 响应数据
|
||||
return {**DataSetSerializers(dataset).data,
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ from common.util.field_message import ErrMessage
|
|||
from common.util.file_util import get_file_content
|
||||
from common.util.fork import Fork
|
||||
from common.util.split_model import SplitModel, get_split_model
|
||||
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, Status
|
||||
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, Status, ProblemParagraphMapping
|
||||
from dataset.serializers.common_serializers import BatchSerializer, MetaSerializer
|
||||
from dataset.serializers.paragraph_serializers import ParagraphSerializers, ParagraphInstanceSerializer
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
|
@ -179,7 +179,7 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||
# 删除段落
|
||||
QuerySet(model=Paragraph).filter(document_id=document_id).delete()
|
||||
# 删除问题
|
||||
QuerySet(model=Problem).filter(document_id=document_id).delete()
|
||||
QuerySet(model=ProblemParagraphMapping).filter(document_id=document_id).delete()
|
||||
# 删除向量库
|
||||
ListenerManagement.delete_embedding_by_document_signal.send(document_id)
|
||||
paragraphs = get_split_model('web.md').parse(result.content)
|
||||
|
|
@ -191,10 +191,14 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||
|
||||
paragraph_model_list = document_paragraph_model.get('paragraph_model_list')
|
||||
problem_model_list = document_paragraph_model.get('problem_model_list')
|
||||
problem_paragraph_mapping_list = document_paragraph_model.get('problem_paragraph_mapping_list')
|
||||
# 批量插入段落
|
||||
QuerySet(Paragraph).bulk_create(paragraph_model_list) if len(paragraph_model_list) > 0 else None
|
||||
# 批量插入问题
|
||||
QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None
|
||||
# 插入关联问题
|
||||
QuerySet(ProblemParagraphMapping).bulk_create(problem_paragraph_mapping_list) if len(
|
||||
problem_paragraph_mapping_list) > 0 else None
|
||||
# 向量化
|
||||
if with_embedding:
|
||||
ListenerManagement.embedding_by_document_signal.send(document_id)
|
||||
|
|
@ -273,7 +277,7 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||
# 删除段落
|
||||
QuerySet(model=Paragraph).filter(document_id=document_id).delete()
|
||||
# 删除问题
|
||||
QuerySet(model=Problem).filter(document_id=document_id).delete()
|
||||
QuerySet(model=ProblemParagraphMapping).filter(document_id=document_id).delete()
|
||||
# 删除向量库
|
||||
ListenerManagement.delete_embedding_by_document_signal.send(document_id)
|
||||
return True
|
||||
|
|
@ -344,12 +348,17 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||
document_model = document_paragraph_model.get('document')
|
||||
paragraph_model_list = document_paragraph_model.get('paragraph_model_list')
|
||||
problem_model_list = document_paragraph_model.get('problem_model_list')
|
||||
problem_paragraph_mapping_list = document_paragraph_model.get('problem_paragraph_mapping_list')
|
||||
|
||||
# 插入文档
|
||||
document_model.save()
|
||||
# 批量插入段落
|
||||
QuerySet(Paragraph).bulk_create(paragraph_model_list) if len(paragraph_model_list) > 0 else None
|
||||
# 批量插入问题
|
||||
QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None
|
||||
# 批量插入关联问题
|
||||
QuerySet(ProblemParagraphMapping).bulk_create(problem_paragraph_mapping_list) if len(
|
||||
problem_paragraph_mapping_list) > 0 else None
|
||||
document_id = str(document_model.id)
|
||||
return DocumentSerializers.Operate(
|
||||
data={'dataset_id': dataset_id, 'document_id': document_id}).one(
|
||||
|
|
@ -396,14 +405,18 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||
|
||||
paragraph_model_list = []
|
||||
problem_model_list = []
|
||||
problem_paragraph_mapping_list = []
|
||||
for paragraphs in paragraph_model_dict_list:
|
||||
paragraph = paragraphs.get('paragraph')
|
||||
for problem_model in paragraphs.get('problem_model_list'):
|
||||
problem_model_list.append(problem_model)
|
||||
for problem_paragraph_mapping in paragraphs.get('problem_paragraph_mapping_list'):
|
||||
problem_paragraph_mapping_list.append(problem_paragraph_mapping)
|
||||
paragraph_model_list.append(paragraph)
|
||||
|
||||
return {'document': document_model, 'paragraph_model_list': paragraph_model_list,
|
||||
'problem_model_list': problem_model_list}
|
||||
'problem_model_list': problem_model_list,
|
||||
'problem_paragraph_mapping_list': problem_paragraph_mapping_list}
|
||||
|
||||
@staticmethod
|
||||
def get_document_paragraph_model(dataset_id, instance: Dict):
|
||||
|
|
@ -523,6 +536,7 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||
document_model_list = []
|
||||
paragraph_model_list = []
|
||||
problem_model_list = []
|
||||
problem_paragraph_mapping_list = []
|
||||
# 插入文档
|
||||
for document in instance_list:
|
||||
document_paragraph_dict_model = DocumentSerializers.Create.get_document_paragraph_model(dataset_id,
|
||||
|
|
@ -532,6 +546,8 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||
paragraph_model_list.append(paragraph)
|
||||
for problem in document_paragraph_dict_model.get('problem_model_list'):
|
||||
problem_model_list.append(problem)
|
||||
for problem_paragraph_mapping in document_paragraph_dict_model.get('problem_paragraph_mapping_list'):
|
||||
problem_paragraph_mapping_list.append(problem_paragraph_mapping)
|
||||
|
||||
# 插入文档
|
||||
QuerySet(Document).bulk_create(document_model_list) if len(document_model_list) > 0 else None
|
||||
|
|
@ -539,6 +555,9 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||
QuerySet(Paragraph).bulk_create(paragraph_model_list) if len(paragraph_model_list) > 0 else None
|
||||
# 批量插入问题
|
||||
QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None
|
||||
# 批量插入关联问题
|
||||
QuerySet(ProblemParagraphMapping).bulk_create(problem_paragraph_mapping_list) if len(
|
||||
problem_paragraph_mapping_list) > 0 else None
|
||||
# 查询文档
|
||||
query_set = QuerySet(model=Document)
|
||||
query_set = query_set.filter(**{'id__in': [d.id for d in document_model_list]})
|
||||
|
|
|
|||
|
|
@ -20,9 +20,10 @@ from common.exception.app_exception import AppApiException
|
|||
from common.mixins.api_mixin import ApiMixin
|
||||
from common.util.common import post
|
||||
from common.util.field_message import ErrMessage
|
||||
from dataset.models import Paragraph, Problem, Document
|
||||
from dataset.models import Paragraph, Problem, Document, ProblemParagraphMapping
|
||||
from dataset.serializers.common_serializers import update_document_char_length
|
||||
from dataset.serializers.problem_serializers import ProblemInstanceSerializer, ProblemSerializer
|
||||
from dataset.serializers.problem_serializers import ProblemInstanceSerializer, ProblemSerializer, ProblemSerializers
|
||||
from embedding.models import SourceType
|
||||
|
||||
|
||||
class ParagraphSerializer(serializers.ModelSerializer):
|
||||
|
|
@ -84,6 +85,193 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
|
|||
content = serializers.CharField(required=True, max_length=4096, error_messages=ErrMessage.char(
|
||||
"分段内容"))
|
||||
|
||||
class Problem(ApiMixin, serializers.Serializer):
|
||||
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id"))
|
||||
|
||||
document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id"))
|
||||
|
||||
paragraph_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("段落id"))
|
||||
|
||||
def is_valid(self, *, raise_exception=False):
|
||||
super().is_valid(raise_exception=True)
|
||||
if not QuerySet(Paragraph).filter(id=self.data.get('paragraph_id')).exists():
|
||||
raise AppApiException(500, "段落id不存在")
|
||||
|
||||
def list(self, with_valid=False):
|
||||
"""
|
||||
获取问题列表
|
||||
:param with_valid: 是否校验
|
||||
:return: 问题列表
|
||||
"""
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
problem_paragraph_mapping = QuerySet(ProblemParagraphMapping).filter(dataset_id=self.data.get("dataset_id"),
|
||||
paragraph_id=self.data.get(
|
||||
'paragraph_id'))
|
||||
return [ProblemSerializer(row).data for row in
|
||||
QuerySet(Problem).filter(id__in=[row.problem_id for row in problem_paragraph_mapping])]
|
||||
|
||||
@transaction.atomic
|
||||
def save(self, instance: Dict, with_valid=True, with_embedding=True):
|
||||
if with_valid:
|
||||
self.is_valid()
|
||||
ProblemInstanceSerializer(data=instance).is_valid(raise_exception=True)
|
||||
problem = QuerySet(Problem).filter(dataset_id=self.data.get('dataset_id'),
|
||||
content=instance.get('content')).first()
|
||||
if problem is None:
|
||||
problem = Problem(id=uuid.uuid1(), dataset_id=self.data.get('dataset_id'),
|
||||
content=instance.get('content'))
|
||||
problem.save()
|
||||
if QuerySet(ProblemParagraphMapping).filter(dataset_id=self.data.get('dataset_id'), problem_id=problem.id,
|
||||
paragraph_id=self.data.get('paragraph_id')).exists():
|
||||
raise AppApiException(500, "已经关联,请勿重复关联")
|
||||
problem_paragraph_mapping = ProblemParagraphMapping(id=uuid.uuid1(),
|
||||
problem_id=problem.id,
|
||||
document_id=self.data.get('document_id'),
|
||||
paragraph_id=self.data.get('paragraph_id'),
|
||||
dataset_id=self.data.get('dataset_id'))
|
||||
problem_paragraph_mapping.save()
|
||||
if with_embedding:
|
||||
ListenerManagement.embedding_by_problem_signal.send({'text': problem.content,
|
||||
'is_active': True,
|
||||
'source_type': SourceType.PROBLEM,
|
||||
'source_id': problem_paragraph_mapping.id,
|
||||
'document_id': self.data.get('document_id'),
|
||||
'paragraph_id': self.data.get('paragraph_id'),
|
||||
'dataset_id': self.data.get('dataset_id'),
|
||||
})
|
||||
|
||||
return ProblemSerializers.Operate(
|
||||
data={'dataset_id': self.data.get('dataset_id'),
|
||||
'problem_id': problem.id}).one(with_valid=True)
|
||||
|
||||
@staticmethod
|
||||
def get_request_params_api():
|
||||
return [openapi.Parameter(name='dataset_id',
|
||||
in_=openapi.IN_PATH,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=True,
|
||||
description='知识库id'),
|
||||
openapi.Parameter(name='document_id',
|
||||
in_=openapi.IN_PATH,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=True,
|
||||
description='文档id'),
|
||||
openapi.Parameter(name='paragraph_id',
|
||||
in_=openapi.IN_PATH,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=True,
|
||||
description='段落id')]
|
||||
|
||||
@staticmethod
|
||||
def get_request_body_api():
|
||||
return openapi.Schema(type=openapi.TYPE_OBJECT,
|
||||
required=["content"],
|
||||
properties={
|
||||
'content': openapi.Schema(
|
||||
type=openapi.TYPE_STRING, title="内容")
|
||||
})
|
||||
|
||||
@staticmethod
|
||||
def get_response_body_api():
|
||||
return openapi.Schema(
|
||||
type=openapi.TYPE_OBJECT,
|
||||
required=['id', 'content', 'hit_num', 'dataset_id', 'create_time', 'update_time'],
|
||||
properties={
|
||||
'id': openapi.Schema(type=openapi.TYPE_STRING, title="id",
|
||||
description="id", default="xx"),
|
||||
'content': openapi.Schema(type=openapi.TYPE_STRING, title="问题内容",
|
||||
description="问题内容", default='问题内容'),
|
||||
'hit_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="命中数量", description="命中数量",
|
||||
default=1),
|
||||
'dataset_id': openapi.Schema(type=openapi.TYPE_STRING, title="知识库id",
|
||||
description="知识库id", default='xxx'),
|
||||
'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间",
|
||||
description="修改时间",
|
||||
default="1970-01-01 00:00:00"),
|
||||
'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间",
|
||||
description="创建时间",
|
||||
default="1970-01-01 00:00:00"
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
class Association(ApiMixin, serializers.Serializer):
|
||||
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id"))
|
||||
|
||||
problem_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("问题id"))
|
||||
|
||||
document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id"))
|
||||
|
||||
paragraph_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("段落id"))
|
||||
|
||||
def is_valid(self, *, raise_exception=True):
|
||||
super().is_valid(raise_exception=True)
|
||||
dataset_id = self.data.get('dataset_id')
|
||||
paragraph_id = self.data.get('paragraph_id')
|
||||
problem_id = self.data.get("problem_id")
|
||||
if not QuerySet(Paragraph).filter(dataset_id=dataset_id, id=paragraph_id).exists():
|
||||
raise AppApiException(500, "段落不存在")
|
||||
if not QuerySet(Problem).filter(dataset_id=dataset_id, id=problem_id).exists():
|
||||
raise AppApiException(500, "问题不存在")
|
||||
|
||||
def association(self, with_valid=True, with_embedding=True):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
problem = QuerySet(Problem).filter(id=self.data.get("problem_id"))
|
||||
problem_paragraph_mapping = ProblemParagraphMapping(id=uuid.uuid1(),
|
||||
document_id=self.data.get('document_id'),
|
||||
paragraph_id=self.data.get('paragraph_id'),
|
||||
dataset_id=self.data.get('dataset_id'),
|
||||
problem_id=problem.id)
|
||||
problem_paragraph_mapping.save()
|
||||
if with_embedding:
|
||||
ListenerManagement.embedding_by_problem_signal.send({'text': problem.content,
|
||||
'is_active': True,
|
||||
'source_type': SourceType.PROBLEM,
|
||||
'source_id': problem_paragraph_mapping.id,
|
||||
'document_id': self.data.get('document_id'),
|
||||
'paragraph_id': self.data.get('paragraph_id'),
|
||||
'dataset_id': self.data.get('dataset_id'),
|
||||
})
|
||||
|
||||
def un_association(self, with_valid=True):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
problem_paragraph_mapping = QuerySet(ProblemParagraphMapping).filter(
|
||||
paragraph_id=self.data.get('paragraph_id'),
|
||||
dataset_id=self.data.get('dataset_id'),
|
||||
problem_id=self.data.get(
|
||||
'problem_id')).first()
|
||||
problem_paragraph_mapping_id = problem_paragraph_mapping.id
|
||||
problem_paragraph_mapping.delete()
|
||||
ListenerManagement.delete_embedding_by_source_signal.send(problem_paragraph_mapping_id)
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def get_request_params_api():
|
||||
return [openapi.Parameter(name='dataset_id',
|
||||
in_=openapi.IN_PATH,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=True,
|
||||
description='知识库id'),
|
||||
openapi.Parameter(name='document_id',
|
||||
in_=openapi.IN_PATH,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=True,
|
||||
description='文档id')
|
||||
, openapi.Parameter(name='paragraph_id',
|
||||
in_=openapi.IN_PATH,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=True,
|
||||
description='段落id'),
|
||||
openapi.Parameter(name='problem_id',
|
||||
in_=openapi.IN_PATH,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=True,
|
||||
description='问题id')
|
||||
]
|
||||
|
||||
class Operate(ApiMixin, serializers.Serializer):
|
||||
# 段落id
|
||||
paragraph_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char(
|
||||
|
|
@ -158,8 +346,13 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
|
|||
return self.one(), instance
|
||||
|
||||
def get_problem_list(self):
|
||||
return [ProblemSerializer(problem).data for problem in
|
||||
QuerySet(Problem).filter(paragraph_id=self.data.get("paragraph_id"))]
|
||||
ProblemParagraphMapping(ProblemParagraphMapping)
|
||||
problem_paragraph_mapping = QuerySet(ProblemParagraphMapping).filter(
|
||||
paragraph_id=self.data.get("paragraph_id"))
|
||||
if len(problem_paragraph_mapping) > 0:
|
||||
return [ProblemSerializer(problem).data for problem in
|
||||
QuerySet(Problem).filter(id__in=[ppm.problem_id for ppm in problem_paragraph_mapping])]
|
||||
return []
|
||||
|
||||
def one(self, with_valid=False):
|
||||
if with_valid:
|
||||
|
|
@ -172,7 +365,7 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
|
|||
self.is_valid(raise_exception=True)
|
||||
paragraph_id = self.data.get('paragraph_id')
|
||||
QuerySet(Paragraph).filter(id=paragraph_id).delete()
|
||||
QuerySet(Problem).filter(paragraph_id=paragraph_id).delete()
|
||||
QuerySet(ProblemParagraphMapping).filter(paragraph_id=paragraph_id).delete()
|
||||
ListenerManagement.delete_embedding_by_paragraph_signal.send(paragraph_id)
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -210,10 +403,14 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
|
|||
paragraph_problem_model = self.get_paragraph_problem_model(dataset_id, document_id, instance)
|
||||
paragraph = paragraph_problem_model.get('paragraph')
|
||||
problem_model_list = paragraph_problem_model.get('problem_model_list')
|
||||
problem_paragraph_mapping_list = paragraph_problem_model.get('problem_paragraph_mapping_list')
|
||||
# 插入段落
|
||||
paragraph_problem_model.get('paragraph').save()
|
||||
# 插入問題
|
||||
QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None
|
||||
# 插入问题关联关系
|
||||
QuerySet(ProblemParagraphMapping).bulk_create(problem_paragraph_mapping_list) if len(
|
||||
problem_paragraph_mapping_list) > 0 else None
|
||||
# 修改长度
|
||||
update_document_char_length(document_id)
|
||||
if with_embedding:
|
||||
|
|
@ -229,12 +426,35 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
|
|||
content=instance.get("content"),
|
||||
dataset_id=dataset_id,
|
||||
title=instance.get("title") if 'title' in instance else '')
|
||||
problem_list = instance.get('problem_list')
|
||||
exists_problem_list = []
|
||||
if 'problem_list' in instance and len(problem_list) > 0:
|
||||
exists_problem_list = QuerySet(Problem).filter(dataset_id=dataset_id,
|
||||
content__in=[p.get('content') for p in
|
||||
problem_list]).all()
|
||||
|
||||
problem_model_list = [Problem(id=uuid.uuid1(), content=problem.get('content'), paragraph_id=paragraph.id,
|
||||
document_id=document_id, dataset_id=dataset_id) for problem in (
|
||||
instance.get('problem_list') if 'problem_list' in instance else [])]
|
||||
problem_model_list = [
|
||||
ParagraphSerializers.Create.or_get(exists_problem_list, problem.get('content'), dataset_id) for
|
||||
problem in (
|
||||
instance.get('problem_list') if 'problem_list' in instance else [])]
|
||||
|
||||
return {'paragraph': paragraph, 'problem_model_list': problem_model_list}
|
||||
problem_paragraph_mapping_list = [
|
||||
ProblemParagraphMapping(id=uuid.uuid1(), document_id=document_id, problem_id=problem_model.id,
|
||||
paragraph_id=paragraph.id,
|
||||
dataset_id=dataset_id) for
|
||||
problem_model in problem_model_list]
|
||||
return {'paragraph': paragraph,
|
||||
'problem_model_list': [problem_model for problem_model in problem_model_list if
|
||||
not list(exists_problem_list).__contains__(problem_model)],
|
||||
'problem_paragraph_mapping_list': problem_paragraph_mapping_list}
|
||||
|
||||
@staticmethod
|
||||
def or_get(exists_problem_list, content, dataset_id):
|
||||
exists = [row for row in exists_problem_list if row.content == content]
|
||||
if len(exists) > 0:
|
||||
return exists[0]
|
||||
else:
|
||||
return Problem(id=uuid.uuid1(), content=content, dataset_id=dataset_id)
|
||||
|
||||
@staticmethod
|
||||
def get_request_body_api():
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
@date:2023/10/23 13:55
|
||||
@desc:
|
||||
"""
|
||||
import os
|
||||
import uuid
|
||||
from typing import Dict
|
||||
|
||||
|
|
@ -14,19 +15,19 @@ from django.db.models import QuerySet
|
|||
from drf_yasg import openapi
|
||||
from rest_framework import serializers
|
||||
|
||||
from common.event.listener_manage import ListenerManagement
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.db.search import native_search, native_page_search
|
||||
from common.event import ListenerManagement, UpdateProblemArgs
|
||||
from common.mixins.api_mixin import ApiMixin
|
||||
from common.util.field_message import ErrMessage
|
||||
from dataset.models import Problem, Paragraph
|
||||
from embedding.models import SourceType
|
||||
from embedding.vector.pg_vector import PGVector
|
||||
from common.util.file_util import get_file_content
|
||||
from dataset.models import Problem, Paragraph, ProblemParagraphMapping
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
|
||||
class ProblemSerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = Problem
|
||||
fields = ['id', 'content', 'dataset_id', 'document_id',
|
||||
fields = ['id', 'content', 'dataset_id',
|
||||
'create_time', 'update_time']
|
||||
|
||||
|
||||
|
|
@ -49,186 +50,92 @@ class ProblemInstanceSerializer(ApiMixin, serializers.Serializer):
|
|||
|
||||
|
||||
class ProblemSerializers(ApiMixin, serializers.Serializer):
|
||||
class Create(ApiMixin, serializers.Serializer):
|
||||
class Create(serializers.Serializer):
|
||||
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id"))
|
||||
problem_list = serializers.ListField(required=True, error_messages=ErrMessage.list("问题列表"),
|
||||
child=serializers.CharField(required=True,
|
||||
error_messages=ErrMessage.char("问题")))
|
||||
|
||||
document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id"))
|
||||
|
||||
paragraph_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("段落id"))
|
||||
|
||||
def is_valid(self, *, raise_exception=False):
|
||||
super().is_valid(raise_exception=True)
|
||||
if not QuerySet(Paragraph).filter(id=self.data.get('paragraph_id'),
|
||||
document_id=self.data.get('document_id'),
|
||||
dataset_id=self.data.get('dataset_id')).exists():
|
||||
raise AppApiException(500, "段落id不正确")
|
||||
|
||||
@transaction.atomic
|
||||
def save(self, instance: Dict, with_valid=True, with_embedding=True):
|
||||
if with_valid:
|
||||
self.is_valid()
|
||||
ProblemInstanceSerializer(data=instance).is_valid(raise_exception=True)
|
||||
problem = Problem(id=uuid.uuid1(), paragraph_id=self.data.get('paragraph_id'),
|
||||
document_id=self.data.get('document_id'), dataset_id=self.data.get('dataset_id'),
|
||||
content=instance.get('content'))
|
||||
problem.save()
|
||||
if with_embedding:
|
||||
ListenerManagement.embedding_by_problem_signal.send({'text': problem.content,
|
||||
'is_active': True,
|
||||
'source_type': SourceType.PROBLEM,
|
||||
'source_id': problem.id,
|
||||
'document_id': self.data.get('document_id'),
|
||||
'paragraph_id': self.data.get('paragraph_id'),
|
||||
'dataset_id': self.data.get('dataset_id'),
|
||||
|
||||
})
|
||||
|
||||
return ProblemSerializers.Operate(
|
||||
data={'dataset_id': self.data.get('dataset_id'), 'document_id': self.data.get('document_id'),
|
||||
'paragraph_id': self.data.get('paragraph_id'), 'problem_id': problem.id}).one(with_valid=True)
|
||||
|
||||
@staticmethod
|
||||
def get_request_body_api():
|
||||
return ProblemInstanceSerializer.get_request_body_api()
|
||||
|
||||
@staticmethod
|
||||
def get_request_params_api():
|
||||
return [openapi.Parameter(name='dataset_id',
|
||||
in_=openapi.IN_PATH,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=True,
|
||||
description='知识库id'),
|
||||
openapi.Parameter(name='document_id',
|
||||
in_=openapi.IN_PATH,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=True,
|
||||
description='文档id'),
|
||||
openapi.Parameter(name='paragraph_id',
|
||||
in_=openapi.IN_PATH,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=True,
|
||||
description='段落id')]
|
||||
|
||||
class Query(ApiMixin, serializers.Serializer):
|
||||
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id"))
|
||||
|
||||
document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id"))
|
||||
|
||||
paragraph_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("段落id"))
|
||||
|
||||
def is_valid(self, *, raise_exception=True):
|
||||
super().is_valid(raise_exception=True)
|
||||
if not QuerySet(Paragraph).filter(id=self.data.get('paragraph_id')).exists():
|
||||
raise AppApiException(500, "段落id不存在")
|
||||
|
||||
def get_query_set(self):
|
||||
dataset_id = self.data.get('dataset_id')
|
||||
document_id = self.data.get('document_id')
|
||||
paragraph_id = self.data.get("paragraph_id")
|
||||
return QuerySet(Problem).filter(
|
||||
**{'paragraph_id': paragraph_id, 'dataset_id': dataset_id, 'document_id': document_id})
|
||||
|
||||
def list(self, with_valid=False):
|
||||
"""
|
||||
获取问题列表
|
||||
:param with_valid: 是否校验
|
||||
:return: 问题列表
|
||||
"""
|
||||
def batch(self, with_valid=True):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
query_set = self.get_query_set()
|
||||
return [ProblemSerializer(p).data for p in query_set]
|
||||
problem_list = self.data.get('problem_list')
|
||||
dataset_id = self.data.get('dataset_id')
|
||||
exists_problem_content_list = [problem.content for problem in
|
||||
QuerySet(Problem).filter(dataset_id=dataset_id,
|
||||
content__in=problem_list)]
|
||||
problem_instance_list = [Problem(id=uuid.uuid1(), dataset_id=dataset_id, content=problem_content) for
|
||||
problem_content in
|
||||
self.data.get('problem_list') if
|
||||
(not exists_problem_content_list.__contains__(problem_content) if
|
||||
len(exists_problem_content_list) > 0 else True)]
|
||||
|
||||
@staticmethod
|
||||
def get_request_params_api():
|
||||
return [openapi.Parameter(name='dataset_id',
|
||||
in_=openapi.IN_PATH,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=True,
|
||||
description='知识库id'),
|
||||
openapi.Parameter(name='document_id',
|
||||
in_=openapi.IN_PATH,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=True,
|
||||
description='文档id')
|
||||
, openapi.Parameter(name='paragraph_id',
|
||||
in_=openapi.IN_PATH,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=True,
|
||||
description='段落id')]
|
||||
QuerySet(Problem).bulk_create(problem_instance_list) if len(problem_instance_list) > 0 else None
|
||||
return [ProblemSerializer(problem_instance).data for problem_instance in problem_instance_list]
|
||||
|
||||
class Operate(ApiMixin, serializers.Serializer):
|
||||
class Query(serializers.Serializer):
|
||||
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id"))
|
||||
content = serializers.CharField(required=False, error_messages=ErrMessage.char("问题"))
|
||||
|
||||
document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id"))
|
||||
def get_query_set(self):
|
||||
query_set = QuerySet(model=Problem)
|
||||
query_set = query_set.filter(
|
||||
**{'dataset_id': self.data.get('dataset_id')})
|
||||
if 'content' in self.data:
|
||||
query_set = query_set.filter(**{'content__contains': self.data.get('content')})
|
||||
return query_set
|
||||
|
||||
paragraph_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("段落id"))
|
||||
def list(self):
|
||||
query_set = self.get_query_set()
|
||||
return native_search(query_set, select_string=get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_problem.sql')))
|
||||
|
||||
def page(self, current_page, page_size):
|
||||
query_set = self.get_query_set()
|
||||
return native_page_search(current_page, page_size, query_set, select_string=get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_problem.sql')))
|
||||
|
||||
class Operate(serializers.Serializer):
|
||||
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id"))
|
||||
|
||||
problem_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("问题id"))
|
||||
|
||||
def delete(self, with_valid=False):
|
||||
def list_paragraph(self, with_valid=True):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
QuerySet(Problem).filter(**{'id': self.data.get('problem_id')}).delete()
|
||||
PGVector().delete_by_source_id(self.data.get('problem_id'), SourceType.PROBLEM)
|
||||
ListenerManagement.delete_embedding_by_source_signal.send(self.data.get('problem_id'))
|
||||
return True
|
||||
problem_paragraph_mapping = QuerySet(ProblemParagraphMapping).filter(dataset_id=self.data.get("dataset_id"),
|
||||
problem_id=self.data.get("problem_id"))
|
||||
return native_search(
|
||||
QuerySet(Paragraph).filter(id__in=[row.paragraph_id for row in problem_paragraph_mapping]),
|
||||
select_string=get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_paragraph.sql')))
|
||||
|
||||
def one(self, with_valid=False):
|
||||
def one(self, with_valid=True):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
return ProblemInstanceSerializer(QuerySet(Problem).get(**{'id': self.data.get('problem_id')})).data
|
||||
|
||||
@staticmethod
|
||||
def get_request_params_api():
|
||||
return [openapi.Parameter(name='dataset_id',
|
||||
in_=openapi.IN_PATH,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=True,
|
||||
description='知识库id'),
|
||||
openapi.Parameter(name='document_id',
|
||||
in_=openapi.IN_PATH,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=True,
|
||||
description='文档id')
|
||||
, openapi.Parameter(name='paragraph_id',
|
||||
in_=openapi.IN_PATH,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=True,
|
||||
description='段落id'),
|
||||
openapi.Parameter(name='problem_id',
|
||||
in_=openapi.IN_PATH,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=True,
|
||||
description='问题id')
|
||||
]
|
||||
@transaction.atomic
|
||||
def delete(self, with_valid=True):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
problem_paragraph_mapping_list = QuerySet(ProblemParagraphMapping).filter(
|
||||
dataset_id=self.data.get('dataset_id'),
|
||||
problem_id=self.data.get('problem_id'))
|
||||
source_ids = [row.id for row in problem_paragraph_mapping_list]
|
||||
QuerySet(Problem).filter(id=self.data.get('problem_id')).delete()
|
||||
ListenerManagement.delete_embedding_by_source_ids_signal.send(source_ids)
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def get_response_body_api():
|
||||
return openapi.Schema(
|
||||
type=openapi.TYPE_OBJECT,
|
||||
required=['id', 'content', 'hit_num', 'star_num', 'trample_num', 'dataset_id',
|
||||
'document_id',
|
||||
'create_time', 'update_time'],
|
||||
properties={
|
||||
'id': openapi.Schema(type=openapi.TYPE_STRING, title="id",
|
||||
description="id", default="xx"),
|
||||
'content': openapi.Schema(type=openapi.TYPE_STRING, title="问题内容",
|
||||
description="问题内容", default='问题内容'),
|
||||
'hit_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="命中数量", description="命中数量",
|
||||
default=1),
|
||||
'star_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="点赞数量",
|
||||
description="点赞数量", default=1),
|
||||
'trample_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="点踩数量",
|
||||
description="点踩数", default=1),
|
||||
'document_id': openapi.Schema(type=openapi.TYPE_STRING, title="文档id",
|
||||
description="文档id", default='xxx'),
|
||||
'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间",
|
||||
description="修改时间",
|
||||
default="1970-01-01 00:00:00"),
|
||||
'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间",
|
||||
description="创建时间",
|
||||
default="1970-01-01 00:00:00"
|
||||
)
|
||||
}
|
||||
)
|
||||
@transaction.atomic
|
||||
def edit(self, instance: Dict, with_valid=True):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
problem_id = self.data.get('problem_id')
|
||||
dataset_id = self.data.get('dataset_id')
|
||||
content = instance.get('content')
|
||||
problem = QuerySet(Problem).filter(id=problem_id,
|
||||
dataset_id=dataset_id).first()
|
||||
problem.content = content
|
||||
problem.save()
|
||||
ListenerManagement.update_problem_signal.send(UpdateProblemArgs(problem_id, content))
|
||||
|
|
|
|||
|
|
@ -1,10 +1,5 @@
|
|||
SELECT
|
||||
problem."id",
|
||||
problem."content",
|
||||
problem_paragraph_mapping.hit_num,
|
||||
problem_paragraph_mapping.star_num,
|
||||
problem_paragraph_mapping.trample_num,
|
||||
problem_paragraph_mapping.paragraph_id
|
||||
problem.*,
|
||||
(SELECT "count"("id") FROM "problem_paragraph_mapping" WHERE problem_id="problem"."id") as "paragraph_count"
|
||||
FROM
|
||||
problem problem
|
||||
LEFT JOIN problem_paragraph_mapping problem_paragraph_mapping ON problem."id" = problem_paragraph_mapping.problem_id
|
||||
|
|
|
|||
|
|
@ -0,0 +1,127 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: problem_api.py
|
||||
@date:2024/3/11 10:49
|
||||
@desc:
|
||||
"""
|
||||
from drf_yasg import openapi
|
||||
|
||||
from common.mixins.api_mixin import ApiMixin
|
||||
|
||||
|
||||
class ProblemApi(ApiMixin):
|
||||
@staticmethod
|
||||
def get_response_body_api():
|
||||
return openapi.Schema(
|
||||
type=openapi.TYPE_OBJECT,
|
||||
required=['id', 'content', 'hit_num', 'dataset_id', 'create_time', 'update_time'],
|
||||
properties={
|
||||
'id': openapi.Schema(type=openapi.TYPE_STRING, title="id",
|
||||
description="id", default="xx"),
|
||||
'content': openapi.Schema(type=openapi.TYPE_STRING, title="问题内容",
|
||||
description="问题内容", default='问题内容'),
|
||||
'hit_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="命中数量", description="命中数量",
|
||||
default=1),
|
||||
'dataset_id': openapi.Schema(type=openapi.TYPE_STRING, title="知识库id",
|
||||
description="知识库id", default='xxx'),
|
||||
'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间",
|
||||
description="修改时间",
|
||||
default="1970-01-01 00:00:00"),
|
||||
'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间",
|
||||
description="创建时间",
|
||||
default="1970-01-01 00:00:00"
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
class Operate(ApiMixin):
|
||||
@staticmethod
|
||||
def get_request_params_api():
|
||||
return [openapi.Parameter(name='dataset_id',
|
||||
in_=openapi.IN_PATH,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=True,
|
||||
description='知识库id'),
|
||||
openapi.Parameter(name='problem_id',
|
||||
in_=openapi.IN_PATH,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=True,
|
||||
description='问题id')]
|
||||
|
||||
@staticmethod
|
||||
def get_request_body_api():
|
||||
return openapi.Schema(
|
||||
type=openapi.TYPE_OBJECT,
|
||||
required=['content'],
|
||||
properties={
|
||||
'content': openapi.Schema(type=openapi.TYPE_STRING, title="问题内容",
|
||||
description="问题内容"),
|
||||
|
||||
}
|
||||
)
|
||||
|
||||
class Paragraph(ApiMixin):
|
||||
@staticmethod
|
||||
def get_request_params_api():
|
||||
return ProblemApi.Operate.get_request_params_api()
|
||||
|
||||
@staticmethod
|
||||
def get_response_body_api():
|
||||
return openapi.Schema(
|
||||
type=openapi.TYPE_OBJECT,
|
||||
required=['content'],
|
||||
properties={
|
||||
'content': openapi.Schema(type=openapi.TYPE_STRING, max_length=4096, title="分段内容",
|
||||
description="分段内容"),
|
||||
'title': openapi.Schema(type=openapi.TYPE_STRING, max_length=256, title="分段标题",
|
||||
description="分段标题"),
|
||||
'is_active': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否可用", description="是否可用"),
|
||||
'hit_num': openapi.Schema(type=openapi.TYPE_NUMBER, title="命中次数", description="命中次数"),
|
||||
'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间",
|
||||
description="修改时间",
|
||||
default="1970-01-01 00:00:00"),
|
||||
'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间",
|
||||
description="创建时间",
|
||||
default="1970-01-01 00:00:00"
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
class Query(ApiMixin):
|
||||
@staticmethod
|
||||
def get_request_params_api():
|
||||
return [openapi.Parameter(name='dataset_id',
|
||||
in_=openapi.IN_PATH,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=True,
|
||||
description='知识库id'),
|
||||
openapi.Parameter(name='content',
|
||||
in_=openapi.IN_QUERY,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=False,
|
||||
description='问题')]
|
||||
|
||||
class BatchCreate(ApiMixin):
|
||||
@staticmethod
|
||||
def get_request_body_api():
|
||||
return openapi.Schema(type=openapi.TYPE_ARRAY,
|
||||
items=ProblemApi.Create.get_request_body_api())
|
||||
|
||||
@staticmethod
|
||||
def get_request_params_api():
|
||||
return ProblemApi.Create.get_request_params_api()
|
||||
|
||||
class Create(ApiMixin):
|
||||
@staticmethod
|
||||
def get_request_body_api():
|
||||
return openapi.Schema(type=openapi.TYPE_STRING, description="问题文本")
|
||||
|
||||
@staticmethod
|
||||
def get_request_params_api():
|
||||
return [openapi.Parameter(name='dataset_id',
|
||||
in_=openapi.IN_PATH,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=True,
|
||||
description='知识库id')]
|
||||
|
|
@ -28,7 +28,16 @@ urlpatterns = [
|
|||
path('dataset/<str:dataset_id>/document/<str:document_id>/paragraph/<paragraph_id>',
|
||||
views.Paragraph.Operate.as_view()),
|
||||
path('dataset/<str:dataset_id>/document/<str:document_id>/paragraph/<str:paragraph_id>/problem',
|
||||
views.Problem.as_view()),
|
||||
path('dataset/<str:dataset_id>/document/<str:document_id>/paragraph/<str:paragraph_id>/problem/<str:problem_id>',
|
||||
views.Problem.Operate.as_view())
|
||||
views.Paragraph.Problem.as_view()),
|
||||
path(
|
||||
'dataset/<str:dataset_id>/document/<str:document_id>/paragraph/<str:paragraph_id>/problem/<str:problem_id>/un_association',
|
||||
views.Paragraph.Problem.UnAssociation.as_view()),
|
||||
path(
|
||||
'dataset/<str:dataset_id>/document/<str:document_id>/paragraph/<str:paragraph_id>/problem/<str:problem_id>/association',
|
||||
views.Paragraph.Problem.Association.as_view()),
|
||||
path('dataset/<str:dataset_id>/problem', views.Problem.as_view()),
|
||||
path('dataset/<str:dataset_id>/problem/<int:current_page>/<int:page_size>', views.Problem.Page.as_view()),
|
||||
path('dataset/<str:dataset_id>/problem/<str:problem_id>', views.Problem.Operate.as_view()),
|
||||
path('dataset/<str:dataset_id>/problem/<str:problem_id>/paragraph', views.Problem.Paragraph.as_view()),
|
||||
|
||||
]
|
||||
|
|
|
|||
|
|
@ -52,6 +52,73 @@ class Paragraph(APIView):
|
|||
return result.success(
|
||||
ParagraphSerializers.Create(data={'dataset_id': dataset_id, 'document_id': document_id}).save(request.data))
|
||||
|
||||
class Problem(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
@action(methods=['POST'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="添加关联问题",
|
||||
operation_id="添加段落关联问题",
|
||||
manual_parameters=ParagraphSerializers.Problem.get_request_params_api(),
|
||||
request_body=ParagraphSerializers.Problem.get_request_body_api(),
|
||||
responses=result.get_api_response(ParagraphSerializers.Problem.get_response_body_api()),
|
||||
tags=["知识库/文档/段落"])
|
||||
@has_permissions(
|
||||
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
|
||||
dynamic_tag=k.get('dataset_id')))
|
||||
def post(self, request: Request, dataset_id: str, document_id: str, paragraph_id: str):
|
||||
return result.success(ParagraphSerializers.Problem(
|
||||
data={"dataset_id": dataset_id, 'document_id': document_id, 'paragraph_id': paragraph_id}).save(
|
||||
request.data, with_valid=True))
|
||||
|
||||
@action(methods=['GET'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="获取段落问题列表",
|
||||
operation_id="获取段落问题列表",
|
||||
manual_parameters=ParagraphSerializers.Problem.get_request_params_api(),
|
||||
responses=result.get_api_array_response(
|
||||
ParagraphSerializers.Problem.get_response_body_api()),
|
||||
tags=["知识库/文档/段落"])
|
||||
@has_permissions(
|
||||
lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE,
|
||||
dynamic_tag=k.get('dataset_id')))
|
||||
def get(self, request: Request, dataset_id: str, document_id: str, paragraph_id: str):
|
||||
return result.success(ParagraphSerializers.Problem(
|
||||
data={"dataset_id": dataset_id, 'document_id': document_id, 'paragraph_id': paragraph_id}).list(
|
||||
with_valid=True))
|
||||
|
||||
class UnAssociation(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
@action(methods=['PUT'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="解除关联问题",
|
||||
operation_id="解除关联问题",
|
||||
manual_parameters=ParagraphSerializers.Association.get_request_params_api(),
|
||||
responses=result.get_default_response(),
|
||||
tags=["知识库/文档/段落"])
|
||||
@has_permissions(
|
||||
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
|
||||
dynamic_tag=k.get('dataset_id')))
|
||||
def put(self, request: Request, dataset_id: str, document_id: str, paragraph_id: str, problem_id: str):
|
||||
return result.success(ParagraphSerializers.Association(
|
||||
data={'dataset_id': dataset_id, 'document_id': document_id, 'paragraph_id': paragraph_id,
|
||||
'problem_id': problem_id}).un_association())
|
||||
|
||||
class Association(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
@action(methods=['PUT'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="关联问题",
|
||||
operation_id="关联问题",
|
||||
manual_parameters=ParagraphSerializers.Association.get_request_params_api(),
|
||||
responses=result.get_default_response(),
|
||||
tags=["知识库/文档/段落"])
|
||||
@has_permissions(
|
||||
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
|
||||
dynamic_tag=k.get('dataset_id')))
|
||||
def put(self, request: Request, dataset_id: str, document_id: str, paragraph_id: str, problem_id: str):
|
||||
return result.success(ParagraphSerializers.Association(
|
||||
data={'dataset_id': dataset_id, 'document_id': document_id, 'paragraph_id': paragraph_id,
|
||||
'problem_id': problem_id}).association())
|
||||
|
||||
class Operate(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
|
|
@ -61,7 +128,7 @@ class Paragraph(APIView):
|
|||
manual_parameters=ParagraphSerializers.Operate.get_request_params_api(),
|
||||
request_body=ParagraphSerializers.Operate.get_request_body_api(),
|
||||
responses=result.get_api_response(ParagraphSerializers.Operate.get_response_body_api())
|
||||
,tags=["知识库/文档/段落"])
|
||||
, tags=["知识库/文档/段落"])
|
||||
@has_permissions(
|
||||
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
|
||||
dynamic_tag=k.get('dataset_id')))
|
||||
|
|
|
|||
|
|
@ -8,61 +8,115 @@
|
|||
"""
|
||||
from drf_yasg.utils import swagger_auto_schema
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.views import APIView
|
||||
from rest_framework.views import Request
|
||||
|
||||
from common.auth import TokenAuth, has_permissions
|
||||
from common.constants.permission_constants import Permission, Group, Operate
|
||||
from common.response import result
|
||||
from common.util.common import query_params_to_single_dict
|
||||
from dataset.serializers.problem_serializers import ProblemSerializers
|
||||
from dataset.swagger_api.problem_api import ProblemApi
|
||||
|
||||
|
||||
class Problem(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
@action(methods=['GET'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="问题列表",
|
||||
operation_id="问题列表",
|
||||
manual_parameters=ProblemApi.Query.get_request_params_api(),
|
||||
responses=result.get_api_array_response(ProblemApi.get_response_body_api()),
|
||||
tags=["知识库/文档/段落/问题"]
|
||||
)
|
||||
@has_permissions(
|
||||
lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE,
|
||||
dynamic_tag=k.get('dataset_id')))
|
||||
def get(self, request: Request, dataset_id: str):
|
||||
q = ProblemSerializers.Query(
|
||||
data={**query_params_to_single_dict(request.query_params), 'dataset_id': dataset_id})
|
||||
q.is_valid(raise_exception=True)
|
||||
return result.success(q.list())
|
||||
|
||||
@action(methods=['POST'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="添加关联问题",
|
||||
operation_id="添加段落关联问题",
|
||||
manual_parameters=ProblemSerializers.Create.get_request_params_api(),
|
||||
request_body=ProblemSerializers.Create.get_request_body_api(),
|
||||
responses=result.get_api_response(ProblemSerializers.Operate.get_response_body_api()),
|
||||
@swagger_auto_schema(operation_summary="创建问题",
|
||||
operation_id="创建问题",
|
||||
manual_parameters=ProblemApi.BatchCreate.get_request_params_api(),
|
||||
request_body=ProblemApi.BatchCreate.get_request_body_api(),
|
||||
responses=result.get_api_response(ProblemApi.Query.get_response_body_api()),
|
||||
tags=["知识库/文档/段落/问题"])
|
||||
@has_permissions(
|
||||
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
|
||||
dynamic_tag=k.get('dataset_id')))
|
||||
def post(self, request: Request, dataset_id: str, document_id: str, paragraph_id: str):
|
||||
return result.success(ProblemSerializers.Create(
|
||||
data={"dataset_id": dataset_id, 'document_id': document_id, 'paragraph_id': paragraph_id}).save(
|
||||
request.data, with_valid=True))
|
||||
def post(self, request: Request, dataset_id: str):
|
||||
return result.success(
|
||||
ProblemSerializers.Create(
|
||||
data={'dataset_id': dataset_id, 'problem_list': request.query_params.get('problem_list')}).save())
|
||||
|
||||
@action(methods=['GET'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="获取段落问题列表",
|
||||
operation_id="获取段落问题列表",
|
||||
manual_parameters=ProblemSerializers.Query.get_request_params_api(),
|
||||
responses=result.get_api_array_response(ProblemSerializers.Operate.get_response_body_api()),
|
||||
tags=["知识库/文档/段落/问题"])
|
||||
@has_permissions(
|
||||
lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE,
|
||||
dynamic_tag=k.get('dataset_id')))
|
||||
def get(self, request: Request, dataset_id: str, document_id: str, paragraph_id: str):
|
||||
return result.success(ProblemSerializers.Query(
|
||||
data={"dataset_id": dataset_id, 'document_id': document_id, 'paragraph_id': paragraph_id}).list(
|
||||
with_valid=True))
|
||||
class Paragraph(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
@action(methods=['GET'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="获取关联段落列表",
|
||||
operation_id="获取关联段落列表",
|
||||
manual_parameters=ProblemApi.Paragraph.get_request_params_api(),
|
||||
responses=result.get_api_array_response(ProblemApi.Paragraph.get_response_body_api()),
|
||||
tags=["知识库/文档/段落/问题"])
|
||||
@has_permissions(
|
||||
lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE,
|
||||
dynamic_tag=k.get('dataset_id')))
|
||||
def get(self, request: Request, dataset_id: str, problem_id: str):
|
||||
return result.success(ProblemSerializers.Operate(
|
||||
data={**query_params_to_single_dict(request.query_params), 'dataset_id': dataset_id,
|
||||
'problem_id': problem_id}).list_paragraph())
|
||||
|
||||
class Operate(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
@action(methods=['DELETE'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="删除段落问题",
|
||||
operation_id="删除段落问题",
|
||||
manual_parameters=ProblemSerializers.Query.get_request_params_api(),
|
||||
@swagger_auto_schema(operation_summary="删除问题",
|
||||
operation_id="删除问题",
|
||||
manual_parameters=ProblemApi.Operate.get_request_params_api(),
|
||||
responses=result.get_default_response(),
|
||||
tags=["知识库/文档/段落/问题"])
|
||||
@has_permissions(
|
||||
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
|
||||
dynamic_tag=k.get('dataset_id')))
|
||||
def delete(self, request: Request, dataset_id: str, document_id: str, paragraph_id: str, problem_id: str):
|
||||
o = ProblemSerializers.Operate(
|
||||
data={'dataset_id': dataset_id, 'document_id': document_id, 'paragraph_id': paragraph_id,
|
||||
'problem_id': problem_id})
|
||||
return result.success(o.delete(with_valid=True))
|
||||
def delete(self, request: Request, dataset_id: str, problem_id: str):
|
||||
return result.success(ProblemSerializers.Operate(
|
||||
data={**query_params_to_single_dict(request.query_params), 'dataset_id': dataset_id,
|
||||
'problem_id': problem_id}).delete())
|
||||
|
||||
@action(methods=['PUT'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="修改问题",
|
||||
operation_id="修改问题",
|
||||
manual_parameters=ProblemApi.Operate.get_request_params_api(),
|
||||
request_body=ProblemApi.Operate.get_request_body_api(),
|
||||
responses=result.get_api_response(ProblemApi.get_response_body_api()),
|
||||
tags=["知识库/文档/段落/问题"])
|
||||
@has_permissions(
|
||||
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
|
||||
dynamic_tag=k.get('dataset_id')))
|
||||
def put(self, request: Request, dataset_id: str, problem_id: str):
|
||||
return result.success(ProblemSerializers.Operate(
|
||||
data={**query_params_to_single_dict(request.query_params), 'dataset_id': dataset_id,
|
||||
'problem_id': problem_id}).edit(request.data))
|
||||
|
||||
class Page(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
@action(methods=['GET'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="分页获取问题列表",
|
||||
operation_id="分页获取问题列表",
|
||||
manual_parameters=result.get_page_request_params(
|
||||
ProblemApi.Query.get_request_params_api()),
|
||||
responses=result.get_page_api_response(ProblemApi.get_response_body_api()),
|
||||
tags=["知识库/文档/段落/问题"])
|
||||
@has_permissions(
|
||||
lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE,
|
||||
dynamic_tag=k.get('dataset_id')))
|
||||
def get(self, request: Request, dataset_id: str, current_page, page_size):
|
||||
d = ProblemSerializers.Query(
|
||||
data={**query_params_to_single_dict(request.query_params), 'dataset_id': dataset_id})
|
||||
d.is_valid(raise_exception=True)
|
||||
return result.success(d.page(current_page, page_size))
|
||||
|
|
|
|||
|
|
@ -131,6 +131,18 @@ class BaseVectorStore(ABC):
|
|||
def update_by_source_id(self, source_id: str, instance: Dict):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_by_source_ids(self, source_ids: List[str], instance: Dict):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def embed_documents(self, text_list: List[str]):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def embed_query(self, text: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_by_dataset_id(self, dataset_id: str):
|
||||
pass
|
||||
|
|
@ -147,6 +159,10 @@ class BaseVectorStore(ABC):
|
|||
def delete_by_source_id(self, source_id: str, source_type: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_by_source_ids(self, source_ids: List[str], source_type: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_by_paragraph_id(self, paragraph_id: str):
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from typing import Dict, List
|
|||
from django.db.models import QuerySet
|
||||
from langchain_community.embeddings import HuggingFaceEmbeddings
|
||||
|
||||
from common.config.embedding_config import EmbeddingModel
|
||||
from common.db.search import native_search, generate_sql_by_query_dict
|
||||
from common.db.sql_execute import select_one, select_list
|
||||
from common.util.file_util import get_file_content
|
||||
|
|
@ -24,6 +25,20 @@ from smartdoc.conf import PROJECT_DIR
|
|||
|
||||
class PGVector(BaseVectorStore):
|
||||
|
||||
def delete_by_source_ids(self, source_ids: List[str], source_type: str):
|
||||
QuerySet(Embedding).filter(source_id__in=source_ids, source_type=source_type).delete()
|
||||
|
||||
def update_by_source_ids(self, source_ids: List[str], instance: Dict):
|
||||
QuerySet(Embedding).filter(source_id__in=source_ids).update(**instance)
|
||||
|
||||
def embed_documents(self, text_list: List[str]):
|
||||
embedding = EmbeddingModel.get_embedding_model()
|
||||
return embedding.embed_documents(text_list)
|
||||
|
||||
def embed_query(self, text: str):
|
||||
embedding = EmbeddingModel.get_embedding_model()
|
||||
return embedding.embed_query(text)
|
||||
|
||||
def vector_is_create(self) -> bool:
|
||||
# 项目启动默认是创建好的 不需要再创建
|
||||
return True
|
||||
|
|
|
|||
|
|
@ -137,7 +137,7 @@ const postProblem: (
|
|||
)
|
||||
}
|
||||
/**
|
||||
* 删除问题
|
||||
* 解除关联问题
|
||||
* @param 参数 dataset_id, document_id, paragraph_id,problem_id
|
||||
*/
|
||||
const delProblem: (
|
||||
|
|
@ -146,8 +146,8 @@ const delProblem: (
|
|||
paragraph_id: string,
|
||||
problem_id: string
|
||||
) => Promise<Result<boolean>> = (dataset_id, document_id, paragraph_id, problem_id) => {
|
||||
return del(
|
||||
`${prefix}/${dataset_id}/document/${document_id}/paragraph/${paragraph_id}/problem/${problem_id}`
|
||||
return put(
|
||||
`${prefix}/${dataset_id}/document/${document_id}/paragraph/${paragraph_id}/problem/${problem_id}/un_association`
|
||||
)
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue