feat: 分离任务

This commit is contained in:
zhangshaohu 2024-08-21 14:46:11 +08:00 committed by shaohuzhang1
parent 8ead0e2b6b
commit 7c5957e0a3
66 changed files with 2074 additions and 293 deletions

View File

@ -14,6 +14,7 @@ import uuid
from functools import reduce
from typing import Dict, List
from django.conf import settings
from django.contrib.postgres.fields import ArrayField
from django.core import cache, validators
from django.core import signing
@ -46,7 +47,6 @@ from setting.models.model_management import Model
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
from setting.serializers.provider_serializers import ModelSerializer
from smartdoc.conf import PROJECT_DIR
from django.conf import settings
chat_cache = cache.caches['chat_cache']

View File

@ -37,8 +37,10 @@ from common.util.field_message import ErrMessage
from common.util.file_util import get_file_content
from common.util.lock import try_lock, un_lock
from dataset.models import Document, Problem, Paragraph, ProblemParagraphMapping
from dataset.serializers.common_serializers import get_embedding_model_by_dataset_id
from dataset.serializers.common_serializers import get_embedding_model_by_dataset_id, \
get_embedding_model_id_by_dataset_id
from dataset.serializers.paragraph_serializers import ParagraphSerializers
from embedding.task import embedding_by_paragraph
from smartdoc.conf import PROJECT_DIR
chat_cache = caches['chat_cache']
@ -516,9 +518,9 @@ class ChatRecordSerializer(serializers.Serializer):
@staticmethod
def post_embedding_paragraph(chat_record, paragraph_id, dataset_id):
model = get_embedding_model_by_dataset_id(dataset_id)
model_id = get_embedding_model_id_by_dataset_id(dataset_id)
# 发送向量化事件
ListenerManagement.embedding_by_paragraph_signal.send(paragraph_id, embedding_model=model)
embedding_by_paragraph(paragraph_id, model_id)
return chat_record
@post(post_function=post_embedding_paragraph)

View File

View File

@ -6,10 +6,13 @@
@date2023/10/23 16:03
@desc:
"""
import threading
import time
from common.cache.mem_cache import MemCache
lock = threading.Lock()
class ModelManage:
cache = MemCache('model', {})
@ -17,15 +20,21 @@ class ModelManage:
@staticmethod
def get_model(_id, get_model):
model_instance = ModelManage.cache.get(_id)
if model_instance is None or not model_instance.is_cache_model():
model_instance = get_model(_id)
ModelManage.cache.set(_id, model_instance, timeout=60 * 30)
# 获取锁
lock.acquire()
try:
model_instance = ModelManage.cache.get(_id)
if model_instance is None or not model_instance.is_cache_model():
model_instance = get_model(_id)
ModelManage.cache.set(_id, model_instance, timeout=60 * 30)
return model_instance
# 续期
ModelManage.cache.touch(_id, timeout=60 * 30)
ModelManage.clear_timeout_cache()
return model_instance
# 续期
ModelManage.cache.touch(_id, timeout=60 * 30)
ModelManage.clear_timeout_cache()
return model_instance
finally:
# 释放锁
lock.release()
@staticmethod
def clear_timeout_cache():

View File

@ -12,7 +12,6 @@ from .listener_manage import *
def run():
listener_manage.ListenerManagement().run()
QuerySet(Document).filter(status__in=[Status.embedding, Status.queue_up]).update(**{'status': Status.error})
QuerySet(Model).filter(status=setting.models.Status.DOWNLOAD).update(status=setting.models.Status.ERROR,
meta={'message': "下载程序被中断,请重试"})

View File

@ -13,22 +13,20 @@ import traceback
from typing import List
import django.db.models
from blinker import signal
from django.db.models import QuerySet
from langchain_core.embeddings import Embeddings
from common.config.embedding_config import VectorStore
from common.db.search import native_search, get_dynamics_model
from common.event.common import poxy, embedding_poxy
from common.event.common import embedding_poxy
from common.util.file_util import get_file_content
from common.util.fork import ForkManage, Fork
from common.util.lock import try_lock, un_lock
from dataset.models import Paragraph, Status, Document, ProblemParagraphMapping
from embedding.models import SourceType
from embedding.models import SourceType, SearchMode
from smartdoc.conf import PROJECT_DIR
max_kb_error = logging.getLogger("max_kb_error")
max_kb = logging.getLogger("max_kb")
max_kb_error = logging.getLogger(__file__)
max_kb = logging.getLogger(__file__)
class SyncWebDatasetArgs:
@ -70,23 +68,6 @@ class UpdateEmbeddingDocumentIdArgs:
class ListenerManagement:
embedding_by_problem_signal = signal("embedding_by_problem")
embedding_by_paragraph_signal = signal("embedding_by_paragraph")
embedding_by_dataset_signal = signal("embedding_by_dataset")
embedding_by_document_signal = signal("embedding_by_document")
delete_embedding_by_document_signal = signal("delete_embedding_by_document")
delete_embedding_by_document_list_signal = signal("delete_embedding_by_document_list")
delete_embedding_by_dataset_signal = signal("delete_embedding_by_dataset")
delete_embedding_by_paragraph_signal = signal("delete_embedding_by_paragraph")
delete_embedding_by_source_signal = signal("delete_embedding_by_source")
enable_embedding_by_paragraph_signal = signal('enable_embedding_by_paragraph')
disable_embedding_by_paragraph_signal = signal('disable_embedding_by_paragraph')
init_embedding_model_signal = signal('init_embedding_model')
sync_web_dataset_signal = signal('sync_web_dataset')
sync_web_document_signal = signal('sync_web_document')
update_problem_signal = signal('update_problem')
delete_embedding_by_source_ids_signal = signal('delete_embedding_by_source_ids')
delete_embedding_by_dataset_id_list_signal = signal("delete_embedding_by_dataset_id_list")
@staticmethod
def embedding_by_problem(args, embedding_model: Embeddings):
@ -160,7 +141,6 @@ class ListenerManagement:
max_kb.info(f'结束--->向量化段落:{paragraph_id}')
@staticmethod
@embedding_poxy
def embedding_by_document(document_id, embedding_model: Embeddings):
"""
向量化文档
@ -227,7 +207,7 @@ class ListenerManagement:
@staticmethod
def delete_embedding_by_document_list(document_id_list: List[str]):
VectorStore.get_embedding_vector().delete_bu_document_id_list(document_id_list)
VectorStore.get_embedding_vector().delete_by_document_id_list(document_id_list)
@staticmethod
def delete_embedding_by_dataset(dataset_id):
@ -249,25 +229,6 @@ class ListenerManagement:
def enable_embedding_by_paragraph(paragraph_id):
VectorStore.get_embedding_vector().update_by_paragraph_id(paragraph_id, {'is_active': True})
@staticmethod
@poxy
def sync_web_document(args: SyncWebDocumentArgs):
for source_url in args.source_url_list:
result = Fork(base_fork_url=source_url, selector_list=args.selector.split(' ')).fork()
args.handler(source_url, args.selector, result)
@staticmethod
@poxy
def sync_web_dataset(args: SyncWebDatasetArgs):
if try_lock('sync_web_dataset' + args.lock_key):
try:
ForkManage(args.url, args.selector.split(" ") if args.selector is not None else []).fork(2, set(),
args.handler)
except Exception as e:
logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}')
finally:
un_lock('sync_web_dataset' + args.lock_key)
@staticmethod
def update_problem(args: UpdateProblemArgs):
problem_paragraph_mapping_list = QuerySet(ProblemParagraphMapping).filter(problem_id=args.problem_id)
@ -281,6 +242,9 @@ class ListenerManagement:
VectorStore.get_embedding_vector().update_by_paragraph_ids(args.paragraph_id_list,
{'dataset_id': args.target_dataset_id})
else:
# 删除向量数据
ListenerManagement.delete_embedding_by_paragraph_ids(args.paragraph_id_list)
# 向量数据
ListenerManagement.embedding_by_paragraph_list(args.paragraph_id_list,
embedding_model=args.target_embedding_model)
@ -306,38 +270,10 @@ class ListenerManagement:
def delete_embedding_by_dataset_id_list(source_ids: List[str]):
VectorStore.get_embedding_vector().delete_by_dataset_id_list(source_ids)
def run(self):
# 添加向量 根据问题id
ListenerManagement.embedding_by_problem_signal.connect(self.embedding_by_problem)
# 添加向量 根据段落id
ListenerManagement.embedding_by_paragraph_signal.connect(self.embedding_by_paragraph)
# 添加向量 根据知识库id
ListenerManagement.embedding_by_dataset_signal.connect(
self.embedding_by_dataset)
# 添加向量 根据文档id
ListenerManagement.embedding_by_document_signal.connect(
self.embedding_by_document)
# 删除 向量 根据文档
ListenerManagement.delete_embedding_by_document_signal.connect(self.delete_embedding_by_document)
# 删除 向量 根据文档id列表
ListenerManagement.delete_embedding_by_document_list_signal.connect(self.delete_embedding_by_document_list)
# 删除 向量 根据知识库id
ListenerManagement.delete_embedding_by_dataset_signal.connect(self.delete_embedding_by_dataset)
# 删除向量 根据段落id
ListenerManagement.delete_embedding_by_paragraph_signal.connect(
self.delete_embedding_by_paragraph)
# 删除向量 根据资源id
ListenerManagement.delete_embedding_by_source_signal.connect(self.delete_embedding_by_source)
# 禁用段落
ListenerManagement.disable_embedding_by_paragraph_signal.connect(self.disable_embedding_by_paragraph)
# 启动段落向量
ListenerManagement.enable_embedding_by_paragraph_signal.connect(self.enable_embedding_by_paragraph)
# 同步web站点知识库
ListenerManagement.sync_web_dataset_signal.connect(self.sync_web_dataset)
# 同步web站点 文档
ListenerManagement.sync_web_document_signal.connect(self.sync_web_document)
# 更新问题向量
ListenerManagement.update_problem_signal.connect(self.update_problem)
ListenerManagement.delete_embedding_by_source_ids_signal.connect(self.delete_embedding_by_source_ids)
ListenerManagement.delete_embedding_by_dataset_id_list_signal.connect(self.delete_embedding_by_dataset_id_list)
@staticmethod
def hit_test(query_text, dataset_id: list[str], exclude_document_id_list: list[str], top_number: int,
similarity: float,
search_mode: SearchMode,
embedding: Embeddings):
return VectorStore.get_embedding_vector().hit_test(query_text, dataset_id, exclude_document_id_list, top_number,
similarity, search_mode, embedding)

View File

@ -13,9 +13,11 @@ from django.db.models import QuerySet
from django_apscheduler.jobstores import DjangoJobStore
from application.models.api_key_model import ApplicationPublicAccessClient
from common.lock.impl.file_lock import FileLock
scheduler = BackgroundScheduler()
scheduler.add_jobstore(DjangoJobStore(), "default")
lock = FileLock()
def client_access_num_reset_job():
@ -25,9 +27,13 @@ def client_access_num_reset_job():
def run():
scheduler.start()
access_num_reset = scheduler.get_job(job_id='access_num_reset')
if access_num_reset is not None:
access_num_reset.remove()
scheduler.add_job(client_access_num_reset_job, 'cron', hour='0', minute='0', second='0',
id='access_num_reset')
if lock.try_lock('client_access_num_reset_job', 30 * 30):
try:
scheduler.start()
access_num_reset = scheduler.get_job(job_id='access_num_reset')
if access_num_reset is not None:
access_num_reset.remove()
scheduler.add_job(client_access_num_reset_job, 'cron', hour='0', minute='0', second='0',
id='access_num_reset')
finally:
lock.un_lock('client_access_num_reset_job')

View File

