From 75b9b17e2ec6c6aba7e5b9d8457bdc2e24d6af15 Mon Sep 17 00:00:00 2001 From: shaohuzhang1 Date: Thu, 18 Jul 2024 10:26:16 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81=E5=90=91=E9=87=8F?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../serializers/chat_serializers.py | 8 +++--- apps/common/event/common.py | 8 +++--- apps/common/event/listener_manage.py | 6 ++--- .../dataset/serializers/common_serializers.py | 4 +-- .../serializers/dataset_serializers.py | 3 ++- .../serializers/document_serializers.py | 25 ++++++++++++------- .../serializers/paragraph_serializers.py | 18 +++++++------ 7 files changed, 43 insertions(+), 29 deletions(-) diff --git a/apps/application/serializers/chat_serializers.py b/apps/application/serializers/chat_serializers.py index 402eff691..48b0ca9ff 100644 --- a/apps/application/serializers/chat_serializers.py +++ b/apps/application/serializers/chat_serializers.py @@ -39,6 +39,7 @@ from common.util.file_util import get_file_content from common.util.lock import try_lock, un_lock from common.util.rsa_util import rsa_long_decrypt from dataset.models import Document, Problem, Paragraph, ProblemParagraphMapping +from dataset.serializers.common_serializers import get_embedding_model_by_dataset_id from dataset.serializers.paragraph_serializers import ParagraphSerializers from setting.models import Model from setting.models_provider.constants.model_provider_constants import ModelProvideConstants @@ -533,9 +534,10 @@ class ChatRecordSerializer(serializers.Serializer): raise AppApiException(500, "文档id不正确") @staticmethod - def post_embedding_paragraph(chat_record, paragraph_id): + def post_embedding_paragraph(chat_record, paragraph_id, dataset_id): + model = get_embedding_model_by_dataset_id(dataset_id) # 发送向量化事件 - ListenerManagement.embedding_by_paragraph_signal.send(paragraph_id) + ListenerManagement.embedding_by_paragraph_signal.send(paragraph_id, embedding_model=model) return chat_record @post(post_function=post_embedding_paragraph) @@ -573,7 +575,7 @@ class ChatRecordSerializer(serializers.Serializer): chat_record.improve_paragraph_id_list.append(paragraph.id) # 添加标注 chat_record.save() - return ChatRecordSerializerModel(chat_record).data, paragraph.id + return ChatRecordSerializerModel(chat_record).data, paragraph.id, dataset_id class Operate(serializers.Serializer): chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("对话id")) diff --git a/apps/common/event/common.py b/apps/common/event/common.py index e35123758..9f4a945bf 100644 --- a/apps/common/event/common.py +++ b/apps/common/event/common.py @@ -14,14 +14,14 @@ embedding_thread_pool = ThreadPoolExecutor(3) def poxy(poxy_function): - def inner(args): - work_thread_pool.submit(poxy_function, args) + def inner(args, **keywords): + work_thread_pool.submit(poxy_function, args, **keywords) return inner def embedding_poxy(poxy_function): - def inner(args): - embedding_thread_pool.submit(poxy_function, args) + def inner(args, **keywords): + embedding_thread_pool.submit(poxy_function, args, **keywords) return inner diff --git a/apps/common/event/listener_manage.py b/apps/common/event/listener_manage.py index 2df7b3415..e1205f3af 100644 --- a/apps/common/event/listener_manage.py +++ b/apps/common/event/listener_manage.py @@ -85,8 +85,8 @@ class ListenerManagement: delete_embedding_by_dataset_id_list_signal = signal("delete_embedding_by_dataset_id_list") @staticmethod - def embedding_by_problem(args): - VectorStore.get_embedding_vector().save(**args) + def embedding_by_problem(args, embedding_model: Embeddings): + VectorStore.get_embedding_vector().save(**args, embedding=embedding_model) @staticmethod @embedding_poxy @@ -165,7 +165,7 @@ class ListenerManagement: document_list = QuerySet(Document).filter(dataset_id=dataset_id) max_kb.info(f"数据集文档:{[d.name for d in document_list]}") for document in document_list: - ListenerManagement.embedding_by_document(document.id, embedding_model) + ListenerManagement.embedding_by_document(document.id, embedding_model=embedding_model) except Exception as e: max_kb_error.error(f'向量化数据集:{dataset_id}出现错误{str(e)}{traceback.format_exc()}') finally: diff --git a/apps/dataset/serializers/common_serializers.py b/apps/dataset/serializers/common_serializers.py index 045588f6d..887772f11 100644 --- a/apps/dataset/serializers/common_serializers.py +++ b/apps/dataset/serializers/common_serializers.py @@ -145,5 +145,5 @@ def get_embedding_model_by_dataset_id_list(dataset_id_list: List): def get_embedding_model_by_dataset_id(dataset_id: str): - dataset = QuerySet(DataSet).select_related('embedding_mode').filter(id=dataset_id) - return EmbeddingModelManage.get_model(str(dataset.id), lambda _id: get_model(dataset.embedding_mode)) + dataset = QuerySet(DataSet).select_related('embedding_mode').filter(id=dataset_id).first() + return EmbeddingModelManage.get_model(dataset_id, lambda _id: get_model(dataset.embedding_mode)) diff --git a/apps/dataset/serializers/dataset_serializers.py b/apps/dataset/serializers/dataset_serializers.py index c6d06fb2a..521aa661a 100644 --- a/apps/dataset/serializers/dataset_serializers.py +++ b/apps/dataset/serializers/dataset_serializers.py @@ -745,7 +745,8 @@ class DataSetSerializers(serializers.ModelSerializer): def re_embedding(self, with_valid=True): if with_valid: self.is_valid(raise_exception=True) - ListenerManagement.embedding_by_dataset_signal.send(self.data.get('id')) + model = get_embedding_model_by_dataset_id(self.data.get('id')) + ListenerManagement.embedding_by_dataset_signal.send(self.data.get('id'), embedding_model=model) def list_application(self, with_valid=True): if with_valid: diff --git a/apps/dataset/serializers/document_serializers.py b/apps/dataset/serializers/document_serializers.py index 9496ac7e1..3bd5a1c9c 100644 --- a/apps/dataset/serializers/document_serializers.py +++ b/apps/dataset/serializers/document_serializers.py @@ -41,7 +41,8 @@ from common.util.file_util import get_file_content from common.util.fork import Fork from common.util.split_model import get_split_model from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, Status, ProblemParagraphMapping, Image -from dataset.serializers.common_serializers import BatchSerializer, MetaSerializer, ProblemParagraphManage +from dataset.serializers.common_serializers import BatchSerializer, MetaSerializer, ProblemParagraphManage, \ + get_embedding_model_by_dataset_id from dataset.serializers.paragraph_serializers import ParagraphSerializers, ParagraphInstanceSerializer from smartdoc.conf import PROJECT_DIR @@ -392,7 +393,8 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): problem_paragraph_mapping_list) > 0 else None # 向量化 if with_embedding: - ListenerManagement.embedding_by_document_signal.send(document_id) + model = get_embedding_model_by_dataset_id(dataset_id=document.dataset_id) + ListenerManagement.embedding_by_document_signal.send(document_id, embedding_model=model) else: document.status = Status.error document.save() @@ -405,6 +407,7 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): class Operate(ApiMixin, serializers.Serializer): document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char( "文档id")) + dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char("数据集id")) @staticmethod def get_request_params_api(): @@ -530,7 +533,8 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): if with_valid: self.is_valid(raise_exception=True) document_id = self.data.get("document_id") - ListenerManagement.embedding_by_document_signal.send(document_id) + model = get_embedding_model_by_dataset_id(dataset_id=self.data.get('dataset_id')) + ListenerManagement.embedding_by_document_signal.send(document_id, embedding_model=model) @transaction.atomic def delete(self): @@ -599,8 +603,9 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): return True @staticmethod - def post_embedding(result, document_id): - ListenerManagement.embedding_by_document_signal.send(document_id) + def post_embedding(result, document_id, dataset_id): + model = get_embedding_model_by_dataset_id(dataset_id) + ListenerManagement.embedding_by_document_signal.send(document_id, embedding_model=model) return result @staticmethod @@ -646,7 +651,7 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): document_id = str(document_model.id) return DocumentSerializers.Operate( data={'dataset_id': dataset_id, 'document_id': document_id}).one( - with_valid=True), document_id + with_valid=True), document_id, dataset_id @staticmethod def get_sync_handler(dataset_id): @@ -803,9 +808,10 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): return openapi.Schema(type=openapi.TYPE_ARRAY, items=DocumentSerializers.Create.get_request_body_api()) @staticmethod - def post_embedding(document_list): + def post_embedding(document_list, dataset_id): for document_dict in document_list: - ListenerManagement.embedding_by_document_signal.send(document_dict.get('id')) + model = get_embedding_model_by_dataset_id(dataset_id) + ListenerManagement.embedding_by_document_signal.send(document_dict.get('id'), embedding_model=model) return document_list @post(post_function=post_embedding) @@ -846,7 +852,8 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): return [], query_set = query_set.filter(**{'id__in': [d.id for d in document_model_list]}) return native_search(query_set, select_string=get_file_content( - os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_document.sql')), with_search_one=False), + os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_document.sql')), + with_search_one=False), dataset_id @staticmethod def _batch_sync(document_id_list: List[str]): diff --git a/apps/dataset/serializers/paragraph_serializers.py b/apps/dataset/serializers/paragraph_serializers.py index 61ae860b6..7f84f61c0 100644 --- a/apps/dataset/serializers/paragraph_serializers.py +++ b/apps/dataset/serializers/paragraph_serializers.py @@ -22,7 +22,7 @@ from common.util.common import post from common.util.field_message import ErrMessage from dataset.models import Paragraph, Problem, Document, ProblemParagraphMapping from dataset.serializers.common_serializers import update_document_char_length, BatchSerializer, ProblemParagraphObject, \ - ProblemParagraphManage + ProblemParagraphManage, get_embedding_model_by_dataset_id from dataset.serializers.problem_serializers import ProblemInstanceSerializer, ProblemSerializer, ProblemSerializers from embedding.models import SourceType @@ -132,6 +132,7 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer): paragraph_id=self.data.get('paragraph_id'), dataset_id=self.data.get('dataset_id')) problem_paragraph_mapping.save() + model = get_embedding_model_by_dataset_id(self.data.get('dataset_id')) if with_embedding: ListenerManagement.embedding_by_problem_signal.send({'text': problem.content, 'is_active': True, @@ -140,7 +141,7 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer): 'document_id': self.data.get('document_id'), 'paragraph_id': self.data.get('paragraph_id'), 'dataset_id': self.data.get('dataset_id'), - }) + }, embedding_model=model) return ProblemSerializers.Operate( data={'dataset_id': self.data.get('dataset_id'), @@ -227,6 +228,7 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer): problem_id=problem.id) problem_paragraph_mapping.save() if with_embedding: + model = get_embedding_model_by_dataset_id(self.data.get('dataset_id')) ListenerManagement.embedding_by_problem_signal.send({'text': problem.content, 'is_active': True, 'source_type': SourceType.PROBLEM, @@ -234,7 +236,7 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer): 'document_id': self.data.get('document_id'), 'paragraph_id': self.data.get('paragraph_id'), 'dataset_id': self.data.get('dataset_id'), - }) + }, embedding_model=model) def un_association(self, with_valid=True): if with_valid: @@ -454,13 +456,14 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer): raise AppApiException(500, "段落id不存在") @staticmethod - def post_embedding(paragraph, instance): + def post_embedding(paragraph, instance, dataset_id): if 'is_active' in instance and instance.get('is_active') is not None: s = (ListenerManagement.enable_embedding_by_paragraph_signal if instance.get( 'is_active') else ListenerManagement.disable_embedding_by_paragraph_signal) s.send(paragraph.get('id')) else: - ListenerManagement.embedding_by_paragraph_signal.send(paragraph.get('id')) + model = get_embedding_model_by_dataset_id(dataset_id) + ListenerManagement.embedding_by_paragraph_signal.send(paragraph.get('id'), embedding_model=model) return paragraph @post(post_embedding) @@ -508,7 +511,7 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer): _paragraph.save() update_document_char_length(self.data.get('document_id')) - return self.one(), instance + return self.one(), instance, self.data.get('dataset_id') def get_problem_list(self): ProblemParagraphMapping(ProblemParagraphMapping) @@ -582,7 +585,8 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer): # 修改长度 update_document_char_length(document_id) if with_embedding: - ListenerManagement.embedding_by_paragraph_signal.send(str(paragraph.id)) + model = get_embedding_model_by_dataset_id(dataset_id) + ListenerManagement.embedding_by_paragraph_signal.send(str(paragraph.id), embedding_model=model) return ParagraphSerializers.Operate( data={'paragraph_id': str(paragraph.id), 'dataset_id': dataset_id, 'document_id': document_id}).one( with_valid=True)