diff --git a/apps/dataset/models/data_set.py b/apps/dataset/models/data_set.py index e1e018891..9fcb0d643 100644 --- a/apps/dataset/models/data_set.py +++ b/apps/dataset/models/data_set.py @@ -22,6 +22,7 @@ class Status(models.TextChoices): success = 1, '已完成' error = 2, '导入失败' queue_up = 3, '排队中' + generating = 4, '生成问题中' class Type(models.TextChoices): diff --git a/apps/dataset/serializers/document_serializers.py b/apps/dataset/serializers/document_serializers.py index 91ade28fc..bc3fbef23 100644 --- a/apps/dataset/serializers/document_serializers.py +++ b/apps/dataset/serializers/document_serializers.py @@ -47,7 +47,7 @@ from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, from dataset.serializers.common_serializers import BatchSerializer, MetaSerializer, ProblemParagraphManage, \ get_embedding_model_id_by_dataset_id from dataset.serializers.paragraph_serializers import ParagraphSerializers, ParagraphInstanceSerializer -from dataset.task import sync_web_document +from dataset.task import sync_web_document, generate_related_by_document_id from embedding.task.embedding import embedding_by_document, delete_embedding_by_document_list, \ delete_embedding_by_document, update_embedding_dataset_id, delete_embedding_by_paragraph_ids, \ embedding_by_document_list @@ -960,6 +960,37 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): except AlreadyQueued as e: raise AppApiException(500, "任务正在执行中,请勿重复下发") + class GenerateRelated(ApiMixin, serializers.Serializer): + document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id")) + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + document_id = self.data.get('document_id') + if not QuerySet(Document).filter(id=document_id).exists(): + raise AppApiException(500, "文档id不存在") + + def generate_related(self, model_id, prompt, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + document_id = self.data.get('document_id') + QuerySet(Document).filter(id=document_id).update(status=Status.queue_up) + generate_related_by_document_id.delay(document_id, model_id, prompt) + + + + class BatchGenerateRelated(ApiMixin, serializers.Serializer): + dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id")) + + @transaction.atomic + def batch_generate_related(self, instance: Dict, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + document_id_list = instance.get("document_id_list") + model_id = instance.get("model_id") + prompt = instance.get("prompt") + for document_id in document_id_list: + DocumentSerializers.GenerateRelated(data={'document_id': document_id}).generate_related(model_id, prompt) + class FileBufferHandle: buffer = None diff --git a/apps/dataset/serializers/paragraph_serializers.py b/apps/dataset/serializers/paragraph_serializers.py index 84423dd4b..8e2a8e60b 100644 --- a/apps/dataset/serializers/paragraph_serializers.py +++ b/apps/dataset/serializers/paragraph_serializers.py @@ -27,6 +27,7 @@ from embedding.models import SourceType from embedding.task.embedding import embedding_by_problem as embedding_by_problem_task, embedding_by_problem, \ delete_embedding_by_source, enable_embedding_by_paragraph, disable_embedding_by_paragraph, embedding_by_paragraph, \ delete_embedding_by_paragraph, delete_embedding_by_paragraph_ids, update_embedding_document_id +from dataset.task import generate_related_by_paragraph_id_list class ParagraphSerializer(serializers.ModelSerializer): @@ -719,3 +720,20 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer): ) } ) + + + class BatchGenerateRelated(ApiMixin, serializers.Serializer): + dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id")) + document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id")) + + @transaction.atomic + def batch_generate_related(self, instance: Dict, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + paragraph_id_list = instance.get("paragraph_id_list") + model_id = instance.get("model_id") + prompt = instance.get("prompt") + generate_related_by_paragraph_id_list.delay(paragraph_id_list, model_id, prompt) + + + diff --git a/apps/dataset/task/__init__.py b/apps/dataset/task/__init__.py index de3f96538..7bb1839d3 100644 --- a/apps/dataset/task/__init__.py +++ b/apps/dataset/task/__init__.py @@ -7,3 +7,4 @@ @desc: """ from .sync import * +from .generate import * diff --git a/apps/dataset/task/generate.py b/apps/dataset/task/generate.py new file mode 100644 index 000000000..860425978 --- /dev/null +++ b/apps/dataset/task/generate.py @@ -0,0 +1,64 @@ +import logging +from math import ceil + +from celery_once import QueueOnce +from django.db.models import QuerySet +from langchain_core.messages import HumanMessage + +from common.config.embedding_config import ModelManage +from dataset.models import Paragraph, Document, Status +from dataset.task.tools import save_problem +from ops import celery_app +from setting.models import Model +from setting.models_provider import get_model + +max_kb_error = logging.getLogger("max_kb_error") +max_kb = logging.getLogger("max_kb") + + +def get_llm_model(model_id): + model = QuerySet(Model).filter(id=model_id).first() + return ModelManage.get_model(model_id, lambda _id: get_model(model)) + + +@celery_app.task(base=QueueOnce, once={'keys': ['document_id']}, + name='celery:generate_related_by_document') +def generate_related_by_document_id(document_id, model_id, prompt): + llm_model = get_llm_model(model_id) + offset = 0 + page_size = 10 + QuerySet(Document).filter(id=document_id).update(status=Status.generating) + + count = QuerySet(Paragraph).filter(document_id=document_id).count() + for i in range(0, ceil(count / page_size)): + paragraph_list = QuerySet(Paragraph).filter(document_id=document_id).all()[offset:offset + page_size] + offset += page_size + for paragraph in paragraph_list: + res = llm_model.invoke([HumanMessage(content=prompt.replace('{data}', paragraph.content))]) + if (res.content is None) or (len(res.content) == 0): + continue + problems = res.content.split('\n') + for problem in problems: + save_problem(paragraph.dataset_id, paragraph.document_id, paragraph.id, problem) + + QuerySet(Document).filter(id=document_id).update(status=Status.success) + + + +@celery_app.task(base=QueueOnce, once={'keys': ['paragraph_id_list']}, + name='celery:generate_related_by_paragraph_list') +def generate_related_by_paragraph_id_list(paragraph_id_list, model_id, prompt): + llm_model = get_llm_model(model_id) + offset = 0 + page_size = 10 + count = QuerySet(Paragraph).filter(id__in=paragraph_id_list).count() + for i in range(0, ceil(count / page_size)): + paragraph_list = QuerySet(Paragraph).filter(id__in=paragraph_id_list).all()[offset:offset + page_size] + offset += page_size + for paragraph in paragraph_list: + res = llm_model.invoke([HumanMessage(content=prompt.replace('{data}', paragraph.content))]) + if (res.content is None) or (len(res.content) == 0): + continue + problems = res.content.split('\n') + for problem in problems: + save_problem(paragraph.dataset_id, paragraph.document_id, paragraph.id, problem) diff --git a/apps/dataset/task/tools.py b/apps/dataset/task/tools.py index 5eb44084e..427a691a0 100644 --- a/apps/dataset/task/tools.py +++ b/apps/dataset/task/tools.py @@ -8,6 +8,7 @@ """ import logging +import re import traceback from common.util.fork import ChildLink, Fork @@ -60,3 +61,23 @@ def get_sync_web_document_handler(dataset_id): status=Status.error).save() return handler + + +def save_problem(dataset_id, document_id, paragraph_id, problem): + from dataset.serializers.paragraph_serializers import ParagraphSerializers + # print(f"dataset_id: {dataset_id}") + # print(f"document_id: {document_id}") + # print(f"paragraph_id: {paragraph_id}") + # print(f"problem: {problem}") + problem = re.sub(r"^\d+\.\s*", "", problem) + pattern = r"(.*?)" + match = re.search(pattern, problem) + problem = match.group(1) if match else None + if problem is None or len(problem) == 0: + return + try: + ParagraphSerializers.Problem( + data={"dataset_id": dataset_id, 'document_id': document_id, + 'paragraph_id': paragraph_id}).save(instance={"content": problem}, with_valid=True) + except Exception as e: + max_kb_error.error(f'关联问题失败: {e}') diff --git a/apps/dataset/urls.py b/apps/dataset/urls.py index 405101796..b22463556 100644 --- a/apps/dataset/urls.py +++ b/apps/dataset/urls.py @@ -15,6 +15,7 @@ urlpatterns = [ path('dataset//sync_web', views.Dataset.SyncWeb.as_view()), path('dataset//hit_test', views.Dataset.HitTest.as_view()), path('dataset//document', views.Document.as_view(), name='document'), + path('dataset//model', views.Dataset.Model.as_view()), path('dataset/document/template/export', views.Template.as_view()), path('dataset/document/table_template/export', views.TableTemplate.as_view()), path('dataset//document/web', views.WebDocument.as_view()), @@ -24,6 +25,7 @@ urlpatterns = [ path('dataset//document/batch_hit_handling', views.Document.BatchEditHitHandling.as_view()), path('dataset//document//', views.Document.Page.as_view()), path('dataset//document/batch_refresh', views.Document.BatchRefresh.as_view()), + path('dataset//document/batch_generate_related', views.Document.BatchGenerateRelated.as_view()), path('dataset//document/', views.Document.Operate.as_view(), name="document_operate"), path('dataset/document/split', views.Document.Split.as_view(), @@ -36,12 +38,14 @@ urlpatterns = [ path('dataset//document//sync', views.Document.SyncWeb.as_view()), path('dataset//document//refresh', views.Document.Refresh.as_view()), path('dataset//document//paragraph', views.Paragraph.as_view()), + path('dataset//document/batch_generate_related', views.Document.BatchGenerateRelated.as_view()), path( 'dataset//document//paragraph/migrate/dataset//document/', views.Paragraph.BatchMigrate.as_view()), path('dataset//document//paragraph/_batch', views.Paragraph.Batch.as_view()), path('dataset//document//paragraph//', views.Paragraph.Page.as_view(), name='paragraph_page'), + path('dataset//document//paragraph/batch_generate_related', views.Paragraph.BatchGenerateRelated.as_view()), path('dataset//document//paragraph/', views.Paragraph.Operate.as_view()), path('dataset//document//paragraph//problem', diff --git a/apps/dataset/views/dataset.py b/apps/dataset/views/dataset.py index adf10e3ff..fc3f8a8a5 100644 --- a/apps/dataset/views/dataset.py +++ b/apps/dataset/views/dataset.py @@ -20,6 +20,7 @@ from common.response import result from common.response.result import get_page_request_params, get_page_api_response, get_api_response from common.swagger_api.common_api import CommonApi from dataset.serializers.dataset_serializers import DataSetSerializers +from setting.serializers.provider_serializers import ModelSerializer class Dataset(APIView): @@ -223,3 +224,20 @@ class Dataset(APIView): 'user_id': str(request.user.id)}) d.is_valid() return result.success(d.page(current_page, page_size)) + + class Model(APIView): + authentication_classes = [TokenAuth] + + @action(methods=["GET"], detail=False) + @has_permissions(ViewPermission( + [RoleConstants.ADMIN, RoleConstants.USER], + [lambda r, keywords: Permission(group=Group.DATASET, operate=Operate.MANAGE, + dynamic_tag=keywords.get('dataset_id'))], + compare=CompareConstants.AND)) + def get(self, request: Request, dataset_id: str): + print(dataset_id) + return result.success( + ModelSerializer.Query( + data={'user_id': request.user.id, 'model_type': 'LLM'}).list( + with_valid=True) + ) diff --git a/apps/dataset/views/document.py b/apps/dataset/views/document.py index d41535b0b..d911d0de8 100644 --- a/apps/dataset/views/document.py +++ b/apps/dataset/views/document.py @@ -393,3 +393,14 @@ class Document(APIView): data={**query_params_to_single_dict(request.query_params), 'dataset_id': dataset_id}) d.is_valid(raise_exception=True) return result.success(d.page(current_page, page_size)) + + class BatchGenerateRelated(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['PUT'], detail=False) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, + dynamic_tag=k.get('dataset_id'))) + def put(self, request: Request, dataset_id: str): + return result.success(DocumentSerializers.BatchGenerateRelated(data={'dataset_id': dataset_id}) + .batch_generate_related(request.data)) diff --git a/apps/dataset/views/paragraph.py b/apps/dataset/views/paragraph.py index af968b8ab..c1286c0d4 100644 --- a/apps/dataset/views/paragraph.py +++ b/apps/dataset/views/paragraph.py @@ -232,3 +232,15 @@ class Paragraph(APIView): 'document_id': document_id}) d.is_valid(raise_exception=True) return result.success(d.page(current_page, page_size)) + + class BatchGenerateRelated(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['PUT'], detail=False) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, + dynamic_tag=k.get('dataset_id'))) + def put(self, request: Request, dataset_id: str, document_id: str): + return result.success( + ParagraphSerializers.BatchGenerateRelated(data={'dataset_id': dataset_id, 'document_id': document_id}) + .batch_generate_related(request.data)) diff --git a/ui/src/api/dataset.ts b/ui/src/api/dataset.ts index 8b9259861..c50e64513 100644 --- a/ui/src/api/dataset.ts +++ b/ui/src/api/dataset.ts @@ -202,6 +202,22 @@ const exportDataset: ( return exportExcel(dataset_name + '.xlsx', `dataset/${dataset_id}/export`, undefined, loading) } + +/** + * 获取当前用户可使用的模型列表 + * @param application_id + * @param loading + * @query { query_text: string, top_number: number, similarity: number } + * @returns + */ +const getDatasetModel: ( + dataset_id: string, + loading?: Ref +) => Promise>> = (dataset_id, loading) => { + return get(`${prefix}/${dataset_id}/model`, loading) +} + + export default { getDataset, getAllDataset, @@ -215,5 +231,6 @@ export default { putSyncWebDataset, putReEmbeddingDataset, postQADataset, - exportDataset + exportDataset, + getDatasetModel } diff --git a/ui/src/api/document.ts b/ui/src/api/document.ts index 371a2d4db..28954d0cc 100644 --- a/ui/src/api/document.ts +++ b/ui/src/api/document.ts @@ -317,6 +317,19 @@ const exportDocument: ( ) } +const batchGenerateRelated: ( + dataset_id: string, + data: any, + loading?: Ref +) => Promise> = (dataset_id, data, loading) => { + return put( + `${prefix}/${dataset_id}/document/batch_generate_related`, + data, + undefined, + loading + ) +} + export default { postSplitDocument, getDocument, @@ -338,5 +351,6 @@ export default { postQADocument, postTableDocument, exportDocument, - batchRefresh + batchRefresh, + batchGenerateRelated } diff --git a/ui/src/api/paragraph.ts b/ui/src/api/paragraph.ts index 675fa6efa..4a7d29b8a 100644 --- a/ui/src/api/paragraph.ts +++ b/ui/src/api/paragraph.ts @@ -226,6 +226,21 @@ const disassociationProblem: ( ) } +const batchGenerateRelated: ( + dataset_id: string, + document_id: string, + data: any, + loading?: Ref +) => Promise> = (dataset_id, document_id, data, loading) => { + return put( + `${prefix}/${dataset_id}/document/${document_id}/paragraph/batch_generate_related`, + data, + undefined, + loading + ) +} + + export default { getParagraph, delParagraph, @@ -236,5 +251,6 @@ export default { disassociationProblem, associationProblem, delMulParagraph, - putMigrateMulParagraph + putMigrateMulParagraph, + batchGenerateRelated } diff --git a/ui/src/views/document/component/GenerateRelatedDialog.vue b/ui/src/views/document/component/GenerateRelatedDialog.vue new file mode 100644 index 000000000..8791c4be9 --- /dev/null +++ b/ui/src/views/document/component/GenerateRelatedDialog.vue @@ -0,0 +1,245 @@ + + + diff --git a/ui/src/views/document/index.vue b/ui/src/views/document/index.vue index 8bf361011..f9e6a8c04 100644 --- a/ui/src/views/document/index.vue +++ b/ui/src/views/document/index.vue @@ -26,6 +26,9 @@ 重新向量化 + + 生成关联问题 + 设置 @@ -138,6 +141,9 @@ 排队中 + + 生成问题中 + @@ -258,6 +264,10 @@ @@ -340,6 +355,7 @@ import { datetimeFormat } from '@/utils/time' import { hitHandlingMethod } from '@/enums/document' import { MsgSuccess, MsgConfirm, MsgError } from '@/utils/message' import useStore from '@/stores' +import GenerateRelatedDialog from '@/views/document/component/GenerateRelatedDialog.vue' const router = useRouter() const route = useRoute() const { @@ -554,6 +570,19 @@ function batchRefresh() { }) } +function batchGenerateRelated() { + const arr: string[] = [] + multipleSelection.value.map((v) => { + if (v) { + arr.push(v.id) + } + }) + documentApi.batchGenerateRelated(id, arr, loading).then(() => { + MsgSuccess('批量生成关联问题成功') + multipleTableRef.value?.clearSelection() + }) +} + function deleteDocument(row: any) { MsgConfirm( `是否删除文档:${row.name} ?`, @@ -643,6 +672,23 @@ function refresh() { getList() } +const GenerateRelatedDialogRef = ref() +function openGenerateDialog(row?: any) { + const arr: string[] = [] + if (row) { + arr.push(row.id) + } else { + multipleSelection.value.map((v) => { + if (v) { + arr.push(v.id) + } + }) + } + + GenerateRelatedDialogRef.value.open(arr) +} + + onMounted(() => { getDetail() if (beforePagination.value) { diff --git a/ui/src/views/paragraph/component/GenerateRelatedDialog.vue b/ui/src/views/paragraph/component/GenerateRelatedDialog.vue new file mode 100644 index 000000000..af59e43b7 --- /dev/null +++ b/ui/src/views/paragraph/component/GenerateRelatedDialog.vue @@ -0,0 +1,246 @@ + + + diff --git a/ui/src/views/paragraph/index.vue b/ui/src/views/paragraph/index.vue index 1a03dc5a8..cb8f68a3e 100644 --- a/ui/src/views/paragraph/index.vue +++ b/ui/src/views/paragraph/index.vue @@ -121,6 +121,10 @@