diff --git a/apps/application/serializers/chat_message_serializers.py b/apps/application/serializers/chat_message_serializers.py index ed0914e55..355647032 100644 --- a/apps/application/serializers/chat_message_serializers.py +++ b/apps/application/serializers/chat_message_serializers.py @@ -45,7 +45,9 @@ class ChatMessage: source_id: str, answer: str, message_tokens: int, - answer_token: int): + answer_token: int, + chat_model=None, + chat_message=None): self.id = id self.problem = problem self.title = title @@ -59,6 +61,8 @@ class ChatMessage: self.answer = answer self.message_tokens = message_tokens self.answer_token = answer_token + self.chat_model = chat_model + self.chat_message = chat_message def get_chat_message(self): return MessageManagement.get_message(self.problem, self.paragraph, self.problem) @@ -85,10 +89,13 @@ class ChatInfo: def append_chat_message(self, chat_message: ChatMessage): self.chat_message_list.append(chat_message) if self.application_id is not None: + # 插入数据库 event.ListenerChatMessage.record_chat_message_signal.send( event.RecordChatMessageArgs(len(self.chat_message_list) - 1, self.chat_id, self.application_id, chat_message) ) + # 异步更新token + event.ListenerChatMessage.update_chat_message_token_signal.send(chat_message) def get_context_message(self): start_index = len(self.chat_message_list) - self.dialogue_number @@ -176,8 +183,10 @@ class ChatMessageSerializer(serializers.Serializer): ChatMessage(_id, message, title, content, embedding_id, dataset_id, document_id, paragraph_id, source_type, - source_id, all_text, chat_model.get_num_tokens_from_messages(chat_message), - chat_model.get_num_tokens(all_text))) + source_id, all_text, + 0, + 0, + chat_message=chat_message, chat_model=chat_model)) # 重新设置缓存 chat_cache.set(chat_id, chat_info, timeout=60 * 30) diff --git a/apps/application/serializers/chat_serializers.py b/apps/application/serializers/chat_serializers.py index 00953487f..1d176d64d 100644 --- a/apps/application/serializers/chat_serializers.py +++ b/apps/application/serializers/chat_serializers.py @@ -61,7 +61,7 @@ class ChatSerializers(serializers.Serializer): query_dict = {'application_id': self.data.get("application_id"), 'create_time__gte': end_time} if 'abstract' in self.data and self.data.get('abstract') is not None: query_dict['abstract'] = self.data.get('abstract') - return QuerySet(Chat).filter(**query_dict) + return QuerySet(Chat).filter(**query_dict).order_by("-create_time") def list(self, with_valid=True): if with_valid: diff --git a/apps/common/event/listener_chat_message.py b/apps/common/event/listener_chat_message.py index 8415e3a7a..26c9fe3bc 100644 --- a/apps/common/event/listener_chat_message.py +++ b/apps/common/event/listener_chat_message.py @@ -6,6 +6,7 @@ @date:2023/10/20 14:01 @desc: """ +import logging from blinker import signal from django.db.models import QuerySet @@ -25,9 +26,9 @@ class RecordChatMessageArgs: class ListenerChatMessage: record_chat_message_signal = signal("record_chat_message") + update_chat_message_token_signal = signal("update_chat_message_token") @staticmethod - @poxy def record_chat_message(args: RecordChatMessageArgs): if not QuerySet(Chat).filter(id=args.chat_id).exists(): Chat(id=args.chat_id, application_id=args.application_id, abstract=args.chat_message.problem).save() @@ -49,6 +50,18 @@ class ListenerChatMessage: except Exception as e: print(e) + @staticmethod + @poxy + def update_token(chat_message: ChatMessage): + if chat_message.chat_model is not None: + logging.getLogger("max_kb").info("开始更新token") + message_token = chat_message.chat_model.get_num_tokens_from_messages(chat_message.chat_message) + answer_token = chat_message.chat_model.get_num_tokens(chat_message.answer) + # 修改token数量 + QuerySet(ChatRecord).filter(id=chat_message.id).update( + **{'message_tokens': message_token, 'answer_tokens': answer_token}) + def run(self): # 记录会话 ListenerChatMessage.record_chat_message_signal.connect(self.record_chat_message) + ListenerChatMessage.update_chat_message_token_signal.connect(self.update_token) diff --git a/apps/common/event/listener_manage.py b/apps/common/event/listener_manage.py index 28143cab7..fb717b49c 100644 --- a/apps/common/event/listener_manage.py +++ b/apps/common/event/listener_manage.py @@ -109,9 +109,10 @@ class ListenerManagement: :param dataset_id: 知识库id :return: None """ - max_kb.info(f"向量化数据集{dataset_id}") + max_kb.info(f"开始--->向量化数据集:{dataset_id}") try: document_list = QuerySet(Document).filter(dataset_id=dataset_id) + max_kb.info(f"数据集文档:{[d.name for d in document_list]}") for document in document_list: ListenerManagement.embedding_by_document(document.id) except Exception as e: diff --git a/apps/common/util/common.py b/apps/common/util/common.py index d9aff2571..52d90ec85 100644 --- a/apps/common/util/common.py +++ b/apps/common/util/common.py @@ -43,3 +43,14 @@ def get_exec_method(clazz_: str, method_: str): package = ".".join([clazz_split[index] for index in range(len(clazz_split) - 1)]) package_model = importlib.import_module(package) return getattr(getattr(package_model, clazz_name), method_) + + +def post(post_function): + def inner(func): + def run(*args, **kwargs): + result = func(*args, **kwargs) + return post_function(*result) + + return run + + return inner diff --git a/apps/dataset/serializers/dataset_serializers.py b/apps/dataset/serializers/dataset_serializers.py index 9b9bd8a91..fad4a1a45 100644 --- a/apps/dataset/serializers/dataset_serializers.py +++ b/apps/dataset/serializers/dataset_serializers.py @@ -23,6 +23,7 @@ from common.db.sql_execute import select_list from common.event.listener_manage import ListenerManagement from common.exception.app_exception import AppApiException from common.mixins.api_mixin import ApiMixin +from common.util.common import post from common.util.file_util import get_file_content from dataset.models.data_set import DataSet, Document, Paragraph, Problem from dataset.serializers.document_serializers import DocumentSerializers, DocumentInstanceSerializer @@ -207,6 +208,13 @@ class DataSetSerializers(serializers.ModelSerializer): super().is_valid(raise_exception=True) return True + @staticmethod + def post_embedding_dataset(document_list, dataset_id): + # 发送向量化事件 + ListenerManagement.embedding_by_dataset_signal.send(dataset_id) + return document_list + + @post(post_function=post_embedding_dataset) @transaction.atomic def save(self, user: User): dataset_id = uuid.uuid1() @@ -234,11 +242,11 @@ class DataSetSerializers(serializers.ModelSerializer): QuerySet(Paragraph).bulk_create(paragraph_model_list) if len(paragraph_model_list) > 0 else None # 批量插入问题 QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None - # 发送向量化事件 - ListenerManagement.embedding_by_dataset_signal.send(str(dataset.id)) + # 响应数据 return {**DataSetSerializers(dataset).data, - 'document_list': DocumentSerializers.Query(data={'dataset_id': dataset_id}).list(with_valid=True)} + 'document_list': DocumentSerializers.Query(data={'dataset_id': dataset_id}).list( + with_valid=True)}, dataset_id @staticmethod def get_response_body_api(): diff --git a/apps/dataset/serializers/document_serializers.py b/apps/dataset/serializers/document_serializers.py index f687f1516..7e6f5dd04 100644 --- a/apps/dataset/serializers/document_serializers.py +++ b/apps/dataset/serializers/document_serializers.py @@ -21,6 +21,7 @@ from common.db.search import native_search, native_page_search from common.event.listener_manage import ListenerManagement from common.exception.app_exception import AppApiException from common.mixins.api_mixin import ApiMixin +from common.util.common import post from common.util.file_util import get_file_content from common.util.split_model import SplitModel, get_split_model from dataset.models.data_set import DataSet, Document, Paragraph, Problem @@ -207,7 +208,14 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): raise AppApiException(10000, "知识库id不存在") return True - def save(self, instance: Dict, with_valid=False, with_embedding=True, **kwargs): + @staticmethod + def post_embedding(result, document_id): + ListenerManagement.embedding_by_document_signal.send(document_id) + return result + + @post(post_function=post_embedding) + @transaction.atomic + def save(self, instance: Dict, with_valid=False, **kwargs): if with_valid: DocumentInstanceSerializer(data=instance).is_valid(raise_exception=True) self.is_valid(raise_exception=True) @@ -222,11 +230,10 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): QuerySet(Paragraph).bulk_create(paragraph_model_list) if len(paragraph_model_list) > 0 else None # 批量插入问题 QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None - if with_embedding: - ListenerManagement.embedding_by_document_signal.send(str(document_model.id)) + document_id = str(document_model.id) return DocumentSerializers.Operate( - data={'dataset_id': dataset_id, 'document_id': str(document_model.id)}).one( - with_valid=True) + data={'dataset_id': dataset_id, 'document_id': document_id}).one( + with_valid=True), document_id @staticmethod def get_document_paragraph_model(dataset_id, instance: Dict): @@ -333,14 +340,43 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): def get_request_body_api(): return openapi.Schema(type=openapi.TYPE_ARRAY, items=DocumentSerializers.Create.get_request_body_api()) + @staticmethod + def post_embedding(document_list): + for document_dict in document_list: + ListenerManagement.embedding_by_document_signal.send(document_dict.get('id')) + return document_list + + @post(post_function=post_embedding) + @transaction.atomic def batch_save(self, instance_list: List[Dict], with_valid=True): if with_valid: self.is_valid(raise_exception=True) DocumentInstanceSerializer(many=True, data=instance_list).is_valid(raise_exception=True) - create_data = {'dataset_id': self.data.get("dataset_id")} - return [DocumentSerializers.Create(data=create_data).save(instance, - with_valid=True) - for instance in instance_list] + dataset_id = self.data.get("dataset_id") + document_model_list = [] + paragraph_model_list = [] + problem_model_list = [] + # 插入文档 + for document in instance_list: + document_paragraph_dict_model = DocumentSerializers.Create.get_document_paragraph_model(dataset_id, + document) + document_model_list.append(document_paragraph_dict_model.get('document')) + for paragraph in document_paragraph_dict_model.get('paragraph_model_list'): + paragraph_model_list.append(paragraph) + for problem in document_paragraph_dict_model.get('problem_model_list'): + problem_model_list.append(problem) + + # 插入文档 + QuerySet(Document).bulk_create(document_model_list) if len(document_model_list) > 0 else None + # 批量插入段落 + QuerySet(Paragraph).bulk_create(paragraph_model_list) if len(paragraph_model_list) > 0 else None + # 批量插入问题 + QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None + # 查询文档 + query_set = QuerySet(model=Document) + query_set = query_set.filter(**{'id__in': [d.id for d in document_model_list]}) + return native_search(query_set, select_string=get_file_content( + os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_document.sql')), with_search_one=False), def file_to_paragraph(file, pattern_list: List, with_filter: bool, limit: int): diff --git a/apps/dataset/serializers/paragraph_serializers.py b/apps/dataset/serializers/paragraph_serializers.py index 28dbeec04..74491c7cb 100644 --- a/apps/dataset/serializers/paragraph_serializers.py +++ b/apps/dataset/serializers/paragraph_serializers.py @@ -19,6 +19,7 @@ from common.db.search import page_search from common.event.listener_manage import ListenerManagement from common.exception.app_exception import AppApiException from common.mixins.api_mixin import ApiMixin +from common.util.common import post from dataset.models import Paragraph, Problem, Document from dataset.serializers.common_serializers import update_document_char_length from dataset.serializers.problem_serializers import ProblemInstanceSerializer, ProblemSerializer @@ -82,6 +83,17 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer): if not QuerySet(Paragraph).filter(id=self.data.get('paragraph_id')).exists(): raise AppApiException(500, "段落id不存在") + @staticmethod + def post_embedding(paragraph, instance): + if 'is_active' in instance and instance.get('is_active') is not None: + s = (ListenerManagement.enable_embedding_by_paragraph_signal if instance.get( + 'is_active') else ListenerManagement.disable_embedding_by_paragraph_signal) + s.send(paragraph.get('id')) + else: + ListenerManagement.embedding_by_paragraph_signal.send(paragraph.get('id')) + return paragraph + + @post(post_embedding) @transaction.atomic def edit(self, instance: Dict): self.is_valid() @@ -125,11 +137,7 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer): _paragraph.save() update_document_char_length(self.data.get('document_id')) - if 'is_active' in instance and instance.get('is_active') is not None: - s = (ListenerManagement.enable_embedding_by_paragraph_signal if instance.get( - 'is_active') else ListenerManagement.disable_embedding_by_paragraph_signal) - s.send(self.data.get('paragraph_id')) - return self.one() + return self.one(), instance def get_problem_list(self): return [ProblemSerializer(problem).data for problem in diff --git a/apps/embedding/vector/base_vector.py b/apps/embedding/vector/base_vector.py index 9113bb52d..b1dda12a8 100644 --- a/apps/embedding/vector/base_vector.py +++ b/apps/embedding/vector/base_vector.py @@ -68,17 +68,12 @@ class BaseVectorStore(ABC): :param trample_num 点踩数量 :return: bool """ - # 获取锁 - lock.acquire() - try: - if embedding is None: - embedding = EmbeddingModel.get_embedding_model() - self.save_pre_handler() - self._save(text, source_type, dataset_id, document_id, paragraph_id, source_id, is_active, star_num, - trample_num, embedding) - finally: - # 释放锁 - lock.release() + + if embedding is None: + embedding = EmbeddingModel.get_embedding_model() + self.save_pre_handler() + self._save(text, source_type, dataset_id, document_id, paragraph_id, source_id, is_active, star_num, + trample_num, embedding) def batch_save(self, data_list: List[Dict], embedding=None): # 获取锁 diff --git a/install_model.py b/install_model.py index d9395bc5d..22394b78d 100644 --- a/install_model.py +++ b/install_model.py @@ -16,35 +16,35 @@ prefix_dir = "/opt/maxkb/model" model_config = [ { 'download_params': { - 'cache_dir': os.path.join(prefix_dir, 'base'), + 'cache_dir': os.path.join(prefix_dir, 'base/hub'), 'pretrained_model_name_or_path': 'gpt2' }, 'download_function': GPT2TokenizerFast.from_pretrained }, { 'download_params': { - 'cache_dir': os.path.join(prefix_dir, 'base'), + 'cache_dir': os.path.join(prefix_dir, 'base/hub'), 'pretrained_model_name_or_path': 'gpt2-medium' }, 'download_function': GPT2TokenizerFast.from_pretrained }, { 'download_params': { - 'cache_dir': os.path.join(prefix_dir, 'base'), + 'cache_dir': os.path.join(prefix_dir, 'base/hub'), 'pretrained_model_name_or_path': 'gpt2-large' }, 'download_function': GPT2TokenizerFast.from_pretrained }, { 'download_params': { - 'cache_dir': os.path.join(prefix_dir, 'base'), + 'cache_dir': os.path.join(prefix_dir, 'base/hub'), 'pretrained_model_name_or_path': 'gpt2-xl' }, 'download_function': GPT2TokenizerFast.from_pretrained }, { 'download_params': { - 'cache_dir': os.path.join(prefix_dir, 'base'), + 'cache_dir': os.path.join(prefix_dir, 'base/hub'), 'pretrained_model_name_or_path': 'distilgpt2' }, 'download_function': GPT2TokenizerFast.from_pretrained