@ -0,0 +1,20 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file base_lock.py
@date2024/8/20 10:33
@desc:
"""
from abc import ABC, abstractmethod
class BaseLock(ABC):
@abstractmethod
def try_lock(self, key, timeout):
pass
@abstractmethod
def un_lock(self, key):
pass

View File

@ -0,0 +1,77 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file file_lock.py
@date2024/8/20 10:48
@desc:
"""
import errno
import hashlib
import os
import time
import six
from common.lock.base_lock import BaseLock
from smartdoc.const import PROJECT_DIR
def key_to_lock_name(key):
"""
Combine part of a key with its hash to prevent very long filenames
"""
MAX_LENGTH = 50
key_hash = hashlib.md5(six.b(key)).hexdigest()
lock_name = key[:MAX_LENGTH - len(key_hash) - 1] + '_' + key_hash
return lock_name
class FileLock(BaseLock):
"""
File locking backend.
"""
def __init__(self, settings=None):
if settings is None:
settings = {}
self.location = settings.get('location')
if self.location is None:
self.location = os.path.join(PROJECT_DIR, 'data', 'lock')
try:
os.makedirs(self.location)
except OSError as error:
# Directory exists?
if error.errno != errno.EEXIST:
# Re-raise unexpected OSError
raise
def _get_lock_path(self, key):
lock_name = key_to_lock_name(key)
return os.path.join(self.location, lock_name)
def try_lock(self, key, timeout):
lock_path = self._get_lock_path(key)
try:
# 创建锁文件,如果没创建成功则拿不到
fd = os.open(lock_path, os.O_CREAT | os.O_EXCL)
except OSError as error:
if error.errno == errno.EEXIST:
# File already exists, check its modification time
mtime = os.path.getmtime(lock_path)
ttl = mtime + timeout - time.time()
if ttl > 0:
return False
else:
# 如果超时时间已到,直接上锁成功继续执行
os.utime(lock_path, None)
return True
else:
return False
else:
os.close(fd)
return True
def un_lock(self, key):
lock_path = self._get_lock_path(key)
os.remove(lock_path)

View File

View File

@ -0,0 +1,44 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file celery.py
@date2024/8/19 11:57
@desc:
"""
import os
import subprocess
from django.core.management.base import BaseCommand
from smartdoc.const import BASE_DIR
class Command(BaseCommand):
help = 'celery'
def add_arguments(self, parser):
parser.add_argument(
'service', nargs='+', type=str, choices=("celery", "model"), help='Service',
)
def handle(self, *args, **options):
service = options.get('service')
os.environ.setdefault('CELERY_NAME', ','.join(service))
server_hostname = os.environ.get("SERVER_HOSTNAME")
if not server_hostname:
server_hostname = '%h'
cmd = [
'celery',
'-A', 'ops',
'worker',
'-P', 'threads',
'-l', 'info',
'-c', '10',
'-Q', ','.join(service),
'--heartbeat-interval', '10',
'-n', f'{",".join(service)}@{server_hostname}',
'--without-mingle',
]
kwargs = {'cwd': BASE_DIR}
subprocess.run(cmd, **kwargs)

View File

@ -1,47 +0,0 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file gunicorn.py
@date2024/7/19 17:43
@desc:
"""
import subprocess
from django.core.management.base import BaseCommand
from smartdoc.const import BASE_DIR
class Command(BaseCommand):
help = 'My custom command'
# 参数设定
def add_arguments(self, parser):
parser.add_argument('-b', nargs='+', type=str, help="端口:0.0.0.0:8080") # 0.0.0.0:8080
parser.add_argument('-k', nargs='?', type=str,
help="workers处理器:gevent") # uvicorn.workers.UvicornWorker
parser.add_argument('-w', type=str, help='worker 数量') # 进程数量
parser.add_argument('--threads', type=str, help='线程数量') # 线程数量
parser.add_argument('--worker-connections', type=str, help="每个线程的协程数量") # 10240
parser.add_argument('--max-requests', type=str, help="最大请求") # 10240
parser.add_argument('--max-requests-jitter', type=str)
parser.add_argument('--access-logformat', type=str) # %(h)s %(t)s %(L)ss "%(r)s" %(s)s %(b)s
def handle(self, *args, **options):
log_format = '%(h)s %(t)s %(L)ss "%(r)s" %(s)s %(b)s '
cmd = [
'gunicorn', 'smartdoc.wsgi:application',
'-b', options.get('b') if options.get('b') is not None else '0.0.0.0:8080',
'-k', options.get('k') if options.get('k') is not None else 'gthread',
'--threads', options.get('threads') if options.get('threads') is not None else '200',
'-w', options.get('w') if options.get('w') is not None else '1',
'--max-requests', options.get('max_requests') if options.get('max_requests') is not None else '10240',
'--max-requests-jitter',
options.get('max_requests_jitter') if options.get('max_requests_jitter') is not None else '2048',
'--access-logformat',
options.get('access_logformat') if options.get('access_logformat') is not None else log_format,
'--access-logfile', '-'
]
kwargs = {'cwd': BASE_DIR}
subprocess.run(cmd, **kwargs)

View File

@ -0,0 +1,6 @@
from .services.command import BaseActionCommand, Action
class Command(BaseActionCommand):
help = 'Restart services'
action = Action.restart.value

View File

@ -0,0 +1,130 @@
from django.core.management.base import BaseCommand
from django.db.models import TextChoices
from .hands import *
from .utils import ServicesUtil
class Services(TextChoices):
gunicorn = 'gunicorn', 'gunicorn'
celery_default = 'celery_default', 'celery_default'
local_model = 'local_model', 'local_model'
web = 'web', 'web'
celery = 'celery', 'celery'
celery_model = 'celery_model', 'celery_model'
task = 'task', 'task'
all = 'all', 'all'
@classmethod
def get_service_object_class(cls, name):
from . import services
services_map = {
cls.gunicorn.value: services.GunicornService,
cls.celery_default: services.CeleryDefaultService,
cls.local_model: services.GunicornLocalModelService
}
return services_map.get(name)
@classmethod
def web_services(cls):
return [cls.gunicorn, cls.local_model]
@classmethod
def celery_services(cls):
return [cls.celery_default, cls.celery_model]
@classmethod
def task_services(cls):
return cls.celery_services()
@classmethod
def all_services(cls):
return cls.web_services() + cls.task_services()
@classmethod
def export_services_values(cls):
return [cls.all.value, cls.web.value, cls.task.value] + [s.value for s in cls.all_services()]
@classmethod
def get_service_objects(cls, service_names, **kwargs):
services = set()
for name in service_names:
method_name = f'{name}_services'
if hasattr(cls, method_name):
_services = getattr(cls, method_name)()
elif hasattr(cls, name):
_services = [getattr(cls, name)]
else:
continue
services.update(set(_services))
service_objects = []
for s in services:
service_class = cls.get_service_object_class(s.value)
if not service_class:
continue
kwargs.update({
'name': s.value
})
service_object = service_class(**kwargs)
service_objects.append(service_object)
return service_objects
class Action(TextChoices):
start = 'start', 'start'
status = 'status', 'status'
stop = 'stop', 'stop'
restart = 'restart', 'restart'
class BaseActionCommand(BaseCommand):
help = 'Service Base Command'
action = None
util = None
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def add_arguments(self, parser):
parser.add_argument(
'services', nargs='+', choices=Services.export_services_values(), help='Service',
)
parser.add_argument('-d', '--daemon', nargs="?", const=True)
parser.add_argument('-w', '--worker', type=int, nargs="?", default=4)
parser.add_argument('-f', '--force', nargs="?", const=True)
def initial_util(self, *args, **options):
service_names = options.get('services')
service_kwargs = {
'worker_gunicorn': options.get('worker')
}
services = Services.get_service_objects(service_names=service_names, **service_kwargs)
kwargs = {
'services': services,
'run_daemon': options.get('daemon', False),
'stop_daemon': self.action == Action.stop.value and Services.all.value in service_names,
'force_stop': options.get('force') or False,
}
self.util = ServicesUtil(**kwargs)
def handle(self, *args, **options):
self.initial_util(*args, **options)
assert self.action in Action.values, f'The action {self.action} is not in the optional list'
_handle = getattr(self, f'_handle_{self.action}', lambda: None)
_handle()
def _handle_start(self):
self.util.start_and_watch()
os._exit(0)
def _handle_stop(self):
self.util.stop()
def _handle_restart(self):
self.util.restart()
def _handle_status(self):
self.util.show_status()

View File

@ -0,0 +1,28 @@
import logging
import os
import sys
from django.conf import settings
from smartdoc.const import CONFIG, PROJECT_DIR
try:
from apps.smartdoc import const
__version__ = const.VERSION
except ImportError as e:
print("Not found __version__: {}".format(e))
print("Python is: ")
logging.info(sys.executable)
__version__ = 'Unknown'
sys.exit(1)
HTTP_HOST = CONFIG.HTTP_BIND_HOST or '127.0.0.1'
HTTP_PORT = CONFIG.HTTP_LISTEN_PORT or 8080
DEBUG = CONFIG.DEBUG or False
LOG_DIR = os.path.join(PROJECT_DIR, 'data', 'logs')
APPS_DIR = os.path.join(PROJECT_DIR, 'apps')
TMP_DIR = os.path.join(PROJECT_DIR, 'tmp')
if not os.path.exists(TMP_DIR):
os.makedirs(TMP_DIR)

View File

@ -0,0 +1,3 @@
from .celery_default import *
from .gunicorn import *
from .local_model import *

View File

@ -0,0 +1,207 @@
import abc
import time
import shutil
import psutil
import datetime
import threading
import subprocess
from ..hands import *
class BaseService(object):
def __init__(self, **kwargs):
self.name = kwargs['name']
self._process = None
self.STOP_TIMEOUT = 10
self.max_retry = 0
self.retry = 3
self.LOG_KEEP_DAYS = 7
self.EXIT_EVENT = threading.Event()
@property
@abc.abstractmethod
def cmd(self):
return []
@property
@abc.abstractmethod
def cwd(self):
return ''
@property
def is_running(self):
if self.pid == 0:
return False
try:
os.kill(self.pid, 0)
except (OSError, ProcessLookupError):
return False
else:
return True
def show_status(self):
if self.is_running:
msg = f'{self.name} is running: {self.pid}.'
else:
msg = f'{self.name} is stopped.'
if DEBUG:
msg = '\033[31m{} is stopped.\033[0m\nYou can manual start it to find the error: \n' \
' $ cd {}\n' \
' $ {}'.format(self.name, self.cwd, ' '.join(self.cmd))
print(msg)
# -- log --
@property
def log_filename(self):
return f'{self.name}.log'
@property
def log_filepath(self):
return os.path.join(LOG_DIR, self.log_filename)
@property
def log_file(self):
return open(self.log_filepath, 'a')
@property
def log_dir(self):
return os.path.dirname(self.log_filepath)
# -- end log --
# -- pid --
@property
def pid_filepath(self):
return os.path.join(TMP_DIR, f'{self.name}.pid')
@property
def pid(self):
if not os.path.isfile(self.pid_filepath):
return 0
with open(self.pid_filepath) as f:
try:
pid = int(f.read().strip())
except ValueError:
pid = 0
return pid
def write_pid(self):
with open(self.pid_filepath, 'w') as f:
f.write(str(self.process.pid))
def remove_pid(self):
if os.path.isfile(self.pid_filepath):
os.unlink(self.pid_filepath)
# -- end pid --
# -- process --
@property
def process(self):
if not self._process:
try:
self._process = psutil.Process(self.pid)
except:
pass
return self._process
# -- end process --
# -- action --
def open_subprocess(self):
kwargs = {'cwd': self.cwd, 'stderr': self.log_file, 'stdout': self.log_file}
self._process = subprocess.Popen(self.cmd, **kwargs)
def start(self):
if self.is_running:
self.show_status()
return
self.remove_pid()
self.open_subprocess()
self.write_pid()
self.start_other()
def start_other(self):
pass
def stop(self, force=False):
if not self.is_running:
self.show_status()
# self.remove_pid()
return
print(f'Stop service: {self.name}', end='')
sig = 9 if force else 15
os.kill(self.pid, sig)
if self.process is None:
print("\033[31m No process found\033[0m")
return
try:
self.process.wait(1)
except:
pass
for i in range(self.STOP_TIMEOUT):
if i == self.STOP_TIMEOUT - 1:
print("\033[31m Error\033[0m")
if not self.is_running:
print("\033[32m Ok\033[0m")
self.remove_pid()
break
else:
continue
def watch(self):
self._check()
if not self.is_running:
self._restart()
self._rotate_log()
def _check(self):
now = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
print(f"{now} Check service status: {self.name} -> ", end='')
if self.process:
try:
self.process.wait(1) # 不wait子进程可能无法回收
except:
pass
if self.is_running:
print(f'running at {self.pid}')
else:
print(f'stopped at {self.pid}')
def _restart(self):
if self.retry > self.max_retry:
logging.info("Service start failed, exit: {}".format(self.name))
self.EXIT_EVENT.set()
return
self.retry += 1
logging.info(f'> Find {self.name} stopped, retry {self.retry}, {self.pid}')
self.start()
def _rotate_log(self):
now = datetime.datetime.now()
_time = now.strftime('%H:%M')
if _time != '23:59':
return
backup_date = now.strftime('%Y-%m-%d')
backup_log_dir = os.path.join(self.log_dir, backup_date)
if not os.path.exists(backup_log_dir):
os.mkdir(backup_log_dir)
backup_log_path = os.path.join(backup_log_dir, self.log_filename)
if os.path.isfile(self.log_filepath) and not os.path.isfile(backup_log_path):
logging.info(f'Rotate log file: {self.log_filepath} => {backup_log_path}')
shutil.copy(self.log_filepath, backup_log_path)
with open(self.log_filepath, 'w') as f:
pass
to_delete_date = now - datetime.timedelta(days=self.LOG_KEEP_DAYS)
to_delete_dir = os.path.join(LOG_DIR, to_delete_date.strftime('%Y-%m-%d'))
if os.path.exists(to_delete_dir):
logging.info(f'Remove old log: {to_delete_dir}')
shutil.rmtree(to_delete_dir, ignore_errors=True)
# -- end action --

