From 7c5957e0a3dbbdd8fea15f9d2a81f25f36345be4 Mon Sep 17 00:00:00 2001 From: zhangshaohu Date: Wed, 21 Aug 2024 14:46:11 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=88=86=E7=A6=BB=E4=BB=BB=E5=8A=A1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../serializers/application_serializers.py | 2 +- .../serializers/chat_serializers.py | 8 +- apps/application/task/__init__.py | 0 apps/common/config/embedding_config.py | 25 +- apps/common/event/__init__.py | 1 - apps/common/event/listener_manage.py | 94 ++------ apps/common/job/client_access_num_job.py | 18 +- apps/common/lock/base_lock.py | 20 ++ apps/common/lock/impl/file_lock.py | 77 ++++++ apps/common/management/__init__.py | 0 apps/common/management/commands/__init__.py | 0 apps/common/management/commands/celery.py | 44 ++++ apps/common/management/commands/gunicorn.py | 47 ---- apps/common/management/commands/restart.py | 6 + .../management/commands/services/__init__.py | 0 .../management/commands/services/command.py | 130 ++++++++++ .../management/commands/services/hands.py | 28 +++ .../commands/services/services/__init__.py | 3 + .../commands/services/services/base.py | 207 ++++++++++++++++ .../commands/services/services/celery_base.py | 43 ++++ .../services/services/celery_default.py | 10 + .../commands/services/services/gunicorn.py | 36 +++ .../commands/services/services/local_model.py | 44 ++++ .../management/commands/services/utils.py | 140 +++++++++++ apps/common/management/commands/start.py | 6 + apps/common/management/commands/status.py | 6 + apps/common/management/commands/stop.py | 6 + apps/common/task/__init__.py | 0 apps/common/util/test.py | 4 +- .../dataset/serializers/common_serializers.py | 14 ++ .../serializers/dataset_serializers.py | 43 +--- .../serializers/document_serializers.py | 66 ++---- .../serializers/paragraph_serializers.py | 79 +++---- .../serializers/problem_serializers.py | 12 +- apps/dataset/task/__init__.py | 9 + apps/dataset/task/sync.py | 42 ++++ apps/dataset/task/tools.py | 62 +++++ apps/embedding/task/__init__.py | 1 + apps/embedding/task/embedding.py | 218 +++++++++++++++++ apps/embedding/vector/base_vector.py | 4 +- apps/embedding/vector/pg_vector.py | 8 +- apps/function_lib/task/__init__.py | 0 apps/ops/__init__.py | 9 + apps/ops/celery/__init__.py | 33 +++ apps/ops/celery/const.py | 4 + apps/ops/celery/decorator.py | 104 ++++++++ apps/ops/celery/heatbeat.py | 25 ++ apps/ops/celery/logger.py | 223 ++++++++++++++++++ apps/ops/celery/signal_handler.py | 63 +++++ apps/ops/celery/utils.py | 68 ++++++ apps/setting/models_provider/__init__.py | 8 +- .../local_model_provider/model/embedding.py | 47 +++- .../serializers/model_apply_serializers.py | 53 +++++ apps/setting/urls.py | 9 + apps/setting/views/__init__.py | 1 + apps/setting/views/model_apply.py | 38 +++ apps/smartdoc/settings/__init__.py | 1 + apps/smartdoc/settings/base.py | 4 +- apps/smartdoc/settings/lib.py | 40 ++++ apps/smartdoc/settings/logging.py | 8 + apps/smartdoc/wsgi.py | 1 - apps/users/apps.py | 2 + apps/users/serializers/user_serializers.py | 10 +- apps/users/task/__init__.py | 0 main.py | 47 +++- pyproject.toml | 6 +- 66 files changed, 2074 insertions(+), 293 deletions(-) create mode 100644 apps/application/task/__init__.py create mode 100644 apps/common/lock/base_lock.py create mode 100644 apps/common/lock/impl/file_lock.py create mode 100644 apps/common/management/__init__.py create mode 100644 apps/common/management/commands/__init__.py create mode 100644 apps/common/management/commands/celery.py delete mode 100644 apps/common/management/commands/gunicorn.py create mode 100644 apps/common/management/commands/restart.py create mode 100644 apps/common/management/commands/services/__init__.py create mode 100644 apps/common/management/commands/services/command.py create mode 100644 apps/common/management/commands/services/hands.py create mode 100644 apps/common/management/commands/services/services/__init__.py create mode 100644 apps/common/management/commands/services/services/base.py create mode 100644 apps/common/management/commands/services/services/celery_base.py create mode 100644 apps/common/management/commands/services/services/celery_default.py create mode 100644 apps/common/management/commands/services/services/gunicorn.py create mode 100644 apps/common/management/commands/services/services/local_model.py create mode 100644 apps/common/management/commands/services/utils.py create mode 100644 apps/common/management/commands/start.py create mode 100644 apps/common/management/commands/status.py create mode 100644 apps/common/management/commands/stop.py create mode 100644 apps/common/task/__init__.py create mode 100644 apps/dataset/task/__init__.py create mode 100644 apps/dataset/task/sync.py create mode 100644 apps/dataset/task/tools.py create mode 100644 apps/embedding/task/__init__.py create mode 100644 apps/embedding/task/embedding.py create mode 100644 apps/function_lib/task/__init__.py create mode 100644 apps/ops/__init__.py create mode 100644 apps/ops/celery/__init__.py create mode 100644 apps/ops/celery/const.py create mode 100644 apps/ops/celery/decorator.py create mode 100644 apps/ops/celery/heatbeat.py create mode 100644 apps/ops/celery/logger.py create mode 100644 apps/ops/celery/signal_handler.py create mode 100644 apps/ops/celery/utils.py create mode 100644 apps/setting/serializers/model_apply_serializers.py create mode 100644 apps/setting/views/model_apply.py create mode 100644 apps/smartdoc/settings/lib.py create mode 100644 apps/users/task/__init__.py diff --git a/apps/application/serializers/application_serializers.py b/apps/application/serializers/application_serializers.py index 909c4a627..9586cb20a 100644 --- a/apps/application/serializers/application_serializers.py +++ b/apps/application/serializers/application_serializers.py @@ -14,6 +14,7 @@ import uuid from functools import reduce from typing import Dict, List +from django.conf import settings from django.contrib.postgres.fields import ArrayField from django.core import cache, validators from django.core import signing @@ -46,7 +47,6 @@ from setting.models.model_management import Model from setting.models_provider.constants.model_provider_constants import ModelProvideConstants from setting.serializers.provider_serializers import ModelSerializer from smartdoc.conf import PROJECT_DIR -from django.conf import settings chat_cache = cache.caches['chat_cache'] diff --git a/apps/application/serializers/chat_serializers.py b/apps/application/serializers/chat_serializers.py index 0255f1c6c..d56e062da 100644 --- a/apps/application/serializers/chat_serializers.py +++ b/apps/application/serializers/chat_serializers.py @@ -37,8 +37,10 @@ from common.util.field_message import ErrMessage from common.util.file_util import get_file_content from common.util.lock import try_lock, un_lock from dataset.models import Document, Problem, Paragraph, ProblemParagraphMapping -from dataset.serializers.common_serializers import get_embedding_model_by_dataset_id +from dataset.serializers.common_serializers import get_embedding_model_by_dataset_id, \ + get_embedding_model_id_by_dataset_id from dataset.serializers.paragraph_serializers import ParagraphSerializers +from embedding.task import embedding_by_paragraph from smartdoc.conf import PROJECT_DIR chat_cache = caches['chat_cache'] @@ -516,9 +518,9 @@ class ChatRecordSerializer(serializers.Serializer): @staticmethod def post_embedding_paragraph(chat_record, paragraph_id, dataset_id): - model = get_embedding_model_by_dataset_id(dataset_id) + model_id = get_embedding_model_id_by_dataset_id(dataset_id) # 发送向量化事件 - ListenerManagement.embedding_by_paragraph_signal.send(paragraph_id, embedding_model=model) + embedding_by_paragraph(paragraph_id, model_id) return chat_record @post(post_function=post_embedding_paragraph) diff --git a/apps/application/task/__init__.py b/apps/application/task/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/apps/common/config/embedding_config.py b/apps/common/config/embedding_config.py index 5938797c3..a6e9ab9aa 100644 --- a/apps/common/config/embedding_config.py +++ b/apps/common/config/embedding_config.py @@ -6,10 +6,13 @@ @date:2023/10/23 16:03 @desc: """ +import threading import time from common.cache.mem_cache import MemCache +lock = threading.Lock() + class ModelManage: cache = MemCache('model', {}) @@ -17,15 +20,21 @@ class ModelManage: @staticmethod def get_model(_id, get_model): - model_instance = ModelManage.cache.get(_id) - if model_instance is None or not model_instance.is_cache_model(): - model_instance = get_model(_id) - ModelManage.cache.set(_id, model_instance, timeout=60 * 30) + # 获取锁 + lock.acquire() + try: + model_instance = ModelManage.cache.get(_id) + if model_instance is None or not model_instance.is_cache_model(): + model_instance = get_model(_id) + ModelManage.cache.set(_id, model_instance, timeout=60 * 30) + return model_instance + # 续期 + ModelManage.cache.touch(_id, timeout=60 * 30) + ModelManage.clear_timeout_cache() return model_instance - # 续期 - ModelManage.cache.touch(_id, timeout=60 * 30) - ModelManage.clear_timeout_cache() - return model_instance + finally: + # 释放锁 + lock.release() @staticmethod def clear_timeout_cache(): diff --git a/apps/common/event/__init__.py b/apps/common/event/__init__.py index 462d8da79..bd4d36359 100644 --- a/apps/common/event/__init__.py +++ b/apps/common/event/__init__.py @@ -12,7 +12,6 @@ from .listener_manage import * def run(): - listener_manage.ListenerManagement().run() QuerySet(Document).filter(status__in=[Status.embedding, Status.queue_up]).update(**{'status': Status.error}) QuerySet(Model).filter(status=setting.models.Status.DOWNLOAD).update(status=setting.models.Status.ERROR, meta={'message': "下载程序被中断,请重试"}) diff --git a/apps/common/event/listener_manage.py b/apps/common/event/listener_manage.py index 9c101cd40..77daabd80 100644 --- a/apps/common/event/listener_manage.py +++ b/apps/common/event/listener_manage.py @@ -13,22 +13,20 @@ import traceback from typing import List import django.db.models -from blinker import signal from django.db.models import QuerySet from langchain_core.embeddings import Embeddings from common.config.embedding_config import VectorStore from common.db.search import native_search, get_dynamics_model -from common.event.common import poxy, embedding_poxy +from common.event.common import embedding_poxy from common.util.file_util import get_file_content -from common.util.fork import ForkManage, Fork from common.util.lock import try_lock, un_lock from dataset.models import Paragraph, Status, Document, ProblemParagraphMapping -from embedding.models import SourceType +from embedding.models import SourceType, SearchMode from smartdoc.conf import PROJECT_DIR -max_kb_error = logging.getLogger("max_kb_error") -max_kb = logging.getLogger("max_kb") +max_kb_error = logging.getLogger(__file__) +max_kb = logging.getLogger(__file__) class SyncWebDatasetArgs: @@ -70,23 +68,6 @@ class UpdateEmbeddingDocumentIdArgs: class ListenerManagement: - embedding_by_problem_signal = signal("embedding_by_problem") - embedding_by_paragraph_signal = signal("embedding_by_paragraph") - embedding_by_dataset_signal = signal("embedding_by_dataset") - embedding_by_document_signal = signal("embedding_by_document") - delete_embedding_by_document_signal = signal("delete_embedding_by_document") - delete_embedding_by_document_list_signal = signal("delete_embedding_by_document_list") - delete_embedding_by_dataset_signal = signal("delete_embedding_by_dataset") - delete_embedding_by_paragraph_signal = signal("delete_embedding_by_paragraph") - delete_embedding_by_source_signal = signal("delete_embedding_by_source") - enable_embedding_by_paragraph_signal = signal('enable_embedding_by_paragraph') - disable_embedding_by_paragraph_signal = signal('disable_embedding_by_paragraph') - init_embedding_model_signal = signal('init_embedding_model') - sync_web_dataset_signal = signal('sync_web_dataset') - sync_web_document_signal = signal('sync_web_document') - update_problem_signal = signal('update_problem') - delete_embedding_by_source_ids_signal = signal('delete_embedding_by_source_ids') - delete_embedding_by_dataset_id_list_signal = signal("delete_embedding_by_dataset_id_list") @staticmethod def embedding_by_problem(args, embedding_model: Embeddings): @@ -160,7 +141,6 @@ class ListenerManagement: max_kb.info(f'结束--->向量化段落:{paragraph_id}') @staticmethod - @embedding_poxy def embedding_by_document(document_id, embedding_model: Embeddings): """ 向量化文档 @@ -227,7 +207,7 @@ class ListenerManagement: @staticmethod def delete_embedding_by_document_list(document_id_list: List[str]): - VectorStore.get_embedding_vector().delete_bu_document_id_list(document_id_list) + VectorStore.get_embedding_vector().delete_by_document_id_list(document_id_list) @staticmethod def delete_embedding_by_dataset(dataset_id): @@ -249,25 +229,6 @@ class ListenerManagement: def enable_embedding_by_paragraph(paragraph_id): VectorStore.get_embedding_vector().update_by_paragraph_id(paragraph_id, {'is_active': True}) - @staticmethod - @poxy - def sync_web_document(args: SyncWebDocumentArgs): - for source_url in args.source_url_list: - result = Fork(base_fork_url=source_url, selector_list=args.selector.split(' ')).fork() - args.handler(source_url, args.selector, result) - - @staticmethod - @poxy - def sync_web_dataset(args: SyncWebDatasetArgs): - if try_lock('sync_web_dataset' + args.lock_key): - try: - ForkManage(args.url, args.selector.split(" ") if args.selector is not None else []).fork(2, set(), - args.handler) - except Exception as e: - logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}') - finally: - un_lock('sync_web_dataset' + args.lock_key) - @staticmethod def update_problem(args: UpdateProblemArgs): problem_paragraph_mapping_list = QuerySet(ProblemParagraphMapping).filter(problem_id=args.problem_id) @@ -281,6 +242,9 @@ class ListenerManagement: VectorStore.get_embedding_vector().update_by_paragraph_ids(args.paragraph_id_list, {'dataset_id': args.target_dataset_id}) else: + # 删除向量数据 + ListenerManagement.delete_embedding_by_paragraph_ids(args.paragraph_id_list) + # 向量数据 ListenerManagement.embedding_by_paragraph_list(args.paragraph_id_list, embedding_model=args.target_embedding_model) @@ -306,38 +270,10 @@ class ListenerManagement: def delete_embedding_by_dataset_id_list(source_ids: List[str]): VectorStore.get_embedding_vector().delete_by_dataset_id_list(source_ids) - def run(self): - # 添加向量 根据问题id - ListenerManagement.embedding_by_problem_signal.connect(self.embedding_by_problem) - # 添加向量 根据段落id - ListenerManagement.embedding_by_paragraph_signal.connect(self.embedding_by_paragraph) - # 添加向量 根据知识库id - ListenerManagement.embedding_by_dataset_signal.connect( - self.embedding_by_dataset) - # 添加向量 根据文档id - ListenerManagement.embedding_by_document_signal.connect( - self.embedding_by_document) - # 删除 向量 根据文档 - ListenerManagement.delete_embedding_by_document_signal.connect(self.delete_embedding_by_document) - # 删除 向量 根据文档id列表 - ListenerManagement.delete_embedding_by_document_list_signal.connect(self.delete_embedding_by_document_list) - # 删除 向量 根据知识库id - ListenerManagement.delete_embedding_by_dataset_signal.connect(self.delete_embedding_by_dataset) - # 删除向量 根据段落id - ListenerManagement.delete_embedding_by_paragraph_signal.connect( - self.delete_embedding_by_paragraph) - # 删除向量 根据资源id - ListenerManagement.delete_embedding_by_source_signal.connect(self.delete_embedding_by_source) - # 禁用段落 - ListenerManagement.disable_embedding_by_paragraph_signal.connect(self.disable_embedding_by_paragraph) - # 启动段落向量 - ListenerManagement.enable_embedding_by_paragraph_signal.connect(self.enable_embedding_by_paragraph) - - # 同步web站点知识库 - ListenerManagement.sync_web_dataset_signal.connect(self.sync_web_dataset) - # 同步web站点 文档 - ListenerManagement.sync_web_document_signal.connect(self.sync_web_document) - # 更新问题向量 - ListenerManagement.update_problem_signal.connect(self.update_problem) - ListenerManagement.delete_embedding_by_source_ids_signal.connect(self.delete_embedding_by_source_ids) - ListenerManagement.delete_embedding_by_dataset_id_list_signal.connect(self.delete_embedding_by_dataset_id_list) + @staticmethod + def hit_test(query_text, dataset_id: list[str], exclude_document_id_list: list[str], top_number: int, + similarity: float, + search_mode: SearchMode, + embedding: Embeddings): + return VectorStore.get_embedding_vector().hit_test(query_text, dataset_id, exclude_document_id_list, top_number, + similarity, search_mode, embedding) diff --git a/apps/common/job/client_access_num_job.py b/apps/common/job/client_access_num_job.py index 4c03fd210..9d9105454 100644 --- a/apps/common/job/client_access_num_job.py +++ b/apps/common/job/client_access_num_job.py @@ -13,9 +13,11 @@ from django.db.models import QuerySet from django_apscheduler.jobstores import DjangoJobStore from application.models.api_key_model import ApplicationPublicAccessClient +from common.lock.impl.file_lock import FileLock scheduler = BackgroundScheduler() scheduler.add_jobstore(DjangoJobStore(), "default") +lock = FileLock() def client_access_num_reset_job(): @@ -25,9 +27,13 @@ def client_access_num_reset_job(): def run(): - scheduler.start() - access_num_reset = scheduler.get_job(job_id='access_num_reset') - if access_num_reset is not None: - access_num_reset.remove() - scheduler.add_job(client_access_num_reset_job, 'cron', hour='0', minute='0', second='0', - id='access_num_reset') + if lock.try_lock('client_access_num_reset_job', 30 * 30): + try: + scheduler.start() + access_num_reset = scheduler.get_job(job_id='access_num_reset') + if access_num_reset is not None: + access_num_reset.remove() + scheduler.add_job(client_access_num_reset_job, 'cron', hour='0', minute='0', second='0', + id='access_num_reset') + finally: + lock.un_lock('client_access_num_reset_job') diff --git a/apps/common/lock/base_lock.py b/apps/common/lock/base_lock.py new file mode 100644 index 000000000..2ca5b21da --- /dev/null +++ b/apps/common/lock/base_lock.py @@ -0,0 +1,20 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: base_lock.py + @date:2024/8/20 10:33 + @desc: +""" + +from abc import ABC, abstractmethod + + +class BaseLock(ABC): + @abstractmethod + def try_lock(self, key, timeout): + pass + + @abstractmethod + def un_lock(self, key): + pass diff --git a/apps/common/lock/impl/file_lock.py b/apps/common/lock/impl/file_lock.py new file mode 100644 index 000000000..f8ea6396c --- /dev/null +++ b/apps/common/lock/impl/file_lock.py @@ -0,0 +1,77 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: file_lock.py + @date:2024/8/20 10:48 + @desc: +""" +import errno +import hashlib +import os +import time + +import six + +from common.lock.base_lock import BaseLock +from smartdoc.const import PROJECT_DIR + + +def key_to_lock_name(key): + """ + Combine part of a key with its hash to prevent very long filenames + """ + MAX_LENGTH = 50 + key_hash = hashlib.md5(six.b(key)).hexdigest() + lock_name = key[:MAX_LENGTH - len(key_hash) - 1] + '_' + key_hash + return lock_name + + +class FileLock(BaseLock): + """ + File locking backend. + """ + + def __init__(self, settings=None): + if settings is None: + settings = {} + self.location = settings.get('location') + if self.location is None: + self.location = os.path.join(PROJECT_DIR, 'data', 'lock') + try: + os.makedirs(self.location) + except OSError as error: + # Directory exists? + if error.errno != errno.EEXIST: + # Re-raise unexpected OSError + raise + + def _get_lock_path(self, key): + lock_name = key_to_lock_name(key) + return os.path.join(self.location, lock_name) + + def try_lock(self, key, timeout): + lock_path = self._get_lock_path(key) + try: + # 创建锁文件,如果没创建成功则拿不到 + fd = os.open(lock_path, os.O_CREAT | os.O_EXCL) + except OSError as error: + if error.errno == errno.EEXIST: + # File already exists, check its modification time + mtime = os.path.getmtime(lock_path) + ttl = mtime + timeout - time.time() + if ttl > 0: + return False + else: + # 如果超时时间已到,直接上锁成功继续执行 + os.utime(lock_path, None) + return True + else: + return False + else: + os.close(fd) + return True + + def un_lock(self, key): + lock_path = self._get_lock_path(key) + os.remove(lock_path) diff --git a/apps/common/management/__init__.py b/apps/common/management/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/apps/common/management/commands/__init__.py b/apps/common/management/commands/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/apps/common/management/commands/celery.py b/apps/common/management/commands/celery.py new file mode 100644 index 000000000..95a21eb76 --- /dev/null +++ b/apps/common/management/commands/celery.py @@ -0,0 +1,44 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: celery.py + @date:2024/8/19 11:57 + @desc: +""" +import os +import subprocess + +from django.core.management.base import BaseCommand + +from smartdoc.const import BASE_DIR + + +class Command(BaseCommand): + help = 'celery' + + def add_arguments(self, parser): + parser.add_argument( + 'service', nargs='+', type=str, choices=("celery", "model"), help='Service', + ) + + def handle(self, *args, **options): + service = options.get('service') + os.environ.setdefault('CELERY_NAME', ','.join(service)) + server_hostname = os.environ.get("SERVER_HOSTNAME") + if not server_hostname: + server_hostname = '%h' + cmd = [ + 'celery', + '-A', 'ops', + 'worker', + '-P', 'threads', + '-l', 'info', + '-c', '10', + '-Q', ','.join(service), + '--heartbeat-interval', '10', + '-n', f'{",".join(service)}@{server_hostname}', + '--without-mingle', + ] + kwargs = {'cwd': BASE_DIR} + subprocess.run(cmd, **kwargs) diff --git a/apps/common/management/commands/gunicorn.py b/apps/common/management/commands/gunicorn.py deleted file mode 100644 index 436a604b4..000000000 --- a/apps/common/management/commands/gunicorn.py +++ /dev/null @@ -1,47 +0,0 @@ -# coding=utf-8 -""" - @project: MaxKB - @Author:虎 - @file: gunicorn.py - @date:2024/7/19 17:43 - @desc: -""" -import subprocess - -from django.core.management.base import BaseCommand - -from smartdoc.const import BASE_DIR - - -class Command(BaseCommand): - help = 'My custom command' - - # 参数设定 - def add_arguments(self, parser): - parser.add_argument('-b', nargs='+', type=str, help="端口:0.0.0.0:8080") # 0.0.0.0:8080 - parser.add_argument('-k', nargs='?', type=str, - help="workers处理器:gevent") # uvicorn.workers.UvicornWorker - parser.add_argument('-w', type=str, help='worker 数量') # 进程数量 - parser.add_argument('--threads', type=str, help='线程数量') # 线程数量 - parser.add_argument('--worker-connections', type=str, help="每个线程的协程数量") # 10240 - parser.add_argument('--max-requests', type=str, help="最大请求") # 10240 - parser.add_argument('--max-requests-jitter', type=str) - parser.add_argument('--access-logformat', type=str) # %(h)s %(t)s %(L)ss "%(r)s" %(s)s %(b)s - - def handle(self, *args, **options): - log_format = '%(h)s %(t)s %(L)ss "%(r)s" %(s)s %(b)s ' - cmd = [ - 'gunicorn', 'smartdoc.wsgi:application', - '-b', options.get('b') if options.get('b') is not None else '0.0.0.0:8080', - '-k', options.get('k') if options.get('k') is not None else 'gthread', - '--threads', options.get('threads') if options.get('threads') is not None else '200', - '-w', options.get('w') if options.get('w') is not None else '1', - '--max-requests', options.get('max_requests') if options.get('max_requests') is not None else '10240', - '--max-requests-jitter', - options.get('max_requests_jitter') if options.get('max_requests_jitter') is not None else '2048', - '--access-logformat', - options.get('access_logformat') if options.get('access_logformat') is not None else log_format, - '--access-logfile', '-' - ] - kwargs = {'cwd': BASE_DIR} - subprocess.run(cmd, **kwargs) diff --git a/apps/common/management/commands/restart.py b/apps/common/management/commands/restart.py new file mode 100644 index 000000000..57285f9c9 --- /dev/null +++ b/apps/common/management/commands/restart.py @@ -0,0 +1,6 @@ +from .services.command import BaseActionCommand, Action + + +class Command(BaseActionCommand): + help = 'Restart services' + action = Action.restart.value diff --git a/apps/common/management/commands/services/__init__.py b/apps/common/management/commands/services/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/apps/common/management/commands/services/command.py b/apps/common/management/commands/services/command.py new file mode 100644 index 000000000..960d019ec --- /dev/null +++ b/apps/common/management/commands/services/command.py @@ -0,0 +1,130 @@ +from django.core.management.base import BaseCommand +from django.db.models import TextChoices + +from .hands import * +from .utils import ServicesUtil + + +class Services(TextChoices): + gunicorn = 'gunicorn', 'gunicorn' + celery_default = 'celery_default', 'celery_default' + local_model = 'local_model', 'local_model' + web = 'web', 'web' + celery = 'celery', 'celery' + celery_model = 'celery_model', 'celery_model' + task = 'task', 'task' + all = 'all', 'all' + + @classmethod + def get_service_object_class(cls, name): + from . import services + services_map = { + cls.gunicorn.value: services.GunicornService, + cls.celery_default: services.CeleryDefaultService, + cls.local_model: services.GunicornLocalModelService + } + return services_map.get(name) + + @classmethod + def web_services(cls): + return [cls.gunicorn, cls.local_model] + + @classmethod + def celery_services(cls): + return [cls.celery_default, cls.celery_model] + + @classmethod + def task_services(cls): + return cls.celery_services() + + @classmethod + def all_services(cls): + return cls.web_services() + cls.task_services() + + @classmethod + def export_services_values(cls): + return [cls.all.value, cls.web.value, cls.task.value] + [s.value for s in cls.all_services()] + + @classmethod + def get_service_objects(cls, service_names, **kwargs): + services = set() + for name in service_names: + method_name = f'{name}_services' + if hasattr(cls, method_name): + _services = getattr(cls, method_name)() + elif hasattr(cls, name): + _services = [getattr(cls, name)] + else: + continue + services.update(set(_services)) + + service_objects = [] + for s in services: + service_class = cls.get_service_object_class(s.value) + if not service_class: + continue + kwargs.update({ + 'name': s.value + }) + service_object = service_class(**kwargs) + service_objects.append(service_object) + return service_objects + + +class Action(TextChoices): + start = 'start', 'start' + status = 'status', 'status' + stop = 'stop', 'stop' + restart = 'restart', 'restart' + + +class BaseActionCommand(BaseCommand): + help = 'Service Base Command' + + action = None + util = None + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def add_arguments(self, parser): + parser.add_argument( + 'services', nargs='+', choices=Services.export_services_values(), help='Service', + ) + parser.add_argument('-d', '--daemon', nargs="?", const=True) + parser.add_argument('-w', '--worker', type=int, nargs="?", default=4) + parser.add_argument('-f', '--force', nargs="?", const=True) + + def initial_util(self, *args, **options): + service_names = options.get('services') + service_kwargs = { + 'worker_gunicorn': options.get('worker') + } + services = Services.get_service_objects(service_names=service_names, **service_kwargs) + + kwargs = { + 'services': services, + 'run_daemon': options.get('daemon', False), + 'stop_daemon': self.action == Action.stop.value and Services.all.value in service_names, + 'force_stop': options.get('force') or False, + } + self.util = ServicesUtil(**kwargs) + + def handle(self, *args, **options): + self.initial_util(*args, **options) + assert self.action in Action.values, f'The action {self.action} is not in the optional list' + _handle = getattr(self, f'_handle_{self.action}', lambda: None) + _handle() + + def _handle_start(self): + self.util.start_and_watch() + os._exit(0) + + def _handle_stop(self): + self.util.stop() + + def _handle_restart(self): + self.util.restart() + + def _handle_status(self): + self.util.show_status() diff --git a/apps/common/management/commands/services/hands.py b/apps/common/management/commands/services/hands.py new file mode 100644 index 000000000..239411c8b --- /dev/null +++ b/apps/common/management/commands/services/hands.py @@ -0,0 +1,28 @@ +import logging +import os +import sys + +from django.conf import settings + +from smartdoc.const import CONFIG, PROJECT_DIR + +try: + from apps.smartdoc import const + + __version__ = const.VERSION +except ImportError as e: + print("Not found __version__: {}".format(e)) + print("Python is: ") + logging.info(sys.executable) + __version__ = 'Unknown' + sys.exit(1) + +HTTP_HOST = CONFIG.HTTP_BIND_HOST or '127.0.0.1' +HTTP_PORT = CONFIG.HTTP_LISTEN_PORT or 8080 +DEBUG = CONFIG.DEBUG or False + +LOG_DIR = os.path.join(PROJECT_DIR, 'data', 'logs') +APPS_DIR = os.path.join(PROJECT_DIR, 'apps') +TMP_DIR = os.path.join(PROJECT_DIR, 'tmp') +if not os.path.exists(TMP_DIR): + os.makedirs(TMP_DIR) diff --git a/apps/common/management/commands/services/services/__init__.py b/apps/common/management/commands/services/services/__init__.py new file mode 100644 index 000000000..102739206 --- /dev/null +++ b/apps/common/management/commands/services/services/__init__.py @@ -0,0 +1,3 @@ +from .celery_default import * +from .gunicorn import * +from .local_model import * \ No newline at end of file diff --git a/apps/common/management/commands/services/services/base.py b/apps/common/management/commands/services/services/base.py new file mode 100644 index 000000000..ddcb4feca --- /dev/null +++ b/apps/common/management/commands/services/services/base.py @@ -0,0 +1,207 @@ +import abc +import time +import shutil +import psutil +import datetime +import threading +import subprocess +from ..hands import * + + +class BaseService(object): + + def __init__(self, **kwargs): + self.name = kwargs['name'] + self._process = None + self.STOP_TIMEOUT = 10 + self.max_retry = 0 + self.retry = 3 + self.LOG_KEEP_DAYS = 7 + self.EXIT_EVENT = threading.Event() + + @property + @abc.abstractmethod + def cmd(self): + return [] + + @property + @abc.abstractmethod + def cwd(self): + return '' + + @property + def is_running(self): + if self.pid == 0: + return False + try: + os.kill(self.pid, 0) + except (OSError, ProcessLookupError): + return False + else: + return True + + def show_status(self): + if self.is_running: + msg = f'{self.name} is running: {self.pid}.' + else: + msg = f'{self.name} is stopped.' + if DEBUG: + msg = '\033[31m{} is stopped.\033[0m\nYou can manual start it to find the error: \n' \ + ' $ cd {}\n' \ + ' $ {}'.format(self.name, self.cwd, ' '.join(self.cmd)) + + print(msg) + + # -- log -- + @property + def log_filename(self): + return f'{self.name}.log' + + @property + def log_filepath(self): + return os.path.join(LOG_DIR, self.log_filename) + + @property + def log_file(self): + return open(self.log_filepath, 'a') + + @property + def log_dir(self): + return os.path.dirname(self.log_filepath) + # -- end log -- + + # -- pid -- + @property + def pid_filepath(self): + return os.path.join(TMP_DIR, f'{self.name}.pid') + + @property + def pid(self): + if not os.path.isfile(self.pid_filepath): + return 0 + with open(self.pid_filepath) as f: + try: + pid = int(f.read().strip()) + except ValueError: + pid = 0 + return pid + + def write_pid(self): + with open(self.pid_filepath, 'w') as f: + f.write(str(self.process.pid)) + + def remove_pid(self): + if os.path.isfile(self.pid_filepath): + os.unlink(self.pid_filepath) + # -- end pid -- + + # -- process -- + @property + def process(self): + if not self._process: + try: + self._process = psutil.Process(self.pid) + except: + pass + return self._process + + # -- end process -- + + # -- action -- + def open_subprocess(self): + kwargs = {'cwd': self.cwd, 'stderr': self.log_file, 'stdout': self.log_file} + self._process = subprocess.Popen(self.cmd, **kwargs) + + def start(self): + if self.is_running: + self.show_status() + return + self.remove_pid() + self.open_subprocess() + self.write_pid() + self.start_other() + + def start_other(self): + pass + + def stop(self, force=False): + if not self.is_running: + self.show_status() + # self.remove_pid() + return + + print(f'Stop service: {self.name}', end='') + sig = 9 if force else 15 + os.kill(self.pid, sig) + + if self.process is None: + print("\033[31m No process found\033[0m") + return + try: + self.process.wait(1) + except: + pass + + for i in range(self.STOP_TIMEOUT): + if i == self.STOP_TIMEOUT - 1: + print("\033[31m Error\033[0m") + if not self.is_running: + print("\033[32m Ok\033[0m") + self.remove_pid() + break + else: + continue + + def watch(self): + self._check() + if not self.is_running: + self._restart() + self._rotate_log() + + def _check(self): + now = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') + print(f"{now} Check service status: {self.name} -> ", end='') + if self.process: + try: + self.process.wait(1) # 不wait,子进程可能无法回收 + except: + pass + + if self.is_running: + print(f'running at {self.pid}') + else: + print(f'stopped at {self.pid}') + + def _restart(self): + if self.retry > self.max_retry: + logging.info("Service start failed, exit: {}".format(self.name)) + self.EXIT_EVENT.set() + return + self.retry += 1 + logging.info(f'> Find {self.name} stopped, retry {self.retry}, {self.pid}') + self.start() + + def _rotate_log(self): + now = datetime.datetime.now() + _time = now.strftime('%H:%M') + if _time != '23:59': + return + + backup_date = now.strftime('%Y-%m-%d') + backup_log_dir = os.path.join(self.log_dir, backup_date) + if not os.path.exists(backup_log_dir): + os.mkdir(backup_log_dir) + + backup_log_path = os.path.join(backup_log_dir, self.log_filename) + if os.path.isfile(self.log_filepath) and not os.path.isfile(backup_log_path): + logging.info(f'Rotate log file: {self.log_filepath} => {backup_log_path}') + shutil.copy(self.log_filepath, backup_log_path) + with open(self.log_filepath, 'w') as f: + pass + + to_delete_date = now - datetime.timedelta(days=self.LOG_KEEP_DAYS) + to_delete_dir = os.path.join(LOG_DIR, to_delete_date.strftime('%Y-%m-%d')) + if os.path.exists(to_delete_dir): + logging.info(f'Remove old log: {to_delete_dir}') + shutil.rmtree(to_delete_dir, ignore_errors=True) + # -- end action -- diff --git a/apps/common/management/commands/services/services/celery_base.py b/apps/common/management/commands/services/services/celery_base.py new file mode 100644 index 000000000..34fbd680e --- /dev/null +++ b/apps/common/management/commands/services/services/celery_base.py @@ -0,0 +1,43 @@ +from .base import BaseService +from ..hands import * + + +class CeleryBaseService(BaseService): + + def __init__(self, queue, num=10, **kwargs): + super().__init__(**kwargs) + self.queue = queue + self.num = num + + @property + def cmd(self): + print('\n- Start Celery as Distributed Task Queue: {}'.format(self.queue.capitalize())) + + os.environ.setdefault('LC_ALL', 'C.UTF-8') + os.environ.setdefault('PYTHONOPTIMIZE', '1') + os.environ.setdefault('ANSIBLE_FORCE_COLOR', 'True') + os.environ.setdefault('PYTHONPATH', settings.APPS_DIR) + + if os.getuid() == 0: + os.environ.setdefault('C_FORCE_ROOT', '1') + server_hostname = os.environ.get("SERVER_HOSTNAME") + if not server_hostname: + server_hostname = '%h' + + cmd = [ + 'celery', + '-A', 'ops', + 'worker', + '-P', 'threads', + '-l', 'error', + '-c', str(self.num), + '-Q', self.queue, + '--heartbeat-interval', '10', + '-n', f'{self.queue}@{server_hostname}', + '--without-mingle', + ] + return cmd + + @property + def cwd(self): + return APPS_DIR diff --git a/apps/common/management/commands/services/services/celery_default.py b/apps/common/management/commands/services/services/celery_default.py new file mode 100644 index 000000000..5d3e6d7b8 --- /dev/null +++ b/apps/common/management/commands/services/services/celery_default.py @@ -0,0 +1,10 @@ +from .celery_base import CeleryBaseService + +__all__ = ['CeleryDefaultService'] + + +class CeleryDefaultService(CeleryBaseService): + + def __init__(self, **kwargs): + kwargs['queue'] = 'celery' + super().__init__(**kwargs) diff --git a/apps/common/management/commands/services/services/gunicorn.py b/apps/common/management/commands/services/services/gunicorn.py new file mode 100644 index 000000000..cc42c4f7c --- /dev/null +++ b/apps/common/management/commands/services/services/gunicorn.py @@ -0,0 +1,36 @@ +from .base import BaseService +from ..hands import * + +__all__ = ['GunicornService'] + + +class GunicornService(BaseService): + + def __init__(self, **kwargs): + self.worker = kwargs['worker_gunicorn'] + super().__init__(**kwargs) + + @property + def cmd(self): + print("\n- Start Gunicorn WSGI HTTP Server") + + log_format = '%(h)s %(t)s %(L)ss "%(r)s" %(s)s %(b)s ' + bind = f'{HTTP_HOST}:{HTTP_PORT}' + cmd = [ + 'gunicorn', 'smartdoc.wsgi:application', + '-b', bind, + '-k', 'gthread', + '--threads', '200', + '-w', str(self.worker), + '--max-requests', '10240', + '--max-requests-jitter', '2048', + '--access-logformat', log_format, + '--access-logfile', '-' + ] + if DEBUG: + cmd.append('--reload') + return cmd + + @property + def cwd(self): + return APPS_DIR diff --git a/apps/common/management/commands/services/services/local_model.py b/apps/common/management/commands/services/services/local_model.py new file mode 100644 index 000000000..1e5e4bc13 --- /dev/null +++ b/apps/common/management/commands/services/services/local_model.py @@ -0,0 +1,44 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: local_model.py + @date:2024/8/21 13:28 + @desc: +""" +from .base import BaseService +from ..hands import * + +__all__ = ['GunicornLocalModelService'] + + +class GunicornLocalModelService(BaseService): + + def __init__(self, **kwargs): + self.worker = kwargs['worker_gunicorn'] + super().__init__(**kwargs) + + @property + def cmd(self): + print("\n- Start Gunicorn Local Model WSGI HTTP Server") + os.environ.setdefault('SERVER_NAME', 'local_model') + log_format = '%(h)s %(t)s %(L)ss "%(r)s" %(s)s %(b)s ' + bind = f'127.0.0.1:5432' + cmd = [ + 'gunicorn', 'smartdoc.wsgi:application', + '-b', bind, + '-k', 'gthread', + '--threads', '200', + '-w', "1", + '--max-requests', '10240', + '--max-requests-jitter', '2048', + '--access-logformat', log_format, + '--access-logfile', '-' + ] + if DEBUG: + cmd.append('--reload') + return cmd + + @property + def cwd(self): + return APPS_DIR diff --git a/apps/common/management/commands/services/utils.py b/apps/common/management/commands/services/utils.py new file mode 100644 index 000000000..2426758b8 --- /dev/null +++ b/apps/common/management/commands/services/utils.py @@ -0,0 +1,140 @@ +import threading +import signal +import time +import daemon +from daemon import pidfile +from .hands import * +from .hands import __version__ +from .services.base import BaseService + + +class ServicesUtil(object): + + def __init__(self, services, run_daemon=False, force_stop=False, stop_daemon=False): + self._services = services + self.run_daemon = run_daemon + self.force_stop = force_stop + self.stop_daemon = stop_daemon + self.EXIT_EVENT = threading.Event() + self.check_interval = 30 + self.files_preserve_map = {} + + def restart(self): + self.stop() + time.sleep(5) + self.start_and_watch() + + def start_and_watch(self): + logging.info(time.ctime()) + logging.info(f'MaxKB version {__version__}, more see https://www.jumpserver.org') + self.start() + if self.run_daemon: + self.show_status() + with self.daemon_context: + self.watch() + else: + self.watch() + + def start(self): + for service in self._services: + service: BaseService + service.start() + self.files_preserve_map[service.name] = service.log_file + + time.sleep(1) + + def stop(self): + for service in self._services: + service: BaseService + service.stop(force=self.force_stop) + + if self.stop_daemon: + self._stop_daemon() + + # -- watch -- + def watch(self): + while not self.EXIT_EVENT.is_set(): + try: + _exit = self._watch() + if _exit: + break + time.sleep(self.check_interval) + except KeyboardInterrupt: + print('Start stop services') + break + self.clean_up() + + def _watch(self): + for service in self._services: + service: BaseService + service.watch() + if service.EXIT_EVENT.is_set(): + self.EXIT_EVENT.set() + return True + return False + # -- end watch -- + + def clean_up(self): + if not self.EXIT_EVENT.is_set(): + self.EXIT_EVENT.set() + self.stop() + + def show_status(self): + for service in self._services: + service: BaseService + service.show_status() + + # -- daemon -- + def _stop_daemon(self): + if self.daemon_pid and self.daemon_is_running: + os.kill(self.daemon_pid, 15) + self.remove_daemon_pid() + + def remove_daemon_pid(self): + if os.path.isfile(self.daemon_pid_filepath): + os.unlink(self.daemon_pid_filepath) + + @property + def daemon_pid(self): + if not os.path.isfile(self.daemon_pid_filepath): + return 0 + with open(self.daemon_pid_filepath) as f: + try: + pid = int(f.read().strip()) + except ValueError: + pid = 0 + return pid + + @property + def daemon_is_running(self): + try: + os.kill(self.daemon_pid, 0) + except (OSError, ProcessLookupError): + return False + else: + return True + + @property + def daemon_pid_filepath(self): + return os.path.join(TMP_DIR, 'mk.pid') + + @property + def daemon_log_filepath(self): + return os.path.join(LOG_DIR, 'mk.log') + + @property + def daemon_context(self): + daemon_log_file = open(self.daemon_log_filepath, 'a') + context = daemon.DaemonContext( + pidfile=pidfile.TimeoutPIDLockFile(self.daemon_pid_filepath), + signal_map={ + signal.SIGTERM: lambda x, y: self.clean_up(), + signal.SIGHUP: 'terminate', + }, + stdout=daemon_log_file, + stderr=daemon_log_file, + files_preserve=list(self.files_preserve_map.values()), + detach_process=True, + ) + return context + # -- end daemon -- diff --git a/apps/common/management/commands/start.py b/apps/common/management/commands/start.py new file mode 100644 index 000000000..4c078a876 --- /dev/null +++ b/apps/common/management/commands/start.py @@ -0,0 +1,6 @@ +from .services.command import BaseActionCommand, Action + + +class Command(BaseActionCommand): + help = 'Start services' + action = Action.start.value diff --git a/apps/common/management/commands/status.py b/apps/common/management/commands/status.py new file mode 100644 index 000000000..36f0d3608 --- /dev/null +++ b/apps/common/management/commands/status.py @@ -0,0 +1,6 @@ +from .services.command import BaseActionCommand, Action + + +class Command(BaseActionCommand): + help = 'Show services status' + action = Action.status.value diff --git a/apps/common/management/commands/stop.py b/apps/common/management/commands/stop.py new file mode 100644 index 000000000..a79a5335c --- /dev/null +++ b/apps/common/management/commands/stop.py @@ -0,0 +1,6 @@ +from .services.command import BaseActionCommand, Action + + +class Command(BaseActionCommand): + help = 'Stop services' + action = Action.stop.value diff --git a/apps/common/task/__init__.py b/apps/common/task/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/apps/common/util/test.py b/apps/common/util/test.py index a9536ba9c..00a267c6a 100644 --- a/apps/common/util/test.py +++ b/apps/common/util/test.py @@ -18,6 +18,7 @@ TOKEN_KEY = 'solomon_world_token' TOKEN_SALT = 'solomonwanc@gmail.com' TIME_OUT = 30 * 60 + # 加密 def encrypt(obj): value = signing.dumps(obj, key=TOKEN_KEY, salt=TOKEN_SALT) @@ -29,7 +30,6 @@ def encrypt(obj): def decrypt(src): src = signing.b64_decode(src.encode()).decode() raw = signing.loads(src, key=TOKEN_KEY, salt=TOKEN_SALT) - print(type(raw)) return raw @@ -74,5 +74,3 @@ def check_token(token): if last_token: return last_token == token return False - - diff --git a/apps/dataset/serializers/common_serializers.py b/apps/dataset/serializers/common_serializers.py index 51b4afefa..8f08a2693 100644 --- a/apps/dataset/serializers/common_serializers.py +++ b/apps/dataset/serializers/common_serializers.py @@ -151,3 +151,17 @@ def get_embedding_model_by_dataset_id(dataset_id: str): def get_embedding_model_by_dataset(dataset): return ModelManage.get_model(str(dataset.embedding_mode_id), lambda _id: get_model(dataset.embedding_mode)) + + +def get_embedding_model_id_by_dataset_id(dataset_id): + dataset = QuerySet(DataSet).select_related('embedding_mode').filter(id=dataset_id).first() + return str(dataset.embedding_mode_id) + + +def get_embedding_model_id_by_dataset_id_list(dataset_id_list: List): + dataset_list = QuerySet(DataSet).filter(id__in=dataset_id_list) + if len(set([dataset.embedding_mode_id for dataset in dataset_list])) > 1: + raise Exception("知识库未向量模型不一致") + if len(dataset_list) == 0: + raise Exception("知识库设置错误,请重新设置知识库") + return str(dataset_list[0].embedding_mode_id) diff --git a/apps/dataset/serializers/dataset_serializers.py b/apps/dataset/serializers/dataset_serializers.py index bc00de6f7..5b8bbd811 100644 --- a/apps/dataset/serializers/dataset_serializers.py +++ b/apps/dataset/serializers/dataset_serializers.py @@ -27,7 +27,6 @@ from application.models import ApplicationDatasetMapping from common.config.embedding_config import VectorStore from common.db.search import get_dynamics_model, native_page_search, native_search from common.db.sql_execute import select_list -from common.event import ListenerManagement, SyncWebDatasetArgs from common.exception.app_exception import AppApiException from common.mixins.api_mixin import ApiMixin from common.util.common import post, flat_map, valid_license @@ -37,9 +36,11 @@ from common.util.fork import ChildLink, Fork from common.util.split_model import get_split_model from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, ProblemParagraphMapping, Status from dataset.serializers.common_serializers import list_paragraph, MetaSerializer, ProblemParagraphManage, \ - get_embedding_model_by_dataset_id + get_embedding_model_by_dataset_id, get_embedding_model_id_by_dataset_id from dataset.serializers.document_serializers import DocumentSerializers, DocumentInstanceSerializer +from dataset.task import sync_web_dataset from embedding.models import SearchMode +from embedding.task import embedding_by_dataset, delete_embedding_by_dataset from setting.models import AuthOperate from smartdoc.conf import PROJECT_DIR @@ -363,9 +364,9 @@ class DataSetSerializers(serializers.ModelSerializer): @staticmethod def post_embedding_dataset(document_list, dataset_id): - model = get_embedding_model_by_dataset_id(dataset_id) + model_id = get_embedding_model_id_by_dataset_id(dataset_id) # 发送向量化事件 - ListenerManagement.embedding_by_dataset_signal.send(dataset_id, embedding_model=model) + embedding_by_dataset.delay(dataset_id, model_id) return document_list def save_qa(self, instance: Dict, with_valid=True): @@ -435,23 +436,6 @@ class DataSetSerializers(serializers.ModelSerializer): else: return parsed_url.path.split("/")[-1] - @staticmethod - def get_save_handler(dataset_id, selector): - def handler(child_link: ChildLink, response: Fork.Response): - if response.status == 200: - try: - document_name = child_link.tag.text if child_link.tag is not None and len( - child_link.tag.text.strip()) > 0 else child_link.url - paragraphs = get_split_model('web.md').parse(response.content) - DocumentSerializers.Create(data={'dataset_id': dataset_id}).save( - {'name': document_name, 'paragraphs': paragraphs, - 'meta': {'source_url': child_link.url, 'selector': selector}, - 'type': Type.web}, with_valid=True) - except Exception as e: - logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}') - - return handler - def save_web(self, instance: Dict, with_valid=True): if with_valid: self.is_valid(raise_exception=True) @@ -467,9 +451,7 @@ class DataSetSerializers(serializers.ModelSerializer): 'meta': {'source_url': instance.get('source_url'), 'selector': instance.get('selector'), 'embedding_mode_id': instance.get('embedding_mode_id')}}) dataset.save() - ListenerManagement.sync_web_dataset_signal.send( - SyncWebDatasetArgs(str(dataset_id), instance.get('source_url'), instance.get('selector'), - self.get_save_handler(dataset_id, instance.get('selector')))) + sync_web_dataset.delay((str(dataset_id), instance.get('source_url'), instance.get('selector'))) return {**DataSetSerializers(dataset).data, 'document_list': []} @@ -642,9 +624,7 @@ class DataSetSerializers(serializers.ModelSerializer): """ url = dataset.meta.get('source_url') selector = dataset.meta.get('selector') if 'selector' in dataset.meta else None - ListenerManagement.sync_web_dataset_signal.send( - SyncWebDatasetArgs(str(dataset.id), url, selector, - self.get_sync_handler(dataset))) + sync_web_dataset.delay(str(dataset.id), url, selector) def complete_sync(self, dataset): """ @@ -658,7 +638,7 @@ class DataSetSerializers(serializers.ModelSerializer): # 删除段落 QuerySet(Paragraph).filter(dataset=dataset).delete() # 删除向量 - ListenerManagement.delete_embedding_by_dataset_signal.send(self.data.get('id')) + delete_embedding_by_dataset(self.data.get('id')) # 同步 self.replace_sync(dataset) @@ -740,16 +720,17 @@ class DataSetSerializers(serializers.ModelSerializer): QuerySet(Paragraph).filter(dataset=dataset).delete() QuerySet(Problem).filter(dataset=dataset).delete() dataset.delete() - ListenerManagement.delete_embedding_by_dataset_signal.send(self.data.get('id')) + delete_embedding_by_dataset(self.data.get('id')) return True def re_embedding(self, with_valid=True): if with_valid: self.is_valid(raise_exception=True) - model = get_embedding_model_by_dataset_id(self.data.get('id')) + QuerySet(Document).filter(dataset_id=self.data.get('id')).update(**{'status': Status.queue_up}) QuerySet(Paragraph).filter(dataset_id=self.data.get('id')).update(**{'status': Status.queue_up}) - ListenerManagement.embedding_by_dataset_signal.send(self.data.get('id'), embedding_model=model) + embedding_model_id = get_embedding_model_id_by_dataset_id(self.data.get('id')) + embedding_by_dataset.delay(self.data.get('id'), embedding_model_id) def list_application(self, with_valid=True): if with_valid: diff --git a/apps/dataset/serializers/document_serializers.py b/apps/dataset/serializers/document_serializers.py index 82531226c..a8ececbc9 100644 --- a/apps/dataset/serializers/document_serializers.py +++ b/apps/dataset/serializers/document_serializers.py @@ -15,6 +15,7 @@ from functools import reduce from typing import List, Dict import xlwt +from celery_once import AlreadyQueued from django.core import validators from django.db import transaction from django.db.models import QuerySet @@ -25,7 +26,6 @@ from xlwt import Utils from common.db.search import native_search, native_page_search from common.event.common import work_thread_pool -from common.event.listener_manage import ListenerManagement, SyncWebDocumentArgs, UpdateEmbeddingDatasetIdArgs from common.exception.app_exception import AppApiException from common.handle.impl.doc_split_handle import DocSplitHandle from common.handle.impl.html_split_handle import HTMLSplitHandle @@ -42,8 +42,11 @@ from common.util.fork import Fork from common.util.split_model import get_split_model from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, Status, ProblemParagraphMapping, Image from dataset.serializers.common_serializers import BatchSerializer, MetaSerializer, ProblemParagraphManage, \ - get_embedding_model_by_dataset_id + get_embedding_model_by_dataset_id, get_embedding_model_id_by_dataset_id from dataset.serializers.paragraph_serializers import ParagraphSerializers, ParagraphInstanceSerializer +from dataset.task import sync_web_document +from embedding.task.embedding import embedding_by_document, delete_embedding_by_document_list, \ + delete_embedding_by_document, update_embedding_dataset_id from smartdoc.conf import PROJECT_DIR parse_qa_handle_list = [XlsParseQAHandle(), CsvParseQAHandle(), XlsxParseQAHandle()] @@ -235,17 +238,15 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): meta={}) else: document_list.update(dataset_id=target_dataset_id) - model = None + model_id = None if dataset.embedding_mode_id != target_dataset.embedding_mode_id: - model = get_embedding_model_by_dataset_id(target_dataset_id) + model_id = get_embedding_model_by_dataset_id(target_dataset_id) pid_list = [paragraph.id for paragraph in paragraph_list] # 修改段落信息 paragraph_list.update(dataset_id=target_dataset_id) # 修改向量信息 - ListenerManagement.update_embedding_dataset_id(UpdateEmbeddingDatasetIdArgs( - pid_list, - target_dataset_id, model)) + update_embedding_dataset_id(pid_list, target_dataset_id, model_id) @staticmethod def get_target_dataset_problem(target_dataset_id: str, @@ -377,7 +378,7 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): # 删除问题 QuerySet(model=ProblemParagraphMapping).filter(document_id=document_id).delete() # 删除向量库 - ListenerManagement.delete_embedding_by_document_signal.send(document_id) + delete_embedding_by_document(document_id) paragraphs = get_split_model('web.md').parse(result.content) document.char_length = reduce(lambda x, y: x + y, [len(p.get('content')) for p in paragraphs], @@ -398,8 +399,8 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): problem_paragraph_mapping_list) > 0 else None # 向量化 if with_embedding: - model = get_embedding_model_by_dataset_id(dataset_id=document.dataset_id) - ListenerManagement.embedding_by_document_signal.send(document_id, embedding_model=model) + embedding_model_id = get_embedding_model_id_by_dataset_id(document.dataset_id) + embedding_by_document.delay(document_id, embedding_model_id) else: document.status = Status.error document.save() @@ -538,10 +539,13 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): if with_valid: self.is_valid(raise_exception=True) document_id = self.data.get("document_id") - model = get_embedding_model_by_dataset_id(dataset_id=self.data.get('dataset_id')) QuerySet(Document).filter(id=document_id).update(**{'status': Status.queue_up}) QuerySet(Paragraph).filter(document_id=document_id).update(**{'status': Status.queue_up}) - ListenerManagement.embedding_by_document_signal.send(document_id, embedding_model=model) + embedding_model_id = get_embedding_model_id_by_dataset_id(dataset_id=self.data.get('dataset_id')) + try: + embedding_by_document.delay(document_id, embedding_model_id) + except AlreadyQueued as e: + raise AppApiException(500, "任务正在执行中,请勿重复下发") @transaction.atomic def delete(self): @@ -552,7 +556,7 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): # 删除问题 QuerySet(model=ProblemParagraphMapping).filter(document_id=document_id).delete() # 删除向量库 - ListenerManagement.delete_embedding_by_document_signal.send(document_id) + delete_embedding_by_document(document_id) return True @staticmethod @@ -611,8 +615,8 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): @staticmethod def post_embedding(result, document_id, dataset_id): - model = get_embedding_model_by_dataset_id(dataset_id) - ListenerManagement.embedding_by_document_signal.send(document_id, embedding_model=model) + model_id = get_embedding_model_id_by_dataset_id(dataset_id) + embedding_by_document.delay(document_id, model_id) return result @staticmethod @@ -660,29 +664,6 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): data={'dataset_id': dataset_id, 'document_id': document_id}).one( with_valid=True), document_id, dataset_id - @staticmethod - def get_sync_handler(dataset_id): - def handler(source_url: str, selector, response: Fork.Response): - if response.status == 200: - try: - paragraphs = get_split_model('web.md').parse(response.content) - # 插入 - DocumentSerializers.Create(data={'dataset_id': dataset_id}).save( - {'name': source_url[0:128], 'paragraphs': paragraphs, - 'meta': {'source_url': source_url, 'selector': selector}, - 'type': Type.web}, with_valid=True) - except Exception as e: - logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}') - else: - Document(name=source_url[0:128], - dataset_id=dataset_id, - meta={'source_url': source_url, 'selector': selector}, - type=Type.web, - char_length=0, - status=Status.error).save() - - return handler - def save_web(self, instance: Dict, with_valid=True): if with_valid: DocumentWebInstanceSerializer(data=instance).is_valid(raise_exception=True) @@ -690,8 +671,7 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): dataset_id = self.data.get('dataset_id') source_url_list = instance.get('source_url_list') selector = instance.get('selector') - args = SyncWebDocumentArgs(source_url_list, selector, self.get_sync_handler(dataset_id)) - ListenerManagement.sync_web_document_signal.send(args) + sync_web_document.delay(dataset_id, source_url_list, selector) @staticmethod def get_paragraph_model(document_model, paragraph_list: List): @@ -818,8 +798,8 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): @staticmethod def post_embedding(document_list, dataset_id): for document_dict in document_list: - model = get_embedding_model_by_dataset_id(dataset_id) - ListenerManagement.embedding_by_document_signal.send(document_dict.get('id'), embedding_model=model) + model_id = get_embedding_model_id_by_dataset_id(dataset_id) + embedding_by_document(document_dict.get('id'), model_id) return document_list @post(post_function=post_embedding) @@ -887,7 +867,7 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): QuerySet(Paragraph).filter(document_id__in=document_id_list).delete() QuerySet(ProblemParagraphMapping).filter(document_id__in=document_id_list).delete() # 删除向量库 - ListenerManagement.delete_embedding_by_document_list_signal.send(document_id_list) + delete_embedding_by_document_list(document_id_list) return True def batch_edit_hit_handling(self, instance: Dict, with_valid=True): diff --git a/apps/dataset/serializers/paragraph_serializers.py b/apps/dataset/serializers/paragraph_serializers.py index 07ea6c190..84423dd4b 100644 --- a/apps/dataset/serializers/paragraph_serializers.py +++ b/apps/dataset/serializers/paragraph_serializers.py @@ -15,16 +15,18 @@ from drf_yasg import openapi from rest_framework import serializers from common.db.search import page_search -from common.event.listener_manage import ListenerManagement, UpdateEmbeddingDocumentIdArgs from common.exception.app_exception import AppApiException from common.mixins.api_mixin import ApiMixin from common.util.common import post from common.util.field_message import ErrMessage from dataset.models import Paragraph, Problem, Document, ProblemParagraphMapping, DataSet from dataset.serializers.common_serializers import update_document_char_length, BatchSerializer, ProblemParagraphObject, \ - ProblemParagraphManage, get_embedding_model_by_dataset_id, get_embedding_model_by_dataset + ProblemParagraphManage, get_embedding_model_id_by_dataset_id from dataset.serializers.problem_serializers import ProblemInstanceSerializer, ProblemSerializer, ProblemSerializers from embedding.models import SourceType +from embedding.task.embedding import embedding_by_problem as embedding_by_problem_task, embedding_by_problem, \ + delete_embedding_by_source, enable_embedding_by_paragraph, disable_embedding_by_paragraph, embedding_by_paragraph, \ + delete_embedding_by_paragraph, delete_embedding_by_paragraph_ids, update_embedding_document_id class ParagraphSerializer(serializers.ModelSerializer): @@ -113,7 +115,7 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer): QuerySet(Problem).filter(id__in=[row.problem_id for row in problem_paragraph_mapping])] @transaction.atomic - def save(self, instance: Dict, with_valid=True, with_embedding=True): + def save(self, instance: Dict, with_valid=True, with_embedding=True, embedding_by_problem=None): if with_valid: self.is_valid() ProblemInstanceSerializer(data=instance).is_valid(raise_exception=True) @@ -132,16 +134,16 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer): paragraph_id=self.data.get('paragraph_id'), dataset_id=self.data.get('dataset_id')) problem_paragraph_mapping.save() - model = get_embedding_model_by_dataset_id(self.data.get('dataset_id')) + model_id = get_embedding_model_id_by_dataset_id(self.data.get('dataset_id')) if with_embedding: - ListenerManagement.embedding_by_problem_signal.send({'text': problem.content, - 'is_active': True, - 'source_type': SourceType.PROBLEM, - 'source_id': problem_paragraph_mapping.id, - 'document_id': self.data.get('document_id'), - 'paragraph_id': self.data.get('paragraph_id'), - 'dataset_id': self.data.get('dataset_id'), - }, embedding_model=model) + embedding_by_problem_task({'text': problem.content, + 'is_active': True, + 'source_type': SourceType.PROBLEM, + 'source_id': problem_paragraph_mapping.id, + 'document_id': self.data.get('document_id'), + 'paragraph_id': self.data.get('paragraph_id'), + 'dataset_id': self.data.get('dataset_id'), + }, model_id) return ProblemSerializers.Operate( data={'dataset_id': self.data.get('dataset_id'), @@ -228,15 +230,15 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer): problem_id=problem.id) problem_paragraph_mapping.save() if with_embedding: - model = get_embedding_model_by_dataset_id(self.data.get('dataset_id')) - ListenerManagement.embedding_by_problem_signal.send({'text': problem.content, - 'is_active': True, - 'source_type': SourceType.PROBLEM, - 'source_id': problem_paragraph_mapping.id, - 'document_id': self.data.get('document_id'), - 'paragraph_id': self.data.get('paragraph_id'), - 'dataset_id': self.data.get('dataset_id'), - }, embedding_model=model) + model_id = get_embedding_model_id_by_dataset_id(self.data.get('dataset_id')) + embedding_by_problem({'text': problem.content, + 'is_active': True, + 'source_type': SourceType.PROBLEM, + 'source_id': problem_paragraph_mapping.id, + 'document_id': self.data.get('document_id'), + 'paragraph_id': self.data.get('paragraph_id'), + 'dataset_id': self.data.get('dataset_id'), + }, model_id) def un_association(self, with_valid=True): if with_valid: @@ -248,7 +250,7 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer): 'problem_id')).first() problem_paragraph_mapping_id = problem_paragraph_mapping.id problem_paragraph_mapping.delete() - ListenerManagement.delete_embedding_by_source_signal.send(problem_paragraph_mapping_id) + delete_embedding_by_source(problem_paragraph_mapping_id) return True @staticmethod @@ -289,7 +291,7 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer): QuerySet(ProblemParagraphMapping).filter(paragraph_id__in=paragraph_id_list).delete() update_document_char_length(self.data.get('document_id')) # 删除向量库 - ListenerManagement.delete_embedding_by_paragraph_ids(paragraph_id_list) + delete_embedding_by_paragraph_ids(paragraph_id_list) return True class Migrate(ApiMixin, serializers.Serializer): @@ -338,11 +340,8 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer): # 修改mapping QuerySet(ProblemParagraphMapping).bulk_update(problem_paragraph_mapping_list, ['document_id']) - - # 修改向量段落信息 - ListenerManagement.update_embedding_document_id(UpdateEmbeddingDocumentIdArgs( - [paragraph.id for paragraph in paragraph_list], - target_document_id, target_dataset_id, target_embedding_model=None)) + update_embedding_document_id([paragraph.id for paragraph in paragraph_list], + target_document_id, target_dataset_id, None) # 修改段落信息 paragraph_list.update(document_id=target_document_id) # 不同数据集迁移 @@ -371,16 +370,14 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer): ['problem_id', 'dataset_id', 'document_id']) target_dataset = QuerySet(DataSet).filter(id=target_dataset_id).first() dataset = QuerySet(DataSet).filter(id=dataset_id).first() - embedding_model = None + embedding_model_id = None if target_dataset.embedding_mode_id != dataset.embedding_mode_id: - embedding_model = get_embedding_model_by_dataset(target_dataset) + embedding_model_id = str(target_dataset.embedding_mode_id) pid_list = [paragraph.id for paragraph in paragraph_list] # 修改段落信息 paragraph_list.update(dataset_id=target_dataset_id, document_id=target_document_id) # 修改向量段落信息 - ListenerManagement.update_embedding_document_id(UpdateEmbeddingDocumentIdArgs( - pid_list, - target_document_id, target_dataset_id, target_embedding_model=embedding_model)) + update_embedding_document_id(pid_list, target_document_id, target_dataset_id, embedding_model_id) update_document_char_length(document_id) update_document_char_length(target_document_id) @@ -466,12 +463,12 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer): @staticmethod def post_embedding(paragraph, instance, dataset_id): if 'is_active' in instance and instance.get('is_active') is not None: - s = (ListenerManagement.enable_embedding_by_paragraph_signal if instance.get( - 'is_active') else ListenerManagement.disable_embedding_by_paragraph_signal) - s.send(paragraph.get('id')) + (enable_embedding_by_paragraph if instance.get( + 'is_active') else disable_embedding_by_paragraph)(paragraph.get('id')) + else: - model = get_embedding_model_by_dataset_id(dataset_id) - ListenerManagement.embedding_by_paragraph_signal.send(paragraph.get('id'), embedding_model=model) + model_id = get_embedding_model_id_by_dataset_id(dataset_id) + embedding_by_paragraph(paragraph.get('id'), model_id) return paragraph @post(post_embedding) @@ -543,7 +540,7 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer): QuerySet(Paragraph).filter(id=paragraph_id).delete() QuerySet(ProblemParagraphMapping).filter(paragraph_id=paragraph_id).delete() update_document_char_length(self.data.get('document_id')) - ListenerManagement.delete_embedding_by_paragraph_signal.send(paragraph_id) + delete_embedding_by_paragraph(paragraph_id) @staticmethod def get_request_body_api(): @@ -593,8 +590,8 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer): # 修改长度 update_document_char_length(document_id) if with_embedding: - model = get_embedding_model_by_dataset_id(dataset_id) - ListenerManagement.embedding_by_paragraph_signal.send(str(paragraph.id), embedding_model=model) + model_id = get_embedding_model_id_by_dataset_id(dataset_id) + embedding_by_paragraph(str(paragraph.id), model_id) return ParagraphSerializers.Operate( data={'paragraph_id': str(paragraph.id), 'dataset_id': dataset_id, 'document_id': document_id}).one( with_valid=True) diff --git a/apps/dataset/serializers/problem_serializers.py b/apps/dataset/serializers/problem_serializers.py index 34064f9a9..65fb44d61 100644 --- a/apps/dataset/serializers/problem_serializers.py +++ b/apps/dataset/serializers/problem_serializers.py @@ -16,12 +16,12 @@ from drf_yasg import openapi from rest_framework import serializers from common.db.search import native_search, native_page_search -from common.event import ListenerManagement, UpdateProblemArgs from common.mixins.api_mixin import ApiMixin from common.util.field_message import ErrMessage from common.util.file_util import get_file_content from dataset.models import Problem, Paragraph, ProblemParagraphMapping, DataSet -from dataset.serializers.common_serializers import get_embedding_model_by_dataset_id +from dataset.serializers.common_serializers import get_embedding_model_id_by_dataset_id +from embedding.task import delete_embedding_by_source_ids, update_problem_embedding from smartdoc.conf import PROJECT_DIR @@ -111,7 +111,7 @@ class ProblemSerializers(ApiMixin, serializers.Serializer): source_ids = [row.id for row in problem_paragraph_mapping_list] problem_paragraph_mapping_list.delete() QuerySet(Problem).filter(id__in=problem_id_list).delete() - ListenerManagement.delete_embedding_by_source_ids_signal.send(source_ids) + delete_embedding_by_source_ids(source_ids) return True class Operate(serializers.Serializer): @@ -146,7 +146,7 @@ class ProblemSerializers(ApiMixin, serializers.Serializer): source_ids = [row.id for row in problem_paragraph_mapping_list] problem_paragraph_mapping_list.delete() QuerySet(Problem).filter(id=self.data.get('problem_id')).delete() - ListenerManagement.delete_embedding_by_source_ids_signal.send(source_ids) + delete_embedding_by_source_ids(source_ids) return True @transaction.atomic @@ -161,5 +161,5 @@ class ProblemSerializers(ApiMixin, serializers.Serializer): QuerySet(DataSet).filter(id=dataset_id) problem.content = content problem.save() - model = get_embedding_model_by_dataset_id(dataset_id) - ListenerManagement.update_problem_signal.send(UpdateProblemArgs(problem_id, content, model)) + model_id = get_embedding_model_id_by_dataset_id(dataset_id) + update_problem_embedding(problem_id, content, model_id) diff --git a/apps/dataset/task/__init__.py b/apps/dataset/task/__init__.py new file mode 100644 index 000000000..de3f96538 --- /dev/null +++ b/apps/dataset/task/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: __init__.py + @date:2024/8/21 9:57 + @desc: +""" +from .sync import * diff --git a/apps/dataset/task/sync.py b/apps/dataset/task/sync.py new file mode 100644 index 000000000..ee4c03430 --- /dev/null +++ b/apps/dataset/task/sync.py @@ -0,0 +1,42 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: sync.py + @date:2024/8/20 21:37 + @desc: +""" + +import logging +import traceback +from typing import List + +from celery_once import QueueOnce + +from common.util.fork import ForkManage, Fork +from dataset.task.tools import get_save_handler, get_sync_web_document_handler + +from ops import celery_app + +max_kb_error = logging.getLogger("max_kb_error") +max_kb = logging.getLogger("max_kb") + + +@celery_app.task(base=QueueOnce, once={'keys': ['dataset_id']}, name='celery:sync_web_dataset') +def sync_web_dataset(dataset_id: str, url: str, selector: str): + try: + max_kb.info(f"开始--->开始同步web知识库:{dataset_id}") + ForkManage(url, selector.split(" ") if selector is not None else []).fork(2, set(), + get_save_handler(dataset_id, + selector)) + max_kb.info(f"结束--->结束同步web知识库:{dataset_id}") + except Exception as e: + max_kb_error.error(f'同步web知识库:{dataset_id}出现错误{str(e)}{traceback.format_exc()}') + + +@celery_app.task(name='celery:sync_web_document') +def sync_web_document(dataset_id, source_url_list: List[str], selector: str): + handler = get_sync_web_document_handler(dataset_id) + for source_url in source_url_list: + result = Fork(base_fork_url=source_url, selector_list=selector.split(' ')).fork() + handler(source_url, selector, result) diff --git a/apps/dataset/task/tools.py b/apps/dataset/task/tools.py new file mode 100644 index 000000000..5eb44084e --- /dev/null +++ b/apps/dataset/task/tools.py @@ -0,0 +1,62 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: tools.py + @date:2024/8/20 21:48 + @desc: +""" + +import logging +import traceback + +from common.util.fork import ChildLink, Fork +from common.util.split_model import get_split_model +from dataset.models import Type, Document, Status + +max_kb_error = logging.getLogger("max_kb_error") +max_kb = logging.getLogger("max_kb") + + +def get_save_handler(dataset_id, selector): + from dataset.serializers.document_serializers import DocumentSerializers + + def handler(child_link: ChildLink, response: Fork.Response): + if response.status == 200: + try: + document_name = child_link.tag.text if child_link.tag is not None and len( + child_link.tag.text.strip()) > 0 else child_link.url + paragraphs = get_split_model('web.md').parse(response.content) + DocumentSerializers.Create(data={'dataset_id': dataset_id}).save( + {'name': document_name, 'paragraphs': paragraphs, + 'meta': {'source_url': child_link.url, 'selector': selector}, + 'type': Type.web}, with_valid=True) + except Exception as e: + logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}') + + return handler + + +def get_sync_web_document_handler(dataset_id): + from dataset.serializers.document_serializers import DocumentSerializers + + def handler(source_url: str, selector, response: Fork.Response): + if response.status == 200: + try: + paragraphs = get_split_model('web.md').parse(response.content) + # 插入 + DocumentSerializers.Create(data={'dataset_id': dataset_id}).save( + {'name': source_url[0:128], 'paragraphs': paragraphs, + 'meta': {'source_url': source_url, 'selector': selector}, + 'type': Type.web}, with_valid=True) + except Exception as e: + logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}') + else: + Document(name=source_url[0:128], + dataset_id=dataset_id, + meta={'source_url': source_url, 'selector': selector}, + type=Type.web, + char_length=0, + status=Status.error).save() + + return handler diff --git a/apps/embedding/task/__init__.py b/apps/embedding/task/__init__.py new file mode 100644 index 000000000..e5e7dd3b4 --- /dev/null +++ b/apps/embedding/task/__init__.py @@ -0,0 +1 @@ +from .embedding import * diff --git a/apps/embedding/task/embedding.py b/apps/embedding/task/embedding.py new file mode 100644 index 000000000..46bf81a4e --- /dev/null +++ b/apps/embedding/task/embedding.py @@ -0,0 +1,218 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: embedding.py + @date:2024/8/19 14:13 + @desc: +""" + +import logging +import traceback +from typing import List + +from celery_once import QueueOnce +from django.db.models import QuerySet + +from common.config.embedding_config import ModelManage +from common.event import ListenerManagement, UpdateProblemArgs, UpdateEmbeddingDatasetIdArgs, \ + UpdateEmbeddingDocumentIdArgs +from dataset.models import Document +from ops import celery_app +from setting.models import Model +from setting.models_provider import get_model + +max_kb_error = logging.getLogger("max_kb_error") +max_kb = logging.getLogger("max_kb") + + +def get_embedding_model(model_id): + model = QuerySet(Model).filter(id=model_id).first() + embedding_model = ModelManage.get_model(model_id, + lambda _id: get_model(model)) + return embedding_model + + +@celery_app.task(base=QueueOnce, once={'keys': ['paragraph_id']}, name='celery:embedding_by_paragraph') +def embedding_by_paragraph(paragraph_id, model_id): + embedding_model = get_embedding_model(model_id) + ListenerManagement.embedding_by_paragraph(paragraph_id, embedding_model) + + +@celery_app.task(base=QueueOnce, once={'keys': ['paragraph_id_list']}, name='celery:embedding_by_paragraph_data_list') +def embedding_by_paragraph_data_list(data_list, paragraph_id_list, model_id): + embedding_model = get_embedding_model(model_id) + ListenerManagement.embedding_by_paragraph_data_list(data_list, paragraph_id_list, embedding_model) + + +@celery_app.task(base=QueueOnce, once={'keys': ['paragraph_id_list']}, name='celery:embedding_by_paragraph_list') +def embedding_by_paragraph_list(paragraph_id_list, model_id): + embedding_model = get_embedding_model(model_id) + ListenerManagement.embedding_by_paragraph_list(paragraph_id_list, embedding_model) + + +@celery_app.task(base=QueueOnce, once={'keys': ['document_id']}, name='celery:embedding_by_document') +def embedding_by_document(document_id, model_id): + """ + 向量化文档 + @param document_id: 文档id + @param model_id 向量模型 + :return: None + """ + embedding_model = get_embedding_model(model_id) + ListenerManagement.embedding_by_document(document_id, embedding_model) + + +@celery_app.task(base=QueueOnce, once={'keys': ['dataset_id']}, name='celery:embedding_by_dataset') +def embedding_by_dataset(dataset_id, model_id): + """ + 向量化知识库 + @param dataset_id: 知识库id + @param model_id 向量模型 + :return: None + """ + max_kb.info(f"开始--->向量化数据集:{dataset_id}") + try: + ListenerManagement.delete_embedding_by_dataset(dataset_id) + document_list = QuerySet(Document).filter(dataset_id=dataset_id) + max_kb.info(f"数据集文档:{[d.name for d in document_list]}") + for document in document_list: + try: + embedding_by_document.delay(document.id, model_id) + except Exception as e: + pass + except Exception as e: + max_kb_error.error(f'向量化数据集:{dataset_id}出现错误{str(e)}{traceback.format_exc()}') + finally: + max_kb.info(f"结束--->向量化数据集:{dataset_id}") + + +def embedding_by_problem(args, model_id): + """ + 向量话问题 + @param args: 问题对象 + @param model_id: 模型id + @return: + """ + embedding_model = get_embedding_model(model_id) + ListenerManagement.embedding_by_problem(args, embedding_model) + + +def delete_embedding_by_document(document_id): + """ + 删除指定文档id的向量 + @param document_id: 文档id + @return: None + """ + + ListenerManagement.delete_embedding_by_document(document_id) + + +def delete_embedding_by_document_list(document_id_list: List[str]): + """ + 删除指定文档列表的向量数据 + @param document_id_list: 文档id列表 + @return: None + """ + ListenerManagement.delete_embedding_by_document_list(document_id_list) + + +def delete_embedding_by_dataset(dataset_id): + """ + 删除指定数据集向量数据 + @param dataset_id: 数据集id + @return: None + """ + ListenerManagement.delete_embedding_by_dataset(dataset_id) + + +def delete_embedding_by_paragraph(paragraph_id): + """ + 删除指定段落的向量数据 + @param paragraph_id: 段落id + @return: None + """ + ListenerManagement.delete_embedding_by_paragraph(paragraph_id) + + +def delete_embedding_by_source(source_id): + """ + 删除指定资源id的向量数据 + @param source_id: 资源id + @return: None + """ + ListenerManagement.delete_embedding_by_source(source_id) + + +def disable_embedding_by_paragraph(paragraph_id): + """ + 禁用某个段落id的向量 + @param paragraph_id: 段落id + @return: None + """ + ListenerManagement.disable_embedding_by_paragraph(paragraph_id) + + +def enable_embedding_by_paragraph(paragraph_id): + """ + 开启某个段落id的向量数据 + @param paragraph_id: 段落id + @return: None + """ + ListenerManagement.enable_embedding_by_paragraph(paragraph_id) + + +def delete_embedding_by_source_ids(source_ids: List[str]): + """ + 删除向量根据source_id_list + @param source_ids: + @return: + """ + ListenerManagement.delete_embedding_by_source_ids(source_ids) + + +def update_problem_embedding(problem_id: str, problem_content: str, model_id): + """ + 更新问题 + @param problem_id: + @param problem_content: + @param model_id: + @return: + """ + model = get_embedding_model(model_id) + ListenerManagement.update_problem(UpdateProblemArgs(problem_id, problem_content, model)) + + +def update_embedding_dataset_id(paragraph_id_list, target_dataset_id, target_embedding_model_id=None): + """ + 修改向量数据到指定知识库 + @param paragraph_id_list: 指定段落的向量数据 + @param target_dataset_id: 知识库id + @param target_embedding_model_id: 目标知识库 + @return: + """ + target_embedding_model = get_embedding_model( + target_embedding_model_id) if target_embedding_model_id is not None else None + ListenerManagement.update_embedding_dataset_id( + UpdateEmbeddingDatasetIdArgs(paragraph_id_list, target_dataset_id, target_embedding_model)) + + +def delete_embedding_by_paragraph_ids(paragraph_ids: List[str]): + """ + 删除指定段落列表的向量数据 + @param paragraph_ids: 段落列表 + @return: None + """ + ListenerManagement.delete_embedding_by_paragraph_ids(paragraph_ids) + + +def update_embedding_document_id(paragraph_id_list, target_document_id, target_dataset_id, + target_embedding_model_id=None): + target_embedding_model = get_embedding_model( + target_embedding_model_id) if target_embedding_model_id is not None else None + ListenerManagement.update_embedding_document_id( + UpdateEmbeddingDocumentIdArgs(paragraph_id_list, target_document_id, target_dataset_id, target_embedding_model)) + + +def delete_embedding_by_dataset_id_list(dataset_id_list): + ListenerManagement.delete_embedding_by_dataset_id_list(dataset_id_list) diff --git a/apps/embedding/vector/base_vector.py b/apps/embedding/vector/base_vector.py index fd9ca3391..f9ef23d72 100644 --- a/apps/embedding/vector/base_vector.py +++ b/apps/embedding/vector/base_vector.py @@ -127,7 +127,7 @@ class BaseVectorStore(ABC): return [] embedding_query = embedding.embed_query(query_text) result = self.query(embedding_query, dataset_id_list, exclude_document_id_list, exclude_paragraph_list, - is_active, 1, 0.65) + is_active, 1, 3, 0.65) return result[0] @abstractmethod @@ -169,7 +169,7 @@ class BaseVectorStore(ABC): pass @abstractmethod - def delete_bu_document_id_list(self, document_id_list: List[str]): + def delete_by_document_id_list(self, document_id_list: List[str]): pass @abstractmethod diff --git a/apps/embedding/vector/pg_vector.py b/apps/embedding/vector/pg_vector.py index 1866971d5..8cd2146ad 100644 --- a/apps/embedding/vector/pg_vector.py +++ b/apps/embedding/vector/pg_vector.py @@ -27,6 +27,8 @@ from smartdoc.conf import PROJECT_DIR class PGVector(BaseVectorStore): def delete_by_source_ids(self, source_ids: List[str], source_type: str): + if len(source_ids) == 0: + return QuerySet(Embedding).filter(source_id__in=source_ids, source_type=source_type).delete() def update_by_source_ids(self, source_ids: List[str], instance: Dict): @@ -67,7 +69,7 @@ class PGVector(BaseVectorStore): source_type=text_list[index].get('source_type'), embedding=embeddings[index], search_vector=to_ts_vector(text_list[index]['text'])) for index in - range(0, len(text_list))] + range(0, len(texts))] if is_save_function(): QuerySet(Embedding).bulk_create(embedding_list) if len(embedding_list) > 0 else None return True @@ -124,7 +126,9 @@ class PGVector(BaseVectorStore): QuerySet(Embedding).filter(document_id=document_id).delete() return True - def delete_bu_document_id_list(self, document_id_list: List[str]): + def delete_by_document_id_list(self, document_id_list: List[str]): + if len(document_id_list) == 0: + return True return QuerySet(Embedding).filter(document_id__in=document_id_list).delete() def delete_by_source_id(self, source_id: str, source_type: str): diff --git a/apps/function_lib/task/__init__.py b/apps/function_lib/task/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/apps/ops/__init__.py b/apps/ops/__init__.py new file mode 100644 index 000000000..a02f13af3 --- /dev/null +++ b/apps/ops/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: __init__.py.py + @date:2024/8/16 14:47 + @desc: +""" +from .celery import app as celery_app diff --git a/apps/ops/celery/__init__.py b/apps/ops/celery/__init__.py new file mode 100644 index 000000000..55e727bf2 --- /dev/null +++ b/apps/ops/celery/__init__.py @@ -0,0 +1,33 @@ +# -*- coding: utf-8 -*- + +import os + +from celery import Celery +from celery.schedules import crontab +from kombu import Exchange, Queue +from smartdoc import settings +from .heatbeat import * + +# set the default Django settings module for the 'celery' program. +os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'smartdoc.settings') + +app = Celery('MaxKB') + +configs = {k: v for k, v in settings.__dict__.items() if k.startswith('CELERY')} +configs['worker_concurrency'] = 5 +# Using a string here means the worker will not have to +# pickle the object when using Windows. +# app.config_from_object('django.conf:settings', namespace='CELERY') + +configs["task_queues"] = [ + Queue("celery", Exchange("celery"), routing_key="celery"), + Queue("model", Exchange("model"), routing_key="model") +] +app.namespace = 'CELERY' +app.conf.update( + {key.replace('CELERY_', '') if key.replace('CELERY_', '').lower() == key.replace('CELERY_', + '') else key: configs.get( + key) for + key + in configs.keys()}) +app.autodiscover_tasks(lambda: [app_config.split('.')[0] for app_config in settings.INSTALLED_APPS]) diff --git a/apps/ops/celery/const.py b/apps/ops/celery/const.py new file mode 100644 index 000000000..2f887023f --- /dev/null +++ b/apps/ops/celery/const.py @@ -0,0 +1,4 @@ +# -*- coding: utf-8 -*- +# + +CELERY_LOG_MAGIC_MARK = b'\x00\x00\x00\x00\x00' \ No newline at end of file diff --git a/apps/ops/celery/decorator.py b/apps/ops/celery/decorator.py new file mode 100644 index 000000000..317a7f7ae --- /dev/null +++ b/apps/ops/celery/decorator.py @@ -0,0 +1,104 @@ +# -*- coding: utf-8 -*- +# +from functools import wraps + +_need_registered_period_tasks = [] +_after_app_ready_start_tasks = [] +_after_app_shutdown_clean_periodic_tasks = [] + + +def add_register_period_task(task): + _need_registered_period_tasks.append(task) + + +def get_register_period_tasks(): + return _need_registered_period_tasks + + +def add_after_app_shutdown_clean_task(name): + _after_app_shutdown_clean_periodic_tasks.append(name) + + +def get_after_app_shutdown_clean_tasks(): + return _after_app_shutdown_clean_periodic_tasks + + +def add_after_app_ready_task(name): + _after_app_ready_start_tasks.append(name) + + +def get_after_app_ready_tasks(): + return _after_app_ready_start_tasks + + +def register_as_period_task( + crontab=None, interval=None, name=None, + args=(), kwargs=None, + description=''): + """ + Warning: Task must have not any args and kwargs + :param crontab: "* * * * *" + :param interval: 60*60*60 + :param args: () + :param kwargs: {} + :param description: " + :param name: "" + :return: + """ + if crontab is None and interval is None: + raise SyntaxError("Must set crontab or interval one") + + def decorate(func): + if crontab is None and interval is None: + raise SyntaxError("Interval and crontab must set one") + + # Because when this decorator run, the task was not created, + # So we can't use func.name + task = '{func.__module__}.{func.__name__}'.format(func=func) + _name = name if name else task + add_register_period_task({ + _name: { + 'task': task, + 'interval': interval, + 'crontab': crontab, + 'args': args, + 'kwargs': kwargs if kwargs else {}, + 'description': description + } + }) + + @wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + return wrapper + + return decorate + + +def after_app_ready_start(func): + # Because when this decorator run, the task was not created, + # So we can't use func.name + name = '{func.__module__}.{func.__name__}'.format(func=func) + if name not in _after_app_ready_start_tasks: + add_after_app_ready_task(name) + + @wraps(func) + def decorate(*args, **kwargs): + return func(*args, **kwargs) + + return decorate + + +def after_app_shutdown_clean_periodic(func): + # Because when this decorator run, the task was not created, + # So we can't use func.name + name = '{func.__module__}.{func.__name__}'.format(func=func) + if name not in _after_app_shutdown_clean_periodic_tasks: + add_after_app_shutdown_clean_task(name) + + @wraps(func) + def decorate(*args, **kwargs): + return func(*args, **kwargs) + + return decorate diff --git a/apps/ops/celery/heatbeat.py b/apps/ops/celery/heatbeat.py new file mode 100644 index 000000000..339a3c60a --- /dev/null +++ b/apps/ops/celery/heatbeat.py @@ -0,0 +1,25 @@ +from pathlib import Path + +from celery.signals import heartbeat_sent, worker_ready, worker_shutdown + + +@heartbeat_sent.connect +def heartbeat(sender, **kwargs): + worker_name = sender.eventer.hostname.split('@')[0] + heartbeat_path = Path('/tmp/worker_heartbeat_{}'.format(worker_name)) + heartbeat_path.touch() + + +@worker_ready.connect +def worker_ready(sender, **kwargs): + worker_name = sender.hostname.split('@')[0] + ready_path = Path('/tmp/worker_ready_{}'.format(worker_name)) + ready_path.touch() + + +@worker_shutdown.connect +def worker_shutdown(sender, **kwargs): + worker_name = sender.hostname.split('@')[0] + for signal in ['ready', 'heartbeat']: + path = Path('/tmp/worker_{}_{}'.format(signal, worker_name)) + path.unlink(missing_ok=True) diff --git a/apps/ops/celery/logger.py b/apps/ops/celery/logger.py new file mode 100644 index 000000000..bdadc5685 --- /dev/null +++ b/apps/ops/celery/logger.py @@ -0,0 +1,223 @@ +from logging import StreamHandler +from threading import get_ident + +from celery import current_task +from celery.signals import task_prerun, task_postrun +from django.conf import settings +from kombu import Connection, Exchange, Queue, Producer +from kombu.mixins import ConsumerMixin + +from .utils import get_celery_task_log_path +from .const import CELERY_LOG_MAGIC_MARK + +routing_key = 'celery_log' +celery_log_exchange = Exchange('celery_log_exchange', type='direct') +celery_log_queue = [Queue('celery_log', celery_log_exchange, routing_key=routing_key)] + + +class CeleryLoggerConsumer(ConsumerMixin): + def __init__(self): + self.connection = Connection(settings.CELERY_LOG_BROKER_URL) + + def get_consumers(self, Consumer, channel): + return [Consumer(queues=celery_log_queue, + accept=['pickle', 'json'], + callbacks=[self.process_task]) + ] + + def handle_task_start(self, task_id, message): + pass + + def handle_task_end(self, task_id, message): + pass + + def handle_task_log(self, task_id, msg, message): + pass + + def process_task(self, body, message): + action = body.get('action') + task_id = body.get('task_id') + msg = body.get('msg') + if action == CeleryLoggerProducer.ACTION_TASK_LOG: + self.handle_task_log(task_id, msg, message) + elif action == CeleryLoggerProducer.ACTION_TASK_START: + self.handle_task_start(task_id, message) + elif action == CeleryLoggerProducer.ACTION_TASK_END: + self.handle_task_end(task_id, message) + + +class CeleryLoggerProducer: + ACTION_TASK_START, ACTION_TASK_LOG, ACTION_TASK_END = range(3) + + def __init__(self): + self.connection = Connection(settings.CELERY_LOG_BROKER_URL) + + @property + def producer(self): + return Producer(self.connection) + + def publish(self, payload): + self.producer.publish( + payload, serializer='json', exchange=celery_log_exchange, + declare=[celery_log_exchange], routing_key=routing_key + ) + + def log(self, task_id, msg): + payload = {'task_id': task_id, 'msg': msg, 'action': self.ACTION_TASK_LOG} + return self.publish(payload) + + def read(self): + pass + + def flush(self): + pass + + def task_end(self, task_id): + payload = {'task_id': task_id, 'action': self.ACTION_TASK_END} + return self.publish(payload) + + def task_start(self, task_id): + payload = {'task_id': task_id, 'action': self.ACTION_TASK_START} + return self.publish(payload) + + +class CeleryTaskLoggerHandler(StreamHandler): + terminator = '\r\n' + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + task_prerun.connect(self.on_task_start) + task_postrun.connect(self.on_start_end) + + @staticmethod + def get_current_task_id(): + if not current_task: + return + task_id = current_task.request.root_id + return task_id + + def on_task_start(self, sender, task_id, **kwargs): + return self.handle_task_start(task_id) + + def on_start_end(self, sender, task_id, **kwargs): + return self.handle_task_end(task_id) + + def after_task_publish(self, sender, body, **kwargs): + pass + + def emit(self, record): + task_id = self.get_current_task_id() + if not task_id: + return + try: + self.write_task_log(task_id, record) + self.flush() + except Exception: + self.handleError(record) + + def write_task_log(self, task_id, msg): + pass + + def handle_task_start(self, task_id): + pass + + def handle_task_end(self, task_id): + pass + + +class CeleryThreadingLoggerHandler(CeleryTaskLoggerHandler): + @staticmethod + def get_current_thread_id(): + return str(get_ident()) + + def emit(self, record): + thread_id = self.get_current_thread_id() + try: + self.write_thread_task_log(thread_id, record) + self.flush() + except ValueError: + self.handleError(record) + + def write_thread_task_log(self, thread_id, msg): + pass + + def handle_task_start(self, task_id): + pass + + def handle_task_end(self, task_id): + pass + + def handleError(self, record) -> None: + pass + + +class CeleryTaskMQLoggerHandler(CeleryTaskLoggerHandler): + def __init__(self): + self.producer = CeleryLoggerProducer() + super().__init__(stream=None) + + def write_task_log(self, task_id, record): + msg = self.format(record) + self.producer.log(task_id, msg) + + def flush(self): + self.producer.flush() + + +class CeleryTaskFileHandler(CeleryTaskLoggerHandler): + def __init__(self, *args, **kwargs): + self.f = None + super().__init__(*args, **kwargs) + + def emit(self, record): + msg = self.format(record) + if not self.f or self.f.closed: + return + self.f.write(msg) + self.f.write(self.terminator) + self.flush() + + def flush(self): + self.f and self.f.flush() + + def handle_task_start(self, task_id): + log_path = get_celery_task_log_path(task_id) + self.f = open(log_path, 'a') + + def handle_task_end(self, task_id): + self.f and self.f.close() + + +class CeleryThreadTaskFileHandler(CeleryThreadingLoggerHandler): + def __init__(self, *args, **kwargs): + self.thread_id_fd_mapper = {} + self.task_id_thread_id_mapper = {} + super().__init__(*args, **kwargs) + + def write_thread_task_log(self, thread_id, record): + f = self.thread_id_fd_mapper.get(thread_id, None) + if not f: + raise ValueError('Not found thread task file') + msg = self.format(record) + f.write(msg.encode()) + f.write(self.terminator.encode()) + f.flush() + + def flush(self): + for f in self.thread_id_fd_mapper.values(): + f.flush() + + def handle_task_start(self, task_id): + log_path = get_celery_task_log_path(task_id) + thread_id = self.get_current_thread_id() + self.task_id_thread_id_mapper[task_id] = thread_id + f = open(log_path, 'ab') + self.thread_id_fd_mapper[thread_id] = f + + def handle_task_end(self, task_id): + ident_id = self.task_id_thread_id_mapper.get(task_id, '') + f = self.thread_id_fd_mapper.pop(ident_id, None) + if f and not f.closed: + f.write(CELERY_LOG_MAGIC_MARK) + f.close() + self.task_id_thread_id_mapper.pop(task_id, None) diff --git a/apps/ops/celery/signal_handler.py b/apps/ops/celery/signal_handler.py new file mode 100644 index 000000000..90ed62405 --- /dev/null +++ b/apps/ops/celery/signal_handler.py @@ -0,0 +1,63 @@ +# -*- coding: utf-8 -*- +# +import logging +import os + +from celery import subtask +from celery.signals import ( + worker_ready, worker_shutdown, after_setup_logger +) +from django.core.cache import cache +from django_celery_beat.models import PeriodicTask + +from .decorator import get_after_app_ready_tasks, get_after_app_shutdown_clean_tasks +from .logger import CeleryThreadTaskFileHandler + +logger = logging.getLogger(__file__) +safe_str = lambda x: x + + +@worker_ready.connect +def on_app_ready(sender=None, headers=None, **kwargs): + if cache.get("CELERY_APP_READY", 0) == 1: + return + cache.set("CELERY_APP_READY", 1, 10) + tasks = get_after_app_ready_tasks() + logger.debug("Work ready signal recv") + logger.debug("Start need start task: [{}]".format(", ".join(tasks))) + for task in tasks: + periodic_task = PeriodicTask.objects.filter(task=task).first() + if periodic_task and not periodic_task.enabled: + logger.debug("Periodic task [{}] is disabled!".format(task)) + continue + subtask(task).delay() + + +def delete_files(directory): + if os.path.isdir(directory): + for filename in os.listdir(directory): + file_path = os.path.join(directory, filename) + if os.path.isfile(file_path): + os.remove(file_path) + + +@worker_shutdown.connect +def after_app_shutdown_periodic_tasks(sender=None, **kwargs): + if cache.get("CELERY_APP_SHUTDOWN", 0) == 1: + return + cache.set("CELERY_APP_SHUTDOWN", 1, 10) + tasks = get_after_app_shutdown_clean_tasks() + logger.debug("Worker shutdown signal recv") + logger.debug("Clean period tasks: [{}]".format(', '.join(tasks))) + PeriodicTask.objects.filter(name__in=tasks).delete() + + +@after_setup_logger.connect +def add_celery_logger_handler(sender=None, logger=None, loglevel=None, format=None, **kwargs): + if not logger: + return + task_handler = CeleryThreadTaskFileHandler() + task_handler.setLevel(loglevel) + formatter = logging.Formatter(format) + task_handler.setFormatter(formatter) + logger.addHandler(task_handler) diff --git a/apps/ops/celery/utils.py b/apps/ops/celery/utils.py new file mode 100644 index 000000000..288089f6f --- /dev/null +++ b/apps/ops/celery/utils.py @@ -0,0 +1,68 @@ +# -*- coding: utf-8 -*- +# +import logging +import os +import uuid + +from django.conf import settings +from django_celery_beat.models import ( + PeriodicTasks +) + +from smartdoc.const import PROJECT_DIR + +logger = logging.getLogger(__file__) + + +def disable_celery_periodic_task(task_name): + from django_celery_beat.models import PeriodicTask + PeriodicTask.objects.filter(name=task_name).update(enabled=False) + PeriodicTasks.update_changed() + + +def delete_celery_periodic_task(task_name): + from django_celery_beat.models import PeriodicTask + PeriodicTask.objects.filter(name=task_name).delete() + PeriodicTasks.update_changed() + + +def get_celery_periodic_task(task_name): + from django_celery_beat.models import PeriodicTask + task = PeriodicTask.objects.filter(name=task_name).first() + return task + + +def make_dirs(name, mode=0o755, exist_ok=False): + """ 默认权限设置为 0o755 """ + return os.makedirs(name, mode=mode, exist_ok=exist_ok) + + +def get_task_log_path(base_path, task_id, level=2): + task_id = str(task_id) + try: + uuid.UUID(task_id) + except: + return os.path.join(PROJECT_DIR, 'data', 'caution.txt') + + rel_path = os.path.join(*task_id[:level], task_id + '.log') + path = os.path.join(base_path, rel_path) + make_dirs(os.path.dirname(path), exist_ok=True) + return path + + +def get_celery_task_log_path(task_id): + return get_task_log_path(settings.CELERY_LOG_DIR, task_id) + + +def get_celery_status(): + from . import app + i = app.control.inspect() + ping_data = i.ping() or {} + active_nodes = [k for k, v in ping_data.items() if v.get('ok') == 'pong'] + active_queue_worker = set([n.split('@')[0] for n in active_nodes if n]) + # Celery Worker 数量: 2 + if len(active_queue_worker) < 2: + print("Not all celery worker worked") + return False + else: + return True diff --git a/apps/setting/models_provider/__init__.py b/apps/setting/models_provider/__init__.py index 4197eb042..7f573ec5e 100644 --- a/apps/setting/models_provider/__init__.py +++ b/apps/setting/models_provider/__init__.py @@ -13,18 +13,22 @@ from common.util.rsa_util import rsa_long_decrypt from setting.models_provider.constants.model_provider_constants import ModelProvideConstants -def get_model_(provider, model_type, model_name, credential, **kwargs): +def get_model_(provider, model_type, model_name, credential, model_id, use_local=False, **kwargs): """ 获取模型实例 @param provider: 供应商 @param model_type: 模型类型 @param model_name: 模型名称 @param credential: 认证信息 + @param model_id: 模型id + @param use_local: 是否调用本地模型 只适用于本地供应商 @return: 模型实例 """ model = get_provider(provider).get_model(model_type, model_name, json.loads( rsa_long_decrypt(credential)), + model_id=model_id, + use_local=use_local, streaming=True, **kwargs) return model @@ -35,7 +39,7 @@ def get_model(model, **kwargs): @param model: model 数据库Model实例对象 @return: 模型实例 """ - return get_model_(model.provider, model.model_type, model.model_name, model.credential, **kwargs) + return get_model_(model.provider, model.model_type, model.model_name, model.credential, str(model.id), **kwargs) def get_provider(provider): diff --git a/apps/setting/models_provider/impl/local_model_provider/model/embedding.py b/apps/setting/models_provider/impl/local_model_provider/model/embedding.py index 92cdd7390..dd4918909 100644 --- a/apps/setting/models_provider/impl/local_model_provider/model/embedding.py +++ b/apps/setting/models_provider/impl/local_model_provider/model/embedding.py @@ -6,17 +6,52 @@ @date:2024/7/11 14:06 @desc: """ -from typing import Dict +from typing import Dict, List +import requests +from langchain_core.embeddings import Embeddings +from langchain_core.pydantic_v1 import BaseModel from langchain_huggingface import HuggingFaceEmbeddings from setting.models_provider.base_model_provider import MaxKBBaseModel -class LocalEmbedding(MaxKBBaseModel, HuggingFaceEmbeddings): +class WebLocalEmbedding(MaxKBBaseModel, BaseModel, Embeddings): @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): - return LocalEmbedding(model_name=model_name, cache_folder=model_credential.get('cache_folder'), - model_kwargs={'device': model_credential.get('device')}, - encode_kwargs={'normalize_embeddings': True}, - ) + pass + + model_id: str = None + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.model_id = kwargs.get('model_id', None) + + def embed_query(self, text: str) -> List[float]: + res = requests.post(f'http://127.0.0.1:5432/api/model/{self.model_id}/embed_query', {'text': text}) + result = res.json() + if result.get('code', 500) == 200: + return result.get('data') + raise Exception(result.get('msg')) + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + res = requests.post(f'http://127.0.0.1:5432/api/model/{self.model_id}/embed_documents', {'texts': texts}) + result = res.json() + if result.get('code', 500) == 200: + return result.get('data') + raise Exception(result.get('msg')) + + +class LocalEmbedding(MaxKBBaseModel, HuggingFaceEmbeddings): + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + if model_kwargs.get('use_local', True): + return LocalEmbedding(model_name=model_name, cache_folder=model_credential.get('cache_folder'), + model_kwargs={'device': model_credential.get('device')}, + encode_kwargs={'normalize_embeddings': True} + ) + return WebLocalEmbedding(model_name=model_name, cache_folder=model_credential.get('cache_folder'), + model_kwargs={'device': model_credential.get('device')}, + encode_kwargs={'normalize_embeddings': True}, + **model_kwargs) diff --git a/apps/setting/serializers/model_apply_serializers.py b/apps/setting/serializers/model_apply_serializers.py new file mode 100644 index 000000000..2177b5fe6 --- /dev/null +++ b/apps/setting/serializers/model_apply_serializers.py @@ -0,0 +1,53 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: model_apply_serializers.py + @date:2024/8/20 20:39 + @desc: +""" +from django.db.models import QuerySet +from rest_framework import serializers + +from common.config.embedding_config import ModelManage +from common.util.field_message import ErrMessage +from setting.models import Model +from setting.models_provider import get_model + + +def get_embedding_model(model_id): + model = QuerySet(Model).filter(id=model_id).first() + embedding_model = ModelManage.get_model(model_id, + lambda _id: get_model(model, use_local=True)) + return embedding_model + + +class EmbedDocuments(serializers.Serializer): + texts = serializers.ListField(required=True, child=serializers.CharField(required=True, + error_messages=ErrMessage.char( + "向量文本")), + error_messages=ErrMessage.list("向量文本列表")) + + +class EmbedQuery(serializers.Serializer): + text = serializers.CharField(required=True, error_messages=ErrMessage.char("向量文本")) + + +class ModelApplySerializers(serializers.Serializer): + model_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("模型id")) + + def embed_documents(self, instance, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + EmbedDocuments(data=instance).is_valid(raise_exception=True) + + model = get_embedding_model(self.data.get('model_id')) + return model.embed_documents(instance.getlist('texts')) + + def embed_query(self, instance, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + EmbedQuery(data=instance).is_valid(raise_exception=True) + + model = get_embedding_model(self.data.get('model_id')) + return model.embed_query(instance.get('text')) diff --git a/apps/setting/urls.py b/apps/setting/urls.py index 650e2a2b3..865ce4088 100644 --- a/apps/setting/urls.py +++ b/apps/setting/urls.py @@ -1,3 +1,5 @@ +import os + from django.urls import path from . import views @@ -22,3 +24,10 @@ urlpatterns = [ path('valid//', views.Valid.as_view()) ] +if os.environ.get('SERVER_NAME', 'web') == 'local_model': + urlpatterns += [ + path('model//embed_documents', views.ModelApply.EmbedDocuments.as_view(), + name='model/embed_documents'), + path('model//embed_query', views.ModelApply.EmbedQuery.as_view(), + name='model/embed_query'), + ] diff --git a/apps/setting/views/__init__.py b/apps/setting/views/__init__.py index d222d52aa..4fe505635 100644 --- a/apps/setting/views/__init__.py +++ b/apps/setting/views/__init__.py @@ -10,3 +10,4 @@ from .Team import * from .model import * from .system_setting import * from .valid import * +from .model_apply import * diff --git a/apps/setting/views/model_apply.py b/apps/setting/views/model_apply.py new file mode 100644 index 000000000..4a4e6139c --- /dev/null +++ b/apps/setting/views/model_apply.py @@ -0,0 +1,38 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: model_apply.py + @date:2024/8/20 20:38 + @desc: +""" +from urllib.request import Request + +from drf_yasg.utils import swagger_auto_schema +from rest_framework.decorators import action +from rest_framework.views import APIView + +from common.response import result +from setting.serializers.model_apply_serializers import ModelApplySerializers + + +class ModelApply(APIView): + class EmbedDocuments(APIView): + @action(methods=['POST'], detail=False) + @swagger_auto_schema(operation_summary="向量化文档", + operation_id="向量化文档", + responses=result.get_default_response(), + tags=["模型"]) + def post(self, request: Request, model_id): + return result.success( + ModelApplySerializers(data={'model_id': model_id}).embed_documents(request.data)) + + class EmbedQuery(APIView): + @action(methods=['POST'], detail=False) + @swagger_auto_schema(operation_summary="向量化文档", + operation_id="向量化文档", + responses=result.get_default_response(), + tags=["模型"]) + def post(self, request: Request, model_id): + return result.success( + ModelApplySerializers(data={'model_id': model_id}).embed_query(request.data)) diff --git a/apps/smartdoc/settings/__init__.py b/apps/smartdoc/settings/__init__.py index dd08e45de..4e7ea78e3 100644 --- a/apps/smartdoc/settings/__init__.py +++ b/apps/smartdoc/settings/__init__.py @@ -9,3 +9,4 @@ from .base import * from .logging import * from .auth import * +from .lib import * diff --git a/apps/smartdoc/settings/base.py b/apps/smartdoc/settings/base.py index f01e0587f..97b976545 100644 --- a/apps/smartdoc/settings/base.py +++ b/apps/smartdoc/settings/base.py @@ -42,7 +42,8 @@ INSTALLED_APPS = [ 'django_filters', # 条件过滤 'django_apscheduler', 'common', - 'function_lib' + 'function_lib', + 'django_celery_beat' ] @@ -60,6 +61,7 @@ JWT_AUTH = { 'JWT_EXPIRATION_DELTA': datetime.timedelta(seconds=60 * 60 * 2) # <-- 设置token有效时间 } +APPS_DIR = os.path.join(PROJECT_DIR, 'apps') ROOT_URLCONF = 'smartdoc.urls' # FORCE_SCRIPT_NAME TEMPLATES = [ diff --git a/apps/smartdoc/settings/lib.py b/apps/smartdoc/settings/lib.py new file mode 100644 index 000000000..e7b6d39dd --- /dev/null +++ b/apps/smartdoc/settings/lib.py @@ -0,0 +1,40 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: lib.py + @date:2024/8/16 17:12 + @desc: +""" +import os + +from smartdoc.const import CONFIG, PROJECT_DIR + +# celery相关配置 +celery_data_dir = os.path.join(PROJECT_DIR, 'data', 'celery_task') +if not os.path.exists(celery_data_dir) or not os.path.isdir(celery_data_dir): + os.makedirs(celery_data_dir) +broker_path = os.path.join(celery_data_dir, "celery_db.sqlite3") +backend_path = os.path.join(celery_data_dir, "celery_results.sqlite3") +# 使用sql_lite 当做broker 和 响应接收 +CELERY_BROKER_URL = f'sqla+sqlite:///{broker_path}' +CELERY_result_backend = f'db+sqlite:///{backend_path}' +CELERY_timezone = CONFIG.TIME_ZONE +CELERY_ENABLE_UTC = False +CELERY_task_serializer = 'pickle' +CELERY_result_serializer = 'pickle' +CELERY_accept_content = ['json', 'pickle'] +CELERY_RESULT_EXPIRES = 600 +CELERY_WORKER_TASK_LOG_FORMAT = '%(asctime).19s %(message)s' +CELERY_WORKER_LOG_FORMAT = '%(asctime).19s %(message)s' +CELERY_TASK_EAGER_PROPAGATES = True +CELERY_WORKER_REDIRECT_STDOUTS = True +CELERY_WORKER_REDIRECT_STDOUTS_LEVEL = "INFO" +CELERY_TASK_SOFT_TIME_LIMIT = 3600 +CELERY_WORKER_CANCEL_LONG_RUNNING_TASKS_ON_CONNECTION_LOSS = True +CELERY_ONCE = { + 'backend': 'celery_once.backends.File', + 'settings': {'location': os.path.join(celery_data_dir, "celery_once")} +} +CELERY_BROKER_CONNECTION_RETRY_ON_STARTUP = True +CELERY_LOG_DIR = os.path.join(PROJECT_DIR, 'logs', 'celery') diff --git a/apps/smartdoc/settings/logging.py b/apps/smartdoc/settings/logging.py index 2627f1201..e4fafa348 100644 --- a/apps/smartdoc/settings/logging.py +++ b/apps/smartdoc/settings/logging.py @@ -114,6 +114,14 @@ LOGGING = { 'level': LOG_LEVEL, 'propagate': False, }, + 'common.event': { + 'handlers': ['console', 'file'], + 'level': "DEBUG", + 'propagate': False, + }, + 'sqlalchemy': {'handlers': ['console', 'file'], + 'level': "ERROR", + 'propagate': False, } } } diff --git a/apps/smartdoc/wsgi.py b/apps/smartdoc/wsgi.py index 2c11c5d9d..6c7c68115 100644 --- a/apps/smartdoc/wsgi.py +++ b/apps/smartdoc/wsgi.py @@ -21,7 +21,6 @@ def post_handler(): from common import job from common.models.db_model_manage import DBModelManage event.run() - event.ListenerManagement.init_embedding_model_signal.send() job.run() DBModelManage.init() diff --git a/apps/users/apps.py b/apps/users/apps.py index 1ea7bf62f..8e0856152 100644 --- a/apps/users/apps.py +++ b/apps/users/apps.py @@ -5,3 +5,5 @@ class UsersConfig(AppConfig): default_auto_field = 'django.db.models.BigAutoField' name = 'users' + def ready(self): + from ops.celery import signal_handler diff --git a/apps/users/serializers/user_serializers.py b/apps/users/serializers/user_serializers.py index 4cb1e7358..5401b8a33 100644 --- a/apps/users/serializers/user_serializers.py +++ b/apps/users/serializers/user_serializers.py @@ -26,7 +26,6 @@ from common.constants.authentication_type import AuthenticationType from common.constants.exception_code_constants import ExceptionCodeConstants from common.constants.permission_constants import RoleConstants, get_permission_list_by_role from common.db.search import page_search -from common.event import ListenerManagement from common.exception.app_exception import AppApiException from common.mixins.api_mixin import ApiMixin from common.response.result import get_api_response @@ -34,6 +33,7 @@ from common.util.common import valid_license from common.util.field_message import ErrMessage from common.util.lock import lock from dataset.models import DataSet, Document, Paragraph, Problem, ProblemParagraphMapping +from embedding.task import delete_embedding_by_dataset_id_list from setting.models import Team, SystemSetting, SettingType, Model, TeamMember, TeamMemberPermission from smartdoc.conf import PROJECT_DIR from users.models.user import User, password_encrypt, get_user_dynamics_permission @@ -497,7 +497,8 @@ class UserSerializer(ApiMixin, serializers.ModelSerializer): class UserInstanceSerializer(ApiMixin, serializers.ModelSerializer): class Meta: model = User - fields = ['id', 'username', 'email', 'phone', 'is_active', 'role', 'nick_name', 'create_time', 'update_time', 'source'] + fields = ['id', 'username', 'email', 'phone', 'is_active', 'role', 'nick_name', 'create_time', 'update_time', + 'source'] @staticmethod def get_response_body_api(): @@ -643,7 +644,8 @@ class UserManageSerializer(serializers.Serializer): def is_valid(self, *, user_id=None, raise_exception=False): super().is_valid(raise_exception=True) - if self.data.get('email') is not None and QuerySet(User).filter(email=self.data.get('email')).exclude(id=user_id).exists(): + if self.data.get('email') is not None and QuerySet(User).filter(email=self.data.get('email')).exclude( + id=user_id).exists(): raise AppApiException(1004, "邮箱已经被使用") @staticmethod @@ -738,7 +740,7 @@ class UserManageSerializer(serializers.Serializer): QuerySet(Paragraph).filter(dataset_id__in=dataset_id_list).delete() QuerySet(ProblemParagraphMapping).filter(dataset_id__in=dataset_id_list).delete() QuerySet(Problem).filter(dataset_id__in=dataset_id_list).delete() - ListenerManagement.delete_embedding_by_dataset_id_list_signal.send(dataset_id_list) + delete_embedding_by_dataset_id_list(dataset_id_list) dataset_list.delete() # 删除团队 QuerySet(Team).filter(user_id=self.data.get('id')).delete() diff --git a/apps/users/task/__init__.py b/apps/users/task/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/main.py b/main.py index 5f3df7965..1ae33fcb4 100644 --- a/main.py +++ b/main.py @@ -2,6 +2,7 @@ import argparse import logging import os import sys +import time import django from django.core import management @@ -43,11 +44,38 @@ def perform_db_migrate(): def start_services(): - management.call_command('gunicorn') + services = args.services if isinstance(args.services, list) else [args.services] + start_args = [] + if args.daemon: + start_args.append('--daemon') + if args.force: + start_args.append('--force') + if args.worker: + start_args.extend(['--worker', str(args.worker)]) + else: + worker = os.environ.get('CORE_WORKER') + if isinstance(worker, str) and worker.isdigit(): + start_args.extend(['--worker', worker]) + + try: + management.call_command(action, *services, *start_args) + except KeyboardInterrupt: + logging.info('Cancel ...') + time.sleep(2) + except Exception as exc: + logging.error("Start service error {}: {}".format(services, exc)) + time.sleep(2) -def runserver(): - management.call_command('runserver', "0.0.0.0:8080") +def dev(): + services = args.services if isinstance(args.services, list) else args.services + if services.__contains__('web'): + management.call_command('runserver', "0.0.0.0:8080") + elif services.__contains__('celery'): + management.call_command('celery', 'celery') + elif services.__contains__('local_model'): + os.environ.setdefault('SERVER_NAME', 'local_model') + management.call_command('runserver', "127.0.0.1:5432") if __name__ == '__main__': @@ -66,8 +94,17 @@ if __name__ == '__main__': choices=("start", "dev", "upgrade_db", "collect_static"), help="Action to run" ) - args = parser.parse_args() + args, e = parser.parse_known_args() + parser.add_argument( + "services", type=str, default='all' if args.action == 'start' else 'web', nargs="*", + choices=("all", "web", "task") if args.action == 'start' else ("web", "celery", 'local_model'), + help="The service to start", + ) + parser.add_argument('-d', '--daemon', nargs="?", const=True) + parser.add_argument('-w', '--worker', type=int, nargs="?") + parser.add_argument('-f', '--force', nargs="?", const=True) + args = parser.parse_args() action = args.action if action == "upgrade_db": perform_db_migrate() @@ -76,7 +113,7 @@ if __name__ == '__main__': elif action == 'dev': collect_static() perform_db_migrate() - runserver() + dev() else: collect_static() perform_db_migrate() diff --git a/pyproject.toml b/pyproject.toml index 62620afd3..9a27bdf16 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,6 @@ pillow = "^10.2.0" filetype = "^1.2.0" torch = "2.2.1" sentence-transformers = "^2.2.2" -blinker = "^1.6.3" openai = "^1.13.3" tiktoken = "^0.7.0" qianfan = "^0.3.6.1" @@ -55,6 +54,11 @@ boto3 = "^1.34.151" langchain-aws = "^0.1.13" tencentcloud-sdk-python = "^3.0.1205" xinference-client = "^0.14.0.post1" +psutil = "^6.0.0" +celery = { extras = ["sqlalchemy"], version = "^5.4.0" } +eventlet = "^0.36.1" +django-celery-beat = "^2.6.0" +celery-once = "^3.0.1" [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api"