perf: Memory optimization (#4318)

This commit is contained in:
shaohuzhang1 2025-11-05 19:05:26 +08:00 committed by GitHub
parent 1f4d6d1123
commit a8d0729e65
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
41 changed files with 1013 additions and 203 deletions

View File

@ -224,11 +224,12 @@ class LoopWorkFlowPostHandler(WorkFlowPostHandler):
class BaseLoopNode(ILoopNode):
def save_context(self, details, workflow_manage):
self.context['result'] = details.get('result')
self.context['loop_context_data'] = details.get('loop_context_data')
self.context['loop_answer_data'] = details.get('loop_answer_data')
for key, value in details['context'].items():
if key not in self.context:
self.context[key] = value
self.answer_text = str(details.get('result'))
self.answer_text = ""
def get_answer_list(self) -> List[Answer] | None:
result = []

View File

@ -15,7 +15,6 @@ from urllib.parse import urljoin
import uuid_utils.compat as uuid
from charset_normalizer import detect
from django.db.models import QuerySet
from django.utils.translation import gettext_lazy as _
from common.handle.base_split_handle import BaseSplitHandle
@ -39,7 +38,6 @@ class FileBufferHandle:
return self.buffer
default_split_handle = TextSplitHandle()
split_handles = [
HTMLSplitHandle(),

View File

@ -0,0 +1,3 @@
from django.contrib import admin
# Register your models here.

6
apps/local_model/apps.py Normal file
View File

@ -0,0 +1,6 @@
from django.apps import AppConfig
class LocalModelConfig(AppConfig):
default_auto_field = 'django.db.models.BigAutoField'
name = 'local_model'

View File

View File

@ -0,0 +1,10 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file __init__.py
@date2023/9/25 15:04
@desc:
"""
from .model_management import *

View File

@ -0,0 +1,49 @@
# coding=utf-8
import uuid_utils.compat as uuid
from django.db import models
from common.mixins.app_model_mixin import AppModelMixin
from local_model.models.user import User
class Status(models.TextChoices):
"""系统设置类型"""
SUCCESS = "SUCCESS", '成功'
ERROR = "ERROR", "失败"
DOWNLOAD = "DOWNLOAD", '下载中'
PAUSE_DOWNLOAD = "PAUSE_DOWNLOAD", '暂停下载'
class Model(AppModelMixin):
"""
模型数据
"""
id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid7, editable=False, verbose_name="主键id")
name = models.CharField(max_length=128, verbose_name="名称", db_index=True)
status = models.CharField(max_length=20, verbose_name='设置类型', choices=Status.choices,
default=Status.SUCCESS, db_index=True)
model_type = models.CharField(max_length=128, verbose_name="模型类型", db_index=True)
model_name = models.CharField(max_length=128, verbose_name="模型名称", db_index=True)
user = models.ForeignKey(User, on_delete=models.SET_NULL, db_constraint=False, blank=True, null=True)
provider = models.CharField(max_length=128, verbose_name='供应商', db_index=True)
credential = models.CharField(max_length=102400, verbose_name="模型认证信息")
meta = models.JSONField(verbose_name="模型元数据,用于存储下载,或者错误信息", default=dict)
model_params_form = models.JSONField(verbose_name="模型参数配置", default=list)
workspace_id = models.CharField(max_length=64, verbose_name="工作空间id", default="default", db_index=True)
class Meta:
db_table = "model"
unique_together = ['name', 'workspace_id']

View File

@ -0,0 +1,34 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file system_management.py
@date2024/3/19 13:47
@desc: 邮箱管理
"""
from django.db import models
from common.mixins.app_model_mixin import AppModelMixin
class SettingType(models.IntegerChoices):
"""系统设置类型"""
EMAIL = 0, '邮箱'
RSA = 1, "私钥秘钥"
LOG = 2, "日志清理时间"
class SystemSetting(AppModelMixin):
"""
系统设置
"""
type = models.IntegerField(primary_key=True, verbose_name='设置类型', choices=SettingType.choices,
default=SettingType.EMAIL)
meta = models.JSONField(verbose_name="配置数据", default=dict)
class Meta:
db_table = "system_setting"

View File

@ -0,0 +1,38 @@
# coding=utf-8
"""
@project: MaxKB
@Author虎虎
@file user.py
@date2025/4/14 10:20
@desc:
"""
import uuid_utils.compat as uuid
from django.db import models
from common.utils.common import password_encrypt
class User(models.Model):
id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid7, editable=False, verbose_name="主键id")
email = models.EmailField(unique=True, null=True, blank=True, verbose_name="邮箱", db_index=True)
phone = models.CharField(max_length=20, verbose_name="电话", default="", db_index=True)
nick_name = models.CharField(max_length=150, verbose_name="昵称", unique=True, db_index=True)
username = models.CharField(max_length=150, unique=True, verbose_name="用户名", db_index=True)
password = models.CharField(max_length=150, verbose_name="密码")
role = models.CharField(max_length=150, verbose_name="角色")
source = models.CharField(max_length=10, verbose_name="来源", default="LOCAL", db_index=True)
is_active = models.BooleanField(default=True, db_index=True)
language = models.CharField(max_length=10, verbose_name="语言", null=True, default=None)
create_time = models.DateTimeField(verbose_name="创建时间", auto_now_add=True, null=True, db_index=True)
update_time = models.DateTimeField(verbose_name="修改时间", auto_now=True, null=True, db_index=True)
USERNAME_FIELD = 'username'
REQUIRED_FIELDS = []
class Meta:
db_table = "user"
def set_password(self, row_password):
self.password = password_encrypt(row_password)
self._password = row_password

View File

@ -0,0 +1 @@
# coding=utf-8

View File

@ -0,0 +1,140 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file model_apply_serializers.py
@date2024/8/20 20:39
@desc:
"""
import json
import threading
import time
from django.db import connection
from django.db.models import QuerySet
from django.utils.translation import gettext_lazy as _
from langchain_core.documents import Document
from rest_framework import serializers
from local_model.models import Model
from local_model.serializers.rsa_util import rsa_long_decrypt
from models_provider.impl.local_model_provider.local_model_provider import LocalModelProvider
from common.cache.mem_cache import MemCache
_lock = threading.Lock()
locks = {}
class ModelManage:
cache = MemCache('model', {})
up_clear_time = time.time()
@staticmethod
def _get_lock(_id):
lock = locks.get(_id)
if lock is None:
with _lock:
lock = locks.get(_id)
if lock is None:
lock = threading.Lock()
locks[_id] = lock
return lock
@staticmethod
def get_model(_id, get_model):
model_instance = ModelManage.cache.get(_id)
if model_instance is None:
lock = ModelManage._get_lock(_id)
with lock:
model_instance = ModelManage.cache.get(_id)
if model_instance is None:
model_instance = get_model(_id)
ModelManage.cache.set(_id, model_instance, timeout=60 * 60 * 8)
else:
if model_instance.is_cache_model():
ModelManage.cache.touch(_id, timeout=60 * 60 * 8)
else:
model_instance = get_model(_id)
ModelManage.cache.set(_id, model_instance, timeout=60 * 60 * 8)
ModelManage.clear_timeout_cache()
return model_instance
@staticmethod
def clear_timeout_cache():
if time.time() - ModelManage.up_clear_time > 60 * 60:
threading.Thread(target=lambda: ModelManage.cache.clear_timeout_data()).start()
ModelManage.up_clear_time = time.time()
@staticmethod
def delete_key(_id):
if ModelManage.cache.has_key(_id):
ModelManage.cache.delete(_id)
def get_local_model(model, **kwargs):
# system_setting = QuerySet(SystemSetting).filter(type=1).first()
return LocalModelProvider().get_model(model.model_type, model.model_name,
json.loads(
rsa_long_decrypt(model.credential)),
model_id=model.id,
streaming=True, **kwargs)
def get_embedding_model(model_id):
model = QuerySet(Model).filter(id=model_id).first()
# 手动关闭数据库连接
connection.close()
embedding_model = ModelManage.get_model(model_id,
lambda _id: get_local_model(model, use_local=True))
return embedding_model
class EmbedDocuments(serializers.Serializer):
texts = serializers.ListField(required=True, child=serializers.CharField(required=True,
label=_('vector text')),
label=_('vector text list')),
class EmbedQuery(serializers.Serializer):
text = serializers.CharField(required=True, label=_('vector text'))
class CompressDocument(serializers.Serializer):
page_content = serializers.CharField(required=True, label=_('text'))
metadata = serializers.DictField(required=False, label=_('metadata'))
class CompressDocuments(serializers.Serializer):
documents = CompressDocument(required=True, many=True)
query = serializers.CharField(required=True, label=_('query'))
class ModelApplySerializers(serializers.Serializer):
model_id = serializers.UUIDField(required=True, label=_('model id'))
def embed_documents(self, instance, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
EmbedDocuments(data=instance).is_valid(raise_exception=True)
model = get_embedding_model(self.data.get('model_id'))
return model.embed_documents(instance.getlist('texts'))
def embed_query(self, instance, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
EmbedQuery(data=instance).is_valid(raise_exception=True)
model = get_embedding_model(self.data.get('model_id'))
return model.embed_query(instance.get('text'))
def compress_documents(self, instance, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
CompressDocuments(data=instance).is_valid(raise_exception=True)
model = get_embedding_model(self.data.get('model_id'))
return [{'page_content': d.page_content, 'metadata': d.metadata} for d in model.compress_documents(
[Document(page_content=document.get('page_content'), metadata=document.get('metadata')) for document in
instance.get('documents')], instance.get('query'))]

View File

@ -0,0 +1,139 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file rsa_util.py
@date2023/11/3 11:13
@desc:
"""
import base64
import threading
from Crypto.Cipher import PKCS1_v1_5 as PKCS1_cipher
from Crypto.PublicKey import RSA
from django.core import cache
from django.db.models import QuerySet
from common.constants.cache_version import Cache_Version
from local_model.models.system_setting import SystemSetting, SettingType
lock = threading.Lock()
rsa_cache = cache.cache
cache_key = "rsa_key"
# 对密钥加密的密码
secret_code = "mac_kb_password"
def generate():
"""
生成 私钥秘钥对
:return:{key:'公钥',value:'私钥'}
"""
# 生成一个 2048 位的密钥
key = RSA.generate(2048)
# 获取私钥
encrypted_key = key.export_key(passphrase=secret_code, pkcs=8,
protection="scryptAndAES128-CBC")
return {'key': key.publickey().export_key(), 'value': encrypted_key}
def get_key_pair():
rsa_value = rsa_cache.get(cache_key)
if rsa_value is None:
with lock:
rsa_value = rsa_cache.get(cache_key)
if rsa_value is not None:
return rsa_value
rsa_value = get_key_pair_by_sql()
version, get_key = Cache_Version.SYSTEM.value
rsa_cache.set(get_key(key='rsa_key'), rsa_value, timeout=None, version=version)
return rsa_value
def get_key_pair_by_sql():
system_setting = QuerySet(SystemSetting).filter(type=SettingType.RSA.value).first()
if system_setting is None:
kv = generate()
system_setting = SystemSetting(type=SettingType.RSA.value,
meta={'key': kv.get('key').decode(), 'value': kv.get('value').decode()})
system_setting.save()
return system_setting.meta
def encrypt(msg, public_key: str | None = None):
"""
加密
:param msg: 加密数据
:param public_key: 公钥
:return: 加密后的数据
"""
if public_key is None:
public_key = get_key_pair().get('key')
cipher = PKCS1_cipher.new(RSA.importKey(public_key))
encrypt_msg = cipher.encrypt(msg.encode("utf-8"))
return base64.b64encode(encrypt_msg).decode()
def decrypt(msg, pri_key: str | None = None):
"""
解密
:param msg: 需要解密的数据
:param pri_key: 私钥
:return: 解密后数据
"""
if pri_key is None:
pri_key = get_key_pair().get('value')
cipher = PKCS1_cipher.new(RSA.importKey(pri_key, passphrase=secret_code))
decrypt_data = cipher.decrypt(base64.b64decode(msg), 0)
return decrypt_data.decode("utf-8")
def rsa_long_encrypt(message, public_key: str | None = None, length=200):
"""
超长文本加密
:param message: 需要加密的字符串
:param public_key 公钥
:param length: 1024bit的证书用100 2048bit的证书用 200
:return: 加密后的数据
"""
# 读取公钥
if public_key is None:
public_key = get_key_pair().get('key')
cipher = PKCS1_cipher.new(RSA.importKey(extern_key=public_key,
passphrase=secret_code))
# 处理Plaintext is too long. 分段加密
if len(message) <= length:
# 对编码的数据进行加密并通过base64进行编码
result = base64.b64encode(cipher.encrypt(message.encode('utf-8')))
else:
rsa_text = []
# 对编码后的数据进行切片,原因:加密长度不能过长
for i in range(0, len(message), length):
cont = message[i:i + length]
# 对切片后的数据进行加密并新增到text后面
rsa_text.append(cipher.encrypt(cont.encode('utf-8')))
# 加密完进行拼接
cipher_text = b''.join(rsa_text)
# base64进行编码
result = base64.b64encode(cipher_text)
return result.decode()
def rsa_long_decrypt(message, pri_key: str | None = None, length=256):
"""
超长文本解密默认不加密
:param message: 需要解密的数据
:param pri_key: 秘钥
:param length : 1024bit的证书用1282048bit证书用256位
:return: 解密后的数据
"""
if pri_key is None:
pri_key = get_key_pair().get('value')
cipher = PKCS1_cipher.new(RSA.importKey(pri_key, passphrase=secret_code))
base64_de = base64.b64decode(message)
res = []
for i in range(0, len(base64_de), length):
res.append(cipher.decrypt(base64_de[i:i + length], 0))
return b"".join(res).decode()

View File

@ -0,0 +1,3 @@
from django.test import TestCase
# Create your tests here.

13
apps/local_model/urls.py Normal file
View File

@ -0,0 +1,13 @@
import os
from django.urls import path
from . import views
app_name = "local_model"
# @formatter:off
urlpatterns = [
path('model/<str:model_id>/embed_documents', views.LocalModelApply.EmbedDocuments.as_view()),
path('model/<str:model_id>/embed_query', views.LocalModelApply.EmbedQuery.as_view()),
path('model/<str:model_id>/compress_documents', views.LocalModelApply.CompressDocuments.as_view()),
]

View File

@ -0,0 +1,2 @@
# coding=utf-8
from .model_apply import *

View File

@ -0,0 +1,34 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file model_apply.py
@date2024/8/20 20:38
@desc:
"""
from urllib.request import Request
from rest_framework.views import APIView
from common.result import result
from local_model.serializers.model_apply_serializers import ModelApplySerializers
class LocalModelApply(APIView):
class EmbedDocuments(APIView):
def post(self, request: Request, model_id):
return result.success(
ModelApplySerializers(data={'model_id': model_id}).embed_documents(request.data))
class EmbedQuery(APIView):
def post(self, request: Request, model_id):
return result.success(
ModelApplySerializers(data={'model_id': model_id}).embed_query(request.data))
class CompressDocuments(APIView):
def post(self, request: Request, model_id):
return result.success(
ModelApplySerializers(data={'model_id': model_id}).compress_documents(request.data))

View File

@ -0,0 +1,14 @@
# coding=utf-8
"""
@project: MaxKB
@Author虎虎
@file __init__.py
@date2025/11/5 14:50
@desc:
"""
import os
if os.environ.get('SERVER_NAME', 'web') == 'local_model':
from .model import *
else:
from .web import *

View File

@ -0,0 +1,11 @@
# coding=utf-8
"""
@project: MaxKB
@Author虎虎
@file auth.py
@date2024/7/9 18:47
@desc:
"""
AUTH_HANDLES = [
]

View File

@ -0,0 +1,14 @@
# coding=utf-8
"""
@project: MaxKB
@Author虎虎
@file __init__.py.py
@date2025/11/5 14:53
@desc:
"""
import os
if os.environ.get('SERVER_NAME', 'web') == 'local_model':
from .model import *
else:
from .web import *

View File

@ -0,0 +1,179 @@
# coding=utf-8
"""
@project: MaxKB
@Author虎虎
@file model.py
@date2025/11/5 14:53
@desc:
"""
from pathlib import Path
from ...const import CONFIG, PROJECT_DIR
import os
from django.utils.translation import gettext_lazy as _
# Build paths inside the project like this: BASE_DIR / 'subdir'.
BASE_DIR = Path(__file__).resolve().parent.parent.parent
# Quick-start development settings - unsuitable for production
# See https://docs.djangoproject.com/en/4.2/howto/deployment/checklist/
# SECURITY WARNING: keep the secret key used in production secret!
SECRET_KEY = CONFIG.get("SECRET_KEY") or 'django-insecure-zm^1_^i5)3gp^&0io6zg72&z!a*d=9kf9o2%uft+27l)+t(#3e'
# SECURITY WARNING: don't run with debug turned on in production!
DEBUG = CONFIG.get_debug()
ALLOWED_HOSTS = ['*']
# Application definition
INSTALLED_APPS = [
'django.contrib.contenttypes',
'django.contrib.messages',
'django.contrib.staticfiles',
'rest_framework',
'local_model',
]
MIDDLEWARE = [
'django.middleware.locale.LocaleMiddleware',
'django.middleware.security.SecurityMiddleware',
'django.contrib.sessions.middleware.SessionMiddleware',
]
REST_FRAMEWORK = {
'EXCEPTION_HANDLER': 'common.exception.handle_exception.handle_exception',
'DEFAULT_SCHEMA_CLASS': 'drf_spectacular.openapi.AutoSchema',
'DEFAULT_AUTHENTICATION_CLASSES': ['common.auth.authenticate.AnonymousAuthentication']
}
STATICFILES_DIRS = [(os.path.join(PROJECT_DIR, 'ui', 'dist'))]
STATIC_ROOT = os.path.join(BASE_DIR.parent, 'static')
ROOT_URLCONF = 'maxkb.urls'
APPS_DIR = os.path.join(PROJECT_DIR, 'apps')
TEMPLATES = [
{
'BACKEND': 'django.template.backends.django.DjangoTemplates',
'DIRS': ["apps/static/admin"],
'APP_DIRS': True,
'OPTIONS': {
'context_processors': [
'django.template.context_processors.debug',
'django.template.context_processors.request',
'django.contrib.auth.context_processors.auth',
'django.contrib.messages.context_processors.messages',
],
},
},
{"NAME": "CHAT",
'BACKEND': 'django.template.backends.django.DjangoTemplates',
'DIRS': ["apps/static/chat"],
'APP_DIRS': True,
'OPTIONS': {
'context_processors': [
'django.template.context_processors.debug',
'django.template.context_processors.request',
'django.contrib.auth.context_processors.auth',
'django.contrib.messages.context_processors.messages',
],
},
},
{"NAME": "DOC",
'BACKEND': 'django.template.backends.django.DjangoTemplates',
'DIRS': ["apps/static/drf_spectacular_sidecar"],
'APP_DIRS': True,
'OPTIONS': {
'context_processors': [
'django.template.context_processors.debug',
'django.template.context_processors.request',
'django.contrib.auth.context_processors.auth',
'django.contrib.messages.context_processors.messages',
],
},
},
]
SPECTACULAR_SETTINGS = {
'TITLE': 'MaxKB API',
'DESCRIPTION': _('Intelligent customer service platform'),
'VERSION': 'v2',
'SERVE_INCLUDE_SCHEMA': False,
# OTHER SETTINGS
'SWAGGER_UI_DIST': f'{CONFIG.get_admin_path()}/api-doc/swagger-ui-dist', # shorthand to use the sidecar instead
'SWAGGER_UI_FAVICON_HREF': f'{CONFIG.get_admin_path()}/api-doc/swagger-ui-dist/favicon-32x32.png',
'REDOC_DIST': f'{CONFIG.get_admin_path()}/api-doc/redoc',
'SECURITY_DEFINITIONS': {
'Bearer': {
'type': 'apiKey',
'name': 'AUTHORIZATION',
'in': 'header',
}
}
}
WSGI_APPLICATION = 'maxkb.wsgi.application'
# Database
# https://docs.djangoproject.com/en/4.2/ref/settings/#databases
DATABASES = {'default': CONFIG.get_db_setting()}
CACHES = CONFIG.get_cache_setting()
# Password validation
# https://docs.djangoproject.com/en/4.2/ref/settings/#auth-password-validators
AUTH_PASSWORD_VALIDATORS = [
{
'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator',
},
{
'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator',
},
{
'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator',
},
{
'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator',
},
]
# Internationalization
# https://docs.djangoproject.com/en/4.2/topics/i18n/
LANGUAGE_CODE = CONFIG.get("LANGUAGE_CODE")
TIME_ZONE = CONFIG.get_time_zone()
USE_I18N = True
USE_TZ = True
# 文件上传配置
DATA_UPLOAD_MAX_NUMBER_FILES = 1000
# 支持的语言
LANGUAGES = [
('en', 'English'),
('zh', '中文简体'),
('zh-hant', '中文繁体')
]
# 翻译文件路径
LOCALE_PATHS = [
os.path.join(BASE_DIR.parent, 'locales')
]
# Static files (CSS, JavaScript, Images)
# https://docs.djangoproject.com/en/4.2/howto/static-files/
STATIC_URL = 'static/'
# Default primary key field type
# https://docs.djangoproject.com/en/4.2/ref/settings/#default-auto-field
DEFAULT_AUTO_FIELD = 'django.db.models.BigAutoField'
edition = 'CE'
if os.environ.get('MAXKB_REDIS_SENTINEL_SENTINELS') is not None:
DJANGO_REDIS_CONNECTION_FACTORY = "django_redis.pool.SentinelConnectionFactory"

View File

@ -1,22 +1,18 @@
# coding=utf-8
"""
Django settings for maxkb project.
Generated by 'django-admin startproject' using Django 4.2.4.
For more information on this file, see
https://docs.djangoproject.com/en/4.2/topics/settings/
For the full list of settings and their values, see
https://docs.djangoproject.com/en/4.2/ref/settings/
@project: MaxKB
@Author虎虎
@file web.py
@date2025/11/5 14:53
@desc:
"""
from pathlib import Path
from ..const import CONFIG, PROJECT_DIR
from ...const import CONFIG, PROJECT_DIR
import os
from django.utils.translation import gettext_lazy as _
# Build paths inside the project like this: BASE_DIR / 'subdir'.
BASE_DIR = Path(__file__).resolve().parent.parent
BASE_DIR = Path(__file__).resolve().parent.parent.parent
# Quick-start development settings - unsuitable for production
# See https://docs.djangoproject.com/en/4.2/howto/deployment/checklist/

View File

@ -0,0 +1,14 @@
# coding=utf-8
"""
@project: MaxKB-xpack
@Author虎虎
@file __init__.py.py
@date2025/11/5 14:45
@desc:
"""
import os
if os.environ.get('SERVER_NAME', 'web') == 'local_model':
from .model import *
else:
from .web import *

28
apps/maxkb/urls/model.py Normal file
View File

@ -0,0 +1,28 @@
"""
URL configuration for maxkb project.
The `urlpatterns` list routes URLs to views. For more information please see:
https://docs.djangoproject.com/en/4.2/topics/http/urls/
Examples:
Function views
1. Add an import: from my_app import views
2. Add a URL to urlpatterns: path('', views.home, name='home')
Class-based views
1. Add an import: from other_app.views import Home
2. Add a URL to urlpatterns: path('', Home.as_view(), name='home')
Including another URLconf
1. Import the include() function: from django.urls import include, path
2. Add a URL to urlpatterns: path('blog/', include('blog.urls'))
"""
from django.urls import path, include
from maxkb.const import CONFIG
admin_api_prefix = CONFIG.get_admin_path()[1:] + '/api/'
admin_ui_prefix = CONFIG.get_admin_path()
chat_api_prefix = CONFIG.get_chat_path()[1:] + '/api/'
chat_ui_prefix = CONFIG.get_chat_path()
urlpatterns = [
path(admin_api_prefix, include("local_model.urls")),
]

View File

@ -0,0 +1,14 @@
# coding=utf-8
"""
@project: MaxKB
@Author虎虎
@file __init__.py.py
@date2025/11/5 15:14
@desc:
"""
import os
if os.environ.get('SERVER_NAME', 'web') == 'local_model':
from .model import *
else:
from .web import *

15
apps/maxkb/wsgi/model.py Normal file
View File

@ -0,0 +1,15 @@
# coding=utf-8
"""
@project: MaxKB
@Author虎虎
@file model.py
@date2025/11/5 15:14
@desc:
"""
import os
from django.core.wsgi import get_wsgi_application
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'maxkb.settings')
application = get_wsgi_application()

View File

@ -1,12 +1,11 @@
# coding=utf-8
"""
WSGI config for maxkb project.
It exposes the WSGI callable as a module-level variable named ``application``.
For more information on this file, see
https://docs.djangoproject.com/en/4.2/howto/deployment/wsgi/
@project: MaxKB
@Author虎虎
@file web.py
@date2025/11/5 15:14
@desc:
"""
import os
from django.core.wsgi import get_wsgi_application
@ -35,4 +34,4 @@ post_handler()
# 仅在scheduler中启动定时任务dev local_model celery 不需要
if os.environ.get('ENABLE_SCHEDULER') == '1':
post_scheduler_handler()
post_scheduler_handler()

View File

@ -15,7 +15,7 @@ from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from models_provider.base_model_provider import BaseModelCredential, ValidCode
from models_provider.impl.local_model_provider.model.reranker import LocalBaseReranker
from models_provider.impl.local_model_provider.model.reranker import LocalReranker
from django.utils.translation import gettext_lazy as _, gettext
@ -33,7 +33,7 @@ class LocalRerankerCredential(BaseForm, BaseModelCredential):
else:
return False
try:
model: LocalBaseReranker = provider.get_model(model_type, model_name, model_credential)
model: LocalReranker = provider.get_model(model_type, model_name, model_credential)
model.compress_documents([Document(page_content=gettext('Hello'))], gettext('Hello'))
except Exception as e:
traceback.print_exc()

View File

@ -8,15 +8,16 @@
"""
import os
from django.utils.translation import gettext as _
from common.utils.common import get_file_content
from maxkb.conf import PROJECT_DIR
from models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, ModelInfo, IModelProvider, \
ModelInfoManage
from models_provider.impl.local_model_provider.credential.embedding import LocalEmbeddingCredential
from models_provider.impl.local_model_provider.credential.reranker import LocalRerankerCredential
from models_provider.impl.local_model_provider.model.embedding import LocalEmbedding
from models_provider.impl.local_model_provider.model.reranker import LocalReranker
from maxkb.conf import PROJECT_DIR
from django.utils.translation import gettext as _
embedding_text2vec_base_chinese = ModelInfo('shibing624/text2vec-base-chinese', '', ModelTypeConst.EMBEDDING,
LocalEmbeddingCredential(), LocalEmbedding)

View File

@ -1,64 +0,0 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file embedding.py
@date2024/7/11 14:06
@desc:
"""
from typing import Dict, List
import requests
from langchain_core.embeddings import Embeddings
from pydantic import BaseModel
from langchain_huggingface import HuggingFaceEmbeddings
from models_provider.base_model_provider import MaxKBBaseModel
from maxkb.const import CONFIG
class WebLocalEmbedding(MaxKBBaseModel, BaseModel, Embeddings):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
pass
model_id: str = None
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.model_id = kwargs.get('model_id', None)
def embed_query(self, text: str) -> List[float]:
bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}'
prefix = CONFIG.get_admin_path()
res = requests.post(f'{CONFIG.get("LOCAL_MODEL_PROTOCOL")}://{bind}{prefix}/api/model/{self.model_id}/embed_query',
{'text': text})
result = res.json()
if result.get('code', 500) == 200:
return result.get('data')
raise Exception(result.get('message'))
def embed_documents(self, texts: List[str]) -> List[List[float]]:
bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}'
prefix = CONFIG.get_admin_path()
res = requests.post(f'{CONFIG.get("LOCAL_MODEL_PROTOCOL")}://{bind}/{prefix}/api/model/{self.model_id}/embed_documents',
{'texts': texts})
result = res.json()
if result.get('code', 500) == 200:
return result.get('data')
raise Exception(result.get('message'))
class LocalEmbedding(MaxKBBaseModel, HuggingFaceEmbeddings):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
if model_kwargs.get('use_local', True):
return LocalEmbedding(model_name=model_name, cache_folder=model_credential.get('cache_folder'),
model_kwargs={'device': model_credential.get('device')},
encode_kwargs={'normalize_embeddings': True}
)
return WebLocalEmbedding(model_name=model_name, cache_folder=model_credential.get('cache_folder'),
model_kwargs={'device': model_credential.get('device')},
encode_kwargs={'normalize_embeddings': True},
**model_kwargs)

View File

@ -0,0 +1,14 @@
# coding=utf-8
"""
@project: MaxKB
@Author虎虎
@file __init__.py
@date2025/11/5 15:24
@desc:
"""
import os
if os.environ.get('SERVER_NAME', 'web') == 'local_model':
from .model import *
else:
from .web import *

View File

@ -0,0 +1,26 @@
# coding=utf-8
"""
@project: MaxKB
@Author虎虎
@file model.py
@date2025/11/5 15:26
@desc:
"""
from typing import Dict
from langchain_huggingface import HuggingFaceEmbeddings
from models_provider.base_model_provider import MaxKBBaseModel
class LocalEmbedding(MaxKBBaseModel, HuggingFaceEmbeddings):
@staticmethod
def is_cache_model():
return True
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
return LocalEmbedding(model_name=model_name, cache_folder=model_credential.get('cache_folder'),
model_kwargs={'device': model_credential.get('device')},
encode_kwargs={'normalize_embeddings': True}
)

View File

@ -0,0 +1,54 @@
# coding=utf-8
"""
@project: MaxKB
@Author虎虎
@file web.py
@date2025/11/5 15:24
@desc:
"""
from typing import Dict, List
import requests
from anthropic import BaseModel
from langchain_core.embeddings import Embeddings
from maxkb.const import CONFIG
from models_provider.base_model_provider import MaxKBBaseModel
class LocalEmbedding(MaxKBBaseModel, BaseModel, Embeddings):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.model_id = kwargs.get('model_id', None)
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
return LocalEmbedding(model_name=model_name, cache_folder=model_credential.get('cache_folder'),
model_kwargs={'device': model_credential.get('device')},
encode_kwargs={'normalize_embeddings': True},
**model_kwargs)
model_id: str = None
def embed_query(self, text: str) -> List[float]:
bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}'
prefix = CONFIG.get_admin_path()
res = requests.post(
f'{CONFIG.get("LOCAL_MODEL_PROTOCOL")}://{bind}{prefix}/api/model/{self.model_id}/embed_query',
{'text': text})
result = res.json()
if result.get('code', 500) == 200:
return result.get('data')
raise Exception(result.get('message'))
def embed_documents(self, texts: List[str]) -> List[List[float]]:
bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}'
prefix = CONFIG.get_admin_path()
res = requests.post(
f'{CONFIG.get("LOCAL_MODEL_PROTOCOL")}://{bind}/{prefix}/api/model/{self.model_id}/embed_documents',
{'texts': texts})
result = res.json()
if result.get('code', 500) == 200:
return result.get('data')
raise Exception(result.get('message'))

View File

@ -1,102 +0,0 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file reranker.py.py
@date2024/9/2 16:42
@desc:
"""
from typing import Sequence, Optional, Dict, Any, ClassVar
import requests
import torch
from langchain_core.callbacks import Callbacks
from langchain_core.documents import BaseDocumentCompressor, Document
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from models_provider.base_model_provider import MaxKBBaseModel
from maxkb.const import CONFIG
class LocalReranker(MaxKBBaseModel):
def __init__(self, model_name, top_n=3, cache_dir=None):
super().__init__()
self.model_name = model_name
self.cache_dir = cache_dir
self.top_n = top_n
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
if model_kwargs.get('use_local', True):
return LocalBaseReranker(model_name=model_name, cache_dir=model_credential.get('cache_dir'),
model_kwargs={'device': model_credential.get('device', 'cpu')}
)
return WebLocalBaseReranker(model_name=model_name, cache_dir=model_credential.get('cache_dir'),
model_kwargs={'device': model_credential.get('device')},
**model_kwargs)
class WebLocalBaseReranker(MaxKBBaseModel, BaseDocumentCompressor):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
pass
model_id: str = None
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.model_id = kwargs.get('model_id', None)
def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \
Sequence[Document]:
if documents is None or len(documents) == 0:
return []
prefix = CONFIG.get_admin_path()
bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}'
res = requests.post(
f'{CONFIG.get("LOCAL_MODEL_PROTOCOL")}://{bind}{prefix}/api/model/{self.model_id}/compress_documents',
json={'documents': [{'page_content': document.page_content, 'metadata': document.metadata} for document in
documents], 'query': query}, headers={'Content-Type': 'application/json'})
result = res.json()
if result.get('code', 500) == 200:
return [Document(page_content=document.get('page_content'), metadata=document.get('metadata')) for document
in result.get('data')]
raise Exception(result.get('message'))
class LocalBaseReranker(MaxKBBaseModel, BaseDocumentCompressor):
client: Any = None
tokenizer: Any = None
model: Optional[str] = None
cache_dir: Optional[str] = None
model_kwargs: Any = {}
def __init__(self, model_name, cache_dir=None, **model_kwargs):
super().__init__()
self.model = model_name
self.cache_dir = cache_dir
self.model_kwargs = model_kwargs
self.client = AutoModelForSequenceClassification.from_pretrained(self.model, cache_dir=self.cache_dir)
self.tokenizer = AutoTokenizer.from_pretrained(self.model, cache_dir=self.cache_dir)
self.client = self.client.to(self.model_kwargs.get('device', 'cpu'))
self.client.eval()
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
return LocalBaseReranker(model_name, cache_dir=model_credential.get('cache_dir'), **model_kwargs)
def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \
Sequence[Document]:
if documents is None or len(documents) == 0:
return []
with torch.no_grad():
inputs = self.tokenizer([[query, document.page_content] for document in documents], padding=True,
truncation=True, return_tensors='pt', max_length=512)
scores = [torch.sigmoid(s).float().item() for s in
self.client(**inputs, return_dict=True).logits.view(-1, ).float()]
result = [Document(page_content=documents[index].page_content, metadata={'relevance_score': scores[index]})
for index
in range(len(documents))]
result.sort(key=lambda row: row.metadata.get('relevance_score'), reverse=True)
return result

View File

@ -0,0 +1,14 @@
# coding=utf-8
"""
@project: MaxKB
@Author虎虎
@file __init__.py.py
@date2025/11/5 15:30
@desc:
"""
import os
if os.environ.get('SERVER_NAME', 'web') == 'local_model':
from .model import *
else:
from .web import *

View File

@ -0,0 +1,58 @@
# coding=utf-8
"""
@project: MaxKB
@Author虎虎
@file model.py
@date2025/11/5 15:30
@desc:
"""
from typing import Sequence, Optional, Dict, Any
from langchain_core.callbacks import Callbacks
from langchain_core.documents import Document, BaseDocumentCompressor
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from models_provider.base_model_provider import MaxKBBaseModel
class LocalReranker(MaxKBBaseModel, BaseDocumentCompressor):
client: Any = None
tokenizer: Any = None
model: Optional[str] = None
cache_dir: Optional[str] = None
model_kwargs: Any = {}
def __init__(self, model_name, cache_dir=None, **model_kwargs):
super().__init__()
self.model = model_name
self.cache_dir = cache_dir
self.model_kwargs = model_kwargs
self.client = AutoModelForSequenceClassification.from_pretrained(self.model, cache_dir=self.cache_dir)
self.tokenizer = AutoTokenizer.from_pretrained(self.model, cache_dir=self.cache_dir)
self.client = self.client.to(self.model_kwargs.get('device', 'cpu'))
self.client.eval()
@staticmethod
def is_cache_model():
return False
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
return LocalReranker(model_name, cache_dir=model_credential.get('cache_dir'))
def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \
Sequence[Document]:
if documents is None or len(documents) == 0:
return []
import torch
with torch.no_grad():
inputs = self.tokenizer([[query, document.page_content] for document in documents], padding=True,
truncation=True, return_tensors='pt', max_length=512)
scores = [torch.sigmoid(s).float().item() for s in
self.client(**inputs, return_dict=True).logits.view(-1, ).float()]
result = [Document(page_content=documents[index].page_content, metadata={'relevance_score': scores[index]})
for index
in range(len(documents))]
result.sort(key=lambda row: row.metadata.get('relevance_score'), reverse=True)
return result

View File

@ -0,0 +1,52 @@
# coding=utf-8
"""
@project: MaxKB
@Author虎虎
@file web.py
@date2025/11/5 15:30
@desc:
"""
from typing import Sequence, Optional, Dict
import requests
from anthropic import BaseModel
from langchain_core.callbacks import Callbacks
from langchain_core.documents import Document, BaseDocumentCompressor
from maxkb.const import CONFIG
from models_provider.base_model_provider import MaxKBBaseModel
class LocalReranker(MaxKBBaseModel, BaseModel, BaseDocumentCompressor):
@staticmethod
def is_cache_model():
return False
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
return LocalReranker(model_type=model_type, model_name=model_name, model_credential=model_credential,
**model_kwargs)
model_id: str = None
def __init__(self, **kwargs):
super().__init__(**kwargs)
print('ssss', kwargs.get('model_id', None))
self.model_id = kwargs.get('model_id', None)
def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \
Sequence[Document]:
if documents is None or len(documents) == 0:
return []
prefix = CONFIG.get_admin_path()
bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}'
res = requests.post(
f'{CONFIG.get("LOCAL_MODEL_PROTOCOL")}://{bind}{prefix}/api/model/{self.model_id}/compress_documents',
json={'documents': [{'page_content': document.page_content, 'metadata': document.metadata} for document in
documents], 'query': query}, headers={'Content-Type': 'application/json'})
result = res.json()
if result.get('code', 500) == 200:
return [Document(page_content=document.get('page_content'), metadata=document.get('metadata')) for document
in result.get('data')]
raise Exception(result.get('message'))

View File

@ -1,12 +1,12 @@
from typing import Sequence, Optional, Any, Dict
from typing import Sequence, Optional, Dict
from langchain_community.embeddings import OllamaEmbeddings
from langchain_core.callbacks import Callbacks
from langchain_core.documents import Document
from models_provider.base_model_provider import MaxKBBaseModel
from sklearn.metrics.pairwise import cosine_similarity
from pydantic import BaseModel, Field
from models_provider.base_model_provider import MaxKBBaseModel
class OllamaReranker(MaxKBBaseModel, OllamaEmbeddings, BaseModel):
top_n: Optional[int] = Field(3, description="Number of top documents to return")
@ -22,6 +22,7 @@ class OllamaReranker(MaxKBBaseModel, OllamaEmbeddings, BaseModel):
def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \
Sequence[Document]:
from sklearn.metrics.pairwise import cosine_similarity
"""Rank documents based on their similarity to the query.
Args:
@ -37,7 +38,7 @@ class OllamaReranker(MaxKBBaseModel, OllamaEmbeddings, BaseModel):
document_embeddings = self.embed_documents(documents)
# 计算相似度
similarities = cosine_similarity([query_embedding], document_embeddings)[0]
ranked_docs = [(doc,_) for _, doc in sorted(zip(similarities, documents), reverse=True)][:self.top_n]
ranked_docs = [(doc, _) for _, doc in sorted(zip(similarities, documents), reverse=True)][:self.top_n]
return [
Document(
page_content=doc, # 第一个值是文档内容
@ -45,5 +46,3 @@ class OllamaReranker(MaxKBBaseModel, OllamaEmbeddings, BaseModel):
)
for doc, score in ranked_docs
]

View File

@ -13,7 +13,6 @@ APP_DIR = os.path.join(BASE_DIR, 'apps')
os.chdir(BASE_DIR)
sys.path.insert(0, APP_DIR)
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "maxkb.settings")
django.setup()
def collect_static():
@ -74,7 +73,6 @@ def dev():
elif services.__contains__('celery'):
management.call_command('celery', 'celery')
elif services.__contains__('local_model'):
os.environ.setdefault('SERVER_NAME', 'local_model')
from maxkb.const import CONFIG
bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}'
management.call_command('runserver', bind)
@ -108,6 +106,12 @@ if __name__ == '__main__':
parser.add_argument('-f', '--force', nargs="?", const=True)
args = parser.parse_args()
action = args.action
services = args.services if isinstance(args.services, list) else args.services
if services.__contains__('web'):
os.environ.setdefault('SERVER_NAME', 'web')
elif services.__contains__('local_model'):
os.environ.setdefault('SERVER_NAME', 'local_model')
django.setup()
if action == "upgrade_db":
perform_db_migrate()
elif action == "collect_static":
@ -120,4 +124,3 @@ if __name__ == '__main__':
collect_static()
perform_db_migrate()
start_services()