View File

@ -0,0 +1,43 @@
from .base import BaseService
from ..hands import *
class CeleryBaseService(BaseService):
def __init__(self, queue, num=10, **kwargs):
super().__init__(**kwargs)
self.queue = queue
self.num = num
@property
def cmd(self):
print('\n- Start Celery as Distributed Task Queue: {}'.format(self.queue.capitalize()))
os.environ.setdefault('LC_ALL', 'C.UTF-8')
os.environ.setdefault('PYTHONOPTIMIZE', '1')
os.environ.setdefault('ANSIBLE_FORCE_COLOR', 'True')
os.environ.setdefault('PYTHONPATH', settings.APPS_DIR)
if os.getuid() == 0:
os.environ.setdefault('C_FORCE_ROOT', '1')
server_hostname = os.environ.get("SERVER_HOSTNAME")
if not server_hostname:
server_hostname = '%h'
cmd = [
'celery',
'-A', 'ops',
'worker',
'-P', 'threads',
'-l', 'error',
'-c', str(self.num),
'-Q', self.queue,
'--heartbeat-interval', '10',
'-n', f'{self.queue}@{server_hostname}',
'--without-mingle',
]
return cmd
@property
def cwd(self):
return APPS_DIR

View File

@ -0,0 +1,10 @@
from .celery_base import CeleryBaseService
__all__ = ['CeleryDefaultService']
class CeleryDefaultService(CeleryBaseService):
def __init__(self, **kwargs):
kwargs['queue'] = 'celery'
super().__init__(**kwargs)

View File

@ -0,0 +1,36 @@
from .base import BaseService
from ..hands import *
__all__ = ['GunicornService']
class GunicornService(BaseService):
def __init__(self, **kwargs):
self.worker = kwargs['worker_gunicorn']
super().__init__(**kwargs)
@property
def cmd(self):
print("\n- Start Gunicorn WSGI HTTP Server")
log_format = '%(h)s %(t)s %(L)ss "%(r)s" %(s)s %(b)s '
bind = f'{HTTP_HOST}:{HTTP_PORT}'
cmd = [
'gunicorn', 'smartdoc.wsgi:application',
'-b', bind,
'-k', 'gthread',
'--threads', '200',
'-w', str(self.worker),
'--max-requests', '10240',
'--max-requests-jitter', '2048',
'--access-logformat', log_format,
'--access-logfile', '-'
]
if DEBUG:
cmd.append('--reload')
return cmd
@property
def cwd(self):
return APPS_DIR

View File

