mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
feat: 分离任务
This commit is contained in:
parent
8ead0e2b6b
commit
7c5957e0a3
|
|
@ -14,6 +14,7 @@ import uuid
|
|||
from functools import reduce
|
||||
from typing import Dict, List
|
||||
|
||||
from django.conf import settings
|
||||
from django.contrib.postgres.fields import ArrayField
|
||||
from django.core import cache, validators
|
||||
from django.core import signing
|
||||
|
|
@ -46,7 +47,6 @@ from setting.models.model_management import Model
|
|||
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
|
||||
from setting.serializers.provider_serializers import ModelSerializer
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
from django.conf import settings
|
||||
|
||||
chat_cache = cache.caches['chat_cache']
|
||||
|
||||
|
|
|
|||
|
|
@ -37,8 +37,10 @@ from common.util.field_message import ErrMessage
|
|||
from common.util.file_util import get_file_content
|
||||
from common.util.lock import try_lock, un_lock
|
||||
from dataset.models import Document, Problem, Paragraph, ProblemParagraphMapping
|
||||
from dataset.serializers.common_serializers import get_embedding_model_by_dataset_id
|
||||
from dataset.serializers.common_serializers import get_embedding_model_by_dataset_id, \
|
||||
get_embedding_model_id_by_dataset_id
|
||||
from dataset.serializers.paragraph_serializers import ParagraphSerializers
|
||||
from embedding.task import embedding_by_paragraph
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
chat_cache = caches['chat_cache']
|
||||
|
|
@ -516,9 +518,9 @@ class ChatRecordSerializer(serializers.Serializer):
|
|||
|
||||
@staticmethod
|
||||
def post_embedding_paragraph(chat_record, paragraph_id, dataset_id):
|
||||
model = get_embedding_model_by_dataset_id(dataset_id)
|
||||
model_id = get_embedding_model_id_by_dataset_id(dataset_id)
|
||||
# 发送向量化事件
|
||||
ListenerManagement.embedding_by_paragraph_signal.send(paragraph_id, embedding_model=model)
|
||||
embedding_by_paragraph(paragraph_id, model_id)
|
||||
return chat_record
|
||||
|
||||
@post(post_function=post_embedding_paragraph)
|
||||
|
|
|
|||
|
|
@ -6,10 +6,13 @@
|
|||
@date:2023/10/23 16:03
|
||||
@desc:
|
||||
"""
|
||||
import threading
|
||||
import time
|
||||
|
||||
from common.cache.mem_cache import MemCache
|
||||
|
||||
lock = threading.Lock()
|
||||
|
||||
|
||||
class ModelManage:
|
||||
cache = MemCache('model', {})
|
||||
|
|
@ -17,15 +20,21 @@ class ModelManage:
|
|||
|
||||
@staticmethod
|
||||
def get_model(_id, get_model):
|
||||
model_instance = ModelManage.cache.get(_id)
|
||||
if model_instance is None or not model_instance.is_cache_model():
|
||||
model_instance = get_model(_id)
|
||||
ModelManage.cache.set(_id, model_instance, timeout=60 * 30)
|
||||
# 获取锁
|
||||
lock.acquire()
|
||||
try:
|
||||
model_instance = ModelManage.cache.get(_id)
|
||||
if model_instance is None or not model_instance.is_cache_model():
|
||||
model_instance = get_model(_id)
|
||||
ModelManage.cache.set(_id, model_instance, timeout=60 * 30)
|
||||
return model_instance
|
||||
# 续期
|
||||
ModelManage.cache.touch(_id, timeout=60 * 30)
|
||||
ModelManage.clear_timeout_cache()
|
||||
return model_instance
|
||||
# 续期
|
||||
ModelManage.cache.touch(_id, timeout=60 * 30)
|
||||
ModelManage.clear_timeout_cache()
|
||||
return model_instance
|
||||
finally:
|
||||
# 释放锁
|
||||
lock.release()
|
||||
|
||||
@staticmethod
|
||||
def clear_timeout_cache():
|
||||
|
|
|
|||
|
|
@ -12,7 +12,6 @@ from .listener_manage import *
|
|||
|
||||
|
||||
def run():
|
||||
listener_manage.ListenerManagement().run()
|
||||
QuerySet(Document).filter(status__in=[Status.embedding, Status.queue_up]).update(**{'status': Status.error})
|
||||
QuerySet(Model).filter(status=setting.models.Status.DOWNLOAD).update(status=setting.models.Status.ERROR,
|
||||
meta={'message': "下载程序被中断,请重试"})
|
||||
|
|
|
|||
|
|
@ -13,22 +13,20 @@ import traceback
|
|||
from typing import List
|
||||
|
||||
import django.db.models
|
||||
from blinker import signal
|
||||
from django.db.models import QuerySet
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
from common.config.embedding_config import VectorStore
|
||||
from common.db.search import native_search, get_dynamics_model
|
||||
from common.event.common import poxy, embedding_poxy
|
||||
from common.event.common import embedding_poxy
|
||||
from common.util.file_util import get_file_content
|
||||
from common.util.fork import ForkManage, Fork
|
||||
from common.util.lock import try_lock, un_lock
|
||||
from dataset.models import Paragraph, Status, Document, ProblemParagraphMapping
|
||||
from embedding.models import SourceType
|
||||
from embedding.models import SourceType, SearchMode
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
max_kb_error = logging.getLogger("max_kb_error")
|
||||
max_kb = logging.getLogger("max_kb")
|
||||
max_kb_error = logging.getLogger(__file__)
|
||||
max_kb = logging.getLogger(__file__)
|
||||
|
||||
|
||||
class SyncWebDatasetArgs:
|
||||
|
|
@ -70,23 +68,6 @@ class UpdateEmbeddingDocumentIdArgs:
|
|||
|
||||
|
||||
class ListenerManagement:
|
||||
embedding_by_problem_signal = signal("embedding_by_problem")
|
||||
embedding_by_paragraph_signal = signal("embedding_by_paragraph")
|
||||
embedding_by_dataset_signal = signal("embedding_by_dataset")
|
||||
embedding_by_document_signal = signal("embedding_by_document")
|
||||
delete_embedding_by_document_signal = signal("delete_embedding_by_document")
|
||||
delete_embedding_by_document_list_signal = signal("delete_embedding_by_document_list")
|
||||
delete_embedding_by_dataset_signal = signal("delete_embedding_by_dataset")
|
||||
delete_embedding_by_paragraph_signal = signal("delete_embedding_by_paragraph")
|
||||
delete_embedding_by_source_signal = signal("delete_embedding_by_source")
|
||||
enable_embedding_by_paragraph_signal = signal('enable_embedding_by_paragraph')
|
||||
disable_embedding_by_paragraph_signal = signal('disable_embedding_by_paragraph')
|
||||
init_embedding_model_signal = signal('init_embedding_model')
|
||||
sync_web_dataset_signal = signal('sync_web_dataset')
|
||||
sync_web_document_signal = signal('sync_web_document')
|
||||
update_problem_signal = signal('update_problem')
|
||||
delete_embedding_by_source_ids_signal = signal('delete_embedding_by_source_ids')
|
||||
delete_embedding_by_dataset_id_list_signal = signal("delete_embedding_by_dataset_id_list")
|
||||
|
||||
@staticmethod
|
||||
def embedding_by_problem(args, embedding_model: Embeddings):
|
||||
|
|
@ -160,7 +141,6 @@ class ListenerManagement:
|
|||
max_kb.info(f'结束--->向量化段落:{paragraph_id}')
|
||||
|
||||
@staticmethod
|
||||
@embedding_poxy
|
||||
def embedding_by_document(document_id, embedding_model: Embeddings):
|
||||
"""
|
||||
向量化文档
|
||||
|
|
@ -227,7 +207,7 @@ class ListenerManagement:
|
|||
|
||||
@staticmethod
|
||||
def delete_embedding_by_document_list(document_id_list: List[str]):
|
||||
VectorStore.get_embedding_vector().delete_bu_document_id_list(document_id_list)
|
||||
VectorStore.get_embedding_vector().delete_by_document_id_list(document_id_list)
|
||||
|
||||
@staticmethod
|
||||
def delete_embedding_by_dataset(dataset_id):
|
||||
|
|
@ -249,25 +229,6 @@ class ListenerManagement:
|
|||
def enable_embedding_by_paragraph(paragraph_id):
|
||||
VectorStore.get_embedding_vector().update_by_paragraph_id(paragraph_id, {'is_active': True})
|
||||
|
||||
@staticmethod
|
||||
@poxy
|
||||
def sync_web_document(args: SyncWebDocumentArgs):
|
||||
for source_url in args.source_url_list:
|
||||
result = Fork(base_fork_url=source_url, selector_list=args.selector.split(' ')).fork()
|
||||
args.handler(source_url, args.selector, result)
|
||||
|
||||
@staticmethod
|
||||
@poxy
|
||||
def sync_web_dataset(args: SyncWebDatasetArgs):
|
||||
if try_lock('sync_web_dataset' + args.lock_key):
|
||||
try:
|
||||
ForkManage(args.url, args.selector.split(" ") if args.selector is not None else []).fork(2, set(),
|
||||
args.handler)
|
||||
except Exception as e:
|
||||
logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}')
|
||||
finally:
|
||||
un_lock('sync_web_dataset' + args.lock_key)
|
||||
|
||||
@staticmethod
|
||||
def update_problem(args: UpdateProblemArgs):
|
||||
problem_paragraph_mapping_list = QuerySet(ProblemParagraphMapping).filter(problem_id=args.problem_id)
|
||||
|
|
@ -281,6 +242,9 @@ class ListenerManagement:
|
|||
VectorStore.get_embedding_vector().update_by_paragraph_ids(args.paragraph_id_list,
|
||||
{'dataset_id': args.target_dataset_id})
|
||||
else:
|
||||
# 删除向量数据
|
||||
ListenerManagement.delete_embedding_by_paragraph_ids(args.paragraph_id_list)
|
||||
# 向量数据
|
||||
ListenerManagement.embedding_by_paragraph_list(args.paragraph_id_list,
|
||||
embedding_model=args.target_embedding_model)
|
||||
|
||||
|
|
@ -306,38 +270,10 @@ class ListenerManagement:
|
|||
def delete_embedding_by_dataset_id_list(source_ids: List[str]):
|
||||
VectorStore.get_embedding_vector().delete_by_dataset_id_list(source_ids)
|
||||
|
||||
def run(self):
|
||||
# 添加向量 根据问题id
|
||||
ListenerManagement.embedding_by_problem_signal.connect(self.embedding_by_problem)
|
||||
# 添加向量 根据段落id
|
||||
ListenerManagement.embedding_by_paragraph_signal.connect(self.embedding_by_paragraph)
|
||||
# 添加向量 根据知识库id
|
||||
ListenerManagement.embedding_by_dataset_signal.connect(
|
||||
self.embedding_by_dataset)
|
||||
# 添加向量 根据文档id
|
||||
ListenerManagement.embedding_by_document_signal.connect(
|
||||
self.embedding_by_document)
|
||||
# 删除 向量 根据文档
|
||||
ListenerManagement.delete_embedding_by_document_signal.connect(self.delete_embedding_by_document)
|
||||
# 删除 向量 根据文档id列表
|
||||
ListenerManagement.delete_embedding_by_document_list_signal.connect(self.delete_embedding_by_document_list)
|
||||
# 删除 向量 根据知识库id
|
||||
ListenerManagement.delete_embedding_by_dataset_signal.connect(self.delete_embedding_by_dataset)
|
||||
# 删除向量 根据段落id
|
||||
ListenerManagement.delete_embedding_by_paragraph_signal.connect(
|
||||
self.delete_embedding_by_paragraph)
|
||||
# 删除向量 根据资源id
|
||||
ListenerManagement.delete_embedding_by_source_signal.connect(self.delete_embedding_by_source)
|
||||
# 禁用段落
|
||||
ListenerManagement.disable_embedding_by_paragraph_signal.connect(self.disable_embedding_by_paragraph)
|
||||
# 启动段落向量
|
||||
ListenerManagement.enable_embedding_by_paragraph_signal.connect(self.enable_embedding_by_paragraph)
|
||||
|
||||
# 同步web站点知识库
|
||||
ListenerManagement.sync_web_dataset_signal.connect(self.sync_web_dataset)
|
||||
# 同步web站点 文档
|
||||
ListenerManagement.sync_web_document_signal.connect(self.sync_web_document)
|
||||
# 更新问题向量
|
||||
ListenerManagement.update_problem_signal.connect(self.update_problem)
|
||||
ListenerManagement.delete_embedding_by_source_ids_signal.connect(self.delete_embedding_by_source_ids)
|
||||
ListenerManagement.delete_embedding_by_dataset_id_list_signal.connect(self.delete_embedding_by_dataset_id_list)
|
||||
@staticmethod
|
||||
def hit_test(query_text, dataset_id: list[str], exclude_document_id_list: list[str], top_number: int,
|
||||
similarity: float,
|
||||
search_mode: SearchMode,
|
||||
embedding: Embeddings):
|
||||
return VectorStore.get_embedding_vector().hit_test(query_text, dataset_id, exclude_document_id_list, top_number,
|
||||
similarity, search_mode, embedding)
|
||||
|
|
|
|||
|
|
@ -13,9 +13,11 @@ from django.db.models import QuerySet
|
|||
from django_apscheduler.jobstores import DjangoJobStore
|
||||
|
||||
from application.models.api_key_model import ApplicationPublicAccessClient
|
||||
from common.lock.impl.file_lock import FileLock
|
||||
|
||||
scheduler = BackgroundScheduler()
|
||||
scheduler.add_jobstore(DjangoJobStore(), "default")
|
||||
lock = FileLock()
|
||||
|
||||
|
||||
def client_access_num_reset_job():
|
||||
|
|
@ -25,9 +27,13 @@ def client_access_num_reset_job():
|
|||
|
||||
|
||||
def run():
|
||||
scheduler.start()
|
||||
access_num_reset = scheduler.get_job(job_id='access_num_reset')
|
||||
if access_num_reset is not None:
|
||||
access_num_reset.remove()
|
||||
scheduler.add_job(client_access_num_reset_job, 'cron', hour='0', minute='0', second='0',
|
||||
id='access_num_reset')
|
||||
if lock.try_lock('client_access_num_reset_job', 30 * 30):
|
||||
try:
|
||||
scheduler.start()
|
||||
access_num_reset = scheduler.get_job(job_id='access_num_reset')
|
||||
if access_num_reset is not None:
|
||||
access_num_reset.remove()
|
||||
scheduler.add_job(client_access_num_reset_job, 'cron', hour='0', minute='0', second='0',
|
||||
id='access_num_reset')
|
||||
finally:
|
||||
lock.un_lock('client_access_num_reset_job')
|
||||
|
|
|
|||
|
|
@ -0,0 +1,20 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: base_lock.py
|
||||
@date:2024/8/20 10:33
|
||||
@desc:
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class BaseLock(ABC):
|
||||
@abstractmethod
|
||||
def try_lock(self, key, timeout):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def un_lock(self, key):
|
||||
pass
|
||||
|
|
@ -0,0 +1,77 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: file_lock.py
|
||||
@date:2024/8/20 10:48
|
||||
@desc:
|
||||
"""
|
||||
import errno
|
||||
import hashlib
|
||||
import os
|
||||
import time
|
||||
|
||||
import six
|
||||
|
||||
from common.lock.base_lock import BaseLock
|
||||
from smartdoc.const import PROJECT_DIR
|
||||
|
||||
|
||||
def key_to_lock_name(key):
|
||||
"""
|
||||
Combine part of a key with its hash to prevent very long filenames
|
||||
"""
|
||||
MAX_LENGTH = 50
|
||||
key_hash = hashlib.md5(six.b(key)).hexdigest()
|
||||
lock_name = key[:MAX_LENGTH - len(key_hash) - 1] + '_' + key_hash
|
||||
return lock_name
|
||||
|
||||
|
||||
class FileLock(BaseLock):
|
||||
"""
|
||||
File locking backend.
|
||||
"""
|
||||
|
||||
def __init__(self, settings=None):
|
||||
if settings is None:
|
||||
settings = {}
|
||||
self.location = settings.get('location')
|
||||
if self.location is None:
|
||||
self.location = os.path.join(PROJECT_DIR, 'data', 'lock')
|
||||
try:
|
||||
os.makedirs(self.location)
|
||||
except OSError as error:
|
||||
# Directory exists?
|
||||
if error.errno != errno.EEXIST:
|
||||
# Re-raise unexpected OSError
|
||||
raise
|
||||
|
||||
def _get_lock_path(self, key):
|
||||
lock_name = key_to_lock_name(key)
|
||||
return os.path.join(self.location, lock_name)
|
||||
|
||||
def try_lock(self, key, timeout):
|
||||
lock_path = self._get_lock_path(key)
|
||||
try:
|
||||
# 创建锁文件,如果没创建成功则拿不到
|
||||
fd = os.open(lock_path, os.O_CREAT | os.O_EXCL)
|
||||
except OSError as error:
|
||||
if error.errno == errno.EEXIST:
|
||||
# File already exists, check its modification time
|
||||
mtime = os.path.getmtime(lock_path)
|
||||
ttl = mtime + timeout - time.time()
|
||||
if ttl > 0:
|
||||
return False
|
||||
else:
|
||||
# 如果超时时间已到,直接上锁成功继续执行
|
||||
os.utime(lock_path, None)
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
else:
|
||||
os.close(fd)
|
||||
return True
|
||||
|
||||
def un_lock(self, key):
|
||||
lock_path = self._get_lock_path(key)
|
||||
os.remove(lock_path)
|
||||
|
|
@ -0,0 +1,44 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: celery.py
|
||||
@date:2024/8/19 11:57
|
||||
@desc:
|
||||
"""
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
from django.core.management.base import BaseCommand
|
||||
|
||||
from smartdoc.const import BASE_DIR
|
||||
|
||||
|
||||
class Command(BaseCommand):
|
||||
help = 'celery'
|
||||
|
||||
def add_arguments(self, parser):
|
||||
parser.add_argument(
|
||||
'service', nargs='+', type=str, choices=("celery", "model"), help='Service',
|
||||
)
|
||||
|
||||
def handle(self, *args, **options):
|
||||
service = options.get('service')
|
||||
os.environ.setdefault('CELERY_NAME', ','.join(service))
|
||||
server_hostname = os.environ.get("SERVER_HOSTNAME")
|
||||
if not server_hostname:
|
||||
server_hostname = '%h'
|
||||
cmd = [
|
||||
'celery',
|
||||
'-A', 'ops',
|
||||
'worker',
|
||||
'-P', 'threads',
|
||||
'-l', 'info',
|
||||
'-c', '10',
|
||||
'-Q', ','.join(service),
|
||||
'--heartbeat-interval', '10',
|
||||
'-n', f'{",".join(service)}@{server_hostname}',
|
||||
'--without-mingle',
|
||||
]
|
||||
kwargs = {'cwd': BASE_DIR}
|
||||
subprocess.run(cmd, **kwargs)
|
||||
|
|
@ -1,47 +0,0 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: gunicorn.py
|
||||
@date:2024/7/19 17:43
|
||||
@desc:
|
||||
"""
|
||||
import subprocess
|
||||
|
||||
from django.core.management.base import BaseCommand
|
||||
|
||||
from smartdoc.const import BASE_DIR
|
||||
|
||||
|
||||
class Command(BaseCommand):
|
||||
help = 'My custom command'
|
||||
|
||||
# 参数设定
|
||||
def add_arguments(self, parser):
|
||||
parser.add_argument('-b', nargs='+', type=str, help="端口:0.0.0.0:8080") # 0.0.0.0:8080
|
||||
parser.add_argument('-k', nargs='?', type=str,
|
||||
help="workers处理器:gevent") # uvicorn.workers.UvicornWorker
|
||||
parser.add_argument('-w', type=str, help='worker 数量') # 进程数量
|
||||
parser.add_argument('--threads', type=str, help='线程数量') # 线程数量
|
||||
parser.add_argument('--worker-connections', type=str, help="每个线程的协程数量") # 10240
|
||||
parser.add_argument('--max-requests', type=str, help="最大请求") # 10240
|
||||
parser.add_argument('--max-requests-jitter', type=str)
|
||||
parser.add_argument('--access-logformat', type=str) # %(h)s %(t)s %(L)ss "%(r)s" %(s)s %(b)s
|
||||
|
||||
def handle(self, *args, **options):
|
||||
log_format = '%(h)s %(t)s %(L)ss "%(r)s" %(s)s %(b)s '
|
||||
cmd = [
|
||||
'gunicorn', 'smartdoc.wsgi:application',
|
||||
'-b', options.get('b') if options.get('b') is not None else '0.0.0.0:8080',
|
||||
'-k', options.get('k') if options.get('k') is not None else 'gthread',
|
||||
'--threads', options.get('threads') if options.get('threads') is not None else '200',
|
||||
'-w', options.get('w') if options.get('w') is not None else '1',
|
||||
'--max-requests', options.get('max_requests') if options.get('max_requests') is not None else '10240',
|
||||
'--max-requests-jitter',
|
||||
options.get('max_requests_jitter') if options.get('max_requests_jitter') is not None else '2048',
|
||||
'--access-logformat',
|
||||
options.get('access_logformat') if options.get('access_logformat') is not None else log_format,
|
||||
'--access-logfile', '-'
|
||||
]
|
||||
kwargs = {'cwd': BASE_DIR}
|
||||
subprocess.run(cmd, **kwargs)
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
from .services.command import BaseActionCommand, Action
|
||||
|
||||
|
||||
class Command(BaseActionCommand):
|
||||
help = 'Restart services'
|
||||
action = Action.restart.value
|
||||
|
|
@ -0,0 +1,130 @@
|
|||
from django.core.management.base import BaseCommand
|
||||
from django.db.models import TextChoices
|
||||
|
||||
from .hands import *
|
||||
from .utils import ServicesUtil
|
||||
|
||||
|
||||
class Services(TextChoices):
|
||||
gunicorn = 'gunicorn', 'gunicorn'
|
||||
celery_default = 'celery_default', 'celery_default'
|
||||
local_model = 'local_model', 'local_model'
|
||||
web = 'web', 'web'
|
||||
celery = 'celery', 'celery'
|
||||
celery_model = 'celery_model', 'celery_model'
|
||||
task = 'task', 'task'
|
||||
all = 'all', 'all'
|
||||
|
||||
@classmethod
|
||||
def get_service_object_class(cls, name):
|
||||
from . import services
|
||||
services_map = {
|
||||
cls.gunicorn.value: services.GunicornService,
|
||||
cls.celery_default: services.CeleryDefaultService,
|
||||
cls.local_model: services.GunicornLocalModelService
|
||||
}
|
||||
return services_map.get(name)
|
||||
|
||||
@classmethod
|
||||
def web_services(cls):
|
||||
return [cls.gunicorn, cls.local_model]
|
||||
|
||||
@classmethod
|
||||
def celery_services(cls):
|
||||
return [cls.celery_default, cls.celery_model]
|
||||
|
||||
@classmethod
|
||||
def task_services(cls):
|
||||
return cls.celery_services()
|
||||
|
||||
@classmethod
|
||||
def all_services(cls):
|
||||
return cls.web_services() + cls.task_services()
|
||||
|
||||
@classmethod
|
||||
def export_services_values(cls):
|
||||
return [cls.all.value, cls.web.value, cls.task.value] + [s.value for s in cls.all_services()]
|
||||
|
||||
@classmethod
|
||||
def get_service_objects(cls, service_names, **kwargs):
|
||||
services = set()
|
||||
for name in service_names:
|
||||
method_name = f'{name}_services'
|
||||
if hasattr(cls, method_name):
|
||||
_services = getattr(cls, method_name)()
|
||||
elif hasattr(cls, name):
|
||||
_services = [getattr(cls, name)]
|
||||
else:
|
||||
continue
|
||||
services.update(set(_services))
|
||||
|
||||
service_objects = []
|
||||
for s in services:
|
||||
service_class = cls.get_service_object_class(s.value)
|
||||
if not service_class:
|
||||
continue
|
||||
kwargs.update({
|
||||
'name': s.value
|
||||
})
|
||||
service_object = service_class(**kwargs)
|
||||
service_objects.append(service_object)
|
||||
return service_objects
|
||||
|
||||
|
||||
class Action(TextChoices):
|
||||
start = 'start', 'start'
|
||||
status = 'status', 'status'
|
||||
stop = 'stop', 'stop'
|
||||
restart = 'restart', 'restart'
|
||||
|
||||
|
||||
class BaseActionCommand(BaseCommand):
|
||||
help = 'Service Base Command'
|
||||
|
||||
action = None
|
||||
util = None
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def add_arguments(self, parser):
|
||||
parser.add_argument(
|
||||
'services', nargs='+', choices=Services.export_services_values(), help='Service',
|
||||
)
|
||||
parser.add_argument('-d', '--daemon', nargs="?", const=True)
|
||||
parser.add_argument('-w', '--worker', type=int, nargs="?", default=4)
|
||||
parser.add_argument('-f', '--force', nargs="?", const=True)
|
||||
|
||||
def initial_util(self, *args, **options):
|
||||
service_names = options.get('services')
|
||||
service_kwargs = {
|
||||
'worker_gunicorn': options.get('worker')
|
||||
}
|
||||
services = Services.get_service_objects(service_names=service_names, **service_kwargs)
|
||||
|
||||
kwargs = {
|
||||
'services': services,
|
||||
'run_daemon': options.get('daemon', False),
|
||||
'stop_daemon': self.action == Action.stop.value and Services.all.value in service_names,
|
||||
'force_stop': options.get('force') or False,
|
||||
}
|
||||
self.util = ServicesUtil(**kwargs)
|
||||
|
||||
def handle(self, *args, **options):
|
||||
self.initial_util(*args, **options)
|
||||
assert self.action in Action.values, f'The action {self.action} is not in the optional list'
|
||||
_handle = getattr(self, f'_handle_{self.action}', lambda: None)
|
||||
_handle()
|
||||
|
||||
def _handle_start(self):
|
||||
self.util.start_and_watch()
|
||||
os._exit(0)
|
||||
|
||||
def _handle_stop(self):
|
||||
self.util.stop()
|
||||
|
||||
def _handle_restart(self):
|
||||
self.util.restart()
|
||||
|
||||
def _handle_status(self):
|
||||
self.util.show_status()
|
||||
|
|
@ -0,0 +1,28 @@
|
|||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
from django.conf import settings
|
||||
|
||||
from smartdoc.const import CONFIG, PROJECT_DIR
|
||||
|
||||
try:
|
||||
from apps.smartdoc import const
|
||||
|
||||
__version__ = const.VERSION
|
||||
except ImportError as e:
|
||||
print("Not found __version__: {}".format(e))
|
||||
print("Python is: ")
|
||||
logging.info(sys.executable)
|
||||
__version__ = 'Unknown'
|
||||
sys.exit(1)
|
||||
|
||||
HTTP_HOST = CONFIG.HTTP_BIND_HOST or '127.0.0.1'
|
||||
HTTP_PORT = CONFIG.HTTP_LISTEN_PORT or 8080
|
||||
DEBUG = CONFIG.DEBUG or False
|
||||
|
||||
LOG_DIR = os.path.join(PROJECT_DIR, 'data', 'logs')
|
||||
APPS_DIR = os.path.join(PROJECT_DIR, 'apps')
|
||||
TMP_DIR = os.path.join(PROJECT_DIR, 'tmp')
|
||||
if not os.path.exists(TMP_DIR):
|
||||
os.makedirs(TMP_DIR)
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
from .celery_default import *
|
||||
from .gunicorn import *
|
||||
from .local_model import *
|
||||
|
|
@ -0,0 +1,207 @@
|
|||
import abc
|
||||
import time
|
||||
import shutil
|
||||
import psutil
|
||||
import datetime
|
||||
import threading
|
||||
import subprocess
|
||||
from ..hands import *
|
||||
|
||||
|
||||
class BaseService(object):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.name = kwargs['name']
|
||||
self._process = None
|
||||
self.STOP_TIMEOUT = 10
|
||||
self.max_retry = 0
|
||||
self.retry = 3
|
||||
self.LOG_KEEP_DAYS = 7
|
||||
self.EXIT_EVENT = threading.Event()
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def cmd(self):
|
||||
return []
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def cwd(self):
|
||||
return ''
|
||||
|
||||
@property
|
||||
def is_running(self):
|
||||
if self.pid == 0:
|
||||
return False
|
||||
try:
|
||||
os.kill(self.pid, 0)
|
||||
except (OSError, ProcessLookupError):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def show_status(self):
|
||||
if self.is_running:
|
||||
msg = f'{self.name} is running: {self.pid}.'
|
||||
else:
|
||||
msg = f'{self.name} is stopped.'
|
||||
if DEBUG:
|
||||
msg = '\033[31m{} is stopped.\033[0m\nYou can manual start it to find the error: \n' \
|
||||
' $ cd {}\n' \
|
||||
' $ {}'.format(self.name, self.cwd, ' '.join(self.cmd))
|
||||
|
||||
print(msg)
|
||||
|
||||
# -- log --
|
||||
@property
|
||||
def log_filename(self):
|
||||
return f'{self.name}.log'
|
||||
|
||||
@property
|
||||
def log_filepath(self):
|
||||
return os.path.join(LOG_DIR, self.log_filename)
|
||||
|
||||
@property
|
||||
def log_file(self):
|
||||
return open(self.log_filepath, 'a')
|
||||
|
||||
@property
|
||||
def log_dir(self):
|
||||
return os.path.dirname(self.log_filepath)
|
||||
# -- end log --
|
||||
|
||||
# -- pid --
|
||||
@property
|
||||
def pid_filepath(self):
|
||||
return os.path.join(TMP_DIR, f'{self.name}.pid')
|
||||
|
||||
@property
|
||||
def pid(self):
|
||||
if not os.path.isfile(self.pid_filepath):
|
||||
return 0
|
||||
with open(self.pid_filepath) as f:
|
||||
try:
|
||||
pid = int(f.read().strip())
|
||||
except ValueError:
|
||||
pid = 0
|
||||
return pid
|
||||
|
||||
def write_pid(self):
|
||||
with open(self.pid_filepath, 'w') as f:
|
||||
f.write(str(self.process.pid))
|
||||
|
||||
def remove_pid(self):
|
||||
if os.path.isfile(self.pid_filepath):
|
||||
os.unlink(self.pid_filepath)
|
||||
# -- end pid --
|
||||
|
||||
# -- process --
|
||||
@property
|
||||
def process(self):
|
||||
if not self._process:
|
||||
try:
|
||||
self._process = psutil.Process(self.pid)
|
||||
except:
|
||||
pass
|
||||
return self._process
|
||||
|
||||
# -- end process --
|
||||
|
||||
# -- action --
|
||||
def open_subprocess(self):
|
||||
kwargs = {'cwd': self.cwd, 'stderr': self.log_file, 'stdout': self.log_file}
|
||||
self._process = subprocess.Popen(self.cmd, **kwargs)
|
||||
|
||||
def start(self):
|
||||
if self.is_running:
|
||||
self.show_status()
|
||||
return
|
||||
self.remove_pid()
|
||||
self.open_subprocess()
|
||||
self.write_pid()
|
||||
self.start_other()
|
||||
|
||||
def start_other(self):
|
||||
pass
|
||||
|
||||
def stop(self, force=False):
|
||||
if not self.is_running:
|
||||
self.show_status()
|
||||
# self.remove_pid()
|
||||
return
|
||||
|
||||
print(f'Stop service: {self.name}', end='')
|
||||
sig = 9 if force else 15
|
||||
os.kill(self.pid, sig)
|
||||
|
||||
if self.process is None:
|
||||
print("\033[31m No process found\033[0m")
|
||||
return
|
||||
try:
|
||||
self.process.wait(1)
|
||||
except:
|
||||
pass
|
||||
|
||||
for i in range(self.STOP_TIMEOUT):
|
||||
if i == self.STOP_TIMEOUT - 1:
|
||||
print("\033[31m Error\033[0m")
|
||||
if not self.is_running:
|
||||
print("\033[32m Ok\033[0m")
|
||||
self.remove_pid()
|
||||
break
|
||||
else:
|
||||
continue
|
||||
|
||||
def watch(self):
|
||||
self._check()
|
||||
if not self.is_running:
|
||||
self._restart()
|
||||
self._rotate_log()
|
||||
|
||||
def _check(self):
|
||||
now = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||||
print(f"{now} Check service status: {self.name} -> ", end='')
|
||||
if self.process:
|
||||
try:
|
||||
self.process.wait(1) # 不wait,子进程可能无法回收
|
||||
except:
|
||||
pass
|
||||
|
||||
if self.is_running:
|
||||
print(f'running at {self.pid}')
|
||||
else:
|
||||
print(f'stopped at {self.pid}')
|
||||
|
||||
def _restart(self):
|
||||
if self.retry > self.max_retry:
|
||||
logging.info("Service start failed, exit: {}".format(self.name))
|
||||
self.EXIT_EVENT.set()
|
||||
return
|
||||
self.retry += 1
|
||||
logging.info(f'> Find {self.name} stopped, retry {self.retry}, {self.pid}')
|
||||
self.start()
|
||||
|
||||
def _rotate_log(self):
|
||||
now = datetime.datetime.now()
|
||||
_time = now.strftime('%H:%M')
|
||||
if _time != '23:59':
|
||||
return
|
||||
|
||||
backup_date = now.strftime('%Y-%m-%d')
|
||||
backup_log_dir = os.path.join(self.log_dir, backup_date)
|
||||
if not os.path.exists(backup_log_dir):
|
||||
os.mkdir(backup_log_dir)
|
||||
|
||||
backup_log_path = os.path.join(backup_log_dir, self.log_filename)
|
||||
if os.path.isfile(self.log_filepath) and not os.path.isfile(backup_log_path):
|
||||
logging.info(f'Rotate log file: {self.log_filepath} => {backup_log_path}')
|
||||
shutil.copy(self.log_filepath, backup_log_path)
|
||||
with open(self.log_filepath, 'w') as f:
|
||||
pass
|
||||
|
||||
to_delete_date = now - datetime.timedelta(days=self.LOG_KEEP_DAYS)
|
||||
to_delete_dir = os.path.join(LOG_DIR, to_delete_date.strftime('%Y-%m-%d'))
|
||||
if os.path.exists(to_delete_dir):
|
||||
logging.info(f'Remove old log: {to_delete_dir}')
|
||||
shutil.rmtree(to_delete_dir, ignore_errors=True)
|
||||
# -- end action --
|
||||
|
|
@ -0,0 +1,43 @@
|
|||
from .base import BaseService
|
||||
from ..hands import *
|
||||
|
||||
|
||||
class CeleryBaseService(BaseService):
|
||||
|
||||
def __init__(self, queue, num=10, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.queue = queue
|
||||
self.num = num
|
||||
|
||||
@property
|
||||
def cmd(self):
|
||||
print('\n- Start Celery as Distributed Task Queue: {}'.format(self.queue.capitalize()))
|
||||
|
||||
os.environ.setdefault('LC_ALL', 'C.UTF-8')
|
||||
os.environ.setdefault('PYTHONOPTIMIZE', '1')
|
||||
os.environ.setdefault('ANSIBLE_FORCE_COLOR', 'True')
|
||||
os.environ.setdefault('PYTHONPATH', settings.APPS_DIR)
|
||||
|
||||
if os.getuid() == 0:
|
||||
os.environ.setdefault('C_FORCE_ROOT', '1')
|
||||
server_hostname = os.environ.get("SERVER_HOSTNAME")
|
||||
if not server_hostname:
|
||||
server_hostname = '%h'
|
||||
|
||||
cmd = [
|
||||
'celery',
|
||||
'-A', 'ops',
|
||||
'worker',
|
||||
'-P', 'threads',
|
||||
'-l', 'error',
|
||||
'-c', str(self.num),
|
||||
'-Q', self.queue,
|
||||
'--heartbeat-interval', '10',
|
||||
'-n', f'{self.queue}@{server_hostname}',
|
||||
'--without-mingle',
|
||||
]
|
||||
return cmd
|
||||
|
||||
@property
|
||||
def cwd(self):
|
||||
return APPS_DIR
|
||||
|
|
@ -0,0 +1,10 @@
|
|||
from .celery_base import CeleryBaseService
|
||||
|
||||
__all__ = ['CeleryDefaultService']
|
||||
|
||||
|
||||
class CeleryDefaultService(CeleryBaseService):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs['queue'] = 'celery'
|
||||
super().__init__(**kwargs)
|
||||
|
|
@ -0,0 +1,36 @@
|
|||
from .base import BaseService
|
||||
from ..hands import *
|
||||
|
||||
__all__ = ['GunicornService']
|
||||
|
||||
|
||||
class GunicornService(BaseService):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.worker = kwargs['worker_gunicorn']
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@property
|
||||
def cmd(self):
|
||||
print("\n- Start Gunicorn WSGI HTTP Server")
|
||||
|
||||
log_format = '%(h)s %(t)s %(L)ss "%(r)s" %(s)s %(b)s '
|
||||
bind = f'{HTTP_HOST}:{HTTP_PORT}'
|
||||
cmd = [
|
||||
'gunicorn', 'smartdoc.wsgi:application',
|
||||
'-b', bind,
|
||||
'-k', 'gthread',
|
||||
'--threads', '200',
|
||||
'-w', str(self.worker),
|
||||
'--max-requests', '10240',
|
||||
'--max-requests-jitter', '2048',
|
||||
'--access-logformat', log_format,
|
||||
'--access-logfile', '-'
|
||||
]
|
||||
if DEBUG:
|
||||
cmd.append('--reload')
|
||||
return cmd
|
||||
|
||||
@property
|
||||
def cwd(self):
|
||||
return APPS_DIR
|
||||
|
|
@ -0,0 +1,44 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: local_model.py
|
||||
@date:2024/8/21 13:28
|
||||
@desc:
|
||||
"""
|
||||
from .base import BaseService
|
||||
from ..hands import *
|
||||
|
||||
__all__ = ['GunicornLocalModelService']
|
||||
|
||||
|
||||
class GunicornLocalModelService(BaseService):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.worker = kwargs['worker_gunicorn']
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@property
|
||||
def cmd(self):
|
||||
print("\n- Start Gunicorn Local Model WSGI HTTP Server")
|
||||
os.environ.setdefault('SERVER_NAME', 'local_model')
|
||||
log_format = '%(h)s %(t)s %(L)ss "%(r)s" %(s)s %(b)s '
|
||||
bind = f'127.0.0.1:5432'
|
||||
cmd = [
|
||||
'gunicorn', 'smartdoc.wsgi:application',
|
||||
'-b', bind,
|
||||
'-k', 'gthread',
|
||||
'--threads', '200',
|
||||
'-w', "1",
|
||||
'--max-requests', '10240',
|
||||
'--max-requests-jitter', '2048',
|
||||
'--access-logformat', log_format,
|
||||
'--access-logfile', '-'
|
||||
]
|
||||
if DEBUG:
|
||||
cmd.append('--reload')
|
||||
return cmd
|
||||
|
||||
@property
|
||||
def cwd(self):
|
||||
return APPS_DIR
|
||||
|
|
@ -0,0 +1,140 @@
|
|||
import threading
|
||||
import signal
|
||||
import time
|
||||
import daemon
|
||||
from daemon import pidfile
|
||||
from .hands import *
|
||||
from .hands import __version__
|
||||
from .services.base import BaseService
|
||||
|
||||
|
||||
class ServicesUtil(object):
|
||||
|
||||
def __init__(self, services, run_daemon=False, force_stop=False, stop_daemon=False):
|
||||
self._services = services
|
||||
self.run_daemon = run_daemon
|
||||
self.force_stop = force_stop
|
||||
self.stop_daemon = stop_daemon
|
||||
self.EXIT_EVENT = threading.Event()
|
||||
self.check_interval = 30
|
||||
self.files_preserve_map = {}
|
||||
|
||||
def restart(self):
|
||||
self.stop()
|
||||
time.sleep(5)
|
||||
self.start_and_watch()
|
||||
|
||||
def start_and_watch(self):
|
||||
logging.info(time.ctime())
|
||||
logging.info(f'MaxKB version {__version__}, more see https://www.jumpserver.org')
|
||||
self.start()
|
||||
if self.run_daemon:
|
||||
self.show_status()
|
||||
with self.daemon_context:
|
||||
self.watch()
|
||||
else:
|
||||
self.watch()
|
||||
|
||||
def start(self):
|
||||
for service in self._services:
|
||||
service: BaseService
|
||||
service.start()
|
||||
self.files_preserve_map[service.name] = service.log_file
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
def stop(self):
|
||||
for service in self._services:
|
||||
service: BaseService
|
||||
service.stop(force=self.force_stop)
|
||||
|
||||
if self.stop_daemon:
|
||||
self._stop_daemon()
|
||||
|
||||
# -- watch --
|
||||
def watch(self):
|
||||
while not self.EXIT_EVENT.is_set():
|
||||
try:
|
||||
_exit = self._watch()
|
||||
if _exit:
|
||||
break
|
||||
time.sleep(self.check_interval)
|
||||
except KeyboardInterrupt:
|
||||
print('Start stop services')
|
||||
break
|
||||
self.clean_up()
|
||||
|
||||
def _watch(self):
|
||||
for service in self._services:
|
||||
service: BaseService
|
||||
service.watch()
|
||||
if service.EXIT_EVENT.is_set():
|
||||
self.EXIT_EVENT.set()
|
||||
return True
|
||||
return False
|
||||
# -- end watch --
|
||||
|
||||
def clean_up(self):
|
||||
if not self.EXIT_EVENT.is_set():
|
||||
self.EXIT_EVENT.set()
|
||||
self.stop()
|
||||
|
||||
def show_status(self):
|
||||
for service in self._services:
|
||||
service: BaseService
|
||||
service.show_status()
|
||||
|
||||
# -- daemon --
|
||||
def _stop_daemon(self):
|
||||
if self.daemon_pid and self.daemon_is_running:
|
||||
os.kill(self.daemon_pid, 15)
|
||||
self.remove_daemon_pid()
|
||||
|
||||
def remove_daemon_pid(self):
|
||||
if os.path.isfile(self.daemon_pid_filepath):
|
||||
os.unlink(self.daemon_pid_filepath)
|
||||
|
||||
@property
|
||||
def daemon_pid(self):
|
||||
if not os.path.isfile(self.daemon_pid_filepath):
|
||||
return 0
|
||||
with open(self.daemon_pid_filepath) as f:
|
||||
try:
|
||||
pid = int(f.read().strip())
|
||||
except ValueError:
|
||||
pid = 0
|
||||
return pid
|
||||
|
||||
@property
|
||||
def daemon_is_running(self):
|
||||
try:
|
||||
os.kill(self.daemon_pid, 0)
|
||||
except (OSError, ProcessLookupError):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
@property
|
||||
def daemon_pid_filepath(self):
|
||||
return os.path.join(TMP_DIR, 'mk.pid')
|
||||
|
||||
@property
|
||||
def daemon_log_filepath(self):
|
||||
return os.path.join(LOG_DIR, 'mk.log')
|
||||
|
||||
@property
|
||||
def daemon_context(self):
|
||||
daemon_log_file = open(self.daemon_log_filepath, 'a')
|
||||
context = daemon.DaemonContext(
|
||||
pidfile=pidfile.TimeoutPIDLockFile(self.daemon_pid_filepath),
|
||||
signal_map={
|
||||
signal.SIGTERM: lambda x, y: self.clean_up(),
|
||||
signal.SIGHUP: 'terminate',
|
||||
},
|
||||
stdout=daemon_log_file,
|
||||
stderr=daemon_log_file,
|
||||
files_preserve=list(self.files_preserve_map.values()),
|
||||
detach_process=True,
|
||||
)
|
||||
return context
|
||||
# -- end daemon --
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
from .services.command import BaseActionCommand, Action
|
||||
|
||||
|
||||
class Command(BaseActionCommand):
|
||||
help = 'Start services'
|
||||
action = Action.start.value
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
from .services.command import BaseActionCommand, Action
|
||||
|
||||
|
||||
class Command(BaseActionCommand):
|
||||
help = 'Show services status'
|
||||
action = Action.status.value
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
from .services.command import BaseActionCommand, Action
|
||||
|
||||
|
||||
class Command(BaseActionCommand):
|
||||
help = 'Stop services'
|
||||
action = Action.stop.value
|
||||
|
|
@ -18,6 +18,7 @@ TOKEN_KEY = 'solomon_world_token'
|
|||
TOKEN_SALT = 'solomonwanc@gmail.com'
|
||||
TIME_OUT = 30 * 60
|
||||
|
||||
|
||||
# 加密
|
||||
def encrypt(obj):
|
||||
value = signing.dumps(obj, key=TOKEN_KEY, salt=TOKEN_SALT)
|
||||
|
|
@ -29,7 +30,6 @@ def encrypt(obj):
|
|||
def decrypt(src):
|
||||
src = signing.b64_decode(src.encode()).decode()
|
||||
raw = signing.loads(src, key=TOKEN_KEY, salt=TOKEN_SALT)
|
||||
print(type(raw))
|
||||
return raw
|
||||
|
||||
|
||||
|
|
@ -74,5 +74,3 @@ def check_token(token):
|
|||
if last_token:
|
||||
return last_token == token
|
||||
return False
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -151,3 +151,17 @@ def get_embedding_model_by_dataset_id(dataset_id: str):
|
|||
|
||||
def get_embedding_model_by_dataset(dataset):
|
||||
return ModelManage.get_model(str(dataset.embedding_mode_id), lambda _id: get_model(dataset.embedding_mode))
|
||||
|
||||
|
||||
def get_embedding_model_id_by_dataset_id(dataset_id):
|
||||
dataset = QuerySet(DataSet).select_related('embedding_mode').filter(id=dataset_id).first()
|
||||
return str(dataset.embedding_mode_id)
|
||||
|
||||
|
||||
def get_embedding_model_id_by_dataset_id_list(dataset_id_list: List):
|
||||
dataset_list = QuerySet(DataSet).filter(id__in=dataset_id_list)
|
||||
if len(set([dataset.embedding_mode_id for dataset in dataset_list])) > 1:
|
||||
raise Exception("知识库未向量模型不一致")
|
||||
if len(dataset_list) == 0:
|
||||
raise Exception("知识库设置错误,请重新设置知识库")
|
||||
return str(dataset_list[0].embedding_mode_id)
|
||||
|
|
|
|||
|
|
@ -27,7 +27,6 @@ from application.models import ApplicationDatasetMapping
|
|||
from common.config.embedding_config import VectorStore
|
||||
from common.db.search import get_dynamics_model, native_page_search, native_search
|
||||
from common.db.sql_execute import select_list
|
||||
from common.event import ListenerManagement, SyncWebDatasetArgs
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.mixins.api_mixin import ApiMixin
|
||||
from common.util.common import post, flat_map, valid_license
|
||||
|
|
@ -37,9 +36,11 @@ from common.util.fork import ChildLink, Fork
|
|||
from common.util.split_model import get_split_model
|
||||
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, ProblemParagraphMapping, Status
|
||||
from dataset.serializers.common_serializers import list_paragraph, MetaSerializer, ProblemParagraphManage, \
|
||||
get_embedding_model_by_dataset_id
|
||||
get_embedding_model_by_dataset_id, get_embedding_model_id_by_dataset_id
|
||||
from dataset.serializers.document_serializers import DocumentSerializers, DocumentInstanceSerializer
|
||||
from dataset.task import sync_web_dataset
|
||||
from embedding.models import SearchMode
|
||||
from embedding.task import embedding_by_dataset, delete_embedding_by_dataset
|
||||
from setting.models import AuthOperate
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
|
|
@ -363,9 +364,9 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||
|
||||
@staticmethod
|
||||
def post_embedding_dataset(document_list, dataset_id):
|
||||
model = get_embedding_model_by_dataset_id(dataset_id)
|
||||
model_id = get_embedding_model_id_by_dataset_id(dataset_id)
|
||||
# 发送向量化事件
|
||||
ListenerManagement.embedding_by_dataset_signal.send(dataset_id, embedding_model=model)
|
||||
embedding_by_dataset.delay(dataset_id, model_id)
|
||||
return document_list
|
||||
|
||||
def save_qa(self, instance: Dict, with_valid=True):
|
||||
|
|
@ -435,23 +436,6 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||
else:
|
||||
return parsed_url.path.split("/")[-1]
|
||||
|
||||
@staticmethod
|
||||
def get_save_handler(dataset_id, selector):
|
||||
def handler(child_link: ChildLink, response: Fork.Response):
|
||||
if response.status == 200:
|
||||
try:
|
||||
document_name = child_link.tag.text if child_link.tag is not None and len(
|
||||
child_link.tag.text.strip()) > 0 else child_link.url
|
||||
paragraphs = get_split_model('web.md').parse(response.content)
|
||||
DocumentSerializers.Create(data={'dataset_id': dataset_id}).save(
|
||||
{'name': document_name, 'paragraphs': paragraphs,
|
||||
'meta': {'source_url': child_link.url, 'selector': selector},
|
||||
'type': Type.web}, with_valid=True)
|
||||
except Exception as e:
|
||||
logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}')
|
||||
|
||||
return handler
|
||||
|
||||
def save_web(self, instance: Dict, with_valid=True):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
|
|
@ -467,9 +451,7 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||
'meta': {'source_url': instance.get('source_url'), 'selector': instance.get('selector'),
|
||||
'embedding_mode_id': instance.get('embedding_mode_id')}})
|
||||
dataset.save()
|
||||
ListenerManagement.sync_web_dataset_signal.send(
|
||||
SyncWebDatasetArgs(str(dataset_id), instance.get('source_url'), instance.get('selector'),
|
||||
self.get_save_handler(dataset_id, instance.get('selector'))))
|
||||
sync_web_dataset.delay((str(dataset_id), instance.get('source_url'), instance.get('selector')))
|
||||
return {**DataSetSerializers(dataset).data,
|
||||
'document_list': []}
|
||||
|
||||
|
|
@ -642,9 +624,7 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||
"""
|
||||
url = dataset.meta.get('source_url')
|
||||
selector = dataset.meta.get('selector') if 'selector' in dataset.meta else None
|
||||
ListenerManagement.sync_web_dataset_signal.send(
|
||||
SyncWebDatasetArgs(str(dataset.id), url, selector,
|
||||
self.get_sync_handler(dataset)))
|
||||
sync_web_dataset.delay(str(dataset.id), url, selector)
|
||||
|
||||
def complete_sync(self, dataset):
|
||||
"""
|
||||
|
|
@ -658,7 +638,7 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||
# 删除段落
|
||||
QuerySet(Paragraph).filter(dataset=dataset).delete()
|
||||
# 删除向量
|
||||
ListenerManagement.delete_embedding_by_dataset_signal.send(self.data.get('id'))
|
||||
delete_embedding_by_dataset(self.data.get('id'))
|
||||
# 同步
|
||||
self.replace_sync(dataset)
|
||||
|
||||
|
|
@ -740,16 +720,17 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||
QuerySet(Paragraph).filter(dataset=dataset).delete()
|
||||
QuerySet(Problem).filter(dataset=dataset).delete()
|
||||
dataset.delete()
|
||||
ListenerManagement.delete_embedding_by_dataset_signal.send(self.data.get('id'))
|
||||
delete_embedding_by_dataset(self.data.get('id'))
|
||||
return True
|
||||
|
||||
def re_embedding(self, with_valid=True):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
model = get_embedding_model_by_dataset_id(self.data.get('id'))
|
||||
|
||||
QuerySet(Document).filter(dataset_id=self.data.get('id')).update(**{'status': Status.queue_up})
|
||||
QuerySet(Paragraph).filter(dataset_id=self.data.get('id')).update(**{'status': Status.queue_up})
|
||||
ListenerManagement.embedding_by_dataset_signal.send(self.data.get('id'), embedding_model=model)
|
||||
embedding_model_id = get_embedding_model_id_by_dataset_id(self.data.get('id'))
|
||||
embedding_by_dataset.delay(self.data.get('id'), embedding_model_id)
|
||||
|
||||
def list_application(self, with_valid=True):
|
||||
if with_valid:
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from functools import reduce
|
|||
from typing import List, Dict
|
||||
|
||||
import xlwt
|
||||
from celery_once import AlreadyQueued
|
||||
from django.core import validators
|
||||
from django.db import transaction
|
||||
from django.db.models import QuerySet
|
||||
|
|
@ -25,7 +26,6 @@ from xlwt import Utils
|
|||
|
||||
from common.db.search import native_search, native_page_search
|
||||
from common.event.common import work_thread_pool
|
||||
from common.event.listener_manage import ListenerManagement, SyncWebDocumentArgs, UpdateEmbeddingDatasetIdArgs
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.handle.impl.doc_split_handle import DocSplitHandle
|
||||
from common.handle.impl.html_split_handle import HTMLSplitHandle
|
||||
|
|
@ -42,8 +42,11 @@ from common.util.fork import Fork
|
|||
from common.util.split_model import get_split_model
|
||||
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, Status, ProblemParagraphMapping, Image
|
||||
from dataset.serializers.common_serializers import BatchSerializer, MetaSerializer, ProblemParagraphManage, \
|
||||
get_embedding_model_by_dataset_id
|
||||
get_embedding_model_by_dataset_id, get_embedding_model_id_by_dataset_id
|
||||
from dataset.serializers.paragraph_serializers import ParagraphSerializers, ParagraphInstanceSerializer
|
||||
from dataset.task import sync_web_document
|
||||
from embedding.task.embedding import embedding_by_document, delete_embedding_by_document_list, \
|
||||
delete_embedding_by_document, update_embedding_dataset_id
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
parse_qa_handle_list = [XlsParseQAHandle(), CsvParseQAHandle(), XlsxParseQAHandle()]
|
||||
|
|
@ -235,17 +238,15 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||
meta={})
|
||||
else:
|
||||
document_list.update(dataset_id=target_dataset_id)
|
||||
model = None
|
||||
model_id = None
|
||||
if dataset.embedding_mode_id != target_dataset.embedding_mode_id:
|
||||
model = get_embedding_model_by_dataset_id(target_dataset_id)
|
||||
model_id = get_embedding_model_by_dataset_id(target_dataset_id)
|
||||
|
||||
pid_list = [paragraph.id for paragraph in paragraph_list]
|
||||
# 修改段落信息
|
||||
paragraph_list.update(dataset_id=target_dataset_id)
|
||||
# 修改向量信息
|
||||
ListenerManagement.update_embedding_dataset_id(UpdateEmbeddingDatasetIdArgs(
|
||||
pid_list,
|
||||
target_dataset_id, model))
|
||||
update_embedding_dataset_id(pid_list, target_dataset_id, model_id)
|
||||
|
||||
@staticmethod
|
||||
def get_target_dataset_problem(target_dataset_id: str,
|
||||
|
|
@ -377,7 +378,7 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||
# 删除问题
|
||||
QuerySet(model=ProblemParagraphMapping).filter(document_id=document_id).delete()
|
||||
# 删除向量库
|
||||
ListenerManagement.delete_embedding_by_document_signal.send(document_id)
|
||||
delete_embedding_by_document(document_id)
|
||||
paragraphs = get_split_model('web.md').parse(result.content)
|
||||
document.char_length = reduce(lambda x, y: x + y,
|
||||
[len(p.get('content')) for p in paragraphs],
|
||||
|
|
@ -398,8 +399,8 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||
problem_paragraph_mapping_list) > 0 else None
|
||||
# 向量化
|
||||
if with_embedding:
|
||||
model = get_embedding_model_by_dataset_id(dataset_id=document.dataset_id)
|
||||
ListenerManagement.embedding_by_document_signal.send(document_id, embedding_model=model)
|
||||
embedding_model_id = get_embedding_model_id_by_dataset_id(document.dataset_id)
|
||||
embedding_by_document.delay(document_id, embedding_model_id)
|
||||
else:
|
||||
document.status = Status.error
|
||||
document.save()
|
||||
|
|
@ -538,10 +539,13 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
document_id = self.data.get("document_id")
|
||||
model = get_embedding_model_by_dataset_id(dataset_id=self.data.get('dataset_id'))
|
||||
QuerySet(Document).filter(id=document_id).update(**{'status': Status.queue_up})
|
||||
QuerySet(Paragraph).filter(document_id=document_id).update(**{'status': Status.queue_up})
|
||||
ListenerManagement.embedding_by_document_signal.send(document_id, embedding_model=model)
|
||||
embedding_model_id = get_embedding_model_id_by_dataset_id(dataset_id=self.data.get('dataset_id'))
|
||||
try:
|
||||
embedding_by_document.delay(document_id, embedding_model_id)
|
||||
except AlreadyQueued as e:
|
||||
raise AppApiException(500, "任务正在执行中,请勿重复下发")
|
||||
|
||||
@transaction.atomic
|
||||
def delete(self):
|
||||
|
|
@ -552,7 +556,7 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||
# 删除问题
|
||||
QuerySet(model=ProblemParagraphMapping).filter(document_id=document_id).delete()
|
||||
# 删除向量库
|
||||
ListenerManagement.delete_embedding_by_document_signal.send(document_id)
|
||||
delete_embedding_by_document(document_id)
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -611,8 +615,8 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||
|
||||
@staticmethod
|
||||
def post_embedding(result, document_id, dataset_id):
|
||||
model = get_embedding_model_by_dataset_id(dataset_id)
|
||||
ListenerManagement.embedding_by_document_signal.send(document_id, embedding_model=model)
|
||||
model_id = get_embedding_model_id_by_dataset_id(dataset_id)
|
||||
embedding_by_document.delay(document_id, model_id)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -660,29 +664,6 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||
data={'dataset_id': dataset_id, 'document_id': document_id}).one(
|
||||
with_valid=True), document_id, dataset_id
|
||||
|
||||
@staticmethod
|
||||
def get_sync_handler(dataset_id):
|
||||
def handler(source_url: str, selector, response: Fork.Response):
|
||||
if response.status == 200:
|
||||
try:
|
||||
paragraphs = get_split_model('web.md').parse(response.content)
|
||||
# 插入
|
||||
DocumentSerializers.Create(data={'dataset_id': dataset_id}).save(
|
||||
{'name': source_url[0:128], 'paragraphs': paragraphs,
|
||||
'meta': {'source_url': source_url, 'selector': selector},
|
||||
'type': Type.web}, with_valid=True)
|
||||
except Exception as e:
|
||||
logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}')
|
||||
else:
|
||||
Document(name=source_url[0:128],
|
||||
dataset_id=dataset_id,
|
||||
meta={'source_url': source_url, 'selector': selector},
|
||||
type=Type.web,
|
||||
char_length=0,
|
||||
status=Status.error).save()
|
||||
|
||||
return handler
|
||||
|
||||
def save_web(self, instance: Dict, with_valid=True):
|
||||
if with_valid:
|
||||
DocumentWebInstanceSerializer(data=instance).is_valid(raise_exception=True)
|
||||
|
|
@ -690,8 +671,7 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||
dataset_id = self.data.get('dataset_id')
|
||||
source_url_list = instance.get('source_url_list')
|
||||
selector = instance.get('selector')
|
||||
args = SyncWebDocumentArgs(source_url_list, selector, self.get_sync_handler(dataset_id))
|
||||
ListenerManagement.sync_web_document_signal.send(args)
|
||||
sync_web_document.delay(dataset_id, source_url_list, selector)
|
||||
|
||||
@staticmethod
|
||||
def get_paragraph_model(document_model, paragraph_list: List):
|
||||
|
|
@ -818,8 +798,8 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||
@staticmethod
|
||||
def post_embedding(document_list, dataset_id):
|
||||
for document_dict in document_list:
|
||||
model = get_embedding_model_by_dataset_id(dataset_id)
|
||||
ListenerManagement.embedding_by_document_signal.send(document_dict.get('id'), embedding_model=model)
|
||||
model_id = get_embedding_model_id_by_dataset_id(dataset_id)
|
||||
embedding_by_document(document_dict.get('id'), model_id)
|
||||
return document_list
|
||||
|
||||
@post(post_function=post_embedding)
|
||||
|
|
@ -887,7 +867,7 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||
QuerySet(Paragraph).filter(document_id__in=document_id_list).delete()
|
||||
QuerySet(ProblemParagraphMapping).filter(document_id__in=document_id_list).delete()
|
||||
# 删除向量库
|
||||
ListenerManagement.delete_embedding_by_document_list_signal.send(document_id_list)
|
||||
delete_embedding_by_document_list(document_id_list)
|
||||
return True
|
||||
|
||||
def batch_edit_hit_handling(self, instance: Dict, with_valid=True):
|
||||
|
|
|
|||
|
|
@ -15,16 +15,18 @@ from drf_yasg import openapi
|
|||
from rest_framework import serializers
|
||||
|
||||
from common.db.search import page_search
|
||||
from common.event.listener_manage import ListenerManagement, UpdateEmbeddingDocumentIdArgs
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.mixins.api_mixin import ApiMixin
|
||||
from common.util.common import post
|
||||
from common.util.field_message import ErrMessage
|
||||
from dataset.models import Paragraph, Problem, Document, ProblemParagraphMapping, DataSet
|
||||
from dataset.serializers.common_serializers import update_document_char_length, BatchSerializer, ProblemParagraphObject, \
|
||||
ProblemParagraphManage, get_embedding_model_by_dataset_id, get_embedding_model_by_dataset
|
||||
ProblemParagraphManage, get_embedding_model_id_by_dataset_id
|
||||
from dataset.serializers.problem_serializers import ProblemInstanceSerializer, ProblemSerializer, ProblemSerializers
|
||||
from embedding.models import SourceType
|
||||
from embedding.task.embedding import embedding_by_problem as embedding_by_problem_task, embedding_by_problem, \
|
||||
delete_embedding_by_source, enable_embedding_by_paragraph, disable_embedding_by_paragraph, embedding_by_paragraph, \
|
||||
delete_embedding_by_paragraph, delete_embedding_by_paragraph_ids, update_embedding_document_id
|
||||
|
||||
|
||||
class ParagraphSerializer(serializers.ModelSerializer):
|
||||
|
|
@ -113,7 +115,7 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
|
|||
QuerySet(Problem).filter(id__in=[row.problem_id for row in problem_paragraph_mapping])]
|
||||
|
||||
@transaction.atomic
|
||||
def save(self, instance: Dict, with_valid=True, with_embedding=True):
|
||||
def save(self, instance: Dict, with_valid=True, with_embedding=True, embedding_by_problem=None):
|
||||
if with_valid:
|
||||
self.is_valid()
|
||||
ProblemInstanceSerializer(data=instance).is_valid(raise_exception=True)
|
||||
|
|
@ -132,16 +134,16 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
|
|||
paragraph_id=self.data.get('paragraph_id'),
|
||||
dataset_id=self.data.get('dataset_id'))
|
||||
problem_paragraph_mapping.save()
|
||||
model = get_embedding_model_by_dataset_id(self.data.get('dataset_id'))
|
||||
model_id = get_embedding_model_id_by_dataset_id(self.data.get('dataset_id'))
|
||||
if with_embedding:
|
||||
ListenerManagement.embedding_by_problem_signal.send({'text': problem.content,
|
||||
'is_active': True,
|
||||
'source_type': SourceType.PROBLEM,
|
||||
'source_id': problem_paragraph_mapping.id,
|
||||
'document_id': self.data.get('document_id'),
|
||||
'paragraph_id': self.data.get('paragraph_id'),
|
||||
'dataset_id': self.data.get('dataset_id'),
|
||||
}, embedding_model=model)
|
||||
embedding_by_problem_task({'text': problem.content,
|
||||
'is_active': True,
|
||||
'source_type': SourceType.PROBLEM,
|
||||
'source_id': problem_paragraph_mapping.id,
|
||||
'document_id': self.data.get('document_id'),
|
||||
'paragraph_id': self.data.get('paragraph_id'),
|
||||
'dataset_id': self.data.get('dataset_id'),
|
||||
}, model_id)
|
||||
|
||||
return ProblemSerializers.Operate(
|
||||
data={'dataset_id': self.data.get('dataset_id'),
|
||||
|
|
@ -228,15 +230,15 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
|
|||
problem_id=problem.id)
|
||||
problem_paragraph_mapping.save()
|
||||
if with_embedding:
|
||||
model = get_embedding_model_by_dataset_id(self.data.get('dataset_id'))
|
||||
ListenerManagement.embedding_by_problem_signal.send({'text': problem.content,
|
||||
'is_active': True,
|
||||
'source_type': SourceType.PROBLEM,
|
||||
'source_id': problem_paragraph_mapping.id,
|
||||
'document_id': self.data.get('document_id'),
|
||||
'paragraph_id': self.data.get('paragraph_id'),
|
||||
'dataset_id': self.data.get('dataset_id'),
|
||||
}, embedding_model=model)
|
||||
model_id = get_embedding_model_id_by_dataset_id(self.data.get('dataset_id'))
|
||||
embedding_by_problem({'text': problem.content,
|
||||
'is_active': True,
|
||||
'source_type': SourceType.PROBLEM,
|
||||
'source_id': problem_paragraph_mapping.id,
|
||||
'document_id': self.data.get('document_id'),
|
||||
'paragraph_id': self.data.get('paragraph_id'),
|
||||
'dataset_id': self.data.get('dataset_id'),
|
||||
}, model_id)
|
||||
|
||||
def un_association(self, with_valid=True):
|
||||
if with_valid:
|
||||
|
|
@ -248,7 +250,7 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
|
|||
'problem_id')).first()
|
||||
problem_paragraph_mapping_id = problem_paragraph_mapping.id
|
||||
problem_paragraph_mapping.delete()
|
||||
ListenerManagement.delete_embedding_by_source_signal.send(problem_paragraph_mapping_id)
|
||||
delete_embedding_by_source(problem_paragraph_mapping_id)
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -289,7 +291,7 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
|
|||
QuerySet(ProblemParagraphMapping).filter(paragraph_id__in=paragraph_id_list).delete()
|
||||
update_document_char_length(self.data.get('document_id'))
|
||||
# 删除向量库
|
||||
ListenerManagement.delete_embedding_by_paragraph_ids(paragraph_id_list)
|
||||
delete_embedding_by_paragraph_ids(paragraph_id_list)
|
||||
return True
|
||||
|
||||
class Migrate(ApiMixin, serializers.Serializer):
|
||||
|
|
@ -338,11 +340,8 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
|
|||
# 修改mapping
|
||||
QuerySet(ProblemParagraphMapping).bulk_update(problem_paragraph_mapping_list,
|
||||
['document_id'])
|
||||
|
||||
# 修改向量段落信息
|
||||
ListenerManagement.update_embedding_document_id(UpdateEmbeddingDocumentIdArgs(
|
||||
[paragraph.id for paragraph in paragraph_list],
|
||||
target_document_id, target_dataset_id, target_embedding_model=None))
|
||||
update_embedding_document_id([paragraph.id for paragraph in paragraph_list],
|
||||
target_document_id, target_dataset_id, None)
|
||||
# 修改段落信息
|
||||
paragraph_list.update(document_id=target_document_id)
|
||||
# 不同数据集迁移
|
||||
|
|
@ -371,16 +370,14 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
|
|||
['problem_id', 'dataset_id', 'document_id'])
|
||||
target_dataset = QuerySet(DataSet).filter(id=target_dataset_id).first()
|
||||
dataset = QuerySet(DataSet).filter(id=dataset_id).first()
|
||||
embedding_model = None
|
||||
embedding_model_id = None
|
||||
if target_dataset.embedding_mode_id != dataset.embedding_mode_id:
|
||||
embedding_model = get_embedding_model_by_dataset(target_dataset)
|
||||
embedding_model_id = str(target_dataset.embedding_mode_id)
|
||||
pid_list = [paragraph.id for paragraph in paragraph_list]
|
||||
# 修改段落信息
|
||||
paragraph_list.update(dataset_id=target_dataset_id, document_id=target_document_id)
|
||||
# 修改向量段落信息
|
||||
ListenerManagement.update_embedding_document_id(UpdateEmbeddingDocumentIdArgs(
|
||||
pid_list,
|
||||
target_document_id, target_dataset_id, target_embedding_model=embedding_model))
|
||||
update_embedding_document_id(pid_list, target_document_id, target_dataset_id, embedding_model_id)
|
||||
|
||||
update_document_char_length(document_id)
|
||||
update_document_char_length(target_document_id)
|
||||
|
|
@ -466,12 +463,12 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
|
|||
@staticmethod
|
||||
def post_embedding(paragraph, instance, dataset_id):
|
||||
if 'is_active' in instance and instance.get('is_active') is not None:
|
||||
s = (ListenerManagement.enable_embedding_by_paragraph_signal if instance.get(
|
||||
'is_active') else ListenerManagement.disable_embedding_by_paragraph_signal)
|
||||
s.send(paragraph.get('id'))
|
||||
(enable_embedding_by_paragraph if instance.get(
|
||||
'is_active') else disable_embedding_by_paragraph)(paragraph.get('id'))
|
||||
|
||||
else:
|
||||
model = get_embedding_model_by_dataset_id(dataset_id)
|
||||
ListenerManagement.embedding_by_paragraph_signal.send(paragraph.get('id'), embedding_model=model)
|
||||
model_id = get_embedding_model_id_by_dataset_id(dataset_id)
|
||||
embedding_by_paragraph(paragraph.get('id'), model_id)
|
||||
return paragraph
|
||||
|
||||
@post(post_embedding)
|
||||
|
|
@ -543,7 +540,7 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
|
|||
QuerySet(Paragraph).filter(id=paragraph_id).delete()
|
||||
QuerySet(ProblemParagraphMapping).filter(paragraph_id=paragraph_id).delete()
|
||||
update_document_char_length(self.data.get('document_id'))
|
||||
ListenerManagement.delete_embedding_by_paragraph_signal.send(paragraph_id)
|
||||
delete_embedding_by_paragraph(paragraph_id)
|
||||
|
||||
@staticmethod
|
||||
def get_request_body_api():
|
||||
|
|
@ -593,8 +590,8 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
|
|||
# 修改长度
|
||||
update_document_char_length(document_id)
|
||||
if with_embedding:
|
||||
model = get_embedding_model_by_dataset_id(dataset_id)
|
||||
ListenerManagement.embedding_by_paragraph_signal.send(str(paragraph.id), embedding_model=model)
|
||||
model_id = get_embedding_model_id_by_dataset_id(dataset_id)
|
||||
embedding_by_paragraph(str(paragraph.id), model_id)
|
||||
return ParagraphSerializers.Operate(
|
||||
data={'paragraph_id': str(paragraph.id), 'dataset_id': dataset_id, 'document_id': document_id}).one(
|
||||
with_valid=True)
|
||||
|
|
|
|||
|
|
@ -16,12 +16,12 @@ from drf_yasg import openapi
|
|||
from rest_framework import serializers
|
||||
|
||||
from common.db.search import native_search, native_page_search
|
||||
from common.event import ListenerManagement, UpdateProblemArgs
|
||||
from common.mixins.api_mixin import ApiMixin
|
||||
from common.util.field_message import ErrMessage
|
||||
from common.util.file_util import get_file_content
|
||||
from dataset.models import Problem, Paragraph, ProblemParagraphMapping, DataSet
|
||||
from dataset.serializers.common_serializers import get_embedding_model_by_dataset_id
|
||||
from dataset.serializers.common_serializers import get_embedding_model_id_by_dataset_id
|
||||
from embedding.task import delete_embedding_by_source_ids, update_problem_embedding
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
|
||||
|
|
@ -111,7 +111,7 @@ class ProblemSerializers(ApiMixin, serializers.Serializer):
|
|||
source_ids = [row.id for row in problem_paragraph_mapping_list]
|
||||
problem_paragraph_mapping_list.delete()
|
||||
QuerySet(Problem).filter(id__in=problem_id_list).delete()
|
||||
ListenerManagement.delete_embedding_by_source_ids_signal.send(source_ids)
|
||||
delete_embedding_by_source_ids(source_ids)
|
||||
return True
|
||||
|
||||
class Operate(serializers.Serializer):
|
||||
|
|
@ -146,7 +146,7 @@ class ProblemSerializers(ApiMixin, serializers.Serializer):
|
|||
source_ids = [row.id for row in problem_paragraph_mapping_list]
|
||||
problem_paragraph_mapping_list.delete()
|
||||
QuerySet(Problem).filter(id=self.data.get('problem_id')).delete()
|
||||
ListenerManagement.delete_embedding_by_source_ids_signal.send(source_ids)
|
||||
delete_embedding_by_source_ids(source_ids)
|
||||
return True
|
||||
|
||||
@transaction.atomic
|
||||
|
|
@ -161,5 +161,5 @@ class ProblemSerializers(ApiMixin, serializers.Serializer):
|
|||
QuerySet(DataSet).filter(id=dataset_id)
|
||||
problem.content = content
|
||||
problem.save()
|
||||
model = get_embedding_model_by_dataset_id(dataset_id)
|
||||
ListenerManagement.update_problem_signal.send(UpdateProblemArgs(problem_id, content, model))
|
||||
model_id = get_embedding_model_id_by_dataset_id(dataset_id)
|
||||
update_problem_embedding(problem_id, content, model_id)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,9 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: __init__.py
|
||||
@date:2024/8/21 9:57
|
||||
@desc:
|
||||
"""
|
||||
from .sync import *
|
||||
|
|
@ -0,0 +1,42 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: sync.py
|
||||
@date:2024/8/20 21:37
|
||||
@desc:
|
||||
"""
|
||||
|
||||
import logging
|
||||
import traceback
|
||||
from typing import List
|
||||
|
||||
from celery_once import QueueOnce
|
||||
|
||||
from common.util.fork import ForkManage, Fork
|
||||
from dataset.task.tools import get_save_handler, get_sync_web_document_handler
|
||||
|
||||
from ops import celery_app
|
||||
|
||||
max_kb_error = logging.getLogger("max_kb_error")
|
||||
max_kb = logging.getLogger("max_kb")
|
||||
|
||||
|
||||
@celery_app.task(base=QueueOnce, once={'keys': ['dataset_id']}, name='celery:sync_web_dataset')
|
||||
def sync_web_dataset(dataset_id: str, url: str, selector: str):
|
||||
try:
|
||||
max_kb.info(f"开始--->开始同步web知识库:{dataset_id}")
|
||||
ForkManage(url, selector.split(" ") if selector is not None else []).fork(2, set(),
|
||||
get_save_handler(dataset_id,
|
||||
selector))
|
||||
max_kb.info(f"结束--->结束同步web知识库:{dataset_id}")
|
||||
except Exception as e:
|
||||
max_kb_error.error(f'同步web知识库:{dataset_id}出现错误{str(e)}{traceback.format_exc()}')
|
||||
|
||||
|
||||
@celery_app.task(name='celery:sync_web_document')
|
||||
def sync_web_document(dataset_id, source_url_list: List[str], selector: str):
|
||||
handler = get_sync_web_document_handler(dataset_id)
|
||||
for source_url in source_url_list:
|
||||
result = Fork(base_fork_url=source_url, selector_list=selector.split(' ')).fork()
|
||||
handler(source_url, selector, result)
|
||||
|
|
@ -0,0 +1,62 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: tools.py
|
||||
@date:2024/8/20 21:48
|
||||
@desc:
|
||||
"""
|
||||
|
||||
import logging
|
||||
import traceback
|
||||
|
||||
from common.util.fork import ChildLink, Fork
|
||||
from common.util.split_model import get_split_model
|
||||
from dataset.models import Type, Document, Status
|
||||
|
||||
max_kb_error = logging.getLogger("max_kb_error")
|
||||
max_kb = logging.getLogger("max_kb")
|
||||
|
||||
|
||||
def get_save_handler(dataset_id, selector):
|
||||
from dataset.serializers.document_serializers import DocumentSerializers
|
||||
|
||||
def handler(child_link: ChildLink, response: Fork.Response):
|
||||
if response.status == 200:
|
||||
try:
|
||||
document_name = child_link.tag.text if child_link.tag is not None and len(
|
||||
child_link.tag.text.strip()) > 0 else child_link.url
|
||||
paragraphs = get_split_model('web.md').parse(response.content)
|
||||
DocumentSerializers.Create(data={'dataset_id': dataset_id}).save(
|
||||
{'name': document_name, 'paragraphs': paragraphs,
|
||||
'meta': {'source_url': child_link.url, 'selector': selector},
|
||||
'type': Type.web}, with_valid=True)
|
||||
except Exception as e:
|
||||
logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}')
|
||||
|
||||
return handler
|
||||
|
||||
|
||||
def get_sync_web_document_handler(dataset_id):
|
||||
from dataset.serializers.document_serializers import DocumentSerializers
|
||||
|
||||
def handler(source_url: str, selector, response: Fork.Response):
|
||||
if response.status == 200:
|
||||
try:
|
||||
paragraphs = get_split_model('web.md').parse(response.content)
|
||||
# 插入
|
||||
DocumentSerializers.Create(data={'dataset_id': dataset_id}).save(
|
||||
{'name': source_url[0:128], 'paragraphs': paragraphs,
|
||||
'meta': {'source_url': source_url, 'selector': selector},
|
||||
'type': Type.web}, with_valid=True)
|
||||
except Exception as e:
|
||||
logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}')
|
||||
else:
|
||||
Document(name=source_url[0:128],
|
||||
dataset_id=dataset_id,
|
||||
meta={'source_url': source_url, 'selector': selector},
|
||||
type=Type.web,
|
||||
char_length=0,
|
||||
status=Status.error).save()
|
||||
|
||||
return handler
|
||||
|
|
@ -0,0 +1 @@
|
|||
from .embedding import *
|
||||
|
|
@ -0,0 +1,218 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: embedding.py
|
||||
@date:2024/8/19 14:13
|
||||
@desc:
|
||||
"""
|
||||
|
||||
import logging
|
||||
import traceback
|
||||
from typing import List
|
||||
|
||||
from celery_once import QueueOnce
|
||||
from django.db.models import QuerySet
|
||||
|
||||
from common.config.embedding_config import ModelManage
|
||||
from common.event import ListenerManagement, UpdateProblemArgs, UpdateEmbeddingDatasetIdArgs, \
|
||||
UpdateEmbeddingDocumentIdArgs
|
||||
from dataset.models import Document
|
||||
from ops import celery_app
|
||||
from setting.models import Model
|
||||
from setting.models_provider import get_model
|
||||
|
||||
max_kb_error = logging.getLogger("max_kb_error")
|
||||
max_kb = logging.getLogger("max_kb")
|
||||
|
||||
|
||||
def get_embedding_model(model_id):
|
||||
model = QuerySet(Model).filter(id=model_id).first()
|
||||
embedding_model = ModelManage.get_model(model_id,
|
||||
lambda _id: get_model(model))
|
||||
return embedding_model
|
||||
|
||||
|
||||
@celery_app.task(base=QueueOnce, once={'keys': ['paragraph_id']}, name='celery:embedding_by_paragraph')
|
||||
def embedding_by_paragraph(paragraph_id, model_id):
|
||||
embedding_model = get_embedding_model(model_id)
|
||||
ListenerManagement.embedding_by_paragraph(paragraph_id, embedding_model)
|
||||
|
||||
|
||||
@celery_app.task(base=QueueOnce, once={'keys': ['paragraph_id_list']}, name='celery:embedding_by_paragraph_data_list')
|
||||
def embedding_by_paragraph_data_list(data_list, paragraph_id_list, model_id):
|
||||
embedding_model = get_embedding_model(model_id)
|
||||
ListenerManagement.embedding_by_paragraph_data_list(data_list, paragraph_id_list, embedding_model)
|
||||
|
||||
|
||||
@celery_app.task(base=QueueOnce, once={'keys': ['paragraph_id_list']}, name='celery:embedding_by_paragraph_list')
|
||||
def embedding_by_paragraph_list(paragraph_id_list, model_id):
|
||||
embedding_model = get_embedding_model(model_id)
|
||||
ListenerManagement.embedding_by_paragraph_list(paragraph_id_list, embedding_model)
|
||||
|
||||
|
||||
@celery_app.task(base=QueueOnce, once={'keys': ['document_id']}, name='celery:embedding_by_document')
|
||||
def embedding_by_document(document_id, model_id):
|
||||
"""
|
||||
向量化文档
|
||||
@param document_id: 文档id
|
||||
@param model_id 向量模型
|
||||
:return: None
|
||||
"""
|
||||
embedding_model = get_embedding_model(model_id)
|
||||
ListenerManagement.embedding_by_document(document_id, embedding_model)
|
||||
|
||||
|
||||
@celery_app.task(base=QueueOnce, once={'keys': ['dataset_id']}, name='celery:embedding_by_dataset')
|
||||
def embedding_by_dataset(dataset_id, model_id):
|
||||
"""
|
||||
向量化知识库
|
||||
@param dataset_id: 知识库id
|
||||
@param model_id 向量模型
|
||||
:return: None
|
||||
"""
|
||||
max_kb.info(f"开始--->向量化数据集:{dataset_id}")
|
||||
try:
|
||||
ListenerManagement.delete_embedding_by_dataset(dataset_id)
|
||||
document_list = QuerySet(Document).filter(dataset_id=dataset_id)
|
||||
max_kb.info(f"数据集文档:{[d.name for d in document_list]}")
|
||||
for document in document_list:
|
||||
try:
|
||||
embedding_by_document.delay(document.id, model_id)
|
||||
except Exception as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
max_kb_error.error(f'向量化数据集:{dataset_id}出现错误{str(e)}{traceback.format_exc()}')
|
||||
finally:
|
||||
max_kb.info(f"结束--->向量化数据集:{dataset_id}")
|
||||
|
||||
|
||||
def embedding_by_problem(args, model_id):
|
||||
"""
|
||||
向量话问题
|
||||
@param args: 问题对象
|
||||
@param model_id: 模型id
|
||||
@return:
|
||||
"""
|
||||
embedding_model = get_embedding_model(model_id)
|
||||
ListenerManagement.embedding_by_problem(args, embedding_model)
|
||||
|
||||
|
||||
def delete_embedding_by_document(document_id):
|
||||
"""
|
||||
删除指定文档id的向量
|
||||
@param document_id: 文档id
|
||||
@return: None
|
||||
"""
|
||||
|
||||
ListenerManagement.delete_embedding_by_document(document_id)
|
||||
|
||||
|
||||
def delete_embedding_by_document_list(document_id_list: List[str]):
|
||||
"""
|
||||
删除指定文档列表的向量数据
|
||||
@param document_id_list: 文档id列表
|
||||
@return: None
|
||||
"""
|
||||
ListenerManagement.delete_embedding_by_document_list(document_id_list)
|
||||
|
||||
|
||||
def delete_embedding_by_dataset(dataset_id):
|
||||
"""
|
||||
删除指定数据集向量数据
|
||||
@param dataset_id: 数据集id
|
||||
@return: None
|
||||
"""
|
||||
ListenerManagement.delete_embedding_by_dataset(dataset_id)
|
||||
|
||||
|
||||
def delete_embedding_by_paragraph(paragraph_id):
|
||||
"""
|
||||
删除指定段落的向量数据
|
||||
@param paragraph_id: 段落id
|
||||
@return: None
|
||||
"""
|
||||
ListenerManagement.delete_embedding_by_paragraph(paragraph_id)
|
||||
|
||||
|
||||
def delete_embedding_by_source(source_id):
|
||||
"""
|
||||
删除指定资源id的向量数据
|
||||
@param source_id: 资源id
|
||||
@return: None
|
||||
"""
|
||||
ListenerManagement.delete_embedding_by_source(source_id)
|
||||
|
||||
|
||||
def disable_embedding_by_paragraph(paragraph_id):
|
||||
"""
|
||||
禁用某个段落id的向量
|
||||
@param paragraph_id: 段落id
|
||||
@return: None
|
||||
"""
|
||||
ListenerManagement.disable_embedding_by_paragraph(paragraph_id)
|
||||
|
||||
|
||||
def enable_embedding_by_paragraph(paragraph_id):
|
||||
"""
|
||||
开启某个段落id的向量数据
|
||||
@param paragraph_id: 段落id
|
||||
@return: None
|
||||
"""
|
||||
ListenerManagement.enable_embedding_by_paragraph(paragraph_id)
|
||||
|
||||
|
||||
def delete_embedding_by_source_ids(source_ids: List[str]):
|
||||
"""
|
||||
删除向量根据source_id_list
|
||||
@param source_ids:
|
||||
@return:
|
||||
"""
|
||||
ListenerManagement.delete_embedding_by_source_ids(source_ids)
|
||||
|
||||
|
||||
def update_problem_embedding(problem_id: str, problem_content: str, model_id):
|
||||
"""
|
||||
更新问题
|
||||
@param problem_id:
|
||||
@param problem_content:
|
||||
@param model_id:
|
||||
@return:
|
||||
"""
|
||||
model = get_embedding_model(model_id)
|
||||
ListenerManagement.update_problem(UpdateProblemArgs(problem_id, problem_content, model))
|
||||
|
||||
|
||||
def update_embedding_dataset_id(paragraph_id_list, target_dataset_id, target_embedding_model_id=None):
|
||||
"""
|
||||
修改向量数据到指定知识库
|
||||
@param paragraph_id_list: 指定段落的向量数据
|
||||
@param target_dataset_id: 知识库id
|
||||
@param target_embedding_model_id: 目标知识库
|
||||
@return:
|
||||
"""
|
||||
target_embedding_model = get_embedding_model(
|
||||
target_embedding_model_id) if target_embedding_model_id is not None else None
|
||||
ListenerManagement.update_embedding_dataset_id(
|
||||
UpdateEmbeddingDatasetIdArgs(paragraph_id_list, target_dataset_id, target_embedding_model))
|
||||
|
||||
|
||||
def delete_embedding_by_paragraph_ids(paragraph_ids: List[str]):
|
||||
"""
|
||||
删除指定段落列表的向量数据
|
||||
@param paragraph_ids: 段落列表
|
||||
@return: None
|
||||
"""
|
||||
ListenerManagement.delete_embedding_by_paragraph_ids(paragraph_ids)
|
||||
|
||||
|
||||
def update_embedding_document_id(paragraph_id_list, target_document_id, target_dataset_id,
|
||||
target_embedding_model_id=None):
|
||||
target_embedding_model = get_embedding_model(
|
||||
target_embedding_model_id) if target_embedding_model_id is not None else None
|
||||
ListenerManagement.update_embedding_document_id(
|
||||
UpdateEmbeddingDocumentIdArgs(paragraph_id_list, target_document_id, target_dataset_id, target_embedding_model))
|
||||
|
||||
|
||||
def delete_embedding_by_dataset_id_list(dataset_id_list):
|
||||
ListenerManagement.delete_embedding_by_dataset_id_list(dataset_id_list)
|
||||
|
|
@ -127,7 +127,7 @@ class BaseVectorStore(ABC):
|
|||
return []
|
||||
embedding_query = embedding.embed_query(query_text)
|
||||
result = self.query(embedding_query, dataset_id_list, exclude_document_id_list, exclude_paragraph_list,
|
||||
is_active, 1, 0.65)
|
||||
is_active, 1, 3, 0.65)
|
||||
return result[0]
|
||||
|
||||
@abstractmethod
|
||||
|
|
@ -169,7 +169,7 @@ class BaseVectorStore(ABC):
|
|||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_bu_document_id_list(self, document_id_list: List[str]):
|
||||
def delete_by_document_id_list(self, document_id_list: List[str]):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
|
|
|||
|
|
@ -27,6 +27,8 @@ from smartdoc.conf import PROJECT_DIR
|
|||
class PGVector(BaseVectorStore):
|
||||
|
||||
def delete_by_source_ids(self, source_ids: List[str], source_type: str):
|
||||
if len(source_ids) == 0:
|
||||
return
|
||||
QuerySet(Embedding).filter(source_id__in=source_ids, source_type=source_type).delete()
|
||||
|
||||
def update_by_source_ids(self, source_ids: List[str], instance: Dict):
|
||||
|
|
@ -67,7 +69,7 @@ class PGVector(BaseVectorStore):
|
|||
source_type=text_list[index].get('source_type'),
|
||||
embedding=embeddings[index],
|
||||
search_vector=to_ts_vector(text_list[index]['text'])) for index in
|
||||
range(0, len(text_list))]
|
||||
range(0, len(texts))]
|
||||
if is_save_function():
|
||||
QuerySet(Embedding).bulk_create(embedding_list) if len(embedding_list) > 0 else None
|
||||
return True
|
||||
|
|
@ -124,7 +126,9 @@ class PGVector(BaseVectorStore):
|
|||
QuerySet(Embedding).filter(document_id=document_id).delete()
|
||||
return True
|
||||
|
||||
def delete_bu_document_id_list(self, document_id_list: List[str]):
|
||||
def delete_by_document_id_list(self, document_id_list: List[str]):
|
||||
if len(document_id_list) == 0:
|
||||
return True
|
||||
return QuerySet(Embedding).filter(document_id__in=document_id_list).delete()
|
||||
|
||||
def delete_by_source_id(self, source_id: str, source_type: str):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,9 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: __init__.py.py
|
||||
@date:2024/8/16 14:47
|
||||
@desc:
|
||||
"""
|
||||
from .celery import app as celery_app
|
||||
|
|
@ -0,0 +1,33 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import os
|
||||
|
||||
from celery import Celery
|
||||
from celery.schedules import crontab
|
||||
from kombu import Exchange, Queue
|
||||
from smartdoc import settings
|
||||
from .heatbeat import *
|
||||
|
||||
# set the default Django settings module for the 'celery' program.
|
||||
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'smartdoc.settings')
|
||||
|
||||
app = Celery('MaxKB')
|
||||
|
||||
configs = {k: v for k, v in settings.__dict__.items() if k.startswith('CELERY')}
|
||||
configs['worker_concurrency'] = 5
|
||||
# Using a string here means the worker will not have to
|
||||
# pickle the object when using Windows.
|
||||
# app.config_from_object('django.conf:settings', namespace='CELERY')
|
||||
|
||||
configs["task_queues"] = [
|
||||
Queue("celery", Exchange("celery"), routing_key="celery"),
|
||||
Queue("model", Exchange("model"), routing_key="model")
|
||||
]
|
||||
app.namespace = 'CELERY'
|
||||
app.conf.update(
|
||||
{key.replace('CELERY_', '') if key.replace('CELERY_', '').lower() == key.replace('CELERY_',
|
||||
'') else key: configs.get(
|
||||
key) for
|
||||
key
|
||||
in configs.keys()})
|
||||
app.autodiscover_tasks(lambda: [app_config.split('.')[0] for app_config in settings.INSTALLED_APPS])
|
||||
|
|
@ -0,0 +1,4 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
|
||||
CELERY_LOG_MAGIC_MARK = b'\x00\x00\x00\x00\x00'
|
||||
|
|
@ -0,0 +1,104 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
from functools import wraps
|
||||
|
||||
_need_registered_period_tasks = []
|
||||
_after_app_ready_start_tasks = []
|
||||
_after_app_shutdown_clean_periodic_tasks = []
|
||||
|
||||
|
||||
def add_register_period_task(task):
|
||||
_need_registered_period_tasks.append(task)
|
||||
|
||||
|
||||
def get_register_period_tasks():
|
||||
return _need_registered_period_tasks
|
||||
|
||||
|
||||
def add_after_app_shutdown_clean_task(name):
|
||||
_after_app_shutdown_clean_periodic_tasks.append(name)
|
||||
|
||||
|
||||
def get_after_app_shutdown_clean_tasks():
|
||||
return _after_app_shutdown_clean_periodic_tasks
|
||||
|
||||
|
||||
def add_after_app_ready_task(name):
|
||||
_after_app_ready_start_tasks.append(name)
|
||||
|
||||
|
||||
def get_after_app_ready_tasks():
|
||||
return _after_app_ready_start_tasks
|
||||
|
||||
|
||||
def register_as_period_task(
|
||||
crontab=None, interval=None, name=None,
|
||||
args=(), kwargs=None,
|
||||
description=''):
|
||||
"""
|
||||
Warning: Task must have not any args and kwargs
|
||||
:param crontab: "* * * * *"
|
||||
:param interval: 60*60*60
|
||||
:param args: ()
|
||||
:param kwargs: {}
|
||||
:param description: "
|
||||
:param name: ""
|
||||
:return:
|
||||
"""
|
||||
if crontab is None and interval is None:
|
||||
raise SyntaxError("Must set crontab or interval one")
|
||||
|
||||
def decorate(func):
|
||||
if crontab is None and interval is None:
|
||||
raise SyntaxError("Interval and crontab must set one")
|
||||
|
||||
# Because when this decorator run, the task was not created,
|
||||
# So we can't use func.name
|
||||
task = '{func.__module__}.{func.__name__}'.format(func=func)
|
||||
_name = name if name else task
|
||||
add_register_period_task({
|
||||
_name: {
|
||||
'task': task,
|
||||
'interval': interval,
|
||||
'crontab': crontab,
|
||||
'args': args,
|
||||
'kwargs': kwargs if kwargs else {},
|
||||
'description': description
|
||||
}
|
||||
})
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorate
|
||||
|
||||
|
||||
def after_app_ready_start(func):
|
||||
# Because when this decorator run, the task was not created,
|
||||
# So we can't use func.name
|
||||
name = '{func.__module__}.{func.__name__}'.format(func=func)
|
||||
if name not in _after_app_ready_start_tasks:
|
||||
add_after_app_ready_task(name)
|
||||
|
||||
@wraps(func)
|
||||
def decorate(*args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return decorate
|
||||
|
||||
|
||||
def after_app_shutdown_clean_periodic(func):
|
||||
# Because when this decorator run, the task was not created,
|
||||
# So we can't use func.name
|
||||
name = '{func.__module__}.{func.__name__}'.format(func=func)
|
||||
if name not in _after_app_shutdown_clean_periodic_tasks:
|
||||
add_after_app_shutdown_clean_task(name)
|
||||
|
||||
@wraps(func)
|
||||
def decorate(*args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return decorate
|
||||
|
|
@ -0,0 +1,25 @@
|
|||
from pathlib import Path
|
||||
|
||||
from celery.signals import heartbeat_sent, worker_ready, worker_shutdown
|
||||
|
||||
|
||||
@heartbeat_sent.connect
|
||||
def heartbeat(sender, **kwargs):
|
||||
worker_name = sender.eventer.hostname.split('@')[0]
|
||||
heartbeat_path = Path('/tmp/worker_heartbeat_{}'.format(worker_name))
|
||||
heartbeat_path.touch()
|
||||
|
||||
|
||||
@worker_ready.connect
|
||||
def worker_ready(sender, **kwargs):
|
||||
worker_name = sender.hostname.split('@')[0]
|
||||
ready_path = Path('/tmp/worker_ready_{}'.format(worker_name))
|
||||
ready_path.touch()
|
||||
|
||||
|
||||
@worker_shutdown.connect
|
||||
def worker_shutdown(sender, **kwargs):
|
||||
worker_name = sender.hostname.split('@')[0]
|
||||
for signal in ['ready', 'heartbeat']:
|
||||
path = Path('/tmp/worker_{}_{}'.format(signal, worker_name))
|
||||
path.unlink(missing_ok=True)
|
||||
|
|
@ -0,0 +1,223 @@
|
|||
from logging import StreamHandler
|
||||
from threading import get_ident
|
||||
|
||||
from celery import current_task
|
||||
from celery.signals import task_prerun, task_postrun
|
||||
from django.conf import settings
|
||||
from kombu import Connection, Exchange, Queue, Producer
|
||||
from kombu.mixins import ConsumerMixin
|
||||
|
||||
from .utils import get_celery_task_log_path
|
||||
from .const import CELERY_LOG_MAGIC_MARK
|
||||
|
||||
routing_key = 'celery_log'
|
||||
celery_log_exchange = Exchange('celery_log_exchange', type='direct')
|
||||
celery_log_queue = [Queue('celery_log', celery_log_exchange, routing_key=routing_key)]
|
||||
|
||||
|
||||
class CeleryLoggerConsumer(ConsumerMixin):
|
||||
def __init__(self):
|
||||
self.connection = Connection(settings.CELERY_LOG_BROKER_URL)
|
||||
|
||||
def get_consumers(self, Consumer, channel):
|
||||
return [Consumer(queues=celery_log_queue,
|
||||
accept=['pickle', 'json'],
|
||||
callbacks=[self.process_task])
|
||||
]
|
||||
|
||||
def handle_task_start(self, task_id, message):
|
||||
pass
|
||||
|
||||
def handle_task_end(self, task_id, message):
|
||||
pass
|
||||
|
||||
def handle_task_log(self, task_id, msg, message):
|
||||
pass
|
||||
|
||||
def process_task(self, body, message):
|
||||
action = body.get('action')
|
||||
task_id = body.get('task_id')
|
||||
msg = body.get('msg')
|
||||
if action == CeleryLoggerProducer.ACTION_TASK_LOG:
|
||||
self.handle_task_log(task_id, msg, message)
|
||||
elif action == CeleryLoggerProducer.ACTION_TASK_START:
|
||||
self.handle_task_start(task_id, message)
|
||||
elif action == CeleryLoggerProducer.ACTION_TASK_END:
|
||||
self.handle_task_end(task_id, message)
|
||||
|
||||
|
||||
class CeleryLoggerProducer:
|
||||
ACTION_TASK_START, ACTION_TASK_LOG, ACTION_TASK_END = range(3)
|
||||
|
||||
def __init__(self):
|
||||
self.connection = Connection(settings.CELERY_LOG_BROKER_URL)
|
||||
|
||||
@property
|
||||
def producer(self):
|
||||
return Producer(self.connection)
|
||||
|
||||
def publish(self, payload):
|
||||
self.producer.publish(
|
||||
payload, serializer='json', exchange=celery_log_exchange,
|
||||
declare=[celery_log_exchange], routing_key=routing_key
|
||||
)
|
||||
|
||||
def log(self, task_id, msg):
|
||||
payload = {'task_id': task_id, 'msg': msg, 'action': self.ACTION_TASK_LOG}
|
||||
return self.publish(payload)
|
||||
|
||||
def read(self):
|
||||
pass
|
||||
|
||||
def flush(self):
|
||||
pass
|
||||
|
||||
def task_end(self, task_id):
|
||||
payload = {'task_id': task_id, 'action': self.ACTION_TASK_END}
|
||||
return self.publish(payload)
|
||||
|
||||
def task_start(self, task_id):
|
||||
payload = {'task_id': task_id, 'action': self.ACTION_TASK_START}
|
||||
return self.publish(payload)
|
||||
|
||||
|
||||
class CeleryTaskLoggerHandler(StreamHandler):
|
||||
terminator = '\r\n'
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
task_prerun.connect(self.on_task_start)
|
||||
task_postrun.connect(self.on_start_end)
|
||||
|
||||
@staticmethod
|
||||
def get_current_task_id():
|
||||
if not current_task:
|
||||
return
|
||||
task_id = current_task.request.root_id
|
||||
return task_id
|
||||
|
||||
def on_task_start(self, sender, task_id, **kwargs):
|
||||
return self.handle_task_start(task_id)
|
||||
|
||||
def on_start_end(self, sender, task_id, **kwargs):
|
||||
return self.handle_task_end(task_id)
|
||||
|
||||
def after_task_publish(self, sender, body, **kwargs):
|
||||
pass
|
||||
|
||||
def emit(self, record):
|
||||
task_id = self.get_current_task_id()
|
||||
if not task_id:
|
||||
return
|
||||
try:
|
||||
self.write_task_log(task_id, record)
|
||||
self.flush()
|
||||
except Exception:
|
||||
self.handleError(record)
|
||||
|
||||
def write_task_log(self, task_id, msg):
|
||||
pass
|
||||
|
||||
def handle_task_start(self, task_id):
|
||||
pass
|
||||
|
||||
def handle_task_end(self, task_id):
|
||||
pass
|
||||
|
||||
|
||||
class CeleryThreadingLoggerHandler(CeleryTaskLoggerHandler):
|
||||
@staticmethod
|
||||
def get_current_thread_id():
|
||||
return str(get_ident())
|
||||
|
||||
def emit(self, record):
|
||||
thread_id = self.get_current_thread_id()
|
||||
try:
|
||||
self.write_thread_task_log(thread_id, record)
|
||||
self.flush()
|
||||
except ValueError:
|
||||
self.handleError(record)
|
||||
|
||||
def write_thread_task_log(self, thread_id, msg):
|
||||
pass
|
||||
|
||||
def handle_task_start(self, task_id):
|
||||
pass
|
||||
|
||||
def handle_task_end(self, task_id):
|
||||
pass
|
||||
|
||||
def handleError(self, record) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class CeleryTaskMQLoggerHandler(CeleryTaskLoggerHandler):
|
||||
def __init__(self):
|
||||
self.producer = CeleryLoggerProducer()
|
||||
super().__init__(stream=None)
|
||||
|
||||
def write_task_log(self, task_id, record):
|
||||
msg = self.format(record)
|
||||
self.producer.log(task_id, msg)
|
||||
|
||||
def flush(self):
|
||||
self.producer.flush()
|
||||
|
||||
|
||||
class CeleryTaskFileHandler(CeleryTaskLoggerHandler):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.f = None
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def emit(self, record):
|
||||
msg = self.format(record)
|
||||
if not self.f or self.f.closed:
|
||||
return
|
||||
self.f.write(msg)
|
||||
self.f.write(self.terminator)
|
||||
self.flush()
|
||||
|
||||
def flush(self):
|
||||
self.f and self.f.flush()
|
||||
|
||||
def handle_task_start(self, task_id):
|
||||
log_path = get_celery_task_log_path(task_id)
|
||||
self.f = open(log_path, 'a')
|
||||
|
||||
def handle_task_end(self, task_id):
|
||||
self.f and self.f.close()
|
||||
|
||||
|
||||
class CeleryThreadTaskFileHandler(CeleryThreadingLoggerHandler):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.thread_id_fd_mapper = {}
|
||||
self.task_id_thread_id_mapper = {}
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def write_thread_task_log(self, thread_id, record):
|
||||
f = self.thread_id_fd_mapper.get(thread_id, None)
|
||||
if not f:
|
||||
raise ValueError('Not found thread task file')
|
||||
msg = self.format(record)
|
||||
f.write(msg.encode())
|
||||
f.write(self.terminator.encode())
|
||||
f.flush()
|
||||
|
||||
def flush(self):
|
||||
for f in self.thread_id_fd_mapper.values():
|
||||
f.flush()
|
||||
|
||||
def handle_task_start(self, task_id):
|
||||
log_path = get_celery_task_log_path(task_id)
|
||||
thread_id = self.get_current_thread_id()
|
||||
self.task_id_thread_id_mapper[task_id] = thread_id
|
||||
f = open(log_path, 'ab')
|
||||
self.thread_id_fd_mapper[thread_id] = f
|
||||
|
||||
def handle_task_end(self, task_id):
|
||||
ident_id = self.task_id_thread_id_mapper.get(task_id, '')
|
||||
f = self.thread_id_fd_mapper.pop(ident_id, None)
|
||||
if f and not f.closed:
|
||||
f.write(CELERY_LOG_MAGIC_MARK)
|
||||
f.close()
|
||||
self.task_id_thread_id_mapper.pop(task_id, None)
|
||||
|
|
@ -0,0 +1,63 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
import logging
|
||||
import os
|
||||
|
||||
from celery import subtask
|
||||
from celery.signals import (
|
||||
worker_ready, worker_shutdown, after_setup_logger
|
||||
)
|
||||
from django.core.cache import cache
|
||||
from django_celery_beat.models import PeriodicTask
|
||||
|
||||
from .decorator import get_after_app_ready_tasks, get_after_app_shutdown_clean_tasks
|
||||
from .logger import CeleryThreadTaskFileHandler
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
safe_str = lambda x: x
|
||||
|
||||
|
||||
@worker_ready.connect
|
||||
def on_app_ready(sender=None, headers=None, **kwargs):
|
||||
if cache.get("CELERY_APP_READY", 0) == 1:
|
||||
return
|
||||
cache.set("CELERY_APP_READY", 1, 10)
|
||||
tasks = get_after_app_ready_tasks()
|
||||
logger.debug("Work ready signal recv")
|
||||
logger.debug("Start need start task: [{}]".format(", ".join(tasks)))
|
||||
for task in tasks:
|
||||
periodic_task = PeriodicTask.objects.filter(task=task).first()
|
||||
if periodic_task and not periodic_task.enabled:
|
||||
logger.debug("Periodic task [{}] is disabled!".format(task))
|
||||
continue
|
||||
subtask(task).delay()
|
||||
|
||||
|
||||
def delete_files(directory):
|
||||
if os.path.isdir(directory):
|
||||
for filename in os.listdir(directory):
|
||||
file_path = os.path.join(directory, filename)
|
||||
if os.path.isfile(file_path):
|
||||
os.remove(file_path)
|
||||
|
||||
|
||||
@worker_shutdown.connect
|
||||
def after_app_shutdown_periodic_tasks(sender=None, **kwargs):
|
||||
if cache.get("CELERY_APP_SHUTDOWN", 0) == 1:
|
||||
return
|
||||
cache.set("CELERY_APP_SHUTDOWN", 1, 10)
|
||||
tasks = get_after_app_shutdown_clean_tasks()
|
||||
logger.debug("Worker shutdown signal recv")
|
||||
logger.debug("Clean period tasks: [{}]".format(', '.join(tasks)))
|
||||
PeriodicTask.objects.filter(name__in=tasks).delete()
|
||||
|
||||
|
||||
@after_setup_logger.connect
|
||||
def add_celery_logger_handler(sender=None, logger=None, loglevel=None, format=None, **kwargs):
|
||||
if not logger:
|
||||
return
|
||||
task_handler = CeleryThreadTaskFileHandler()
|
||||
task_handler.setLevel(loglevel)
|
||||
formatter = logging.Formatter(format)
|
||||
task_handler.setFormatter(formatter)
|
||||
logger.addHandler(task_handler)
|
||||
|
|
@ -0,0 +1,68 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
|
||||
from django.conf import settings
|
||||
from django_celery_beat.models import (
|
||||
PeriodicTasks
|
||||
)
|
||||
|
||||
from smartdoc.const import PROJECT_DIR
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
|
||||
|
||||
def disable_celery_periodic_task(task_name):
|
||||
from django_celery_beat.models import PeriodicTask
|
||||
PeriodicTask.objects.filter(name=task_name).update(enabled=False)
|
||||
PeriodicTasks.update_changed()
|
||||
|
||||
|
||||
def delete_celery_periodic_task(task_name):
|
||||
from django_celery_beat.models import PeriodicTask
|
||||
PeriodicTask.objects.filter(name=task_name).delete()
|
||||
PeriodicTasks.update_changed()
|
||||
|
||||
|
||||
def get_celery_periodic_task(task_name):
|
||||
from django_celery_beat.models import PeriodicTask
|
||||
task = PeriodicTask.objects.filter(name=task_name).first()
|
||||
return task
|
||||
|
||||
|
||||
def make_dirs(name, mode=0o755, exist_ok=False):
|
||||
""" 默认权限设置为 0o755 """
|
||||
return os.makedirs(name, mode=mode, exist_ok=exist_ok)
|
||||
|
||||
|
||||
def get_task_log_path(base_path, task_id, level=2):
|
||||
task_id = str(task_id)
|
||||
try:
|
||||
uuid.UUID(task_id)
|
||||
except:
|
||||
return os.path.join(PROJECT_DIR, 'data', 'caution.txt')
|
||||
|
||||
rel_path = os.path.join(*task_id[:level], task_id + '.log')
|
||||
path = os.path.join(base_path, rel_path)
|
||||
make_dirs(os.path.dirname(path), exist_ok=True)
|
||||
return path
|
||||
|
||||
|
||||
def get_celery_task_log_path(task_id):
|
||||
return get_task_log_path(settings.CELERY_LOG_DIR, task_id)
|
||||
|
||||
|
||||
def get_celery_status():
|
||||
from . import app
|
||||
i = app.control.inspect()
|
||||
ping_data = i.ping() or {}
|
||||
active_nodes = [k for k, v in ping_data.items() if v.get('ok') == 'pong']
|
||||
active_queue_worker = set([n.split('@')[0] for n in active_nodes if n])
|
||||
# Celery Worker 数量: 2
|
||||
if len(active_queue_worker) < 2:
|
||||
print("Not all celery worker worked")
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
|
@ -13,18 +13,22 @@ from common.util.rsa_util import rsa_long_decrypt
|
|||
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
|
||||
|
||||
|
||||
def get_model_(provider, model_type, model_name, credential, **kwargs):
|
||||
def get_model_(provider, model_type, model_name, credential, model_id, use_local=False, **kwargs):
|
||||
"""
|
||||
获取模型实例
|
||||
@param provider: 供应商
|
||||
@param model_type: 模型类型
|
||||
@param model_name: 模型名称
|
||||
@param credential: 认证信息
|
||||
@param model_id: 模型id
|
||||
@param use_local: 是否调用本地模型 只适用于本地供应商
|
||||
@return: 模型实例
|
||||
"""
|
||||
model = get_provider(provider).get_model(model_type, model_name,
|
||||
json.loads(
|
||||
rsa_long_decrypt(credential)),
|
||||
model_id=model_id,
|
||||
use_local=use_local,
|
||||
streaming=True, **kwargs)
|
||||
return model
|
||||
|
||||
|
|
@ -35,7 +39,7 @@ def get_model(model, **kwargs):
|
|||
@param model: model 数据库Model实例对象
|
||||
@return: 模型实例
|
||||
"""
|
||||
return get_model_(model.provider, model.model_type, model.model_name, model.credential, **kwargs)
|
||||
return get_model_(model.provider, model.model_type, model.model_name, model.credential, str(model.id), **kwargs)
|
||||
|
||||
|
||||
def get_provider(provider):
|
||||
|
|
|
|||
|
|
@ -6,17 +6,52 @@
|
|||
@date:2024/7/11 14:06
|
||||
@desc:
|
||||
"""
|
||||
from typing import Dict
|
||||
from typing import Dict, List
|
||||
|
||||
import requests
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_huggingface import HuggingFaceEmbeddings
|
||||
|
||||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||
|
||||
|
||||
class LocalEmbedding(MaxKBBaseModel, HuggingFaceEmbeddings):
|
||||
class WebLocalEmbedding(MaxKBBaseModel, BaseModel, Embeddings):
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
return LocalEmbedding(model_name=model_name, cache_folder=model_credential.get('cache_folder'),
|
||||
model_kwargs={'device': model_credential.get('device')},
|
||||
encode_kwargs={'normalize_embeddings': True},
|
||||
)
|
||||
pass
|
||||
|
||||
model_id: str = None
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.model_id = kwargs.get('model_id', None)
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
res = requests.post(f'http://127.0.0.1:5432/api/model/{self.model_id}/embed_query', {'text': text})
|
||||
result = res.json()
|
||||
if result.get('code', 500) == 200:
|
||||
return result.get('data')
|
||||
raise Exception(result.get('msg'))
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
res = requests.post(f'http://127.0.0.1:5432/api/model/{self.model_id}/embed_documents', {'texts': texts})
|
||||
result = res.json()
|
||||
if result.get('code', 500) == 200:
|
||||
return result.get('data')
|
||||
raise Exception(result.get('msg'))
|
||||
|
||||
|
||||
class LocalEmbedding(MaxKBBaseModel, HuggingFaceEmbeddings):
|
||||
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
if model_kwargs.get('use_local', True):
|
||||
return LocalEmbedding(model_name=model_name, cache_folder=model_credential.get('cache_folder'),
|
||||
model_kwargs={'device': model_credential.get('device')},
|
||||
encode_kwargs={'normalize_embeddings': True}
|
||||
)
|
||||
return WebLocalEmbedding(model_name=model_name, cache_folder=model_credential.get('cache_folder'),
|
||||
model_kwargs={'device': model_credential.get('device')},
|
||||
encode_kwargs={'normalize_embeddings': True},
|
||||
**model_kwargs)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,53 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: model_apply_serializers.py
|
||||
@date:2024/8/20 20:39
|
||||
@desc:
|
||||
"""
|
||||
from django.db.models import QuerySet
|
||||
from rest_framework import serializers
|
||||
|
||||
from common.config.embedding_config import ModelManage
|
||||
from common.util.field_message import ErrMessage
|
||||
from setting.models import Model
|
||||
from setting.models_provider import get_model
|
||||
|
||||
|
||||
def get_embedding_model(model_id):
|
||||
model = QuerySet(Model).filter(id=model_id).first()
|
||||
embedding_model = ModelManage.get_model(model_id,
|
||||
lambda _id: get_model(model, use_local=True))
|
||||
return embedding_model
|
||||
|
||||
|
||||
class EmbedDocuments(serializers.Serializer):
|
||||
texts = serializers.ListField(required=True, child=serializers.CharField(required=True,
|
||||
error_messages=ErrMessage.char(
|
||||
"向量文本")),
|
||||
error_messages=ErrMessage.list("向量文本列表"))
|
||||
|
||||
|
||||
class EmbedQuery(serializers.Serializer):
|
||||
text = serializers.CharField(required=True, error_messages=ErrMessage.char("向量文本"))
|
||||
|
||||
|
||||
class ModelApplySerializers(serializers.Serializer):
|
||||
model_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("模型id"))
|
||||
|
||||
def embed_documents(self, instance, with_valid=True):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
EmbedDocuments(data=instance).is_valid(raise_exception=True)
|
||||
|
||||
model = get_embedding_model(self.data.get('model_id'))
|
||||
return model.embed_documents(instance.getlist('texts'))
|
||||
|
||||
def embed_query(self, instance, with_valid=True):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
EmbedQuery(data=instance).is_valid(raise_exception=True)
|
||||
|
||||
model = get_embedding_model(self.data.get('model_id'))
|
||||
return model.embed_query(instance.get('text'))
|
||||
|
|
@ -1,3 +1,5 @@
|
|||
import os
|
||||
|
||||
from django.urls import path
|
||||
|
||||
from . import views
|
||||
|
|
@ -22,3 +24,10 @@ urlpatterns = [
|
|||
path('valid/<str:valid_type>/<int:valid_count>', views.Valid.as_view())
|
||||
|
||||
]
|
||||
if os.environ.get('SERVER_NAME', 'web') == 'local_model':
|
||||
urlpatterns += [
|
||||
path('model/<str:model_id>/embed_documents', views.ModelApply.EmbedDocuments.as_view(),
|
||||
name='model/embed_documents'),
|
||||
path('model/<str:model_id>/embed_query', views.ModelApply.EmbedQuery.as_view(),
|
||||
name='model/embed_query'),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -10,3 +10,4 @@ from .Team import *
|
|||
from .model import *
|
||||
from .system_setting import *
|
||||
from .valid import *
|
||||
from .model_apply import *
|
||||
|
|
|
|||
|
|
@ -0,0 +1,38 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: model_apply.py
|
||||
@date:2024/8/20 20:38
|
||||
@desc:
|
||||
"""
|
||||
from urllib.request import Request
|
||||
|
||||
from drf_yasg.utils import swagger_auto_schema
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.views import APIView
|
||||
|
||||
from common.response import result
|
||||
from setting.serializers.model_apply_serializers import ModelApplySerializers
|
||||
|
||||
|
||||
class ModelApply(APIView):
|
||||
class EmbedDocuments(APIView):
|
||||
@action(methods=['POST'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="向量化文档",
|
||||
operation_id="向量化文档",
|
||||
responses=result.get_default_response(),
|
||||
tags=["模型"])
|
||||
def post(self, request: Request, model_id):
|
||||
return result.success(
|
||||
ModelApplySerializers(data={'model_id': model_id}).embed_documents(request.data))
|
||||
|
||||
class EmbedQuery(APIView):
|
||||
@action(methods=['POST'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="向量化文档",
|
||||
operation_id="向量化文档",
|
||||
responses=result.get_default_response(),
|
||||
tags=["模型"])
|
||||
def post(self, request: Request, model_id):
|
||||
return result.success(
|
||||
ModelApplySerializers(data={'model_id': model_id}).embed_query(request.data))
|
||||
|
|
@ -9,3 +9,4 @@
|
|||
from .base import *
|
||||
from .logging import *
|
||||
from .auth import *
|
||||
from .lib import *
|
||||
|
|
|
|||
|
|
@ -42,7 +42,8 @@ INSTALLED_APPS = [
|
|||
'django_filters', # 条件过滤
|
||||
'django_apscheduler',
|
||||
'common',
|
||||
'function_lib'
|
||||
'function_lib',
|
||||
'django_celery_beat'
|
||||
|
||||
]
|
||||
|
||||
|
|
@ -60,6 +61,7 @@ JWT_AUTH = {
|
|||
'JWT_EXPIRATION_DELTA': datetime.timedelta(seconds=60 * 60 * 2) # <-- 设置token有效时间
|
||||
}
|
||||
|
||||
APPS_DIR = os.path.join(PROJECT_DIR, 'apps')
|
||||
ROOT_URLCONF = 'smartdoc.urls'
|
||||
# FORCE_SCRIPT_NAME
|
||||
TEMPLATES = [
|
||||
|
|
|
|||
|
|
@ -0,0 +1,40 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: lib.py
|
||||
@date:2024/8/16 17:12
|
||||
@desc:
|
||||
"""
|
||||
import os
|
||||
|
||||
from smartdoc.const import CONFIG, PROJECT_DIR
|
||||
|
||||
# celery相关配置
|
||||
celery_data_dir = os.path.join(PROJECT_DIR, 'data', 'celery_task')
|
||||
if not os.path.exists(celery_data_dir) or not os.path.isdir(celery_data_dir):
|
||||
os.makedirs(celery_data_dir)
|
||||
broker_path = os.path.join(celery_data_dir, "celery_db.sqlite3")
|
||||
backend_path = os.path.join(celery_data_dir, "celery_results.sqlite3")
|
||||
# 使用sql_lite 当做broker 和 响应接收
|
||||
CELERY_BROKER_URL = f'sqla+sqlite:///{broker_path}'
|
||||
CELERY_result_backend = f'db+sqlite:///{backend_path}'
|
||||
CELERY_timezone = CONFIG.TIME_ZONE
|
||||
CELERY_ENABLE_UTC = False
|
||||
CELERY_task_serializer = 'pickle'
|
||||
CELERY_result_serializer = 'pickle'
|
||||
CELERY_accept_content = ['json', 'pickle']
|
||||
CELERY_RESULT_EXPIRES = 600
|
||||
CELERY_WORKER_TASK_LOG_FORMAT = '%(asctime).19s %(message)s'
|
||||
CELERY_WORKER_LOG_FORMAT = '%(asctime).19s %(message)s'
|
||||
CELERY_TASK_EAGER_PROPAGATES = True
|
||||
CELERY_WORKER_REDIRECT_STDOUTS = True
|
||||
CELERY_WORKER_REDIRECT_STDOUTS_LEVEL = "INFO"
|
||||
CELERY_TASK_SOFT_TIME_LIMIT = 3600
|
||||
CELERY_WORKER_CANCEL_LONG_RUNNING_TASKS_ON_CONNECTION_LOSS = True
|
||||
CELERY_ONCE = {
|
||||
'backend': 'celery_once.backends.File',
|
||||
'settings': {'location': os.path.join(celery_data_dir, "celery_once")}
|
||||
}
|
||||
CELERY_BROKER_CONNECTION_RETRY_ON_STARTUP = True
|
||||
CELERY_LOG_DIR = os.path.join(PROJECT_DIR, 'logs', 'celery')
|
||||
|
|
@ -114,6 +114,14 @@ LOGGING = {
|
|||
'level': LOG_LEVEL,
|
||||
'propagate': False,
|
||||
},
|
||||
'common.event': {
|
||||
'handlers': ['console', 'file'],
|
||||
'level': "DEBUG",
|
||||
'propagate': False,
|
||||
},
|
||||
'sqlalchemy': {'handlers': ['console', 'file'],
|
||||
'level': "ERROR",
|
||||
'propagate': False, }
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -21,7 +21,6 @@ def post_handler():
|
|||
from common import job
|
||||
from common.models.db_model_manage import DBModelManage
|
||||
event.run()
|
||||
event.ListenerManagement.init_embedding_model_signal.send()
|
||||
job.run()
|
||||
DBModelManage.init()
|
||||
|
||||
|
|
|
|||
|
|
@ -5,3 +5,5 @@ class UsersConfig(AppConfig):
|
|||
default_auto_field = 'django.db.models.BigAutoField'
|
||||
name = 'users'
|
||||
|
||||
def ready(self):
|
||||
from ops.celery import signal_handler
|
||||
|
|
|
|||
|
|
@ -26,7 +26,6 @@ from common.constants.authentication_type import AuthenticationType
|
|||
from common.constants.exception_code_constants import ExceptionCodeConstants
|
||||
from common.constants.permission_constants import RoleConstants, get_permission_list_by_role
|
||||
from common.db.search import page_search
|
||||
from common.event import ListenerManagement
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.mixins.api_mixin import ApiMixin
|
||||
from common.response.result import get_api_response
|
||||
|
|
@ -34,6 +33,7 @@ from common.util.common import valid_license
|
|||
from common.util.field_message import ErrMessage
|
||||
from common.util.lock import lock
|
||||
from dataset.models import DataSet, Document, Paragraph, Problem, ProblemParagraphMapping
|
||||
from embedding.task import delete_embedding_by_dataset_id_list
|
||||
from setting.models import Team, SystemSetting, SettingType, Model, TeamMember, TeamMemberPermission
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
from users.models.user import User, password_encrypt, get_user_dynamics_permission
|
||||
|
|
@ -497,7 +497,8 @@ class UserSerializer(ApiMixin, serializers.ModelSerializer):
|
|||
class UserInstanceSerializer(ApiMixin, serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = User
|
||||
fields = ['id', 'username', 'email', 'phone', 'is_active', 'role', 'nick_name', 'create_time', 'update_time', 'source']
|
||||
fields = ['id', 'username', 'email', 'phone', 'is_active', 'role', 'nick_name', 'create_time', 'update_time',
|
||||
'source']
|
||||
|
||||
@staticmethod
|
||||
def get_response_body_api():
|
||||
|
|
@ -643,7 +644,8 @@ class UserManageSerializer(serializers.Serializer):
|
|||
|
||||
def is_valid(self, *, user_id=None, raise_exception=False):
|
||||
super().is_valid(raise_exception=True)
|
||||
if self.data.get('email') is not None and QuerySet(User).filter(email=self.data.get('email')).exclude(id=user_id).exists():
|
||||
if self.data.get('email') is not None and QuerySet(User).filter(email=self.data.get('email')).exclude(
|
||||
id=user_id).exists():
|
||||
raise AppApiException(1004, "邮箱已经被使用")
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -738,7 +740,7 @@ class UserManageSerializer(serializers.Serializer):
|
|||
QuerySet(Paragraph).filter(dataset_id__in=dataset_id_list).delete()
|
||||
QuerySet(ProblemParagraphMapping).filter(dataset_id__in=dataset_id_list).delete()
|
||||
QuerySet(Problem).filter(dataset_id__in=dataset_id_list).delete()
|
||||
ListenerManagement.delete_embedding_by_dataset_id_list_signal.send(dataset_id_list)
|
||||
delete_embedding_by_dataset_id_list(dataset_id_list)
|
||||
dataset_list.delete()
|
||||
# 删除团队
|
||||
QuerySet(Team).filter(user_id=self.data.get('id')).delete()
|
||||
|
|
|
|||
47
main.py
47
main.py
|
|
@ -2,6 +2,7 @@ import argparse
|
|||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
import django
|
||||
from django.core import management
|
||||
|
|
@ -43,11 +44,38 @@ def perform_db_migrate():
|
|||
|
||||
|
||||
def start_services():
|
||||
management.call_command('gunicorn')
|
||||
services = args.services if isinstance(args.services, list) else [args.services]
|
||||
start_args = []
|
||||
if args.daemon:
|
||||
start_args.append('--daemon')
|
||||
if args.force:
|
||||
start_args.append('--force')
|
||||
if args.worker:
|
||||
start_args.extend(['--worker', str(args.worker)])
|
||||
else:
|
||||
worker = os.environ.get('CORE_WORKER')
|
||||
if isinstance(worker, str) and worker.isdigit():
|
||||
start_args.extend(['--worker', worker])
|
||||
|
||||
try:
|
||||
management.call_command(action, *services, *start_args)
|
||||
except KeyboardInterrupt:
|
||||
logging.info('Cancel ...')
|
||||
time.sleep(2)
|
||||
except Exception as exc:
|
||||
logging.error("Start service error {}: {}".format(services, exc))
|
||||
time.sleep(2)
|
||||
|
||||
|
||||
def runserver():
|
||||
management.call_command('runserver', "0.0.0.0:8080")
|
||||
def dev():
|
||||
services = args.services if isinstance(args.services, list) else args.services
|
||||
if services.__contains__('web'):
|
||||
management.call_command('runserver', "0.0.0.0:8080")
|
||||
elif services.__contains__('celery'):
|
||||
management.call_command('celery', 'celery')
|
||||
elif services.__contains__('local_model'):
|
||||
os.environ.setdefault('SERVER_NAME', 'local_model')
|
||||
management.call_command('runserver', "127.0.0.1:5432")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
@ -66,8 +94,17 @@ if __name__ == '__main__':
|
|||
choices=("start", "dev", "upgrade_db", "collect_static"),
|
||||
help="Action to run"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args, e = parser.parse_known_args()
|
||||
parser.add_argument(
|
||||
"services", type=str, default='all' if args.action == 'start' else 'web', nargs="*",
|
||||
choices=("all", "web", "task") if args.action == 'start' else ("web", "celery", 'local_model'),
|
||||
help="The service to start",
|
||||
)
|
||||
|
||||
parser.add_argument('-d', '--daemon', nargs="?", const=True)
|
||||
parser.add_argument('-w', '--worker', type=int, nargs="?")
|
||||
parser.add_argument('-f', '--force', nargs="?", const=True)
|
||||
args = parser.parse_args()
|
||||
action = args.action
|
||||
if action == "upgrade_db":
|
||||
perform_db_migrate()
|
||||
|
|
@ -76,7 +113,7 @@ if __name__ == '__main__':
|
|||
elif action == 'dev':
|
||||
collect_static()
|
||||
perform_db_migrate()
|
||||
runserver()
|
||||
dev()
|
||||
else:
|
||||
collect_static()
|
||||
perform_db_migrate()
|
||||
|
|
|
|||
|
|
@ -21,7 +21,6 @@ pillow = "^10.2.0"
|
|||
filetype = "^1.2.0"
|
||||
torch = "2.2.1"
|
||||
sentence-transformers = "^2.2.2"
|
||||
blinker = "^1.6.3"
|
||||
openai = "^1.13.3"
|
||||
tiktoken = "^0.7.0"
|
||||
qianfan = "^0.3.6.1"
|
||||
|
|
@ -55,6 +54,11 @@ boto3 = "^1.34.151"
|
|||
langchain-aws = "^0.1.13"
|
||||
tencentcloud-sdk-python = "^3.0.1205"
|
||||
xinference-client = "^0.14.0.post1"
|
||||
psutil = "^6.0.0"
|
||||
celery = { extras = ["sqlalchemy"], version = "^5.4.0" }
|
||||
eventlet = "^0.36.1"
|
||||
django-celery-beat = "^2.6.0"
|
||||
celery-once = "^3.0.1"
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
|
|
|||
Loading…
Reference in New Issue