From 83cd69e5b71b7d07ce11b8478a3c9cde545641a1 Mon Sep 17 00:00:00 2001
From: shaohuzhang1 <80892890+shaohuzhang1@users.noreply.github.com>
Date: Mon, 17 Feb 2025 15:49:41 +0800
Subject: [PATCH] feat: Generate problem support for generating unfinished
paragraphs #2174 (#2299)
---
.../serializers/document_serializers.py | 26 +++++++++++------
apps/dataset/task/generate.py | 16 +++++++++--
.../xinference_model_provider/model/image.py | 17 ++++++++++-
.../xinference_model_provider/model/llm.py | 17 ++++++++++-
.../generate-related-dialog/index.vue | 28 +++++++++++++++++--
5 files changed, 88 insertions(+), 16 deletions(-)
diff --git a/apps/dataset/serializers/document_serializers.py b/apps/dataset/serializers/document_serializers.py
index 811c55624..265903c33 100644
--- a/apps/dataset/serializers/document_serializers.py
+++ b/apps/dataset/serializers/document_serializers.py
@@ -243,13 +243,16 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
self.is_valid(raise_exception=True)
language = get_language()
if self.data.get('type') == 'csv':
- file = open(os.path.join(PROJECT_DIR, "apps", "dataset", 'template', f'csv_template_{to_locale(language)}.csv'), "rb")
+ file = open(
+ os.path.join(PROJECT_DIR, "apps", "dataset", 'template', f'csv_template_{to_locale(language)}.csv'),
+ "rb")
content = file.read()
file.close()
return HttpResponse(content, status=200, headers={'Content-Type': 'text/cxv',
'Content-Disposition': 'attachment; filename="csv_template.csv"'})
elif self.data.get('type') == 'excel':
- file = open(os.path.join(PROJECT_DIR, "apps", "dataset", 'template', f'excel_template_{to_locale(language)}.xlsx'), "rb")
+ file = open(os.path.join(PROJECT_DIR, "apps", "dataset", 'template',
+ f'excel_template_{to_locale(language)}.xlsx'), "rb")
content = file.read()
file.close()
return HttpResponse(content, status=200, headers={'Content-Type': 'application/vnd.ms-excel',
@@ -261,7 +264,8 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
language = get_language()
if self.data.get('type') == 'csv':
file = open(
- os.path.join(PROJECT_DIR, "apps", "dataset", 'template', f'table_template_{to_locale(language)}.csv'),
+ os.path.join(PROJECT_DIR, "apps", "dataset", 'template',
+ f'table_template_{to_locale(language)}.csv'),
"rb")
content = file.read()
file.close()
@@ -1180,7 +1184,7 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
if not QuerySet(Document).filter(id=document_id).exists():
raise AppApiException(500, _('document id not exist'))
- def generate_related(self, model_id, prompt, with_valid=True):
+ def generate_related(self, model_id, prompt, state_list=None, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
document_id = self.data.get('document_id')
@@ -1192,7 +1196,7 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
State.PENDING)
ListenerManagement.get_aggregation_document_status(document_id)()
try:
- generate_related_by_document_id.delay(document_id, model_id, prompt)
+ generate_related_by_document_id.delay(document_id, model_id, prompt, state_list)
except AlreadyQueued as e:
raise AppApiException(500, _('The task is being executed, please do not send it again.'))
@@ -1205,17 +1209,23 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
document_id_list = instance.get("document_id_list")
model_id = instance.get("model_id")
prompt = instance.get("prompt")
+ state_list = instance.get('state_list')
ListenerManagement.update_status(QuerySet(Document).filter(id__in=document_id_list),
TaskType.GENERATE_PROBLEM,
State.PENDING)
- ListenerManagement.update_status(QuerySet(Paragraph).filter(document_id__in=document_id_list),
- TaskType.GENERATE_PROBLEM,
+ ListenerManagement.update_status(QuerySet(Paragraph).annotate(
+ reversed_status=Reverse('status'),
+ task_type_status=Substr('reversed_status', TaskType.GENERATE_PROBLEM.value,
+ 1),
+ ).filter(task_type_status__in=state_list, document_id__in=document_id_list)
+ .values('id'),
+ TaskType.EMBEDDING,
State.PENDING)
ListenerManagement.get_aggregation_document_status_by_query_set(
QuerySet(Document).filter(id__in=document_id_list))()
try:
for document_id in document_id_list:
- generate_related_by_document_id.delay(document_id, model_id, prompt)
+ generate_related_by_document_id.delay(document_id, model_id, prompt, state_list)
except AlreadyQueued as e:
pass
diff --git a/apps/dataset/task/generate.py b/apps/dataset/task/generate.py
index 5ffcd1bec..bf9e53869 100644
--- a/apps/dataset/task/generate.py
+++ b/apps/dataset/task/generate.py
@@ -3,11 +3,12 @@ import traceback
from celery_once import QueueOnce
from django.db.models import QuerySet
+from django.db.models.functions import Reverse, Substr
from langchain_core.messages import HumanMessage
from common.config.embedding_config import ModelManage
from common.event import ListenerManagement
-from common.util.page_utils import page
+from common.util.page_utils import page, page_desc
from dataset.models import Paragraph, Document, Status, TaskType, State
from dataset.task.tools import save_problem
from ops import celery_app
@@ -64,7 +65,11 @@ def get_is_the_task_interrupted(document_id):
@celery_app.task(base=QueueOnce, once={'keys': ['document_id']},
name='celery:generate_related_by_document')
-def generate_related_by_document_id(document_id, model_id, prompt):
+def generate_related_by_document_id(document_id, model_id, prompt, state_list=None):
+ if state_list is None:
+ state_list = [State.PENDING.value, State.STARTED.value, State.SUCCESS.value, State.FAILURE.value,
+ State.REVOKE.value,
+ State.REVOKED.value, State.IGNORED.value]
try:
is_the_task_interrupted = get_is_the_task_interrupted(document_id)
if is_the_task_interrupted():
@@ -78,7 +83,12 @@ def generate_related_by_document_id(document_id, model_id, prompt):
generate_problem = get_generate_problem(llm_model, prompt,
ListenerManagement.get_aggregation_document_status(
document_id), is_the_task_interrupted)
- page(QuerySet(Paragraph).filter(document_id=document_id), 10, generate_problem, is_the_task_interrupted)
+ query_set = QuerySet(Paragraph).annotate(
+ reversed_status=Reverse('status'),
+ task_type_status=Substr('reversed_status', TaskType.GENERATE_PROBLEM.value,
+ 1),
+ ).filter(task_type_status__in=state_list, document_id=document_id)
+ page_desc(query_set, 10, generate_problem, is_the_task_interrupted)
except Exception as e:
max_kb_error.error(f'根据文档生成问题:{document_id}出现错误{str(e)}{traceback.format_exc()}')
max_kb_error.error(_('Generate issue based on document: {document_id} error {error}{traceback}').format(
diff --git a/apps/setting/models_provider/impl/xinference_model_provider/model/image.py b/apps/setting/models_provider/impl/xinference_model_provider/model/image.py
index f51a64ec4..a195b8649 100644
--- a/apps/setting/models_provider/impl/xinference_model_provider/model/image.py
+++ b/apps/setting/models_provider/impl/xinference_model_provider/model/image.py
@@ -1,5 +1,8 @@
-from typing import Dict
+from typing import Dict, List
+from langchain_core.messages import BaseMessage, get_buffer_string
+
+from common.config.tokenizer_manage_config import TokenizerManage
from setting.models_provider.base_model_provider import MaxKBBaseModel
from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI
@@ -18,3 +21,15 @@ class XinferenceImage(MaxKBBaseModel, BaseChatOpenAI):
stream_usage=True,
**optional_params,
)
+
+ def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
+ if self.usage_metadata is None or self.usage_metadata == {}:
+ tokenizer = TokenizerManage.get_tokenizer()
+ return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
+ return self.usage_metadata.get('input_tokens', 0)
+
+ def get_num_tokens(self, text: str) -> int:
+ if self.usage_metadata is None or self.usage_metadata == {}:
+ tokenizer = TokenizerManage.get_tokenizer()
+ return len(tokenizer.encode(text))
+ return self.get_last_generation_info().get('output_tokens', 0)
diff --git a/apps/setting/models_provider/impl/xinference_model_provider/model/llm.py b/apps/setting/models_provider/impl/xinference_model_provider/model/llm.py
index 42e098aa3..d76979bd3 100644
--- a/apps/setting/models_provider/impl/xinference_model_provider/model/llm.py
+++ b/apps/setting/models_provider/impl/xinference_model_provider/model/llm.py
@@ -1,8 +1,11 @@
# coding=utf-8
-from typing import Dict
+from typing import Dict, List
from urllib.parse import urlparse, ParseResult
+from langchain_core.messages import BaseMessage, get_buffer_string
+
+from common.config.tokenizer_manage_config import TokenizerManage
from setting.models_provider.base_model_provider import MaxKBBaseModel
from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI
@@ -33,3 +36,15 @@ class XinferenceChatModel(MaxKBBaseModel, BaseChatOpenAI):
openai_api_key=model_credential.get('api_key'),
**optional_params
)
+
+ def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
+ if self.usage_metadata is None or self.usage_metadata == {}:
+ tokenizer = TokenizerManage.get_tokenizer()
+ return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
+ return self.usage_metadata.get('input_tokens', 0)
+
+ def get_num_tokens(self, text: str) -> int:
+ if self.usage_metadata is None or self.usage_metadata == {}:
+ tokenizer = TokenizerManage.get_tokenizer()
+ return len(tokenizer.encode(text))
+ return self.get_last_generation_info().get('output_tokens', 0)
diff --git a/ui/src/components/generate-related-dialog/index.vue b/ui/src/components/generate-related-dialog/index.vue
index 8c3a3e1d3..4d485eae0 100644
--- a/ui/src/components/generate-related-dialog/index.vue
+++ b/ui/src/components/generate-related-dialog/index.vue
@@ -48,6 +48,16 @@
type="textarea"
/>
+
+
+ {{
+ $t('views.document.form.selectVectorization.error')
+ }}
+ {{
+ $t('views.document.form.selectVectorization.all')
+ }}
+
+
@@ -87,7 +97,11 @@ const dialogVisible = ref(false)
const modelOptions = ref(null)
const idList = ref([])
const apiType = ref('') // 文档document或段落paragraph
-
+const state = ref<'all' | 'error'>('error')
+const stateMap = {
+ all: ['0', '1', '2', '3', '4', '5', 'n'],
+ error: ['0', '1', '3', '4', '5', 'n']
+}
const FormRef = ref()
const userId = user.userInfo?.id as string
const form = ref(prompt.get(userId))
@@ -131,14 +145,22 @@ const submitHandle = async (formEl: FormInstance) => {
// 保存提示词
prompt.save(user.userInfo?.id as string, form.value)
if (apiType.value === 'paragraph') {
- const data = { ...form.value, paragraph_id_list: idList.value }
+ const data = {
+ ...form.value,
+ paragraph_id_list: idList.value,
+ state_list: stateMap[state.value]
+ }
paragraphApi.batchGenerateRelated(id, documentId, data, loading).then(() => {
MsgSuccess(t('views.document.generateQuestion.successMessage'))
emit('refresh')
dialogVisible.value = false
})
} else if (apiType.value === 'document') {
- const data = { ...form.value, document_id_list: idList.value }
+ const data = {
+ ...form.value,
+ document_id_list: idList.value,
+ state_list: stateMap[state.value]
+ }
documentApi.batchGenerateRelated(id, data, loading).then(() => {
MsgSuccess(t('views.document.generateQuestion.successMessage'))
emit('refresh')