@ -0,0 +1,44 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file local_model.py
@date2024/8/21 13:28
@desc:
"""
from .base import BaseService
from ..hands import *
__all__ = ['GunicornLocalModelService']
class GunicornLocalModelService(BaseService):
def __init__(self, **kwargs):
self.worker = kwargs['worker_gunicorn']
super().__init__(**kwargs)
@property
def cmd(self):
print("\n- Start Gunicorn Local Model WSGI HTTP Server")
os.environ.setdefault('SERVER_NAME', 'local_model')
log_format = '%(h)s %(t)s %(L)ss "%(r)s" %(s)s %(b)s '
bind = f'127.0.0.1:5432'
cmd = [
'gunicorn', 'smartdoc.wsgi:application',
'-b', bind,
'-k', 'gthread',
'--threads', '200',
'-w', "1",
'--max-requests', '10240',
'--max-requests-jitter', '2048',
'--access-logformat', log_format,
'--access-logfile', '-'
]
if DEBUG:
cmd.append('--reload')
return cmd
@property
def cwd(self):
return APPS_DIR

View File

@ -0,0 +1,140 @@
import threading
import signal
import time
import daemon
from daemon import pidfile
from .hands import *
from .hands import __version__
from .services.base import BaseService
class ServicesUtil(object):
def __init__(self, services, run_daemon=False, force_stop=False, stop_daemon=False):
self._services = services
self.run_daemon = run_daemon
self.force_stop = force_stop
self.stop_daemon = stop_daemon
self.EXIT_EVENT = threading.Event()
self.check_interval = 30
self.files_preserve_map = {}
def restart(self):
self.stop()
time.sleep(5)
self.start_and_watch()
def start_and_watch(self):
logging.info(time.ctime())
logging.info(f'MaxKB version {__version__}, more see https://www.jumpserver.org')
self.start()
if self.run_daemon:
self.show_status()
with self.daemon_context:
self.watch()
else:
self.watch()
def start(self):
for service in self._services:
service: BaseService
service.start()
self.files_preserve_map[service.name] = service.log_file
time.sleep(1)
def stop(self):
for service in self._services:
service: BaseService
service.stop(force=self.force_stop)
if self.stop_daemon:
self._stop_daemon()
# -- watch --
def watch(self):
while not self.EXIT_EVENT.is_set():
try:
_exit = self._watch()
if _exit:
break
time.sleep(self.check_interval)
except KeyboardInterrupt:
print('Start stop services')
break
self.clean_up()
def _watch(self):
for service in self._services:
service: BaseService
service.watch()
if service.EXIT_EVENT.is_set():
self.EXIT_EVENT.set()
return True
return False
# -- end watch --
def clean_up(self):
if not self.EXIT_EVENT.is_set():
self.EXIT_EVENT.set()
self.stop()
def show_status(self):
for service in self._services:
service: BaseService
service.show_status()
# -- daemon --
def _stop_daemon(self):
if self.daemon_pid and self.daemon_is_running:
os.kill(self.daemon_pid, 15)
self.remove_daemon_pid()
def remove_daemon_pid(self):
if os.path.isfile(self.daemon_pid_filepath):
os.unlink(self.daemon_pid_filepath)
@property
def daemon_pid(self):
if not os.path.isfile(self.daemon_pid_filepath):
return 0
with open(self.daemon_pid_filepath) as f:
try:
pid = int(f.read().strip())
except ValueError:
pid = 0
return pid
@property
def daemon_is_running(self):
try:
os.kill(self.daemon_pid, 0)
except (OSError, ProcessLookupError):
return False
else:
return True
@property
def daemon_pid_filepath(self):
return os.path.join(TMP_DIR, 'mk.pid')
@property
def daemon_log_filepath(self):
return os.path.join(LOG_DIR, 'mk.log')
@property
def daemon_context(self):
daemon_log_file = open(self.daemon_log_filepath, 'a')
context = daemon.DaemonContext(
pidfile=pidfile.TimeoutPIDLockFile(self.daemon_pid_filepath),
signal_map={
signal.SIGTERM: lambda x, y: self.clean_up(),
signal.SIGHUP: 'terminate',
},
stdout=daemon_log_file,
stderr=daemon_log_file,
files_preserve=list(self.files_preserve_map.values()),
detach_process=True,
)
return context
# -- end daemon --

View File

@ -0,0 +1,6 @@
from .services.command import BaseActionCommand, Action
class Command(BaseActionCommand):
help = 'Start services'
action = Action.start.value

View File

@ -0,0 +1,6 @@
from .services.command import BaseActionCommand, Action
class Command(BaseActionCommand):
help = 'Show services status'
action = Action.status.value

View File

@ -0,0 +1,6 @@
from .services.command import BaseActionCommand, Action
class Command(BaseActionCommand):
help = 'Stop services'
action = Action.stop.value

View File

View File

@ -18,6 +18,7 @@ TOKEN_KEY = 'solomon_world_token'
TOKEN_SALT = 'solomonwanc@gmail.com'
TIME_OUT = 30 * 60
# 加密
def encrypt(obj):
value = signing.dumps(obj, key=TOKEN_KEY, salt=TOKEN_SALT)
@ -29,7 +30,6 @@ def encrypt(obj):
def decrypt(src):
src = signing.b64_decode(src.encode()).decode()
raw = signing.loads(src, key=TOKEN_KEY, salt=TOKEN_SALT)
print(type(raw))
return raw
@ -74,5 +74,3 @@ def check_token(token):
if last_token:
return last_token == token
return False

View File

@ -151,3 +151,17 @@ def get_embedding_model_by_dataset_id(dataset_id: str):
def get_embedding_model_by_dataset(dataset):
return ModelManage.get_model(str(dataset.embedding_mode_id), lambda _id: get_model(dataset.embedding_mode))
def get_embedding_model_id_by_dataset_id(dataset_id):
dataset = QuerySet(DataSet).select_related('embedding_mode').filter(id=dataset_id).first()
return str(dataset.embedding_mode_id)
def get_embedding_model_id_by_dataset_id_list(dataset_id_list: List):
dataset_list = QuerySet(DataSet).filter(id__in=dataset_id_list)
if len(set([dataset.embedding_mode_id for dataset in dataset_list])) > 1:
raise Exception("知识库未向量模型不一致")
if len(dataset_list) == 0:
raise Exception("知识库设置错误,请重新设置知识库")
return str(dataset_list[0].embedding_mode_id)

View File

@ -27,7 +27,6 @@ from application.models import ApplicationDatasetMapping
from common.config.embedding_config import VectorStore
from common.db.search import get_dynamics_model, native_page_search, native_search
from common.db.sql_execute import select_list
from common.event import ListenerManagement, SyncWebDatasetArgs
from common.exception.app_exception import AppApiException
from common.mixins.api_mixin import ApiMixin
from common.util.common import post, flat_map, valid_license
@ -37,9 +36,11 @@ from common.util.fork import ChildLink, Fork
from common.util.split_model import get_split_model
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, ProblemParagraphMapping, Status
from dataset.serializers.common_serializers import list_paragraph, MetaSerializer, ProblemParagraphManage, \
get_embedding_model_by_dataset_id
get_embedding_model_by_dataset_id, get_embedding_model_id_by_dataset_id
from dataset.serializers.document_serializers import DocumentSerializers, DocumentInstanceSerializer
from dataset.task import sync_web_dataset
from embedding.models import SearchMode
from embedding.task import embedding_by_dataset, delete_embedding_by_dataset
from setting.models import AuthOperate
from smartdoc.conf import PROJECT_DIR
@ -363,9 +364,9 @@ class DataSetSerializers(serializers.ModelSerializer):
@staticmethod
def post_embedding_dataset(document_list, dataset_id):
model = get_embedding_model_by_dataset_id(dataset_id)
model_id = get_embedding_model_id_by_dataset_id(dataset_id)
# 发送向量化事件
ListenerManagement.embedding_by_dataset_signal.send(dataset_id, embedding_model=model)
embedding_by_dataset.delay(dataset_id, model_id)
return document_list
def save_qa(self, instance: Dict, with_valid=True):
@ -435,23 +436,6 @@ class DataSetSerializers(serializers.ModelSerializer):
else:
return parsed_url.path.split("/")[-1]
@staticmethod
def get_save_handler(dataset_id, selector):
def handler(child_link: ChildLink, response: Fork.Response):
if response.status == 200:
try:
document_name = child_link.tag.text if child_link.tag is not None and len(
child_link.tag.text.strip()) > 0 else child_link.url
paragraphs = get_split_model('web.md').parse(response.content)
DocumentSerializers.Create(data={'dataset_id': dataset_id}).save(
{'name': document_name, 'paragraphs': paragraphs,
'meta': {'source_url': child_link.url, 'selector': selector},
'type': Type.web}, with_valid=True)
except Exception as e:
logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}')
return handler
def save_web(self, instance: Dict, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
@ -467,9 +451,7 @@ class DataSetSerializers(serializers.ModelSerializer):
'meta': {'source_url': instance.get('source_url'), 'selector': instance.get('selector'),
'embedding_mode_id': instance.get('embedding_mode_id')}})
dataset.save()
ListenerManagement.sync_web_dataset_signal.send(
SyncWebDatasetArgs(str(dataset_id), instance.get('source_url'), instance.get('selector'),
self.get_save_handler(dataset_id, instance.get('selector'))))
sync_web_dataset.delay((str(dataset_id), instance.get('source_url'), instance.get('selector')))
return {**DataSetSerializers(dataset).data,
'document_list': []}
@ -642,9 +624,7 @@ class DataSetSerializers(serializers.ModelSerializer):
"""
url = dataset.meta.get('source_url')
selector = dataset.meta.get('selector') if 'selector' in dataset.meta else None
ListenerManagement.sync_web_dataset_signal.send(
SyncWebDatasetArgs(str(dataset.id), url, selector,
self.get_sync_handler(dataset)))
sync_web_dataset.delay(str(dataset.id), url, selector)
def complete_sync(self, dataset):
"""
@ -658,7 +638,7 @@ class DataSetSerializers(serializers.ModelSerializer):
# 删除段落
QuerySet(Paragraph).filter(dataset=dataset).delete()
# 删除向量
ListenerManagement.delete_embedding_by_dataset_signal.send(self.data.get('id'))
delete_embedding_by_dataset(self.data.get('id'))
# 同步
self.replace_sync(dataset)
@ -740,16 +720,17 @@ class DataSetSerializers(serializers.ModelSerializer):
QuerySet(Paragraph).filter(dataset=dataset).delete()
QuerySet(Problem).filter(dataset=dataset).delete()
dataset.delete()
ListenerManagement.delete_embedding_by_dataset_signal.send(self.data.get('id'))
delete_embedding_by_dataset(self.data.get('id'))
return True
def re_embedding(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
model = get_embedding_model_by_dataset_id(self.data.get('id'))
QuerySet(Document).filter(dataset_id=self.data.get('id')).update(**{'status': Status.queue_up})
QuerySet(Paragraph).filter(dataset_id=self.data.get('id')).update(**{'status': Status.queue_up})
ListenerManagement.embedding_by_dataset_signal.send(self.data.get('id'), embedding_model=model)
embedding_model_id = get_embedding_model_id_by_dataset_id(self.data.get('id'))
embedding_by_dataset.delay(self.data.get('id'), embedding_model_id)
def list_application(self, with_valid=True):
if with_valid:

View File

@ -15,6 +15,7 @@ from functools import reduce
from typing import List, Dict
import xlwt
from celery_once import AlreadyQueued
from django.core import validators
from django.db import transaction
from django.db.models import QuerySet
@ -25,7 +26,6 @@ from xlwt import Utils
from common.db.search import native_search, native_page_search
from common.event.common import work_thread_pool
from common.event.listener_manage import ListenerManagement, SyncWebDocumentArgs, UpdateEmbeddingDatasetIdArgs
from common.exception.app_exception import AppApiException
from common.handle.impl.doc_split_handle import DocSplitHandle
from common.handle.impl.html_split_handle import HTMLSplitHandle
@ -42,8 +42,11 @@ from common.util.fork import Fork
from common.util.split_model import get_split_model
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, Status, ProblemParagraphMapping, Image
from dataset.serializers.common_serializers import BatchSerializer, MetaSerializer, ProblemParagraphManage, \
get_embedding_model_by_dataset_id
get_embedding_model_by_dataset_id, get_embedding_model_id_by_dataset_id
from dataset.serializers.paragraph_serializers import ParagraphSerializers, ParagraphInstanceSerializer
from dataset.task import sync_web_document
from embedding.task.embedding import embedding_by_document, delete_embedding_by_document_list, \
delete_embedding_by_document, update_embedding_dataset_id
from smartdoc.conf import PROJECT_DIR
parse_qa_handle_list = [XlsParseQAHandle(), CsvParseQAHandle(), XlsxParseQAHandle()]
@ -235,17 +238,15 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
meta={})
else:
document_list.update(dataset_id=target_dataset_id)
model = None
model_id = None
if dataset.embedding_mode_id != target_dataset.embedding_mode_id:
model = get_embedding_model_by_dataset_id(target_dataset_id)
model_id = get_embedding_model_by_dataset_id(target_dataset_id)
pid_list = [paragraph.id for paragraph in paragraph_list]
# 修改段落信息
paragraph_list.update(dataset_id=target_dataset_id)
# 修改向量信息
ListenerManagement.update_embedding_dataset_id(UpdateEmbeddingDatasetIdArgs(
pid_list,
target_dataset_id, model))
update_embedding_dataset_id(pid_list, target_dataset_id, model_id)
@staticmethod
def get_target_dataset_problem(target_dataset_id: str,
@ -377,7 +378,7 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
# 删除问题
QuerySet(model=ProblemParagraphMapping).filter(document_id=document_id).delete()
# 删除向量库
ListenerManagement.delete_embedding_by_document_signal.send(document_id)
delete_embedding_by_document(document_id)
paragraphs = get_split_model('web.md').parse(result.content)
document.char_length = reduce(lambda x, y: x + y,
[len(p.get('content')) for p in paragraphs],
@ -398,8 +399,8 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
problem_paragraph_mapping_list) > 0 else None
# 向量化
if with_embedding:
model = get_embedding_model_by_dataset_id(dataset_id=document.dataset_id)
ListenerManagement.embedding_by_document_signal.send(document_id, embedding_model=model)
embedding_model_id = get_embedding_model_id_by_dataset_id(document.dataset_id)
embedding_by_document.delay(document_id, embedding_model_id)
else:
document.status = Status.error
document.save()
@ -538,10 +539,13 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
if with_valid:
self.is_valid(raise_exception=True)
document_id = self.data.get("document_id")
model = get_embedding_model_by_dataset_id(dataset_id=self.data.get('dataset_id'))
QuerySet(Document).filter(id=document_id).update(**{'status': Status.queue_up})
QuerySet(Paragraph).filter(document_id=document_id).update(**{'status': Status.queue_up})
ListenerManagement.embedding_by_document_signal.send(document_id, embedding_model=model)
embedding_model_id = get_embedding_model_id_by_dataset_id(dataset_id=self.data.get('dataset_id'))
try:
embedding_by_document.delay(document_id, embedding_model_id)
except AlreadyQueued as e:
raise AppApiException(500, "任务正在执行中,请勿重复下发")
@transaction.atomic
def delete(self):
@ -552,7 +556,7 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
# 删除问题
QuerySet(model=ProblemParagraphMapping).filter(document_id=document_id).delete()
# 删除向量库
ListenerManagement.delete_embedding_by_document_signal.send(document_id)
delete_embedding_by_document(document_id)
return True
@staticmethod
@ -611,8 +615,8 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
@staticmethod
def post_embedding(result, document_id, dataset_id):
model = get_embedding_model_by_dataset_id(dataset_id)
ListenerManagement.embedding_by_document_signal.send(document_id, embedding_model=model)
model_id = get_embedding_model_id_by_dataset_id(dataset_id)
embedding_by_document.delay(document_id, model_id)
return result
@staticmethod
@ -660,29 +664,6 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
data={'dataset_id': dataset_id, 'document_id': document_id}).one(
with_valid=True), document_id, dataset_id
@staticmethod
def get_sync_handler(dataset_id):
def handler(source_url: str, selector, response: Fork.Response):
if response.status == 200:
try:
paragraphs = get_split_model('web.md').parse(response.content)
# 插入
DocumentSerializers.Create(data={'dataset_id': dataset_id}).save(
{'name': source_url[0:128], 'paragraphs': paragraphs,
'meta': {'source_url': source_url, 'selector': selector},
'type': Type.web}, with_valid=True)
except Exception as e:
logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}')
else:
Document(name=source_url[0:128],
dataset_id=dataset_id,
meta={'source_url': source_url, 'selector': selector},
type=Type.web,
char_length=0,
status=Status.error).save()
return handler
def save_web(self, instance: Dict, with_valid=True):
if with_valid:
DocumentWebInstanceSerializer(data=instance).is_valid(raise_exception=True)
@ -690,8 +671,7 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
dataset_id = self.data.get('dataset_id')
source_url_list = instance.get('source_url_list')
selector = instance.get('selector')
args = SyncWebDocumentArgs(source_url_list, selector, self.get_sync_handler(dataset_id))
ListenerManagement.sync_web_document_signal.send(args)
sync_web_document.delay(dataset_id, source_url_list, selector)
@staticmethod
def get_paragraph_model(document_model, paragraph_list: List):
@ -818,8 +798,8 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
@staticmethod
def post_embedding(document_list, dataset_id):
for document_dict in document_list:
model = get_embedding_model_by_dataset_id(dataset_id)
ListenerManagement.embedding_by_document_signal.send(document_dict.get('id'), embedding_model=model)
model_id = get_embedding_model_id_by_dataset_id(dataset_id)
embedding_by_document(document_dict.get('id'), model_id)
return document_list
@post(post_function=post_embedding)
@ -887,7 +867,7 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
QuerySet(Paragraph).filter(document_id__in=document_id_list).delete()
QuerySet(ProblemParagraphMapping).filter(document_id__in=document_id_list).delete()
# 删除向量库
ListenerManagement.delete_embedding_by_document_list_signal.send(document_id_list)
delete_embedding_by_document_list(document_id_list)
return True
def batch_edit_hit_handling(self, instance: Dict, with_valid=True):

View File

@ -15,16 +15,18 @@ from drf_yasg import openapi
from rest_framework import serializers
from common.db.search import page_search
from common.event.listener_manage import ListenerManagement, UpdateEmbeddingDocumentIdArgs
from common.exception.app_exception import AppApiException
from common.mixins.api_mixin import ApiMixin
from common.util.common import post
from common.util.field_message import ErrMessage
from dataset.models import Paragraph, Problem, Document, ProblemParagraphMapping, DataSet
from dataset.serializers.common_serializers import update_document_char_length, BatchSerializer, ProblemParagraphObject, \
ProblemParagraphManage, get_embedding_model_by_dataset_id, get_embedding_model_by_dataset
ProblemParagraphManage, get_embedding_model_id_by_dataset_id
from dataset.serializers.problem_serializers import ProblemInstanceSerializer, ProblemSerializer, ProblemSerializers
from embedding.models import SourceType
from embedding.task.embedding import embedding_by_problem as embedding_by_problem_task, embedding_by_problem, \
delete_embedding_by_source, enable_embedding_by_paragraph, disable_embedding_by_paragraph, embedding_by_paragraph, \
delete_embedding_by_paragraph, delete_embedding_by_paragraph_ids, update_embedding_document_id
class ParagraphSerializer(serializers.ModelSerializer):
@ -113,7 +115,7 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
QuerySet(Problem).filter(id__in=[row.problem_id for row in problem_paragraph_mapping])]
@transaction.atomic
def save(self, instance: Dict, with_valid=True, with_embedding=True):
def save(self, instance: Dict, with_valid=True, with_embedding=True, embedding_by_problem=None):
if with_valid:
self.is_valid()
ProblemInstanceSerializer(data=instance).is_valid(raise_exception=True)
@ -132,16 +134,16 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
paragraph_id=self.data.get('paragraph_id'),
dataset_id=self.data.get('dataset_id'))
problem_paragraph_mapping.save()
model = get_embedding_model_by_dataset_id(self.data.get('dataset_id'))
model_id = get_embedding_model_id_by_dataset_id(self.data.get('dataset_id'))
if with_embedding:
ListenerManagement.embedding_by_problem_signal.send({'text': problem.content,
'is_active': True,
'source_type': SourceType.PROBLEM,
'source_id': problem_paragraph_mapping.id,
'document_id': self.data.get('document_id'),
'paragraph_id': self.data.get('paragraph_id'),
'dataset_id': self.data.get('dataset_id'),
}, embedding_model=model)
embedding_by_problem_task({'text': problem.content,
'is_active': True,
'source_type': SourceType.PROBLEM,
'source_id': problem_paragraph_mapping.id,
'document_id': self.data.get('document_id'),
'paragraph_id': self.data.get('paragraph_id'),
'dataset_id': self.data.get('dataset_id'),
}, model_id)
return ProblemSerializers.Operate(
data={'dataset_id': self.data.get('dataset_id'),
@ -228,15 +230,15 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
problem_id=problem.id)
problem_paragraph_mapping.save()
if with_embedding:
model = get_embedding_model_by_dataset_id(self.data.get('dataset_id'))
ListenerManagement.embedding_by_problem_signal.send({'text': problem.content,
'is_active': True,
'source_type': SourceType.PROBLEM,
'source_id': problem_paragraph_mapping.id,
'document_id': self.data.get('document_id'),
'paragraph_id': self.data.get('paragraph_id'),
'dataset_id': self.data.get('dataset_id'),
}, embedding_model=model)
model_id = get_embedding_model_id_by_dataset_id(self.data.get('dataset_id'))
embedding_by_problem({'text': problem.content,
'is_active': True,
'source_type': SourceType.PROBLEM,
'source_id': problem_paragraph_mapping.id,
'document_id': self.data.get('document_id'),
'paragraph_id': self.data.get('paragraph_id'),
'dataset_id': self.data.get('dataset_id'),
}, model_id)
def un_association(self, with_valid=True):
if with_valid:
@ -248,7 +250,7 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
'problem_id')).first()
problem_paragraph_mapping_id = problem_paragraph_mapping.id
problem_paragraph_mapping.delete()
ListenerManagement.delete_embedding_by_source_signal.send(problem_paragraph_mapping_id)
delete_embedding_by_source(problem_paragraph_mapping_id)
return True
@staticmethod
@ -289,7 +291,7 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
QuerySet(ProblemParagraphMapping).filter(paragraph_id__in=paragraph_id_list).delete()
update_document_char_length(self.data.get('document_id'))
# 删除向量库
ListenerManagement.delete_embedding_by_paragraph_ids(paragraph_id_list)
delete_embedding_by_paragraph_ids(paragraph_id_list)
return True
class Migrate(ApiMixin, serializers.Serializer):
@ -338,11 +340,8 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
# 修改mapping
QuerySet(ProblemParagraphMapping).bulk_update(problem_paragraph_mapping_list,
['document_id'])
# 修改向量段落信息
ListenerManagement.update_embedding_document_id(UpdateEmbeddingDocumentIdArgs(
[paragraph.id for paragraph in paragraph_list],
target_document_id, target_dataset_id, target_embedding_model=None))
update_embedding_document_id([paragraph.id for paragraph in paragraph_list],
target_document_id, target_dataset_id, None)
# 修改段落信息
paragraph_list.update(document_id=target_document_id)
# 不同数据集迁移
@ -371,16 +370,14 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
['problem_id', 'dataset_id', 'document_id'])
target_dataset = QuerySet(DataSet).filter(id=target_dataset_id).first()
dataset = QuerySet(DataSet).filter(id=dataset_id).first()
embedding_model = None
embedding_model_id = None
if target_dataset.embedding_mode_id != dataset.embedding_mode_id:
embedding_model = get_embedding_model_by_dataset(target_dataset)
embedding_model_id = str(target_dataset.embedding_mode_id)
pid_list = [paragraph.id for paragraph in paragraph_list]
# 修改段落信息
paragraph_list.update(dataset_id=target_dataset_id, document_id=target_document_id)
# 修改向量段落信息
ListenerManagement.update_embedding_document_id(UpdateEmbeddingDocumentIdArgs(
pid_list,
target_document_id, target_dataset_id, target_embedding_model=embedding_model))
update_embedding_document_id(pid_list, target_document_id, target_dataset_id, embedding_model_id)
update_document_char_length(document_id)
update_document_char_length(target_document_id)
@ -466,12 +463,12 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
@staticmethod
def post_embedding(paragraph, instance, dataset_id):
if 'is_active' in instance and instance.get('is_active') is not None:
s = (ListenerManagement.enable_embedding_by_paragraph_signal if instance.get(
'is_active') else ListenerManagement.disable_embedding_by_paragraph_signal)
s.send(paragraph.get('id'))
(enable_embedding_by_paragraph if instance.get(
'is_active') else disable_embedding_by_paragraph)(paragraph.get('id'))
else:
model = get_embedding_model_by_dataset_id(dataset_id)
ListenerManagement.embedding_by_paragraph_signal.send(paragraph.get('id'), embedding_model=model)
model_id = get_embedding_model_id_by_dataset_id(dataset_id)
embedding_by_paragraph(paragraph.get('id'), model_id)
return paragraph
@post(post_embedding)
@ -543,7 +540,7 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
QuerySet(Paragraph).filter(id=paragraph_id).delete()
QuerySet(ProblemParagraphMapping).filter(paragraph_id=paragraph_id).delete()
update_document_char_length(self.data.get('document_id'))
ListenerManagement.delete_embedding_by_paragraph_signal.send(paragraph_id)
delete_embedding_by_paragraph(paragraph_id)
@staticmethod
def get_request_body_api():
@ -593,8 +590,8 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
# 修改长度
update_document_char_length(document_id)
if with_embedding:
model = get_embedding_model_by_dataset_id(dataset_id)
ListenerManagement.embedding_by_paragraph_signal.send(str(paragraph.id), embedding_model=model)
model_id = get_embedding_model_id_by_dataset_id(dataset_id)
embedding_by_paragraph(str(paragraph.id), model_id)
return ParagraphSerializers.Operate(
data={'paragraph_id': str(paragraph.id), 'dataset_id': dataset_id, 'document_id': document_id}).one(
with_valid=True)

View File

@ -16,12 +16,12 @@ from drf_yasg import openapi
from rest_framework import serializers
from common.db.search import native_search, native_page_search
from common.event import ListenerManagement, UpdateProblemArgs
from common.mixins.api_mixin import ApiMixin
from common.util.field_message import ErrMessage
from common.util.file_util import get_file_content
from dataset.models import Problem, Paragraph, ProblemParagraphMapping, DataSet
from dataset.serializers.common_serializers import get_embedding_model_by_dataset_id
from dataset.serializers.common_serializers import get_embedding_model_id_by_dataset_id
from embedding.task import delete_embedding_by_source_ids, update_problem_embedding
from smartdoc.conf import PROJECT_DIR
@ -111,7 +111,7 @@ class ProblemSerializers(ApiMixin, serializers.Serializer):
source_ids = [row.id for row in problem_paragraph_mapping_list]
problem_paragraph_mapping_list.delete()
QuerySet(Problem).filter(id__in=problem_id_list).delete()
ListenerManagement.delete_embedding_by_source_ids_signal.send(source_ids)
delete_embedding_by_source_ids(source_ids)
return True
class Operate(serializers.Serializer):
@ -146,7 +146,7 @@ class ProblemSerializers(ApiMixin, serializers.Serializer):
source_ids = [row.id for row in problem_paragraph_mapping_list]
problem_paragraph_mapping_list.delete()
QuerySet(Problem).filter(id=self.data.get('problem_id')).delete()
ListenerManagement.delete_embedding_by_source_ids_signal.send(source_ids)
delete_embedding_by_source_ids(source_ids)
return True
@transaction.atomic
@ -161,5 +161,5 @@ class ProblemSerializers(ApiMixin, serializers.Serializer):
QuerySet(DataSet).filter(id=dataset_id)
problem.content = content
problem.save()
model = get_embedding_model_by_dataset_id(dataset_id)
ListenerManagement.update_problem_signal.send(UpdateProblemArgs(problem_id, content, model))
model_id = get_embedding_model_id_by_dataset_id(dataset_id)
update_problem_embedding(problem_id, content, model_id)

View File

@ -0,0 +1,9 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file __init__.py
@date2024/8/21 9:57
@desc:
"""
from .sync import *

42
apps/dataset/task/sync.py Normal file
View File

@ -0,0 +1,42 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file sync.py
@date2024/8/20 21:37
@desc:
"""
import logging
import traceback
from typing import List
from celery_once import QueueOnce
from common.util.fork import ForkManage, Fork
from dataset.task.tools import get_save_handler, get_sync_web_document_handler
from ops import celery_app
max_kb_error = logging.getLogger("max_kb_error")
max_kb = logging.getLogger("max_kb")
@celery_app.task(base=QueueOnce, once={'keys': ['dataset_id']}, name='celery:sync_web_dataset')
def sync_web_dataset(dataset_id: str, url: str, selector: str):
try:
max_kb.info(f"开始--->开始同步web知识库:{dataset_id}")
ForkManage(url, selector.split(" ") if selector is not None else []).fork(2, set(),
get_save_handler(dataset_id,
selector))
max_kb.info(f"结束--->结束同步web知识库:{dataset_id}")
except Exception as e:
max_kb_error.error(f'同步web知识库:{dataset_id}出现错误{str(e)}{traceback.format_exc()}')
@celery_app.task(name='celery:sync_web_document')
def sync_web_document(dataset_id, source_url_list: List[str], selector: str):
handler = get_sync_web_document_handler(dataset_id)
for source_url in source_url_list:
result = Fork(base_fork_url=source_url, selector_list=selector.split(' ')).fork()
handler(source_url, selector, result)

View File

@ -0,0 +1,62 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file tools.py
@date2024/8/20 21:48
@desc:
"""
import logging
import traceback
from common.util.fork import ChildLink, Fork
from common.util.split_model import get_split_model
from dataset.models import Type, Document, Status
max_kb_error = logging.getLogger("max_kb_error")
max_kb = logging.getLogger("max_kb")
def get_save_handler(dataset_id, selector):
from dataset.serializers.document_serializers import DocumentSerializers
def handler(child_link: ChildLink, response: Fork.Response):
if response.status == 200:
try:
document_name = child_link.tag.text if child_link.tag is not None and len(
child_link.tag.text.strip()) > 0 else child_link.url
paragraphs = get_split_model('web.md').parse(response.content)
DocumentSerializers.Create(data={'dataset_id': dataset_id}).save(
{'name': document_name, 'paragraphs': paragraphs,
'meta': {'source_url': child_link.url, 'selector': selector},
'type': Type.web}, with_valid=True)
except Exception as e:
logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}')
return handler
def get_sync_web_document_handler(dataset_id):
from dataset.serializers.document_serializers import DocumentSerializers
def handler(source_url: str, selector, response: Fork.Response):
if response.status == 200:
try:
paragraphs = get_split_model('web.md').parse(response.content)
# 插入
DocumentSerializers.Create(data={'dataset_id': dataset_id}).save(
{'name': source_url[0:128], 'paragraphs': paragraphs,
'meta': {'source_url': source_url, 'selector': selector},
'type': Type.web}, with_valid=True)
except Exception as e:
logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}')
else:
Document(name=source_url[0:128],
dataset_id=dataset_id,
meta={'source_url': source_url, 'selector': selector},
type=Type.web,
char_length=0,
status=Status.error).save()
return handler

View File

@ -0,0 +1 @@
from .embedding import *

View File

@ -0,0 +1,218 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file embedding.py
@date2024/8/19 14:13
@desc:
"""
import logging
import traceback
from typing import List
from celery_once import QueueOnce
from django.db.models import QuerySet
from common.config.embedding_config import ModelManage
from common.event import ListenerManagement, UpdateProblemArgs, UpdateEmbeddingDatasetIdArgs, \
UpdateEmbeddingDocumentIdArgs
from dataset.models import Document
from ops import celery_app
from setting.models import Model
from setting.models_provider import get_model
max_kb_error = logging.getLogger("max_kb_error")
max_kb = logging.getLogger("max_kb")
def get_embedding_model(model_id):
model = QuerySet(Model).filter(id=model_id).first()
embedding_model = ModelManage.get_model(model_id,
lambda _id: get_model(model))
return embedding_model
@celery_app.task(base=QueueOnce, once={'keys': ['paragraph_id']}, name='celery:embedding_by_paragraph')
def embedding_by_paragraph(paragraph_id, model_id):
embedding_model = get_embedding_model(model_id)
ListenerManagement.embedding_by_paragraph(paragraph_id, embedding_model)
@celery_app.task(base=QueueOnce, once={'keys': ['paragraph_id_list']}, name='celery:embedding_by_paragraph_data_list')
def embedding_by_paragraph_data_list(data_list, paragraph_id_list, model_id):
embedding_model = get_embedding_model(model_id)
ListenerManagement.embedding_by_paragraph_data_list(data_list, paragraph_id_list, embedding_model)
@celery_app.task(base=QueueOnce, once={'keys': ['paragraph_id_list']}, name='celery:embedding_by_paragraph_list')
def embedding_by_paragraph_list(paragraph_id_list, model_id):
embedding_model = get_embedding_model(model_id)
ListenerManagement.embedding_by_paragraph_list(paragraph_id_list, embedding_model)
@celery_app.task(base=QueueOnce, once={'keys': ['document_id']}, name='celery:embedding_by_document')
def embedding_by_document(document_id, model_id):
"""
向量化文档
@param document_id: 文档id
@param model_id 向量模型
:return: None
"""
embedding_model = get_embedding_model(model_id)
ListenerManagement.embedding_by_document(document_id, embedding_model)
@celery_app.task(base=QueueOnce, once={'keys': ['dataset_id']}, name='celery:embedding_by_dataset')
def embedding_by_dataset(dataset_id, model_id):
"""
向量化知识库
@param dataset_id: 知识库id
@param model_id 向量模型
:return: None
"""
max_kb.info(f"开始--->向量化数据集:{dataset_id}")
try:
ListenerManagement.delete_embedding_by_dataset(dataset_id)
document_list = QuerySet(Document).filter(dataset_id=dataset_id)
max_kb.info(f"数据集文档:{[d.name for d in document_list]}")
for document in document_list:
try:
embedding_by_document.delay(document.id, model_id)
except Exception as e:
pass
except Exception as e:
max_kb_error.error(f'向量化数据集:{dataset_id}出现错误{str(e)}{traceback.format_exc()}')
finally:
max_kb.info(f"结束--->向量化数据集:{dataset_id}")
def embedding_by_problem(args, model_id):
"""
向量话问题
@param args: 问题对象
@param model_id: 模型id
@return:
"""
embedding_model = get_embedding_model(model_id)
ListenerManagement.embedding_by_problem(args, embedding_model)
def delete_embedding_by_document(document_id):
"""
删除指定文档id的向量
@param document_id: 文档id
@return: None
"""
ListenerManagement.delete_embedding_by_document(document_id)
def delete_embedding_by_document_list(document_id_list: List[str]):
"""
删除指定文档列表的向量数据
@param document_id_list: 文档id列表
@return: None
"""
ListenerManagement.delete_embedding_by_document_list(document_id_list)
def delete_embedding_by_dataset(dataset_id):
"""
删除指定数据集向量数据
@param dataset_id: 数据集id
@return: None
"""
ListenerManagement.delete_embedding_by_dataset(dataset_id)
def delete_embedding_by_paragraph(paragraph_id):
"""
删除指定段落的向量数据
@param paragraph_id: 段落id
@return: None
"""
ListenerManagement.delete_embedding_by_paragraph(paragraph_id)
def delete_embedding_by_source(source_id):
"""
删除指定资源id的向量数据
@param source_id: 资源id
@return: None
"""
ListenerManagement.delete_embedding_by_source(source_id)
def disable_embedding_by_paragraph(paragraph_id):
"""
禁用某个段落id的向量
@param paragraph_id: 段落id
@return: None
"""
ListenerManagement.disable_embedding_by_paragraph(paragraph_id)
def enable_embedding_by_paragraph(paragraph_id):
"""
开启某个段落id的向量数据
@param paragraph_id: 段落id
@return: None
"""
ListenerManagement.enable_embedding_by_paragraph(paragraph_id)
def delete_embedding_by_source_ids(source_ids: List[str]):
"""
删除向量根据source_id_list
@param source_ids:
@return:
"""
ListenerManagement.delete_embedding_by_source_ids(source_ids)
def update_problem_embedding(problem_id: str, problem_content: str, model_id):
"""
更新问题
@param problem_id:
@param problem_content:
@param model_id:
@return:
"""
model = get_embedding_model(model_id)
ListenerManagement.update_problem(UpdateProblemArgs(problem_id, problem_content, model))
def update_embedding_dataset_id(paragraph_id_list, target_dataset_id, target_embedding_model_id=None):
"""
修改向量数据到指定知识库
@param paragraph_id_list: 指定段落的向量数据
@param target_dataset_id: 知识库id
@param target_embedding_model_id: 目标知识库
@return:
"""
target_embedding_model = get_embedding_model(
target_embedding_model_id) if target_embedding_model_id is not None else None
ListenerManagement.update_embedding_dataset_id(
UpdateEmbeddingDatasetIdArgs(paragraph_id_list, target_dataset_id, target_embedding_model))
def delete_embedding_by_paragraph_ids(paragraph_ids: List[str]):
"""
删除指定段落列表的向量数据
@param paragraph_ids: 段落列表
@return: None
"""
ListenerManagement.delete_embedding_by_paragraph_ids(paragraph_ids)
def update_embedding_document_id(paragraph_id_list, target_document_id, target_dataset_id,
target_embedding_model_id=None):
target_embedding_model = get_embedding_model(
target_embedding_model_id) if target_embedding_model_id is not None else None
ListenerManagement.update_embedding_document_id(
UpdateEmbeddingDocumentIdArgs(paragraph_id_list, target_document_id, target_dataset_id, target_embedding_model))
def delete_embedding_by_dataset_id_list(dataset_id_list):
ListenerManagement.delete_embedding_by_dataset_id_list(dataset_id_list)

View File

@ -127,7 +127,7 @@ class BaseVectorStore(ABC):
return []
embedding_query = embedding.embed_query(query_text)
result = self.query(embedding_query, dataset_id_list, exclude_document_id_list, exclude_paragraph_list,
is_active, 1, 0.65)
is_active, 1, 3, 0.65)
return result[0]
@abstractmethod
@ -169,7 +169,7 @@ class BaseVectorStore(ABC):
pass
@abstractmethod
def delete_bu_document_id_list(self, document_id_list: List[str]):
def delete_by_document_id_list(self, document_id_list: List[str]):
pass
@abstractmethod

View File

@ -27,6 +27,8 @@ from smartdoc.conf import PROJECT_DIR
class PGVector(BaseVectorStore):
def delete_by_source_ids(self, source_ids: List[str], source_type: str):
if len(source_ids) == 0:
return
QuerySet(Embedding).filter(source_id__in=source_ids, source_type=source_type).delete()
def update_by_source_ids(self, source_ids: List[str], instance: Dict):
@ -67,7 +69,7 @@ class PGVector(BaseVectorStore):
source_type=text_list[index].get('source_type'),
embedding=embeddings[index],
search_vector=to_ts_vector(text_list[index]['text'])) for index in
range(0, len(text_list))]
range(0, len(texts))]
if is_save_function():
QuerySet(Embedding).bulk_create(embedding_list) if len(embedding_list) > 0 else None
return True
@ -124,7 +126,9 @@ class PGVector(BaseVectorStore):
QuerySet(Embedding).filter(document_id=document_id).delete()
return True
def delete_bu_document_id_list(self, document_id_list: List[str]):
def delete_by_document_id_list(self, document_id_list: List[str]):
if len(document_id_list) == 0:
return True
return QuerySet(Embedding).filter(document_id__in=document_id_list).delete()
def delete_by_source_id(self, source_id: str, source_type: str):

View File

9
apps/ops/__init__.py Normal file
View File

@ -0,0 +1,9 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file __init__.py.py
@date2024/8/16 14:47
@desc:
"""
from .celery import app as celery_app

View File

@ -0,0 +1,33 @@
# -*- coding: utf-8 -*-
import os
from celery import Celery
from celery.schedules import crontab
from kombu import Exchange, Queue
from smartdoc import settings
from .heatbeat import *
# set the default Django settings module for the 'celery' program.
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'smartdoc.settings')
app = Celery('MaxKB')
configs = {k: v for k, v in settings.__dict__.items() if k.startswith('CELERY')}
configs['worker_concurrency'] = 5
# Using a string here means the worker will not have to
# pickle the object when using Windows.
# app.config_from_object('django.conf:settings', namespace='CELERY')
configs["task_queues"] = [
Queue("celery", Exchange("celery"), routing_key="celery"),
Queue("model", Exchange("model"), routing_key="model")
]
app.namespace = 'CELERY'
app.conf.update(
{key.replace('CELERY_', '') if key.replace('CELERY_', '').lower() == key.replace('CELERY_',
'') else key: configs.get(
key) for
key
in configs.keys()})
app.autodiscover_tasks(lambda: [app_config.split('.')[0] for app_config in settings.INSTALLED_APPS])

4
apps/ops/celery/const.py Normal file
View File

@ -0,0 +1,4 @@
# -*- coding: utf-8 -*-
#
CELERY_LOG_MAGIC_MARK = b'\x00\x00\x00\x00\x00'

View File

@ -0,0 +1,104 @@
# -*- coding: utf-8 -*-
#
from functools import wraps
_need_registered_period_tasks = []
_after_app_ready_start_tasks = []
_after_app_shutdown_clean_periodic_tasks = []
def add_register_period_task(task):
_need_registered_period_tasks.append(task)
def get_register_period_tasks():
return _need_registered_period_tasks
def add_after_app_shutdown_clean_task(name):
_after_app_shutdown_clean_periodic_tasks.append(name)
def get_after_app_shutdown_clean_tasks():
return _after_app_shutdown_clean_periodic_tasks
def add_after_app_ready_task(name):
_after_app_ready_start_tasks.append(name)
def get_after_app_ready_tasks():
return _after_app_ready_start_tasks
def register_as_period_task(
crontab=None, interval=None, name=None,
args=(), kwargs=None,
description=''):
"""
Warning: Task must have not any args and kwargs
:param crontab: "* * * * *"
:param interval: 60*60*60
:param args: ()
:param kwargs: {}
:param description: "
:param name: ""
:return:
"""
if crontab is None and interval is None:
raise SyntaxError("Must set crontab or interval one")
def decorate(func):
if crontab is None and interval is None:
raise SyntaxError("Interval and crontab must set one")
# Because when this decorator run, the task was not created,
# So we can't use func.name
task = '{func.__module__}.{func.__name__}'.format(func=func)
_name = name if name else task
add_register_period_task({
_name: {
'task': task,
'interval': interval,
'crontab': crontab,
'args': args,
'kwargs': kwargs if kwargs else {},
'description': description
}
})
@wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper
return decorate
def after_app_ready_start(func):
# Because when this decorator run, the task was not created,
# So we can't use func.name
name = '{func.__module__}.{func.__name__}'.format(func=func)
if name not in _after_app_ready_start_tasks:
add_after_app_ready_task(name)
@wraps(func)
def decorate(*args, **kwargs):
return func(*args, **kwargs)
return decorate
def after_app_shutdown_clean_periodic(func):
# Because when this decorator run, the task was not created,
# So we can't use func.name
name = '{func.__module__}.{func.__name__}'.format(func=func)
if name not in _after_app_shutdown_clean_periodic_tasks:
add_after_app_shutdown_clean_task(name)
@wraps(func)
def decorate(*args, **kwargs):
return func(*args, **kwargs)
return decorate

View File

@ -0,0 +1,25 @@
from pathlib import Path
from celery.signals import heartbeat_sent, worker_ready, worker_shutdown
@heartbeat_sent.connect
def heartbeat(sender, **kwargs):
worker_name = sender.eventer.hostname.split('@')[0]
heartbeat_path = Path('/tmp/worker_heartbeat_{}'.format(worker_name))
heartbeat_path.touch()
@worker_ready.connect
def worker_ready(sender, **kwargs):
worker_name = sender.hostname.split('@')[0]
ready_path = Path('/tmp/worker_ready_{}'.format(worker_name))
ready_path.touch()
@worker_shutdown.connect
def worker_shutdown(sender, **kwargs):
worker_name = sender.hostname.split('@')[0]
for signal in ['ready', 'heartbeat']:
path = Path('/tmp/worker_{}_{}'.format(signal, worker_name))
path.unlink(missing_ok=True)

223
apps/ops/celery/logger.py Normal file
View File

@ -0,0 +1,223 @@
from logging import StreamHandler
from threading import get_ident
from celery import current_task
from celery.signals import task_prerun, task_postrun
from django.conf import settings
from kombu import Connection, Exchange, Queue, Producer
from kombu.mixins import ConsumerMixin
from .utils import get_celery_task_log_path
from .const import CELERY_LOG_MAGIC_MARK
routing_key = 'celery_log'
celery_log_exchange = Exchange('celery_log_exchange', type='direct')
celery_log_queue = [Queue('celery_log', celery_log_exchange, routing_key=routing_key)]
class CeleryLoggerConsumer(ConsumerMixin):
def __init__(self):
self.connection = Connection(settings.CELERY_LOG_BROKER_URL)
def get_consumers(self, Consumer, channel):
return [Consumer(queues=celery_log_queue,
accept=['pickle', 'json'],
callbacks=[self.process_task])
]
def handle_task_start(self, task_id, message):
pass
def handle_task_end(self, task_id, message):
pass
def handle_task_log(self, task_id, msg, message):
pass
def process_task(self, body, message):
action = body.get('action')
task_id = body.get('task_id')
msg = body.get('msg')
if action == CeleryLoggerProducer.ACTION_TASK_LOG:
self.handle_task_log(task_id, msg, message)
elif action == CeleryLoggerProducer.ACTION_TASK_START:
self.handle_task_start(task_id, message)
elif action == CeleryLoggerProducer.ACTION_TASK_END:
self.handle_task_end(task_id, message)
class CeleryLoggerProducer:
ACTION_TASK_START, ACTION_TASK_LOG, ACTION_TASK_END = range(3)
def __init__(self):
self.connection = Connection(settings.CELERY_LOG_BROKER_URL)
@property
def producer(self):
return Producer(self.connection)
def publish(self, payload):
self.producer.publish(
payload, serializer='json', exchange=celery_log_exchange,
declare=[celery_log_exchange], routing_key=routing_key
)
def log(self, task_id, msg):
payload = {'task_id': task_id, 'msg': msg, 'action': self.ACTION_TASK_LOG}
return self.publish(payload)
def read(self):
pass
def flush(self):
pass
def task_end(self, task_id):
payload = {'task_id': task_id, 'action': self.ACTION_TASK_END}
return self.publish(payload)
def task_start(self, task_id):
payload = {'task_id': task_id, 'action': self.ACTION_TASK_START}
return self.publish(payload)
class CeleryTaskLoggerHandler(StreamHandler):
terminator = '\r\n'
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
task_prerun.connect(self.on_task_start)
task_postrun.connect(self.on_start_end)
@staticmethod
def get_current_task_id():
if not current_task:
return
task_id = current_task.request.root_id
return task_id
def on_task_start(self, sender, task_id, **kwargs):
return self.handle_task_start(task_id)
def on_start_end(self, sender, task_id, **kwargs):
return self.handle_task_end(task_id)
def after_task_publish(self, sender, body, **kwargs):
pass
def emit(self, record):
task_id = self.get_current_task_id()
if not task_id:
return
try:
self.write_task_log(task_id, record)
self.flush()
except Exception:
self.handleError(record)
def write_task_log(self, task_id, msg):
pass
def handle_task_start(self, task_id):
pass
def handle_task_end(self, task_id):
pass
class CeleryThreadingLoggerHandler(CeleryTaskLoggerHandler):
@staticmethod
def get_current_thread_id():
return str(get_ident())
def emit(self, record):
thread_id = self.get_current_thread_id()
try:
self.write_thread_task_log(thread_id, record)
self.flush()
except ValueError:
self.handleError(record)
def write_thread_task_log(self, thread_id, msg):
pass
def handle_task_start(self, task_id):
pass
def handle_task_end(self, task_id):
pass
def handleError(self, record) -> None:
pass
class CeleryTaskMQLoggerHandler(CeleryTaskLoggerHandler):
def __init__(self):
self.producer = CeleryLoggerProducer()
super().__init__(stream=None)
def write_task_log(self, task_id, record):
msg = self.format(record)
self.producer.log(task_id, msg)
def flush(self):
self.producer.flush()
class CeleryTaskFileHandler(CeleryTaskLoggerHandler):
def __init__(self, *args, **kwargs):
self.f = None
super().__init__(*args, **kwargs)
def emit(self, record):
msg = self.format(record)
if not self.f or self.f.closed:
return
self.f.write(msg)
self.f.write(self.terminator)
self.flush()
def flush(self):
self.f and self.f.flush()
def handle_task_start(self, task_id):
log_path = get_celery_task_log_path(task_id)
self.f = open(log_path, 'a')
def handle_task_end(self, task_id):
self.f and self.f.close()
class CeleryThreadTaskFileHandler(CeleryThreadingLoggerHandler):
def __init__(self, *args, **kwargs):
self.thread_id_fd_mapper = {}
self.task_id_thread_id_mapper = {}
super().__init__(*args, **kwargs)
def write_thread_task_log(self, thread_id, record):
f = self.thread_id_fd_mapper.get(thread_id, None)
if not f:
raise ValueError('Not found thread task file')
msg = self.format(record)
f.write(msg.encode())
f.write(self.terminator.encode())
f.flush()
def flush(self):
for f in self.thread_id_fd_mapper.values():
f.flush()
def handle_task_start(self, task_id):
log_path = get_celery_task_log_path(task_id)
thread_id = self.get_current_thread_id()
self.task_id_thread_id_mapper[task_id] = thread_id
f = open(log_path, 'ab')
self.thread_id_fd_mapper[thread_id] = f
def handle_task_end(self, task_id):
ident_id = self.task_id_thread_id_mapper.get(task_id, '')
f = self.thread_id_fd_mapper.pop(ident_id, None)
if f and not f.closed:
f.write(CELERY_LOG_MAGIC_MARK)
f.close()
self.task_id_thread_id_mapper.pop(task_id, None)

View File

@ -0,0 +1,63 @@
# -*- coding: utf-8 -*-
#
import logging
import os
from celery import subtask
from celery.signals import (
worker_ready, worker_shutdown, after_setup_logger
)
from django.core.cache import cache
from django_celery_beat.models import PeriodicTask
from .decorator import get_after_app_ready_tasks, get_after_app_shutdown_clean_tasks
from .logger import CeleryThreadTaskFileHandler
logger = logging.getLogger(__file__)
safe_str = lambda x: x
@worker_ready.connect
def on_app_ready(sender=None, headers=None, **kwargs):
if cache.get("CELERY_APP_READY", 0) == 1:
return
cache.set("CELERY_APP_READY", 1, 10)
tasks = get_after_app_ready_tasks()
logger.debug("Work ready signal recv")
logger.debug("Start need start task: [{}]".format(", ".join(tasks)))
for task in tasks:
periodic_task = PeriodicTask.objects.filter(task=task).first()
if periodic_task and not periodic_task.enabled:
logger.debug("Periodic task [{}] is disabled!".format(task))
continue
subtask(task).delay()
def delete_files(directory):
if os.path.isdir(directory):
for filename in os.listdir(directory):
file_path = os.path.join(directory, filename)
if os.path.isfile(file_path):
os.remove(file_path)
@worker_shutdown.connect
def after_app_shutdown_periodic_tasks(sender=None, **kwargs):
if cache.get("CELERY_APP_SHUTDOWN", 0) == 1:
return
cache.set("CELERY_APP_SHUTDOWN", 1, 10)
tasks = get_after_app_shutdown_clean_tasks()
logger.debug("Worker shutdown signal recv")
logger.debug("Clean period tasks: [{}]".format(', '.join(tasks)))
PeriodicTask.objects.filter(name__in=tasks).delete()
@after_setup_logger.connect
def add_celery_logger_handler(sender=None, logger=None, loglevel=None, format=None, **kwargs):
if not logger:
return
task_handler = CeleryThreadTaskFileHandler()
task_handler.setLevel(loglevel)
formatter = logging.Formatter(format)
task_handler.setFormatter(formatter)
logger.addHandler(task_handler)

68
apps/ops/celery/utils.py Normal file
View File

@ -0,0 +1,68 @@
# -*- coding: utf-8 -*-
#
import logging
import os
import uuid
from django.conf import settings
from django_celery_beat.models import (
PeriodicTasks
)
from smartdoc.const import PROJECT_DIR
logger = logging.getLogger(__file__)
def disable_celery_periodic_task(task_name):
from django_celery_beat.models import PeriodicTask
PeriodicTask.objects.filter(name=task_name).update(enabled=False)
PeriodicTasks.update_changed()
def delete_celery_periodic_task(task_name):
from django_celery_beat.models import PeriodicTask
PeriodicTask.objects.filter(name=task_name).delete()
PeriodicTasks.update_changed()
def get_celery_periodic_task(task_name):
from django_celery_beat.models import PeriodicTask
task = PeriodicTask.objects.filter(name=task_name).first()
return task
def make_dirs(name, mode=0o755, exist_ok=False):
""" 默认权限设置为 0o755 """
return os.makedirs(name, mode=mode, exist_ok=exist_ok)
def get_task_log_path(base_path, task_id, level=2):
task_id = str(task_id)
try:
uuid.UUID(task_id)
except:
return os.path.join(PROJECT_DIR, 'data', 'caution.txt')
rel_path = os.path.join(*task_id[:level], task_id + '.log')
path = os.path.join(base_path, rel_path)
make_dirs(os.path.dirname(path), exist_ok=True)
return path
def get_celery_task_log_path(task_id):
return get_task_log_path(settings.CELERY_LOG_DIR, task_id)
def get_celery_status():
from . import app
i = app.control.inspect()
ping_data = i.ping() or {}
active_nodes = [k for k, v in ping_data.items() if v.get('ok') == 'pong']
active_queue_worker = set([n.split('@')[0] for n in active_nodes if n])
# Celery Worker 数量: 2
if len(active_queue_worker) < 2:
print("Not all celery worker worked")
return False
else:
return True

View File

@ -13,18 +13,22 @@ from common.util.rsa_util import rsa_long_decrypt
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
def get_model_(provider, model_type, model_name, credential, **kwargs):
def get_model_(provider, model_type, model_name, credential, model_id, use_local=False, **kwargs):
"""
获取模型实例
@param provider: 供应商
@param model_type: 模型类型
@param model_name: 模型名称
@param credential: 认证信息
@param model_id: 模型id
@param use_local: 是否调用本地模型 只适用于本地供应商
@return: 模型实例
"""
model = get_provider(provider).get_model(model_type, model_name,
json.loads(
rsa_long_decrypt(credential)),
model_id=model_id,
use_local=use_local,
streaming=True, **kwargs)
return model
@ -35,7 +39,7 @@ def get_model(model, **kwargs):
@param model: model 数据库Model实例对象
@return: 模型实例
"""
return get_model_(model.provider, model.model_type, model.model_name, model.credential, **kwargs)
return get_model_(model.provider, model.model_type, model.model_name, model.credential, str(model.id), **kwargs)
def get_provider(provider):

View File

@ -6,17 +6,52 @@
@date2024/7/11 14:06
@desc:
"""
from typing import Dict
from typing import Dict, List
import requests
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel
from langchain_huggingface import HuggingFaceEmbeddings
from setting.models_provider.base_model_provider import MaxKBBaseModel
class LocalEmbedding(MaxKBBaseModel, HuggingFaceEmbeddings):
class WebLocalEmbedding(MaxKBBaseModel, BaseModel, Embeddings):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
return LocalEmbedding(model_name=model_name, cache_folder=model_credential.get('cache_folder'),
model_kwargs={'device': model_credential.get('device')},
encode_kwargs={'normalize_embeddings': True},
)
pass
model_id: str = None
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.model_id = kwargs.get('model_id', None)
def embed_query(self, text: str) -> List[float]:
res = requests.post(f'http://127.0.0.1:5432/api/model/{self.model_id}/embed_query', {'text': text})
result = res.json()
if result.get('code', 500) == 200:
return result.get('data')
raise Exception(result.get('msg'))
def embed_documents(self, texts: List[str]) -> List[List[float]]:
res = requests.post(f'http://127.0.0.1:5432/api/model/{self.model_id}/embed_documents', {'texts': texts})
result = res.json()
if result.get('code', 500) == 200:
return result.get('data')
raise Exception(result.get('msg'))
class LocalEmbedding(MaxKBBaseModel, HuggingFaceEmbeddings):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
if model_kwargs.get('use_local', True):
return LocalEmbedding(model_name=model_name, cache_folder=model_credential.get('cache_folder'),
model_kwargs={'device': model_credential.get('device')},
encode_kwargs={'normalize_embeddings': True}
)
return WebLocalEmbedding(model_name=model_name, cache_folder=model_credential.get('cache_folder'),
model_kwargs={'device': model_credential.get('device')},
encode_kwargs={'normalize_embeddings': True},
**model_kwargs)

View File

@ -0,0 +1,53 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file model_apply_serializers.py
@date2024/8/20 20:39
@desc:
"""
from django.db.models import QuerySet
from rest_framework import serializers
from common.config.embedding_config import ModelManage
from common.util.field_message import ErrMessage
from setting.models import Model
from setting.models_provider import get_model
def get_embedding_model(model_id):
model = QuerySet(Model).filter(id=model_id).first()
embedding_model = ModelManage.get_model(model_id,
lambda _id: get_model(model, use_local=True))
return embedding_model
class EmbedDocuments(serializers.Serializer):
texts = serializers.ListField(required=True, child=serializers.CharField(required=True,
error_messages=ErrMessage.char(
"向量文本")),
error_messages=ErrMessage.list("向量文本列表"))
class EmbedQuery(serializers.Serializer):
text = serializers.CharField(required=True, error_messages=ErrMessage.char("向量文本"))
class ModelApplySerializers(serializers.Serializer):
model_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("模型id"))
def embed_documents(self, instance, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
EmbedDocuments(data=instance).is_valid(raise_exception=True)
model = get_embedding_model(self.data.get('model_id'))
return model.embed_documents(instance.getlist('texts'))
def embed_query(self, instance, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
EmbedQuery(data=instance).is_valid(raise_exception=True)
model = get_embedding_model(self.data.get('model_id'))
return model.embed_query(instance.get('text'))

View File

@ -1,3 +1,5 @@
import os
from django.urls import path
from . import views
@ -22,3 +24,10 @@ urlpatterns = [
path('valid/<str:valid_type>/<int:valid_count>', views.Valid.as_view())
]
if os.environ.get('SERVER_NAME', 'web') == 'local_model':
urlpatterns += [
path('model/<str:model_id>/embed_documents', views.ModelApply.EmbedDocuments.as_view(),
name='model/embed_documents'),
path('model/<str:model_id>/embed_query', views.ModelApply.EmbedQuery.as_view(),
name='model/embed_query'),
]

View File

@ -10,3 +10,4 @@ from .Team import *
from .model import *
from .system_setting import *
from .valid import *
from .model_apply import *

View File

@ -0,0 +1,38 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file model_apply.py
@date2024/8/20 20:38
@desc:
"""
from urllib.request import Request
from drf_yasg.utils import swagger_auto_schema
from rest_framework.decorators import action
from rest_framework.views import APIView
from common.response import result
from setting.serializers.model_apply_serializers import ModelApplySerializers
class ModelApply(APIView):
class EmbedDocuments(APIView):
@action(methods=['POST'], detail=False)
@swagger_auto_schema(operation_summary="向量化文档",
operation_id="向量化文档",
responses=result.get_default_response(),
tags=["模型"])
def post(self, request: Request, model_id):
return result.success(
ModelApplySerializers(data={'model_id': model_id}).embed_documents(request.data))
class EmbedQuery(APIView):
@action(methods=['POST'], detail=False)
@swagger_auto_schema(operation_summary="向量化文档",
operation_id="向量化文档",
responses=result.get_default_response(),
tags=["模型"])
def post(self, request: Request, model_id):
return result.success(
ModelApplySerializers(data={'model_id': model_id}).embed_query(request.data))

View File

@ -9,3 +9,4 @@
from .base import *
from .logging import *
from .auth import *
from .lib import *

View File

@ -42,7 +42,8 @@ INSTALLED_APPS = [
'django_filters', # 条件过滤
'django_apscheduler',
'common',
'function_lib'
'function_lib',
'django_celery_beat'
]
@ -60,6 +61,7 @@ JWT_AUTH = {
'JWT_EXPIRATION_DELTA': datetime.timedelta(seconds=60 * 60 * 2) # <-- 设置token有效时间
}
APPS_DIR = os.path.join(PROJECT_DIR, 'apps')
ROOT_URLCONF = 'smartdoc.urls'
# FORCE_SCRIPT_NAME
TEMPLATES = [

View File

@ -0,0 +1,40 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file lib.py
@date2024/8/16 17:12
@desc:
"""
import os
from smartdoc.const import CONFIG, PROJECT_DIR
# celery相关配置
celery_data_dir = os.path.join(PROJECT_DIR, 'data', 'celery_task')
if not os.path.exists(celery_data_dir) or not os.path.isdir(celery_data_dir):
os.makedirs(celery_data_dir)
broker_path = os.path.join(celery_data_dir, "celery_db.sqlite3")
backend_path = os.path.join(celery_data_dir, "celery_results.sqlite3")
# 使用sql_lite 当做broker 和 响应接收
CELERY_BROKER_URL = f'sqla+sqlite:///{broker_path}'
CELERY_result_backend = f'db+sqlite:///{backend_path}'
CELERY_timezone = CONFIG.TIME_ZONE
CELERY_ENABLE_UTC = False
CELERY_task_serializer = 'pickle'
CELERY_result_serializer = 'pickle'
CELERY_accept_content = ['json', 'pickle']
CELERY_RESULT_EXPIRES = 600
CELERY_WORKER_TASK_LOG_FORMAT = '%(asctime).19s %(message)s'
CELERY_WORKER_LOG_FORMAT = '%(asctime).19s %(message)s'
CELERY_TASK_EAGER_PROPAGATES = True
CELERY_WORKER_REDIRECT_STDOUTS = True
CELERY_WORKER_REDIRECT_STDOUTS_LEVEL = "INFO"
CELERY_TASK_SOFT_TIME_LIMIT = 3600
CELERY_WORKER_CANCEL_LONG_RUNNING_TASKS_ON_CONNECTION_LOSS = True
CELERY_ONCE = {
'backend': 'celery_once.backends.File',
'settings': {'location': os.path.join(celery_data_dir, "celery_once")}
}
CELERY_BROKER_CONNECTION_RETRY_ON_STARTUP = True
CELERY_LOG_DIR = os.path.join(PROJECT_DIR, 'logs', 'celery')

View File

@ -114,6 +114,14 @@ LOGGING = {
'level': LOG_LEVEL,
'propagate': False,
},
'common.event': {
'handlers': ['console', 'file'],
'level': "DEBUG",
'propagate': False,
},
'sqlalchemy': {'handlers': ['console', 'file'],
'level': "ERROR",
'propagate': False, }
}
}

View File

@ -21,7 +21,6 @@ def post_handler():
from common import job
from common.models.db_model_manage import DBModelManage
event.run()
event.ListenerManagement.init_embedding_model_signal.send()
job.run()
DBModelManage.init()

View File

@ -5,3 +5,5 @@ class UsersConfig(AppConfig):
default_auto_field = 'django.db.models.BigAutoField'
name = 'users'
def ready(self):
from ops.celery import signal_handler

View File

@ -26,7 +26,6 @@ from common.constants.authentication_type import AuthenticationType
from common.constants.exception_code_constants import ExceptionCodeConstants
from common.constants.permission_constants import RoleConstants, get_permission_list_by_role
from common.db.search import page_search
from common.event import ListenerManagement
from common.exception.app_exception import AppApiException
from common.mixins.api_mixin import ApiMixin
from common.response.result import get_api_response
@ -34,6 +33,7 @@ from common.util.common import valid_license
from common.util.field_message import ErrMessage
from common.util.lock import lock
from dataset.models import DataSet, Document, Paragraph, Problem, ProblemParagraphMapping
from embedding.task import delete_embedding_by_dataset_id_list
from setting.models import Team, SystemSetting, SettingType, Model, TeamMember, TeamMemberPermission
from smartdoc.conf import PROJECT_DIR
from users.models.user import User, password_encrypt, get_user_dynamics_permission
@ -497,7 +497,8 @@ class UserSerializer(ApiMixin, serializers.ModelSerializer):
class UserInstanceSerializer(ApiMixin, serializers.ModelSerializer):
class Meta:
model = User
fields = ['id', 'username', 'email', 'phone', 'is_active', 'role', 'nick_name', 'create_time', 'update_time', 'source']
fields = ['id', 'username', 'email', 'phone', 'is_active', 'role', 'nick_name', 'create_time', 'update_time',
'source']
@staticmethod
def get_response_body_api():
@ -643,7 +644,8 @@ class UserManageSerializer(serializers.Serializer):
def is_valid(self, *, user_id=None, raise_exception=False):
super().is_valid(raise_exception=True)
if self.data.get('email') is not None and QuerySet(User).filter(email=self.data.get('email')).exclude(id=user_id).exists():
if self.data.get('email') is not None and QuerySet(User).filter(email=self.data.get('email')).exclude(
id=user_id).exists():
raise AppApiException(1004, "邮箱已经被使用")
@staticmethod
@ -738,7 +740,7 @@ class UserManageSerializer(serializers.Serializer):
QuerySet(Paragraph).filter(dataset_id__in=dataset_id_list).delete()
QuerySet(ProblemParagraphMapping).filter(dataset_id__in=dataset_id_list).delete()
QuerySet(Problem).filter(dataset_id__in=dataset_id_list).delete()
ListenerManagement.delete_embedding_by_dataset_id_list_signal.send(dataset_id_list)
delete_embedding_by_dataset_id_list(dataset_id_list)
dataset_list.delete()
# 删除团队
QuerySet(Team).filter(user_id=self.data.get('id')).delete()

View File

47
main.py
View File

@ -2,6 +2,7 @@ import argparse
import logging
import os
import sys
import time
import django
from django.core import management
@ -43,11 +44,38 @@ def perform_db_migrate():
def start_services():
management.call_command('gunicorn')
services = args.services if isinstance(args.services, list) else [args.services]
start_args = []
if args.daemon:
start_args.append('--daemon')
if args.force:
start_args.append('--force')
if args.worker:
start_args.extend(['--worker', str(args.worker)])
else:
worker = os.environ.get('CORE_WORKER')
if isinstance(worker, str) and worker.isdigit():
start_args.extend(['--worker', worker])
try:
management.call_command(action, *services, *start_args)
except KeyboardInterrupt:
logging.info('Cancel ...')
time.sleep(2)
except Exception as exc:
logging.error("Start service error {}: {}".format(services, exc))
time.sleep(2)
def runserver():
management.call_command('runserver', "0.0.0.0:8080")
def dev():
services = args.services if isinstance(args.services, list) else args.services
if services.__contains__('web'):
management.call_command('runserver', "0.0.0.0:8080")
elif services.__contains__('celery'):
management.call_command('celery', 'celery')
elif services.__contains__('local_model'):
os.environ.setdefault('SERVER_NAME', 'local_model')
management.call_command('runserver', "127.0.0.1:5432")
if __name__ == '__main__':
@ -66,8 +94,17 @@ if __name__ == '__main__':
choices=("start", "dev", "upgrade_db", "collect_static"),
help="Action to run"
)
args = parser.parse_args()
args, e = parser.parse_known_args()
parser.add_argument(
"services", type=str, default='all' if args.action == 'start' else 'web', nargs="*",
choices=("all", "web", "task") if args.action == 'start' else ("web", "celery", 'local_model'),
help="The service to start",
)
parser.add_argument('-d', '--daemon', nargs="?", const=True)
parser.add_argument('-w', '--worker', type=int, nargs="?")
parser.add_argument('-f', '--force', nargs="?", const=True)
args = parser.parse_args()
action = args.action
if action == "upgrade_db":
perform_db_migrate()
@ -76,7 +113,7 @@ if __name__ == '__main__':
elif action == 'dev':
collect_static()
perform_db_migrate()
runserver()
dev()
else:
collect_static()
perform_db_migrate()

View File

@ -21,7 +21,6 @@ pillow = "^10.2.0"
filetype = "^1.2.0"
torch = "2.2.1"
sentence-transformers = "^2.2.2"
blinker = "^1.6.3"
openai = "^1.13.3"
tiktoken = "^0.7.0"
qianfan = "^0.3.6.1"
@ -55,6 +54,11 @@ boto3 = "^1.34.151"
langchain-aws = "^0.1.13"
tencentcloud-sdk-python = "^3.0.1205"
xinference-client = "^0.14.0.post1"
psutil = "^6.0.0"
celery = { extras = ["sqlalchemy"], version = "^5.4.0" }
eventlet = "^0.36.1"
django-celery-beat = "^2.6.0"
celery-once = "^3.0.1"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"