From 6e87ceffceb1874356d00d03f4591f7421f9b37a Mon Sep 17 00:00:00 2001 From: zhangzhanwei Date: Tue, 9 Dec 2025 12:03:36 +0800 Subject: [PATCH] feat: Support tag in knowledge_workflow --- .../impl/base_knowledge_write_node.py | 107 +++++++++++++++++- apps/locales/en_US/LC_MESSAGES/django.po | 2 +- apps/locales/zh_CN/LC_MESSAGES/django.po | 4 +- apps/locales/zh_Hant/LC_MESSAGES/django.po | 4 +- .../credential/stt/default_stt.py | 4 +- .../model/stt/omni_stt.py | 4 +- 6 files changed, 113 insertions(+), 12 deletions(-) diff --git a/apps/application/flow/step_node/knowledge_write_node/impl/base_knowledge_write_node.py b/apps/application/flow/step_node/knowledge_write_node/impl/base_knowledge_write_node.py index abcec8078..41c9d5c07 100644 --- a/apps/application/flow/step_node/knowledge_write_node/impl/base_knowledge_write_node.py +++ b/apps/application/flow/step_node/knowledge_write_node/impl/base_knowledge_write_node.py @@ -7,7 +7,7 @@ @desc: """ from functools import reduce -from typing import Dict, List +from typing import Dict, List, Any import uuid_utils.compat as uuid from django.db.models import QuerySet from django.db.models.aggregates import Max @@ -18,7 +18,8 @@ from application.flow.i_step_node import NodeResult from application.flow.step_node.knowledge_write_node.i_knowledge_write_node import IKnowledgeWriteNode from common.chunk import text_to_chunk from common.utils.common import bulk_create_in_batches -from knowledge.models import Document, KnowledgeType, Paragraph, File, FileSourceType, Problem, ProblemParagraphMapping +from knowledge.models import Document, KnowledgeType, Paragraph, File, FileSourceType, Problem, ProblemParagraphMapping, \ + Tag, DocumentTag from knowledge.serializers.common import ProblemParagraphObject, ProblemParagraphManage from knowledge.serializers.document import DocumentSerializers @@ -33,10 +34,16 @@ class ParagraphInstanceSerializer(serializers.Serializer): chunks = serializers.ListField(required=False, child=serializers.CharField(required=True)) +class TagInstanceSerializer(serializers.Serializer): + key = serializers.CharField(required=True, max_length=64, label=_('Tag Key')) + value = serializers.CharField(required=True, max_length=128, label=_('Tag Value')) + + class KnowledgeWriteParamSerializer(serializers.Serializer): name = serializers.CharField(required=True, label=_('document name'), max_length=128, min_length=1, source=_('document name')) meta = serializers.DictField(required=False) + tags = serializers.ListField(required=False, label=_('Tags'), child=TagInstanceSerializer()) paragraphs = ParagraphInstanceSerializer(required=False, many=True, allow_null=True) source_file_id = serializers.UUIDField(required=False, allow_null=True) @@ -51,6 +58,7 @@ def convert_uuid_to_str(obj): else: return obj + def link_file(source_file_id, document_id): if source_file_id is None: return @@ -70,6 +78,7 @@ def link_file(source_file_id, document_id): # 保存文件内容和元数据 new_file.save(file_content) + def get_paragraph_problem_model(knowledge_id: str, document_id: str, instance: Dict): paragraph = Paragraph( id=uuid.uuid7(), @@ -77,7 +86,7 @@ def get_paragraph_problem_model(knowledge_id: str, document_id: str, instance: D content=instance.get("content"), knowledge_id=knowledge_id, title=instance.get("title") if 'title' in instance else '', - chunks = instance.get('chunks') if 'chunks' in instance else text_to_chunk(instance.get("content")), + chunks=instance.get('chunks') if 'chunks' in instance else text_to_chunk(instance.get("content")), ) problem_paragraph_object_list = [ProblemParagraphObject( @@ -136,6 +145,65 @@ def get_document_paragraph_model(knowledge_id: str, instance: Dict): instance.get('paragraphs') if 'paragraphs' in instance else [] ) +def save_knowledge_tags(knowledge_id: str, tags: List[Dict[str,Any]]): + + existed_tags_dict = { + (key, value): str(tag_id) + for key,value,tag_id in QuerySet(Tag).filter(knowledge_id=knowledge_id).values_list("key", "value", "id") + } + + tag_model_list = [] + new_tag_dict = {} + for tag in tags: + key = tag.get("key") + value = tag.get("value") + + if (key,value) not in existed_tags_dict: + tag_model = Tag( + id=uuid.uuid7(), + knowledge_id=knowledge_id, + key=key, + value=value + ) + tag_model_list.append(tag_model) + new_tag_dict[(key,value)] = str(tag_model.id) + + if tag_model_list: + Tag.objects.bulk_create(tag_model_list) + + all_tag_dict={**existed_tags_dict,**new_tag_dict} + + return all_tag_dict, new_tag_dict + +def batch_add_document_tag(document_tag_map: Dict[str, List[str]]): + """ + 批量添加文档-标签关联 + document_tag_map: {document_id: [tag_id1, tag_id2, ...]} + """ + all_document_ids = list(document_tag_map.keys()) + all_tag_ids = list(set(tag_id for tag_ids in document_tag_map.values() for tag_id in tag_ids)) + + # 查询已存在的文档-标签关联 + existed_relations = set( + QuerySet(DocumentTag).filter( + document_id__in=all_document_ids, + tag_id__in=all_tag_ids + ).values_list('document_id', 'tag_id') + ) + + new_relations = [ + DocumentTag( + id=uuid.uuid7(), + document_id=doc_id, + tag_id=tag_id, + ) + for doc_id, tag_ids in document_tag_map.items() + for tag_id in tag_ids + if (doc_id,tag_id) not in existed_relations + ] + + if new_relations: + QuerySet(DocumentTag).bulk_create(new_relations) class BaseKnowledgeWriteNode(IKnowledgeWriteNode): @@ -153,6 +221,11 @@ class BaseKnowledgeWriteNode(IKnowledgeWriteNode): document_model_list = [] paragraph_model_list = [] problem_paragraph_object_list = [] + # 所有标签 + knowledge_tag_list = [] + # 文档标签映射关系 + document_tags_map = {} + knowledge_tag_dict = {} for document in document_list: document_paragraph_dict_model = get_document_paragraph_model( @@ -162,10 +235,38 @@ class BaseKnowledgeWriteNode(IKnowledgeWriteNode): document_instance = document_paragraph_dict_model.get('document') link_file(document.get("source_file_id"), document_instance.id) document_model_list.append(document_instance) + # 收集标签 + single_document_tag_list = document.get("tags", []) + # 去重传入的标签 + for tag in single_document_tag_list: + tag_key = (tag['key'], tag['value']) + if tag_key not in knowledge_tag_dict: + knowledge_tag_dict[tag_key]= tag + + if single_document_tag_list: + document_tags_map[str(document_instance.id)] = single_document_tag_list + for paragraph in document_paragraph_dict_model.get("paragraph_model_list"): paragraph_model_list.append(paragraph) for problem_paragraph_object in document_paragraph_dict_model.get("problem_paragraph_object_list"): problem_paragraph_object_list.append(problem_paragraph_object) + knowledge_tag_list = list(knowledge_tag_dict.values()) + # 保存所有文档中含有的标签到知识库 + if knowledge_tag_list: + all_tag_dict, new_tag_dict = save_knowledge_tags(knowledge_id, knowledge_tag_list) + # 构建文档-标签ID映射 + document_tag_id_map = {} + # 为每个文档添加其对应的标签 + for doc_id, doc_tags in document_tags_map.items(): + doc_tag_ids = [ + all_tag_dict[(tag.get("key"),tag.get("value"))] + for tag in doc_tags + if (tag.get("key"),tag.get("value")) in all_tag_dict + ] + if doc_tag_ids: + document_tag_id_map[doc_id] = doc_tag_ids + if document_tag_id_map: + batch_add_document_tag(document_tag_id_map) problem_model_list, problem_paragraph_mapping_list = ( ProblemParagraphManage(problem_paragraph_object_list, knowledge_id).to_problem_model_list() diff --git a/apps/locales/en_US/LC_MESSAGES/django.po b/apps/locales/en_US/LC_MESSAGES/django.po index 3e7319209..2a60f5396 100644 --- a/apps/locales/en_US/LC_MESSAGES/django.po +++ b/apps/locales/en_US/LC_MESSAGES/django.po @@ -8810,7 +8810,7 @@ msgstr "" msgid "Audio file recognition - Tongyi Qwen" msgstr "" -msgid "Audio file recognition - Fun-ASR/Paraformer/SenseVoice" +msgid "Real-time speech recognition - Fun-ASR/Paraformer" msgstr "" msgid "Qwen-Omni" diff --git a/apps/locales/zh_CN/LC_MESSAGES/django.po b/apps/locales/zh_CN/LC_MESSAGES/django.po index f65464371..c4674986f 100644 --- a/apps/locales/zh_CN/LC_MESSAGES/django.po +++ b/apps/locales/zh_CN/LC_MESSAGES/django.po @@ -8936,8 +8936,8 @@ msgstr "阻止不安全的重定向到内部主机" msgid "Audio file recognition - Tongyi Qwen" msgstr "录音文件识别-通义千问" -msgid "Audio file recognition - Fun-ASR/Paraformer/SenseVoice" -msgstr "录音文件识别-Fun-ASR/Paraformer/SenseVoice" +msgid "Real-time speech recognition - Fun-ASR/Paraformer" +msgstr "实时语音识别-Fun-ASR/Paraformer" msgid "Qwen-Omni" msgstr "多模态" \ No newline at end of file diff --git a/apps/locales/zh_Hant/LC_MESSAGES/django.po b/apps/locales/zh_Hant/LC_MESSAGES/django.po index b9b712603..5f49094ab 100644 --- a/apps/locales/zh_Hant/LC_MESSAGES/django.po +++ b/apps/locales/zh_Hant/LC_MESSAGES/django.po @@ -8936,8 +8936,8 @@ msgstr "阻止不安全的重定向到內部主機" msgid "Audio file recognition - Tongyi Qwen" msgstr "錄音文件識別-通義千問" -msgid "Audio file recognition - Fun-ASR/Paraformer/SenseVoice" -msgstr "錄音文件識別-Fun-ASR/Paraformer/SenseVoice" +msgid "Real-time speech recognition - Fun-ASR/Paraformer" +msgstr "實時語音識別-Fun-ASR/Paraformer" msgid "Qwen-Omni" msgstr "多模態" \ No newline at end of file diff --git a/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/stt/default_stt.py b/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/stt/default_stt.py index 1a7b93967..06f410b69 100644 --- a/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/stt/default_stt.py +++ b/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/stt/default_stt.py @@ -18,13 +18,13 @@ from django.utils.translation import gettext as _ class AliyunBaiLianDefaultSTTModelCredential(BaseForm, BaseModelCredential): - type = forms.Radio(_("Type"), required=True, text_field='label', default_value='qwen', provider='', method='', + type = forms.SingleSelect(_("API"), required=True, text_field='label', default_value='qwen', provider='', method='', value_field='value', option_list=[ {'label': _('Audio file recognition - Tongyi Qwen'), 'value': 'qwen'}, {'label': _('Qwen-Omni'), 'value': 'omni'}, - {'label': _('Audio file recognition - Fun-ASR/Paraformer/SenseVoice'), + {'label': _('Real-time speech recognition - Fun-ASR/Paraformer'), 'value': 'other'} ]) api_url = forms.TextInputField(_('API URL'), required=True, relation_show_field_dict={'type': ['qwen', 'omni']}) diff --git a/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/stt/omni_stt.py b/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/stt/omni_stt.py index 1951953e2..de2ff64d2 100644 --- a/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/stt/omni_stt.py +++ b/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/stt/omni_stt.py @@ -68,7 +68,7 @@ class AliyunBaiLianOmiSpeechToText(MaxKBBaseModel, BaseSpeechToText): "format": "mp3", }, }, - {"type": "text", "text": self.params.get('CueWord')}, + {"type": "text", "text": self.params.get('CueWord') or '这段音频在说什么'}, ], }, ], @@ -77,7 +77,7 @@ class AliyunBaiLianOmiSpeechToText(MaxKBBaseModel, BaseSpeechToText): # stream 必须设置为 True,否则会报错 stream=True, stream_options={"include_usage": True}, - extra_body=self.params + extra_body = {'enable_thinking': False, **self.params}, ) result = [] for chunk in completion: