feat: Generate problem support for generating unfinished paragraphs #2174 (#2299)

This commit is contained in:
shaohuzhang1 2025-02-17 15:49:41 +08:00 committed by GitHub
parent f45855c34b
commit 83cd69e5b7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 88 additions and 16 deletions

View File

@ -243,13 +243,16 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
self.is_valid(raise_exception=True)
language = get_language()
if self.data.get('type') == 'csv':
file = open(os.path.join(PROJECT_DIR, "apps", "dataset", 'template', f'csv_template_{to_locale(language)}.csv'), "rb")
file = open(
os.path.join(PROJECT_DIR, "apps", "dataset", 'template', f'csv_template_{to_locale(language)}.csv'),
"rb")
content = file.read()
file.close()
return HttpResponse(content, status=200, headers={'Content-Type': 'text/cxv',
'Content-Disposition': 'attachment; filename="csv_template.csv"'})
elif self.data.get('type') == 'excel':
file = open(os.path.join(PROJECT_DIR, "apps", "dataset", 'template', f'excel_template_{to_locale(language)}.xlsx'), "rb")
file = open(os.path.join(PROJECT_DIR, "apps", "dataset", 'template',
f'excel_template_{to_locale(language)}.xlsx'), "rb")
content = file.read()
file.close()
return HttpResponse(content, status=200, headers={'Content-Type': 'application/vnd.ms-excel',
@ -261,7 +264,8 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
language = get_language()
if self.data.get('type') == 'csv':
file = open(
os.path.join(PROJECT_DIR, "apps", "dataset", 'template', f'table_template_{to_locale(language)}.csv'),
os.path.join(PROJECT_DIR, "apps", "dataset", 'template',
f'table_template_{to_locale(language)}.csv'),
"rb")
content = file.read()
file.close()
@ -1180,7 +1184,7 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
if not QuerySet(Document).filter(id=document_id).exists():
raise AppApiException(500, _('document id not exist'))
def generate_related(self, model_id, prompt, with_valid=True):
def generate_related(self, model_id, prompt, state_list=None, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
document_id = self.data.get('document_id')
@ -1192,7 +1196,7 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
State.PENDING)
ListenerManagement.get_aggregation_document_status(document_id)()
try:
generate_related_by_document_id.delay(document_id, model_id, prompt)
generate_related_by_document_id.delay(document_id, model_id, prompt, state_list)
except AlreadyQueued as e:
raise AppApiException(500, _('The task is being executed, please do not send it again.'))
@ -1205,17 +1209,23 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
document_id_list = instance.get("document_id_list")
model_id = instance.get("model_id")
prompt = instance.get("prompt")
state_list = instance.get('state_list')
ListenerManagement.update_status(QuerySet(Document).filter(id__in=document_id_list),
TaskType.GENERATE_PROBLEM,
State.PENDING)
ListenerManagement.update_status(QuerySet(Paragraph).filter(document_id__in=document_id_list),
TaskType.GENERATE_PROBLEM,
ListenerManagement.update_status(QuerySet(Paragraph).annotate(
reversed_status=Reverse('status'),
task_type_status=Substr('reversed_status', TaskType.GENERATE_PROBLEM.value,
1),
).filter(task_type_status__in=state_list, document_id__in=document_id_list)
.values('id'),
TaskType.EMBEDDING,
State.PENDING)
ListenerManagement.get_aggregation_document_status_by_query_set(
QuerySet(Document).filter(id__in=document_id_list))()
try:
for document_id in document_id_list:
generate_related_by_document_id.delay(document_id, model_id, prompt)
generate_related_by_document_id.delay(document_id, model_id, prompt, state_list)
except AlreadyQueued as e:
pass

View File

@ -3,11 +3,12 @@ import traceback
from celery_once import QueueOnce
from django.db.models import QuerySet
from django.db.models.functions import Reverse, Substr
from langchain_core.messages import HumanMessage
from common.config.embedding_config import ModelManage
from common.event import ListenerManagement
from common.util.page_utils import page
from common.util.page_utils import page, page_desc
from dataset.models import Paragraph, Document, Status, TaskType, State
from dataset.task.tools import save_problem
from ops import celery_app
@ -64,7 +65,11 @@ def get_is_the_task_interrupted(document_id):
@celery_app.task(base=QueueOnce, once={'keys': ['document_id']},
name='celery:generate_related_by_document')
def generate_related_by_document_id(document_id, model_id, prompt):
def generate_related_by_document_id(document_id, model_id, prompt, state_list=None):
if state_list is None:
state_list = [State.PENDING.value, State.STARTED.value, State.SUCCESS.value, State.FAILURE.value,
State.REVOKE.value,
State.REVOKED.value, State.IGNORED.value]
try:
is_the_task_interrupted = get_is_the_task_interrupted(document_id)
if is_the_task_interrupted():
@ -78,7 +83,12 @@ def generate_related_by_document_id(document_id, model_id, prompt):
generate_problem = get_generate_problem(llm_model, prompt,
ListenerManagement.get_aggregation_document_status(
document_id), is_the_task_interrupted)
page(QuerySet(Paragraph).filter(document_id=document_id), 10, generate_problem, is_the_task_interrupted)
query_set = QuerySet(Paragraph).annotate(
reversed_status=Reverse('status'),
task_type_status=Substr('reversed_status', TaskType.GENERATE_PROBLEM.value,
1),
).filter(task_type_status__in=state_list, document_id=document_id)
page_desc(query_set, 10, generate_problem, is_the_task_interrupted)
except Exception as e:
max_kb_error.error(f'根据文档生成问题:{document_id}出现错误{str(e)}{traceback.format_exc()}')
max_kb_error.error(_('Generate issue based on document: {document_id} error {error}{traceback}').format(

View File

@ -1,5 +1,8 @@
from typing import Dict
from typing import Dict, List
from langchain_core.messages import BaseMessage, get_buffer_string
from common.config.tokenizer_manage_config import TokenizerManage
from setting.models_provider.base_model_provider import MaxKBBaseModel
from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI
@ -18,3 +21,15 @@ class XinferenceImage(MaxKBBaseModel, BaseChatOpenAI):
stream_usage=True,
**optional_params,
)
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
if self.usage_metadata is None or self.usage_metadata == {}:
tokenizer = TokenizerManage.get_tokenizer()
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
return self.usage_metadata.get('input_tokens', 0)
def get_num_tokens(self, text: str) -> int:
if self.usage_metadata is None or self.usage_metadata == {}:
tokenizer = TokenizerManage.get_tokenizer()
return len(tokenizer.encode(text))
return self.get_last_generation_info().get('output_tokens', 0)

View File

@ -1,8 +1,11 @@
# coding=utf-8
from typing import Dict
from typing import Dict, List
from urllib.parse import urlparse, ParseResult
from langchain_core.messages import BaseMessage, get_buffer_string
from common.config.tokenizer_manage_config import TokenizerManage
from setting.models_provider.base_model_provider import MaxKBBaseModel
from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI
@ -33,3 +36,15 @@ class XinferenceChatModel(MaxKBBaseModel, BaseChatOpenAI):
openai_api_key=model_credential.get('api_key'),
**optional_params
)
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
if self.usage_metadata is None or self.usage_metadata == {}:
tokenizer = TokenizerManage.get_tokenizer()
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
return self.usage_metadata.get('input_tokens', 0)
def get_num_tokens(self, text: str) -> int:
if self.usage_metadata is None or self.usage_metadata == {}:
tokenizer = TokenizerManage.get_tokenizer()
return len(tokenizer.encode(text))
return self.get_last_generation_info().get('output_tokens', 0)

View File

@ -48,6 +48,16 @@
type="textarea"
/>
</el-form-item>
<el-form-item :label="$t('views.problem.relateParagraph.selectParagraph')" prop="state">
<el-radio-group v-model="state" class="radio-block">
<el-radio value="error" size="large" class="mb-16">{{
$t('views.document.form.selectVectorization.error')
}}</el-radio>
<el-radio value="all" size="large">{{
$t('views.document.form.selectVectorization.all')
}}</el-radio>
</el-radio-group>
</el-form-item>
</el-form>
</div>
<template #footer>
@ -87,7 +97,11 @@ const dialogVisible = ref<boolean>(false)
const modelOptions = ref<any>(null)
const idList = ref<string[]>([])
const apiType = ref('') // documentparagraph
const state = ref<'all' | 'error'>('error')
const stateMap = {
all: ['0', '1', '2', '3', '4', '5', 'n'],
error: ['0', '1', '3', '4', '5', 'n']
}
const FormRef = ref()
const userId = user.userInfo?.id as string
const form = ref(prompt.get(userId))
@ -131,14 +145,22 @@ const submitHandle = async (formEl: FormInstance) => {
//
prompt.save(user.userInfo?.id as string, form.value)
if (apiType.value === 'paragraph') {
const data = { ...form.value, paragraph_id_list: idList.value }
const data = {
...form.value,
paragraph_id_list: idList.value,
state_list: stateMap[state.value]
}
paragraphApi.batchGenerateRelated(id, documentId, data, loading).then(() => {
MsgSuccess(t('views.document.generateQuestion.successMessage'))
emit('refresh')
dialogVisible.value = false
})
} else if (apiType.value === 'document') {
const data = { ...form.value, document_id_list: idList.value }
const data = {
...form.value,
document_id_list: idList.value,
state_list: stateMap[state.value]
}
documentApi.batchGenerateRelated(id, data, loading).then(() => {
MsgSuccess(t('views.document.generateQuestion.successMessage'))
emit('refresh')