From 9e80a652c413d3a103e1922ee729161781b803a6 Mon Sep 17 00:00:00 2001 From: CaptainB Date: Thu, 17 Jul 2025 13:24:59 +0800 Subject: [PATCH] refactor: replace try_lock and un_lock with RedisLock for improved locking mechanism --- apps/chat/serializers/chat_record.py | 7 +- apps/common/event/__init__.py | 7 +- apps/common/event/listener_manage.py | 7 +- apps/common/job/clean_chat_job.py | 7 +- apps/common/job/clean_debug_file_job.py | 10 +-- apps/common/job/client_access_num_job.py | 7 +- .../commands/services/services/gunicorn.py | 14 ++++ apps/common/utils/lock.py | 69 ++++++++++++------- apps/maxkb/wsgi.py | 6 +- 9 files changed, 89 insertions(+), 45 deletions(-) diff --git a/apps/chat/serializers/chat_record.py b/apps/chat/serializers/chat_record.py index a32240790..838f55d62 100644 --- a/apps/chat/serializers/chat_record.py +++ b/apps/chat/serializers/chat_record.py @@ -19,7 +19,7 @@ from application.serializers.application_chat_record import ChatRecordSerializer ApplicationChatRecordQuerySerializers from common.db.search import page_search from common.exception.app_exception import AppApiException -from common.utils.lock import try_lock, un_lock +from common.utils.lock import RedisLock class VoteRequest(serializers.Serializer): @@ -48,7 +48,8 @@ class VoteSerializer(serializers.Serializer): if with_valid: self.is_valid(raise_exception=True) VoteRequest(data=instance).is_valid(raise_exception=True) - if not try_lock(self.data.get('chat_record_id')): + rlock = RedisLock() + if not rlock.try_lock(self.data.get('chat_record_id')): raise AppApiException(500, gettext( "Voting on the current session minutes, please do not send repeated requests")) @@ -75,7 +76,7 @@ class VoteSerializer(serializers.Serializer): else: raise AppApiException(500, gettext("Already voted, please cancel first and then vote again")) finally: - un_lock(self.data.get('chat_record_id')) + rlock.un_lock(self.data.get('chat_record_id')) ChatCountSerializer(data={'chat_id': self.data.get('chat_id')}).update_chat() return True diff --git a/apps/common/event/__init__.py b/apps/common/event/__init__.py index 5af1b4ba2..59a61e9ab 100644 --- a/apps/common/event/__init__.py +++ b/apps/common/event/__init__.py @@ -12,6 +12,7 @@ from django.utils.translation import gettext as _ from .listener_manage import * from ..constants.cache_version import Cache_Version from ..db.sql_execute import update_execute +from ..utils.lock import RedisLock update_document_status_sql = """ UPDATE "public"."document" @@ -22,8 +23,8 @@ update_document_status_sql = """ def run(): from models_provider.models import Model, Status - - if try_lock('event_init', 30 * 30): + rlock = RedisLock() + if rlock.try_lock('event_init', 30 * 30): try: # 修改Model状态为ERROR QuerySet(Model).filter( @@ -36,4 +37,4 @@ def run(): version, get_key = Cache_Version.SYSTEM.value cache.delete(get_key(key='rsa_key'), version=version) finally: - un_lock('event_init') + rlock.un_lock('event_init') diff --git a/apps/common/event/listener_manage.py b/apps/common/event/listener_manage.py index f570fdfb6..040fad65a 100644 --- a/apps/common/event/listener_manage.py +++ b/apps/common/event/listener_manage.py @@ -20,7 +20,7 @@ from langchain_core.embeddings import Embeddings from common.config.embedding_config import VectorStore from common.db.search import native_search, get_dynamics_model, native_update from common.utils.common import get_file_content -from common.utils.lock import try_lock, un_lock +from common.utils.lock import RedisLock from common.utils.logger import maxkb_logger from common.utils.page_utils import page_desc from knowledge.models import Paragraph, Status, Document, ProblemParagraphMapping, TaskType, State,SourceType, SearchMode @@ -253,7 +253,8 @@ class ListenerManagement: """ if state_list is None: state_list = [State.PENDING, State.SUCCESS, State.FAILURE, State.REVOKE, State.REVOKED] - if not try_lock('embedding:' + str(document_id)): + rlock = RedisLock() + if not rlock.try_lock('embedding:' + str(document_id)): return try: def is_the_task_interrupted(): @@ -290,7 +291,7 @@ class ListenerManagement: ListenerManagement.post_update_document_status(document_id, TaskType.EMBEDDING) ListenerManagement.get_aggregation_document_status(document_id)() maxkb_logger.info(_('End--->Embedding document: {document_id}').format(document_id=document_id)) - un_lock('embedding:' + str(document_id)) + rlock.un_lock('embedding:' + str(document_id)) @staticmethod def embedding_by_knowledge(knowledge_id, embedding_model: Embeddings): diff --git a/apps/common/job/clean_chat_job.py b/apps/common/job/clean_chat_job.py index cbbf26684..c45a95bde 100644 --- a/apps/common/job/clean_chat_job.py +++ b/apps/common/job/clean_chat_job.py @@ -8,7 +8,7 @@ from django.utils import timezone from application.models import Application, Chat, ChatRecord from common.job.scheduler import scheduler -from common.utils.lock import try_lock, un_lock, lock +from common.utils.lock import lock, RedisLock from common.utils.logger import maxkb_logger from knowledge.models import File @@ -70,7 +70,8 @@ def clean_chat_log_job_lock(): def run(): - if try_lock('clean_chat_log_job', 30 * 30): + rlock = RedisLock() + if rlock.try_lock('clean_chat_log_job', 30 * 30): try: maxkb_logger.debug('get lock clean_chat_log_job') @@ -79,4 +80,4 @@ def run(): existing_job.remove() scheduler.add_job(clean_chat_log_job, 'cron', hour='0', minute='5', id='clean_chat_log') finally: - un_lock('clean_chat_log_job') + rlock.un_lock('clean_chat_log_job') diff --git a/apps/common/job/clean_debug_file_job.py b/apps/common/job/clean_debug_file_job.py index d55898e0b..9167cadba 100644 --- a/apps/common/job/clean_debug_file_job.py +++ b/apps/common/job/clean_debug_file_job.py @@ -5,7 +5,7 @@ from django.db.models import Q from django.utils import timezone from common.job.scheduler import scheduler -from common.utils.lock import un_lock, try_lock, lock +from common.utils.lock import lock, RedisLock from common.utils.logger import maxkb_logger from knowledge.models import File, FileSourceType @@ -25,12 +25,14 @@ def clean_debug_file_lock(): File.objects.filter( Q(create_time__lt=one_days_ago, source_type=FileSourceType.TEMPORARY_1_DAY.value) | Q(create_time__lt=two_hours_ago, source_type=FileSourceType.TEMPORARY_120_MINUTE.value) | - Q(create_time__lt=minutes_30_ago, source_type=FileSourceType.TEMPORARY_30_MINUTE.value)).delete() + Q(create_time__lt=minutes_30_ago, source_type=FileSourceType.TEMPORARY_30_MINUTE.value) + ).delete() maxkb_logger.debug(_('end clean debug file')) def run(): - if try_lock('clean_debug_file', 30 * 30): + rlock = RedisLock() + if rlock.try_lock('clean_debug_file', 30 * 30): try: maxkb_logger.debug('get lock clean_debug_file') @@ -39,4 +41,4 @@ def run(): clean_debug_file_job.remove() scheduler.add_job(clean_debug_file, 'cron', hour='*', minute='*/30', second='0', id='clean_debug_file') finally: - un_lock('clean_debug_file') + rlock.un_lock('clean_debug_file') diff --git a/apps/common/job/client_access_num_job.py b/apps/common/job/client_access_num_job.py index c5c4a52d6..5588c96f0 100644 --- a/apps/common/job/client_access_num_job.py +++ b/apps/common/job/client_access_num_job.py @@ -11,7 +11,7 @@ from django.db.models import QuerySet from application.models import ApplicationChatUserStats from common.job.scheduler import scheduler -from common.utils.lock import try_lock, un_lock, lock +from common.utils.lock import lock, RedisLock from common.utils.logger import maxkb_logger @@ -28,7 +28,8 @@ def client_access_num_reset_job_lock(): def run(): - if try_lock('access_num_reset', 30 * 30): + rlock = RedisLock() + if rlock.try_lock('access_num_reset', 30 * 30): try: maxkb_logger.debug('get lock access_num_reset') @@ -38,4 +39,4 @@ def run(): scheduler.add_job(client_access_num_reset_job, 'cron', hour='0', minute='0', second='0', id='access_num_reset') finally: - un_lock('access_num_reset') + rlock.un_lock('access_num_reset') diff --git a/apps/common/management/commands/services/services/gunicorn.py b/apps/common/management/commands/services/services/gunicorn.py index 4e5beed54..fa8b0d056 100644 --- a/apps/common/management/commands/services/services/gunicorn.py +++ b/apps/common/management/commands/services/services/gunicorn.py @@ -1,3 +1,5 @@ +import subprocess + from .base import BaseService from ..hands import * @@ -35,3 +37,15 @@ class GunicornService(BaseService): @property def cwd(self): return APPS_DIR + + def open_subprocess(self): + # 复制当前环境变量,并设置 ENABLE_SCHEDULER=1 + env = os.environ.copy() + env['ENABLE_SCHEDULER'] = '1' + kwargs = { + 'cwd': self.cwd, + 'stderr': self.log_file, + 'stdout': self.log_file, + 'env': env + } + self._process = subprocess.Popen(self.cmd, **kwargs) \ No newline at end of file diff --git a/apps/common/utils/lock.py b/apps/common/utils/lock.py index 2866738ba..a57b45a5e 100644 --- a/apps/common/utils/lock.py +++ b/apps/common/utils/lock.py @@ -6,32 +6,47 @@ @date:2023/9/11 11:45 @desc: """ -from datetime import timedelta +from functools import wraps +import uuid_utils.compat as uuid from django.core.cache import caches +from django_redis import get_redis_connection memory_cache = caches['default'] +class RedisLock(): + def __init__(self): + self.lock_value = None -def try_lock(key: str, timeout=None): - """ - 获取锁 - :param key: 获取锁 key - :param timeout 超时时间 - :return: 是否获取到锁 - """ - if timeout is None: - timeout = 3600 # 默认超时时间为3600秒 - return memory_cache.add(key, 'lock', timeout=timeout) + def try_lock(self, key: str, timeout=None): + """ + 获取锁 + :param key: 获取锁 key + :param timeout 超时时间 + :return: 是否获取到锁 + """ + redis_client = get_redis_connection("default") + if timeout is None: + timeout = 3600 # 默认超时时间为3600秒 + self.lock_value = str(uuid.uuid7()) + return redis_client.set(key, self.lock_value, nx=True, ex=timeout) -def un_lock(key: str): - """ - 解锁 - :param key: 解锁 key - :return: 是否解锁成功 - """ - return memory_cache.delete(key) + def un_lock(self, key: str): + """ + 解锁 + :param key: 解锁 key + :return: 是否解锁成功 + """ + redis_client = get_redis_connection("default") + unlock_script = """ + if redis.call("get", KEYS[1]) == ARGV[1] then + return redis.call("del", KEYS[1]) + else + return 0 + end + """ + redis_client.eval(unlock_script, 1, key, self.lock_value) def lock(lock_key, timeout=None): @@ -43,15 +58,19 @@ def lock(lock_key, timeout=None): """ - def inner(func): - def run(*args, **kwargs): + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): key = lock_key(*args, **kwargs) if callable(lock_key) else lock_key + rlock = RedisLock() + if not rlock.try_lock(key, timeout): + # 获取锁失败,可自定义异常或返回 + return None try: - if try_lock(key=key, timeout=timeout): - return func(*args, **kwargs) + return func(*args, **kwargs) finally: - un_lock(key=key) + rlock.un_lock(key) - return run + return wrapper - return inner + return decorator diff --git a/apps/maxkb/wsgi.py b/apps/maxkb/wsgi.py index dde8db615..9b361a669 100644 --- a/apps/maxkb/wsgi.py +++ b/apps/maxkb/wsgi.py @@ -26,4 +26,8 @@ def post_handler(): job.run() DatabaseModelManage.init() -post_handler() +# 仅在web中启动定时任务,local_model celery 不需要 +if os.environ.get('ENABLE_SCHEDULER') == '1': + post_handler() + +