mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
This commit is contained in:
parent
f45855c34b
commit
83cd69e5b7
|
|
@ -243,13 +243,16 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||
self.is_valid(raise_exception=True)
|
||||
language = get_language()
|
||||
if self.data.get('type') == 'csv':
|
||||
file = open(os.path.join(PROJECT_DIR, "apps", "dataset", 'template', f'csv_template_{to_locale(language)}.csv'), "rb")
|
||||
file = open(
|
||||
os.path.join(PROJECT_DIR, "apps", "dataset", 'template', f'csv_template_{to_locale(language)}.csv'),
|
||||
"rb")
|
||||
content = file.read()
|
||||
file.close()
|
||||
return HttpResponse(content, status=200, headers={'Content-Type': 'text/cxv',
|
||||
'Content-Disposition': 'attachment; filename="csv_template.csv"'})
|
||||
elif self.data.get('type') == 'excel':
|
||||
file = open(os.path.join(PROJECT_DIR, "apps", "dataset", 'template', f'excel_template_{to_locale(language)}.xlsx'), "rb")
|
||||
file = open(os.path.join(PROJECT_DIR, "apps", "dataset", 'template',
|
||||
f'excel_template_{to_locale(language)}.xlsx'), "rb")
|
||||
content = file.read()
|
||||
file.close()
|
||||
return HttpResponse(content, status=200, headers={'Content-Type': 'application/vnd.ms-excel',
|
||||
|
|
@ -261,7 +264,8 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||
language = get_language()
|
||||
if self.data.get('type') == 'csv':
|
||||
file = open(
|
||||
os.path.join(PROJECT_DIR, "apps", "dataset", 'template', f'table_template_{to_locale(language)}.csv'),
|
||||
os.path.join(PROJECT_DIR, "apps", "dataset", 'template',
|
||||
f'table_template_{to_locale(language)}.csv'),
|
||||
"rb")
|
||||
content = file.read()
|
||||
file.close()
|
||||
|
|
@ -1180,7 +1184,7 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||
if not QuerySet(Document).filter(id=document_id).exists():
|
||||
raise AppApiException(500, _('document id not exist'))
|
||||
|
||||
def generate_related(self, model_id, prompt, with_valid=True):
|
||||
def generate_related(self, model_id, prompt, state_list=None, with_valid=True):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
document_id = self.data.get('document_id')
|
||||
|
|
@ -1192,7 +1196,7 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||
State.PENDING)
|
||||
ListenerManagement.get_aggregation_document_status(document_id)()
|
||||
try:
|
||||
generate_related_by_document_id.delay(document_id, model_id, prompt)
|
||||
generate_related_by_document_id.delay(document_id, model_id, prompt, state_list)
|
||||
except AlreadyQueued as e:
|
||||
raise AppApiException(500, _('The task is being executed, please do not send it again.'))
|
||||
|
||||
|
|
@ -1205,17 +1209,23 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||
document_id_list = instance.get("document_id_list")
|
||||
model_id = instance.get("model_id")
|
||||
prompt = instance.get("prompt")
|
||||
state_list = instance.get('state_list')
|
||||
ListenerManagement.update_status(QuerySet(Document).filter(id__in=document_id_list),
|
||||
TaskType.GENERATE_PROBLEM,
|
||||
State.PENDING)
|
||||
ListenerManagement.update_status(QuerySet(Paragraph).filter(document_id__in=document_id_list),
|
||||
TaskType.GENERATE_PROBLEM,
|
||||
ListenerManagement.update_status(QuerySet(Paragraph).annotate(
|
||||
reversed_status=Reverse('status'),
|
||||
task_type_status=Substr('reversed_status', TaskType.GENERATE_PROBLEM.value,
|
||||
1),
|
||||
).filter(task_type_status__in=state_list, document_id__in=document_id_list)
|
||||
.values('id'),
|
||||
TaskType.EMBEDDING,
|
||||
State.PENDING)
|
||||
ListenerManagement.get_aggregation_document_status_by_query_set(
|
||||
QuerySet(Document).filter(id__in=document_id_list))()
|
||||
try:
|
||||
for document_id in document_id_list:
|
||||
generate_related_by_document_id.delay(document_id, model_id, prompt)
|
||||
generate_related_by_document_id.delay(document_id, model_id, prompt, state_list)
|
||||
except AlreadyQueued as e:
|
||||
pass
|
||||
|
||||
|
|
|
|||
|
|
@ -3,11 +3,12 @@ import traceback
|
|||
|
||||
from celery_once import QueueOnce
|
||||
from django.db.models import QuerySet
|
||||
from django.db.models.functions import Reverse, Substr
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from common.config.embedding_config import ModelManage
|
||||
from common.event import ListenerManagement
|
||||
from common.util.page_utils import page
|
||||
from common.util.page_utils import page, page_desc
|
||||
from dataset.models import Paragraph, Document, Status, TaskType, State
|
||||
from dataset.task.tools import save_problem
|
||||
from ops import celery_app
|
||||
|
|
@ -64,7 +65,11 @@ def get_is_the_task_interrupted(document_id):
|
|||
|
||||
@celery_app.task(base=QueueOnce, once={'keys': ['document_id']},
|
||||
name='celery:generate_related_by_document')
|
||||
def generate_related_by_document_id(document_id, model_id, prompt):
|
||||
def generate_related_by_document_id(document_id, model_id, prompt, state_list=None):
|
||||
if state_list is None:
|
||||
state_list = [State.PENDING.value, State.STARTED.value, State.SUCCESS.value, State.FAILURE.value,
|
||||
State.REVOKE.value,
|
||||
State.REVOKED.value, State.IGNORED.value]
|
||||
try:
|
||||
is_the_task_interrupted = get_is_the_task_interrupted(document_id)
|
||||
if is_the_task_interrupted():
|
||||
|
|
@ -78,7 +83,12 @@ def generate_related_by_document_id(document_id, model_id, prompt):
|
|||
generate_problem = get_generate_problem(llm_model, prompt,
|
||||
ListenerManagement.get_aggregation_document_status(
|
||||
document_id), is_the_task_interrupted)
|
||||
page(QuerySet(Paragraph).filter(document_id=document_id), 10, generate_problem, is_the_task_interrupted)
|
||||
query_set = QuerySet(Paragraph).annotate(
|
||||
reversed_status=Reverse('status'),
|
||||
task_type_status=Substr('reversed_status', TaskType.GENERATE_PROBLEM.value,
|
||||
1),
|
||||
).filter(task_type_status__in=state_list, document_id=document_id)
|
||||
page_desc(query_set, 10, generate_problem, is_the_task_interrupted)
|
||||
except Exception as e:
|
||||
max_kb_error.error(f'根据文档生成问题:{document_id}出现错误{str(e)}{traceback.format_exc()}')
|
||||
max_kb_error.error(_('Generate issue based on document: {document_id} error {error}{traceback}').format(
|
||||
|
|
|
|||
|
|
@ -1,5 +1,8 @@
|
|||
from typing import Dict
|
||||
from typing import Dict, List
|
||||
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
|
||||
from common.config.tokenizer_manage_config import TokenizerManage
|
||||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||
from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI
|
||||
|
||||
|
|
@ -18,3 +21,15 @@ class XinferenceImage(MaxKBBaseModel, BaseChatOpenAI):
|
|||
stream_usage=True,
|
||||
**optional_params,
|
||||
)
|
||||
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
if self.usage_metadata is None or self.usage_metadata == {}:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
||||
return self.usage_metadata.get('input_tokens', 0)
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
if self.usage_metadata is None or self.usage_metadata == {}:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return len(tokenizer.encode(text))
|
||||
return self.get_last_generation_info().get('output_tokens', 0)
|
||||
|
|
|
|||
|
|
@ -1,8 +1,11 @@
|
|||
# coding=utf-8
|
||||
|
||||
from typing import Dict
|
||||
from typing import Dict, List
|
||||
from urllib.parse import urlparse, ParseResult
|
||||
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
|
||||
from common.config.tokenizer_manage_config import TokenizerManage
|
||||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||
from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI
|
||||
|
||||
|
|
@ -33,3 +36,15 @@ class XinferenceChatModel(MaxKBBaseModel, BaseChatOpenAI):
|
|||
openai_api_key=model_credential.get('api_key'),
|
||||
**optional_params
|
||||
)
|
||||
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
if self.usage_metadata is None or self.usage_metadata == {}:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
||||
return self.usage_metadata.get('input_tokens', 0)
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
if self.usage_metadata is None or self.usage_metadata == {}:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return len(tokenizer.encode(text))
|
||||
return self.get_last_generation_info().get('output_tokens', 0)
|
||||
|
|
|
|||
|
|
@ -48,6 +48,16 @@
|
|||
type="textarea"
|
||||
/>
|
||||
</el-form-item>
|
||||
<el-form-item :label="$t('views.problem.relateParagraph.selectParagraph')" prop="state">
|
||||
<el-radio-group v-model="state" class="radio-block">
|
||||
<el-radio value="error" size="large" class="mb-16">{{
|
||||
$t('views.document.form.selectVectorization.error')
|
||||
}}</el-radio>
|
||||
<el-radio value="all" size="large">{{
|
||||
$t('views.document.form.selectVectorization.all')
|
||||
}}</el-radio>
|
||||
</el-radio-group>
|
||||
</el-form-item>
|
||||
</el-form>
|
||||
</div>
|
||||
<template #footer>
|
||||
|
|
@ -87,7 +97,11 @@ const dialogVisible = ref<boolean>(false)
|
|||
const modelOptions = ref<any>(null)
|
||||
const idList = ref<string[]>([])
|
||||
const apiType = ref('') // 文档document或段落paragraph
|
||||
|
||||
const state = ref<'all' | 'error'>('error')
|
||||
const stateMap = {
|
||||
all: ['0', '1', '2', '3', '4', '5', 'n'],
|
||||
error: ['0', '1', '3', '4', '5', 'n']
|
||||
}
|
||||
const FormRef = ref()
|
||||
const userId = user.userInfo?.id as string
|
||||
const form = ref(prompt.get(userId))
|
||||
|
|
@ -131,14 +145,22 @@ const submitHandle = async (formEl: FormInstance) => {
|
|||
// 保存提示词
|
||||
prompt.save(user.userInfo?.id as string, form.value)
|
||||
if (apiType.value === 'paragraph') {
|
||||
const data = { ...form.value, paragraph_id_list: idList.value }
|
||||
const data = {
|
||||
...form.value,
|
||||
paragraph_id_list: idList.value,
|
||||
state_list: stateMap[state.value]
|
||||
}
|
||||
paragraphApi.batchGenerateRelated(id, documentId, data, loading).then(() => {
|
||||
MsgSuccess(t('views.document.generateQuestion.successMessage'))
|
||||
emit('refresh')
|
||||
dialogVisible.value = false
|
||||
})
|
||||
} else if (apiType.value === 'document') {
|
||||
const data = { ...form.value, document_id_list: idList.value }
|
||||
const data = {
|
||||
...form.value,
|
||||
document_id_list: idList.value,
|
||||
state_list: stateMap[state.value]
|
||||
}
|
||||
documentApi.batchGenerateRelated(id, data, loading).then(() => {
|
||||
MsgSuccess(t('views.document.generateQuestion.successMessage'))
|
||||
emit('refresh')
|
||||
|
|
|
|||
Loading…
Reference in New Issue