feat: Support tag in knowledge_workflow

This commit is contained in:
zhangzhanwei 2025-12-09 12:03:36 +08:00 committed by zhanweizhang7
parent 735d3d0121
commit 6e87ceffce
6 changed files with 113 additions and 12 deletions

View File

@ -7,7 +7,7 @@
@desc:
"""
from functools import reduce
from typing import Dict, List
from typing import Dict, List, Any
import uuid_utils.compat as uuid
from django.db.models import QuerySet
from django.db.models.aggregates import Max
@ -18,7 +18,8 @@ from application.flow.i_step_node import NodeResult
from application.flow.step_node.knowledge_write_node.i_knowledge_write_node import IKnowledgeWriteNode
from common.chunk import text_to_chunk
from common.utils.common import bulk_create_in_batches
from knowledge.models import Document, KnowledgeType, Paragraph, File, FileSourceType, Problem, ProblemParagraphMapping
from knowledge.models import Document, KnowledgeType, Paragraph, File, FileSourceType, Problem, ProblemParagraphMapping, \
Tag, DocumentTag
from knowledge.serializers.common import ProblemParagraphObject, ProblemParagraphManage
from knowledge.serializers.document import DocumentSerializers
@ -33,10 +34,16 @@ class ParagraphInstanceSerializer(serializers.Serializer):
chunks = serializers.ListField(required=False, child=serializers.CharField(required=True))
class TagInstanceSerializer(serializers.Serializer):
key = serializers.CharField(required=True, max_length=64, label=_('Tag Key'))
value = serializers.CharField(required=True, max_length=128, label=_('Tag Value'))
class KnowledgeWriteParamSerializer(serializers.Serializer):
name = serializers.CharField(required=True, label=_('document name'), max_length=128, min_length=1,
source=_('document name'))
meta = serializers.DictField(required=False)
tags = serializers.ListField(required=False, label=_('Tags'), child=TagInstanceSerializer())
paragraphs = ParagraphInstanceSerializer(required=False, many=True, allow_null=True)
source_file_id = serializers.UUIDField(required=False, allow_null=True)
@ -51,6 +58,7 @@ def convert_uuid_to_str(obj):
else:
return obj
def link_file(source_file_id, document_id):
if source_file_id is None:
return
@ -70,6 +78,7 @@ def link_file(source_file_id, document_id):
# 保存文件内容和元数据
new_file.save(file_content)
def get_paragraph_problem_model(knowledge_id: str, document_id: str, instance: Dict):
paragraph = Paragraph(
id=uuid.uuid7(),
@ -77,7 +86,7 @@ def get_paragraph_problem_model(knowledge_id: str, document_id: str, instance: D
content=instance.get("content"),
knowledge_id=knowledge_id,
title=instance.get("title") if 'title' in instance else '',
chunks = instance.get('chunks') if 'chunks' in instance else text_to_chunk(instance.get("content")),
chunks=instance.get('chunks') if 'chunks' in instance else text_to_chunk(instance.get("content")),
)
problem_paragraph_object_list = [ProblemParagraphObject(
@ -136,6 +145,65 @@ def get_document_paragraph_model(knowledge_id: str, instance: Dict):
instance.get('paragraphs') if 'paragraphs' in instance else []
)
def save_knowledge_tags(knowledge_id: str, tags: List[Dict[str,Any]]):
existed_tags_dict = {
(key, value): str(tag_id)
for key,value,tag_id in QuerySet(Tag).filter(knowledge_id=knowledge_id).values_list("key", "value", "id")
}
tag_model_list = []
new_tag_dict = {}
for tag in tags:
key = tag.get("key")
value = tag.get("value")
if (key,value) not in existed_tags_dict:
tag_model = Tag(
id=uuid.uuid7(),
knowledge_id=knowledge_id,
key=key,
value=value
)
tag_model_list.append(tag_model)
new_tag_dict[(key,value)] = str(tag_model.id)
if tag_model_list:
Tag.objects.bulk_create(tag_model_list)
all_tag_dict={**existed_tags_dict,**new_tag_dict}
return all_tag_dict, new_tag_dict
def batch_add_document_tag(document_tag_map: Dict[str, List[str]]):
"""
批量添加文档-标签关联
document_tag_map: {document_id: [tag_id1, tag_id2, ...]}
"""
all_document_ids = list(document_tag_map.keys())
all_tag_ids = list(set(tag_id for tag_ids in document_tag_map.values() for tag_id in tag_ids))
# 查询已存在的文档-标签关联
existed_relations = set(
QuerySet(DocumentTag).filter(
document_id__in=all_document_ids,
tag_id__in=all_tag_ids
).values_list('document_id', 'tag_id')
)
new_relations = [
DocumentTag(
id=uuid.uuid7(),
document_id=doc_id,
tag_id=tag_id,
)
for doc_id, tag_ids in document_tag_map.items()
for tag_id in tag_ids
if (doc_id,tag_id) not in existed_relations
]
if new_relations:
QuerySet(DocumentTag).bulk_create(new_relations)
class BaseKnowledgeWriteNode(IKnowledgeWriteNode):
@ -153,6 +221,11 @@ class BaseKnowledgeWriteNode(IKnowledgeWriteNode):
document_model_list = []
paragraph_model_list = []
problem_paragraph_object_list = []
# 所有标签
knowledge_tag_list = []
# 文档标签映射关系
document_tags_map = {}
knowledge_tag_dict = {}
for document in document_list:
document_paragraph_dict_model = get_document_paragraph_model(
@ -162,10 +235,38 @@ class BaseKnowledgeWriteNode(IKnowledgeWriteNode):
document_instance = document_paragraph_dict_model.get('document')
link_file(document.get("source_file_id"), document_instance.id)
document_model_list.append(document_instance)
# 收集标签
single_document_tag_list = document.get("tags", [])
# 去重传入的标签
for tag in single_document_tag_list:
tag_key = (tag['key'], tag['value'])
if tag_key not in knowledge_tag_dict:
knowledge_tag_dict[tag_key]= tag
if single_document_tag_list:
document_tags_map[str(document_instance.id)] = single_document_tag_list
for paragraph in document_paragraph_dict_model.get("paragraph_model_list"):
paragraph_model_list.append(paragraph)
for problem_paragraph_object in document_paragraph_dict_model.get("problem_paragraph_object_list"):
problem_paragraph_object_list.append(problem_paragraph_object)
knowledge_tag_list = list(knowledge_tag_dict.values())
# 保存所有文档中含有的标签到知识库
if knowledge_tag_list:
all_tag_dict, new_tag_dict = save_knowledge_tags(knowledge_id, knowledge_tag_list)
# 构建文档-标签ID映射
document_tag_id_map = {}
# 为每个文档添加其对应的标签
for doc_id, doc_tags in document_tags_map.items():
doc_tag_ids = [
all_tag_dict[(tag.get("key"),tag.get("value"))]
for tag in doc_tags
if (tag.get("key"),tag.get("value")) in all_tag_dict
]
if doc_tag_ids:
document_tag_id_map[doc_id] = doc_tag_ids
if document_tag_id_map:
batch_add_document_tag(document_tag_id_map)
problem_model_list, problem_paragraph_mapping_list = (
ProblemParagraphManage(problem_paragraph_object_list, knowledge_id).to_problem_model_list()

View File

@ -8810,7 +8810,7 @@ msgstr ""
msgid "Audio file recognition - Tongyi Qwen"
msgstr ""
msgid "Audio file recognition - Fun-ASR/Paraformer/SenseVoice"
msgid "Real-time speech recognition - Fun-ASR/Paraformer"
msgstr ""
msgid "Qwen-Omni"

View File

@ -8936,8 +8936,8 @@ msgstr "阻止不安全的重定向到内部主机"
msgid "Audio file recognition - Tongyi Qwen"
msgstr "录音文件识别-通义千问"
msgid "Audio file recognition - Fun-ASR/Paraformer/SenseVoice"
msgstr "录音文件识别-Fun-ASR/Paraformer/SenseVoice"
msgid "Real-time speech recognition - Fun-ASR/Paraformer"
msgstr "实时语音识别-Fun-ASR/Paraformer"
msgid "Qwen-Omni"
msgstr "多模态"

View File

@ -8936,8 +8936,8 @@ msgstr "阻止不安全的重定向到內部主機"
msgid "Audio file recognition - Tongyi Qwen"
msgstr "錄音文件識別-通義千問"
msgid "Audio file recognition - Fun-ASR/Paraformer/SenseVoice"
msgstr "錄音文件識別-Fun-ASR/Paraformer/SenseVoice"
msgid "Real-time speech recognition - Fun-ASR/Paraformer"
msgstr "實時語音識別-Fun-ASR/Paraformer"
msgid "Qwen-Omni"
msgstr "多模態"

View File

@ -18,13 +18,13 @@ from django.utils.translation import gettext as _
class AliyunBaiLianDefaultSTTModelCredential(BaseForm, BaseModelCredential):
type = forms.Radio(_("Type"), required=True, text_field='label', default_value='qwen', provider='', method='',
type = forms.SingleSelect(_("API"), required=True, text_field='label', default_value='qwen', provider='', method='',
value_field='value', option_list=[
{'label': _('Audio file recognition - Tongyi Qwen'),
'value': 'qwen'},
{'label': _('Qwen-Omni'),
'value': 'omni'},
{'label': _('Audio file recognition - Fun-ASR/Paraformer/SenseVoice'),
{'label': _('Real-time speech recognition - Fun-ASR/Paraformer'),
'value': 'other'}
])
api_url = forms.TextInputField(_('API URL'), required=True, relation_show_field_dict={'type': ['qwen', 'omni']})

View File

@ -68,7 +68,7 @@ class AliyunBaiLianOmiSpeechToText(MaxKBBaseModel, BaseSpeechToText):
"format": "mp3",
},
},
{"type": "text", "text": self.params.get('CueWord')},
{"type": "text", "text": self.params.get('CueWord') or '这段音频在说什么'},
],
},
],
@ -77,7 +77,7 @@ class AliyunBaiLianOmiSpeechToText(MaxKBBaseModel, BaseSpeechToText):
# stream 必须设置为 True否则会报错
stream=True,
stream_options={"include_usage": True},
extra_body=self.params
extra_body = {'enable_thinking': False, **self.params},
)
result = []
for chunk in completion: