From a8d0729e652223173411ebcf0fb6b5ea8966cff2 Mon Sep 17 00:00:00 2001 From: shaohuzhang1 <80892890+shaohuzhang1@users.noreply.github.com> Date: Wed, 5 Nov 2025 19:05:26 +0800 Subject: [PATCH] perf: Memory optimization (#4318) --- .../loop_node/impl/base_loop_node.py | 5 +- .../handle/impl/text/zip_split_handle.py | 2 - .../model => local_model}/__init__.py | 0 apps/local_model/admin.py | 3 + apps/local_model/apps.py | 6 + apps/local_model/migrations/__init__.py | 0 apps/local_model/models/__init__.py | 10 + apps/local_model/models/model_management.py | 49 +++++ apps/local_model/models/system_setting.py | 34 ++++ apps/local_model/models/user.py | 38 ++++ apps/local_model/serializers/__init__.py | 1 + .../serializers/model_apply_serializers.py | 140 ++++++++++++++ apps/local_model/serializers/rsa_util.py | 139 ++++++++++++++ apps/local_model/tests.py | 3 + apps/local_model/urls.py | 13 ++ apps/local_model/views/__init__.py | 2 + apps/local_model/views/model_apply.py | 34 ++++ apps/maxkb/settings/auth/__init__.py | 14 ++ apps/maxkb/settings/auth/model.py | 11 ++ apps/maxkb/settings/{auth.py => auth/web.py} | 0 apps/maxkb/settings/base/__init__.py | 14 ++ apps/maxkb/settings/base/model.py | 179 ++++++++++++++++++ apps/maxkb/settings/{base.py => base/web.py} | 20 +- apps/maxkb/urls/__init__.py | 14 ++ apps/maxkb/urls/model.py | 28 +++ apps/maxkb/{urls.py => urls/web.py} | 0 apps/maxkb/wsgi/__init__.py | 14 ++ apps/maxkb/wsgi/model.py | 15 ++ apps/maxkb/{wsgi.py => wsgi/web.py} | 15 +- .../credential/reranker.py | 4 +- .../local_model_provider.py | 5 +- .../local_model_provider/model/embedding.py | 64 ------- .../model/embedding/__init__.py | 14 ++ .../model/embedding/model.py | 26 +++ .../model/embedding/web.py | 54 ++++++ .../local_model_provider/model/reranker.py | 102 ---------- .../model/reranker/__init__.py | 14 ++ .../model/reranker/model.py | 58 ++++++ .../model/reranker/web.py | 52 +++++ .../ollama_model_provider/model/reranker.py | 11 +- main.py | 9 +- 41 files changed, 1013 insertions(+), 203 deletions(-) rename apps/{models_provider/impl/local_model_provider/model => local_model}/__init__.py (100%) create mode 100644 apps/local_model/admin.py create mode 100644 apps/local_model/apps.py create mode 100644 apps/local_model/migrations/__init__.py create mode 100644 apps/local_model/models/__init__.py create mode 100644 apps/local_model/models/model_management.py create mode 100644 apps/local_model/models/system_setting.py create mode 100644 apps/local_model/models/user.py create mode 100644 apps/local_model/serializers/__init__.py create mode 100644 apps/local_model/serializers/model_apply_serializers.py create mode 100644 apps/local_model/serializers/rsa_util.py create mode 100644 apps/local_model/tests.py create mode 100644 apps/local_model/urls.py create mode 100644 apps/local_model/views/__init__.py create mode 100644 apps/local_model/views/model_apply.py create mode 100644 apps/maxkb/settings/auth/__init__.py create mode 100644 apps/maxkb/settings/auth/model.py rename apps/maxkb/settings/{auth.py => auth/web.py} (100%) create mode 100644 apps/maxkb/settings/base/__init__.py create mode 100644 apps/maxkb/settings/base/model.py rename apps/maxkb/settings/{base.py => base/web.py} (93%) create mode 100644 apps/maxkb/urls/__init__.py create mode 100644 apps/maxkb/urls/model.py rename apps/maxkb/{urls.py => urls/web.py} (100%) create mode 100644 apps/maxkb/wsgi/__init__.py create mode 100644 apps/maxkb/wsgi/model.py rename apps/maxkb/{wsgi.py => wsgi/web.py} (71%) delete mode 100644 apps/models_provider/impl/local_model_provider/model/embedding.py create mode 100644 apps/models_provider/impl/local_model_provider/model/embedding/__init__.py create mode 100644 apps/models_provider/impl/local_model_provider/model/embedding/model.py create mode 100644 apps/models_provider/impl/local_model_provider/model/embedding/web.py delete mode 100644 apps/models_provider/impl/local_model_provider/model/reranker.py create mode 100644 apps/models_provider/impl/local_model_provider/model/reranker/__init__.py create mode 100644 apps/models_provider/impl/local_model_provider/model/reranker/model.py create mode 100644 apps/models_provider/impl/local_model_provider/model/reranker/web.py diff --git a/apps/application/flow/step_node/loop_node/impl/base_loop_node.py b/apps/application/flow/step_node/loop_node/impl/base_loop_node.py index 233551e7a..9ec14cb5d 100644 --- a/apps/application/flow/step_node/loop_node/impl/base_loop_node.py +++ b/apps/application/flow/step_node/loop_node/impl/base_loop_node.py @@ -224,11 +224,12 @@ class LoopWorkFlowPostHandler(WorkFlowPostHandler): class BaseLoopNode(ILoopNode): def save_context(self, details, workflow_manage): - self.context['result'] = details.get('result') + self.context['loop_context_data'] = details.get('loop_context_data') + self.context['loop_answer_data'] = details.get('loop_answer_data') for key, value in details['context'].items(): if key not in self.context: self.context[key] = value - self.answer_text = str(details.get('result')) + self.answer_text = "" def get_answer_list(self) -> List[Answer] | None: result = [] diff --git a/apps/common/handle/impl/text/zip_split_handle.py b/apps/common/handle/impl/text/zip_split_handle.py index 5752fe0d7..6609a981c 100644 --- a/apps/common/handle/impl/text/zip_split_handle.py +++ b/apps/common/handle/impl/text/zip_split_handle.py @@ -15,7 +15,6 @@ from urllib.parse import urljoin import uuid_utils.compat as uuid from charset_normalizer import detect -from django.db.models import QuerySet from django.utils.translation import gettext_lazy as _ from common.handle.base_split_handle import BaseSplitHandle @@ -39,7 +38,6 @@ class FileBufferHandle: return self.buffer - default_split_handle = TextSplitHandle() split_handles = [ HTMLSplitHandle(), diff --git a/apps/models_provider/impl/local_model_provider/model/__init__.py b/apps/local_model/__init__.py similarity index 100% rename from apps/models_provider/impl/local_model_provider/model/__init__.py rename to apps/local_model/__init__.py diff --git a/apps/local_model/admin.py b/apps/local_model/admin.py new file mode 100644 index 000000000..8c38f3f3d --- /dev/null +++ b/apps/local_model/admin.py @@ -0,0 +1,3 @@ +from django.contrib import admin + +# Register your models here. diff --git a/apps/local_model/apps.py b/apps/local_model/apps.py new file mode 100644 index 000000000..285ca7278 --- /dev/null +++ b/apps/local_model/apps.py @@ -0,0 +1,6 @@ +from django.apps import AppConfig + + +class LocalModelConfig(AppConfig): + default_auto_field = 'django.db.models.BigAutoField' + name = 'local_model' diff --git a/apps/local_model/migrations/__init__.py b/apps/local_model/migrations/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/apps/local_model/models/__init__.py b/apps/local_model/models/__init__.py new file mode 100644 index 000000000..8d63a0921 --- /dev/null +++ b/apps/local_model/models/__init__.py @@ -0,0 +1,10 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py + @date:2023/9/25 15:04 + @desc: +""" + +from .model_management import * diff --git a/apps/local_model/models/model_management.py b/apps/local_model/models/model_management.py new file mode 100644 index 000000000..ff3c0bf4e --- /dev/null +++ b/apps/local_model/models/model_management.py @@ -0,0 +1,49 @@ +# coding=utf-8 +import uuid_utils.compat as uuid + +from django.db import models + +from common.mixins.app_model_mixin import AppModelMixin +from local_model.models.user import User + + +class Status(models.TextChoices): + """系统设置类型""" + SUCCESS = "SUCCESS", '成功' + + ERROR = "ERROR", "失败" + + DOWNLOAD = "DOWNLOAD", '下载中' + + PAUSE_DOWNLOAD = "PAUSE_DOWNLOAD", '暂停下载' + + +class Model(AppModelMixin): + """ + 模型数据 + """ + id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid7, editable=False, verbose_name="主键id") + + name = models.CharField(max_length=128, verbose_name="名称", db_index=True) + + status = models.CharField(max_length=20, verbose_name='设置类型', choices=Status.choices, + default=Status.SUCCESS, db_index=True) + + model_type = models.CharField(max_length=128, verbose_name="模型类型", db_index=True) + + model_name = models.CharField(max_length=128, verbose_name="模型名称", db_index=True) + + user = models.ForeignKey(User, on_delete=models.SET_NULL, db_constraint=False, blank=True, null=True) + + provider = models.CharField(max_length=128, verbose_name='供应商', db_index=True) + + credential = models.CharField(max_length=102400, verbose_name="模型认证信息") + + meta = models.JSONField(verbose_name="模型元数据,用于存储下载,或者错误信息", default=dict) + + model_params_form = models.JSONField(verbose_name="模型参数配置", default=list) + workspace_id = models.CharField(max_length=64, verbose_name="工作空间id", default="default", db_index=True) + + class Meta: + db_table = "model" + unique_together = ['name', 'workspace_id'] diff --git a/apps/local_model/models/system_setting.py b/apps/local_model/models/system_setting.py new file mode 100644 index 000000000..4b62d47a7 --- /dev/null +++ b/apps/local_model/models/system_setting.py @@ -0,0 +1,34 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: system_management.py + @date:2024/3/19 13:47 + @desc: 邮箱管理 +""" + +from django.db import models + +from common.mixins.app_model_mixin import AppModelMixin + + +class SettingType(models.IntegerChoices): + """系统设置类型""" + EMAIL = 0, '邮箱' + + RSA = 1, "私钥秘钥" + + LOG = 2, "日志清理时间" + + +class SystemSetting(AppModelMixin): + """ + 系统设置 + """ + type = models.IntegerField(primary_key=True, verbose_name='设置类型', choices=SettingType.choices, + default=SettingType.EMAIL) + + meta = models.JSONField(verbose_name="配置数据", default=dict) + + class Meta: + db_table = "system_setting" diff --git a/apps/local_model/models/user.py b/apps/local_model/models/user.py new file mode 100644 index 000000000..0f480d89f --- /dev/null +++ b/apps/local_model/models/user.py @@ -0,0 +1,38 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎虎 + @file: user.py + @date:2025/4/14 10:20 + @desc: +""" +import uuid_utils.compat as uuid + +from django.db import models + +from common.utils.common import password_encrypt + + +class User(models.Model): + id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid7, editable=False, verbose_name="主键id") + email = models.EmailField(unique=True, null=True, blank=True, verbose_name="邮箱", db_index=True) + phone = models.CharField(max_length=20, verbose_name="电话", default="", db_index=True) + nick_name = models.CharField(max_length=150, verbose_name="昵称", unique=True, db_index=True) + username = models.CharField(max_length=150, unique=True, verbose_name="用户名", db_index=True) + password = models.CharField(max_length=150, verbose_name="密码") + role = models.CharField(max_length=150, verbose_name="角色") + source = models.CharField(max_length=10, verbose_name="来源", default="LOCAL", db_index=True) + is_active = models.BooleanField(default=True, db_index=True) + language = models.CharField(max_length=10, verbose_name="语言", null=True, default=None) + create_time = models.DateTimeField(verbose_name="创建时间", auto_now_add=True, null=True, db_index=True) + update_time = models.DateTimeField(verbose_name="修改时间", auto_now=True, null=True, db_index=True) + + USERNAME_FIELD = 'username' + REQUIRED_FIELDS = [] + + class Meta: + db_table = "user" + + def set_password(self, row_password): + self.password = password_encrypt(row_password) + self._password = row_password diff --git a/apps/local_model/serializers/__init__.py b/apps/local_model/serializers/__init__.py new file mode 100644 index 000000000..9bad5790a --- /dev/null +++ b/apps/local_model/serializers/__init__.py @@ -0,0 +1 @@ +# coding=utf-8 diff --git a/apps/local_model/serializers/model_apply_serializers.py b/apps/local_model/serializers/model_apply_serializers.py new file mode 100644 index 000000000..cf57c2fef --- /dev/null +++ b/apps/local_model/serializers/model_apply_serializers.py @@ -0,0 +1,140 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: model_apply_serializers.py + @date:2024/8/20 20:39 + @desc: +""" +import json +import threading +import time + +from django.db import connection +from django.db.models import QuerySet +from django.utils.translation import gettext_lazy as _ +from langchain_core.documents import Document +from rest_framework import serializers + +from local_model.models import Model +from local_model.serializers.rsa_util import rsa_long_decrypt +from models_provider.impl.local_model_provider.local_model_provider import LocalModelProvider + +from common.cache.mem_cache import MemCache + +_lock = threading.Lock() +locks = {} + + +class ModelManage: + cache = MemCache('model', {}) + up_clear_time = time.time() + + @staticmethod + def _get_lock(_id): + lock = locks.get(_id) + if lock is None: + with _lock: + lock = locks.get(_id) + if lock is None: + lock = threading.Lock() + locks[_id] = lock + + return lock + + @staticmethod + def get_model(_id, get_model): + model_instance = ModelManage.cache.get(_id) + if model_instance is None: + lock = ModelManage._get_lock(_id) + with lock: + model_instance = ModelManage.cache.get(_id) + if model_instance is None: + model_instance = get_model(_id) + ModelManage.cache.set(_id, model_instance, timeout=60 * 60 * 8) + else: + if model_instance.is_cache_model(): + ModelManage.cache.touch(_id, timeout=60 * 60 * 8) + else: + model_instance = get_model(_id) + ModelManage.cache.set(_id, model_instance, timeout=60 * 60 * 8) + ModelManage.clear_timeout_cache() + return model_instance + + @staticmethod + def clear_timeout_cache(): + if time.time() - ModelManage.up_clear_time > 60 * 60: + threading.Thread(target=lambda: ModelManage.cache.clear_timeout_data()).start() + ModelManage.up_clear_time = time.time() + + @staticmethod + def delete_key(_id): + if ModelManage.cache.has_key(_id): + ModelManage.cache.delete(_id) + + +def get_local_model(model, **kwargs): + # system_setting = QuerySet(SystemSetting).filter(type=1).first() + return LocalModelProvider().get_model(model.model_type, model.model_name, + json.loads( + rsa_long_decrypt(model.credential)), + model_id=model.id, + streaming=True, **kwargs) + + +def get_embedding_model(model_id): + model = QuerySet(Model).filter(id=model_id).first() + # 手动关闭数据库连接 + connection.close() + embedding_model = ModelManage.get_model(model_id, + lambda _id: get_local_model(model, use_local=True)) + return embedding_model + + +class EmbedDocuments(serializers.Serializer): + texts = serializers.ListField(required=True, child=serializers.CharField(required=True, + label=_('vector text')), + label=_('vector text list')), + + +class EmbedQuery(serializers.Serializer): + text = serializers.CharField(required=True, label=_('vector text')) + + +class CompressDocument(serializers.Serializer): + page_content = serializers.CharField(required=True, label=_('text')) + metadata = serializers.DictField(required=False, label=_('metadata')) + + +class CompressDocuments(serializers.Serializer): + documents = CompressDocument(required=True, many=True) + query = serializers.CharField(required=True, label=_('query')) + + +class ModelApplySerializers(serializers.Serializer): + model_id = serializers.UUIDField(required=True, label=_('model id')) + + def embed_documents(self, instance, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + EmbedDocuments(data=instance).is_valid(raise_exception=True) + + model = get_embedding_model(self.data.get('model_id')) + return model.embed_documents(instance.getlist('texts')) + + def embed_query(self, instance, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + EmbedQuery(data=instance).is_valid(raise_exception=True) + + model = get_embedding_model(self.data.get('model_id')) + return model.embed_query(instance.get('text')) + + def compress_documents(self, instance, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + CompressDocuments(data=instance).is_valid(raise_exception=True) + model = get_embedding_model(self.data.get('model_id')) + return [{'page_content': d.page_content, 'metadata': d.metadata} for d in model.compress_documents( + [Document(page_content=document.get('page_content'), metadata=document.get('metadata')) for document in + instance.get('documents')], instance.get('query'))] diff --git a/apps/local_model/serializers/rsa_util.py b/apps/local_model/serializers/rsa_util.py new file mode 100644 index 000000000..df2cedba7 --- /dev/null +++ b/apps/local_model/serializers/rsa_util.py @@ -0,0 +1,139 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: rsa_util.py + @date:2023/11/3 11:13 + @desc: +""" +import base64 +import threading + +from Crypto.Cipher import PKCS1_v1_5 as PKCS1_cipher +from Crypto.PublicKey import RSA +from django.core import cache +from django.db.models import QuerySet + +from common.constants.cache_version import Cache_Version +from local_model.models.system_setting import SystemSetting, SettingType + +lock = threading.Lock() +rsa_cache = cache.cache +cache_key = "rsa_key" +# 对密钥加密的密码 +secret_code = "mac_kb_password" + + +def generate(): + """ + 生成 私钥秘钥对 + :return:{key:'公钥',value:'私钥'} + """ + # 生成一个 2048 位的密钥 + key = RSA.generate(2048) + + # 获取私钥 + encrypted_key = key.export_key(passphrase=secret_code, pkcs=8, + protection="scryptAndAES128-CBC") + return {'key': key.publickey().export_key(), 'value': encrypted_key} + + +def get_key_pair(): + rsa_value = rsa_cache.get(cache_key) + if rsa_value is None: + with lock: + rsa_value = rsa_cache.get(cache_key) + if rsa_value is not None: + return rsa_value + rsa_value = get_key_pair_by_sql() + version, get_key = Cache_Version.SYSTEM.value + rsa_cache.set(get_key(key='rsa_key'), rsa_value, timeout=None, version=version) + return rsa_value + + +def get_key_pair_by_sql(): + system_setting = QuerySet(SystemSetting).filter(type=SettingType.RSA.value).first() + if system_setting is None: + kv = generate() + system_setting = SystemSetting(type=SettingType.RSA.value, + meta={'key': kv.get('key').decode(), 'value': kv.get('value').decode()}) + system_setting.save() + return system_setting.meta + + +def encrypt(msg, public_key: str | None = None): + """ + 加密 + :param msg: 加密数据 + :param public_key: 公钥 + :return: 加密后的数据 + """ + if public_key is None: + public_key = get_key_pair().get('key') + cipher = PKCS1_cipher.new(RSA.importKey(public_key)) + encrypt_msg = cipher.encrypt(msg.encode("utf-8")) + return base64.b64encode(encrypt_msg).decode() + + +def decrypt(msg, pri_key: str | None = None): + """ + 解密 + :param msg: 需要解密的数据 + :param pri_key: 私钥 + :return: 解密后数据 + """ + if pri_key is None: + pri_key = get_key_pair().get('value') + cipher = PKCS1_cipher.new(RSA.importKey(pri_key, passphrase=secret_code)) + decrypt_data = cipher.decrypt(base64.b64decode(msg), 0) + return decrypt_data.decode("utf-8") + + +def rsa_long_encrypt(message, public_key: str | None = None, length=200): + """ + 超长文本加密 + + :param message: 需要加密的字符串 + :param public_key 公钥 + :param length: 1024bit的证书用100, 2048bit的证书用 200 + :return: 加密后的数据 + """ + # 读取公钥 + if public_key is None: + public_key = get_key_pair().get('key') + cipher = PKCS1_cipher.new(RSA.importKey(extern_key=public_key, + passphrase=secret_code)) + # 处理:Plaintext is too long. 分段加密 + if len(message) <= length: + # 对编码的数据进行加密,并通过base64进行编码 + result = base64.b64encode(cipher.encrypt(message.encode('utf-8'))) + else: + rsa_text = [] + # 对编码后的数据进行切片,原因:加密长度不能过长 + for i in range(0, len(message), length): + cont = message[i:i + length] + # 对切片后的数据进行加密,并新增到text后面 + rsa_text.append(cipher.encrypt(cont.encode('utf-8'))) + # 加密完进行拼接 + cipher_text = b''.join(rsa_text) + # base64进行编码 + result = base64.b64encode(cipher_text) + return result.decode() + + +def rsa_long_decrypt(message, pri_key: str | None = None, length=256): + """ + 超长文本解密,默认不加密 + :param message: 需要解密的数据 + :param pri_key: 秘钥 + :param length : 1024bit的证书用128,2048bit证书用256位 + :return: 解密后的数据 + """ + if pri_key is None: + pri_key = get_key_pair().get('value') + cipher = PKCS1_cipher.new(RSA.importKey(pri_key, passphrase=secret_code)) + base64_de = base64.b64decode(message) + res = [] + for i in range(0, len(base64_de), length): + res.append(cipher.decrypt(base64_de[i:i + length], 0)) + return b"".join(res).decode() diff --git a/apps/local_model/tests.py b/apps/local_model/tests.py new file mode 100644 index 000000000..7ce503c2d --- /dev/null +++ b/apps/local_model/tests.py @@ -0,0 +1,3 @@ +from django.test import TestCase + +# Create your tests here. diff --git a/apps/local_model/urls.py b/apps/local_model/urls.py new file mode 100644 index 000000000..cc47946e0 --- /dev/null +++ b/apps/local_model/urls.py @@ -0,0 +1,13 @@ +import os + +from django.urls import path + +from . import views + +app_name = "local_model" +# @formatter:off +urlpatterns = [ + path('model//embed_documents', views.LocalModelApply.EmbedDocuments.as_view()), + path('model//embed_query', views.LocalModelApply.EmbedQuery.as_view()), + path('model//compress_documents', views.LocalModelApply.CompressDocuments.as_view()), + ] diff --git a/apps/local_model/views/__init__.py b/apps/local_model/views/__init__.py new file mode 100644 index 000000000..b9dd8b0a4 --- /dev/null +++ b/apps/local_model/views/__init__.py @@ -0,0 +1,2 @@ +# coding=utf-8 +from .model_apply import * diff --git a/apps/local_model/views/model_apply.py b/apps/local_model/views/model_apply.py new file mode 100644 index 000000000..218d4f091 --- /dev/null +++ b/apps/local_model/views/model_apply.py @@ -0,0 +1,34 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: model_apply.py + @date:2024/8/20 20:38 + @desc: +""" +from urllib.request import Request + +from rest_framework.views import APIView + +from common.result import result +from local_model.serializers.model_apply_serializers import ModelApplySerializers + + +class LocalModelApply(APIView): + class EmbedDocuments(APIView): + + def post(self, request: Request, model_id): + return result.success( + ModelApplySerializers(data={'model_id': model_id}).embed_documents(request.data)) + + class EmbedQuery(APIView): + + def post(self, request: Request, model_id): + return result.success( + ModelApplySerializers(data={'model_id': model_id}).embed_query(request.data)) + + class CompressDocuments(APIView): + + def post(self, request: Request, model_id): + return result.success( + ModelApplySerializers(data={'model_id': model_id}).compress_documents(request.data)) diff --git a/apps/maxkb/settings/auth/__init__.py b/apps/maxkb/settings/auth/__init__.py new file mode 100644 index 000000000..b7996e1f5 --- /dev/null +++ b/apps/maxkb/settings/auth/__init__.py @@ -0,0 +1,14 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎虎 + @file: __init__.py + @date:2025/11/5 14:50 + @desc: +""" +import os + +if os.environ.get('SERVER_NAME', 'web') == 'local_model': + from .model import * +else: + from .web import * diff --git a/apps/maxkb/settings/auth/model.py b/apps/maxkb/settings/auth/model.py new file mode 100644 index 000000000..a21013025 --- /dev/null +++ b/apps/maxkb/settings/auth/model.py @@ -0,0 +1,11 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎虎 + @file: auth.py + @date:2024/7/9 18:47 + @desc: +""" + +AUTH_HANDLES = [ +] diff --git a/apps/maxkb/settings/auth.py b/apps/maxkb/settings/auth/web.py similarity index 100% rename from apps/maxkb/settings/auth.py rename to apps/maxkb/settings/auth/web.py diff --git a/apps/maxkb/settings/base/__init__.py b/apps/maxkb/settings/base/__init__.py new file mode 100644 index 000000000..65d1845bb --- /dev/null +++ b/apps/maxkb/settings/base/__init__.py @@ -0,0 +1,14 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎虎 + @file: __init__.py.py + @date:2025/11/5 14:53 + @desc: +""" +import os + +if os.environ.get('SERVER_NAME', 'web') == 'local_model': + from .model import * +else: + from .web import * diff --git a/apps/maxkb/settings/base/model.py b/apps/maxkb/settings/base/model.py new file mode 100644 index 000000000..a47364237 --- /dev/null +++ b/apps/maxkb/settings/base/model.py @@ -0,0 +1,179 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎虎 + @file: model.py + @date:2025/11/5 14:53 + @desc: +""" + +from pathlib import Path +from ...const import CONFIG, PROJECT_DIR +import os +from django.utils.translation import gettext_lazy as _ + +# Build paths inside the project like this: BASE_DIR / 'subdir'. +BASE_DIR = Path(__file__).resolve().parent.parent.parent + +# Quick-start development settings - unsuitable for production +# See https://docs.djangoproject.com/en/4.2/howto/deployment/checklist/ + +# SECURITY WARNING: keep the secret key used in production secret! +SECRET_KEY = CONFIG.get("SECRET_KEY") or 'django-insecure-zm^1_^i5)3gp^&0io6zg72&z!a*d=9kf9o2%uft+27l)+t(#3e' + +# SECURITY WARNING: don't run with debug turned on in production! +DEBUG = CONFIG.get_debug() + +ALLOWED_HOSTS = ['*'] + +# Application definition + +INSTALLED_APPS = [ + 'django.contrib.contenttypes', + 'django.contrib.messages', + 'django.contrib.staticfiles', + 'rest_framework', + 'local_model', +] + +MIDDLEWARE = [ + 'django.middleware.locale.LocaleMiddleware', + 'django.middleware.security.SecurityMiddleware', + 'django.contrib.sessions.middleware.SessionMiddleware', + +] + +REST_FRAMEWORK = { + 'EXCEPTION_HANDLER': 'common.exception.handle_exception.handle_exception', + 'DEFAULT_SCHEMA_CLASS': 'drf_spectacular.openapi.AutoSchema', + 'DEFAULT_AUTHENTICATION_CLASSES': ['common.auth.authenticate.AnonymousAuthentication'] +} +STATICFILES_DIRS = [(os.path.join(PROJECT_DIR, 'ui', 'dist'))] +STATIC_ROOT = os.path.join(BASE_DIR.parent, 'static') +ROOT_URLCONF = 'maxkb.urls' +APPS_DIR = os.path.join(PROJECT_DIR, 'apps') + +TEMPLATES = [ + { + 'BACKEND': 'django.template.backends.django.DjangoTemplates', + 'DIRS': ["apps/static/admin"], + 'APP_DIRS': True, + 'OPTIONS': { + 'context_processors': [ + 'django.template.context_processors.debug', + 'django.template.context_processors.request', + 'django.contrib.auth.context_processors.auth', + 'django.contrib.messages.context_processors.messages', + ], + }, + }, + {"NAME": "CHAT", + 'BACKEND': 'django.template.backends.django.DjangoTemplates', + 'DIRS': ["apps/static/chat"], + 'APP_DIRS': True, + 'OPTIONS': { + 'context_processors': [ + 'django.template.context_processors.debug', + 'django.template.context_processors.request', + 'django.contrib.auth.context_processors.auth', + 'django.contrib.messages.context_processors.messages', + ], + }, + }, + {"NAME": "DOC", + 'BACKEND': 'django.template.backends.django.DjangoTemplates', + 'DIRS': ["apps/static/drf_spectacular_sidecar"], + 'APP_DIRS': True, + 'OPTIONS': { + 'context_processors': [ + 'django.template.context_processors.debug', + 'django.template.context_processors.request', + 'django.contrib.auth.context_processors.auth', + 'django.contrib.messages.context_processors.messages', + ], + }, + }, +] +SPECTACULAR_SETTINGS = { + 'TITLE': 'MaxKB API', + 'DESCRIPTION': _('Intelligent customer service platform'), + 'VERSION': 'v2', + 'SERVE_INCLUDE_SCHEMA': False, + # OTHER SETTINGS + 'SWAGGER_UI_DIST': f'{CONFIG.get_admin_path()}/api-doc/swagger-ui-dist', # shorthand to use the sidecar instead + 'SWAGGER_UI_FAVICON_HREF': f'{CONFIG.get_admin_path()}/api-doc/swagger-ui-dist/favicon-32x32.png', + 'REDOC_DIST': f'{CONFIG.get_admin_path()}/api-doc/redoc', + 'SECURITY_DEFINITIONS': { + 'Bearer': { + 'type': 'apiKey', + 'name': 'AUTHORIZATION', + 'in': 'header', + } + } +} +WSGI_APPLICATION = 'maxkb.wsgi.application' + +# Database +# https://docs.djangoproject.com/en/4.2/ref/settings/#databases + +DATABASES = {'default': CONFIG.get_db_setting()} + +CACHES = CONFIG.get_cache_setting() + +# Password validation +# https://docs.djangoproject.com/en/4.2/ref/settings/#auth-password-validators + +AUTH_PASSWORD_VALIDATORS = [ + { + 'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator', + }, + { + 'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator', + }, + { + 'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator', + }, + { + 'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator', + }, +] + +# Internationalization +# https://docs.djangoproject.com/en/4.2/topics/i18n/ + +LANGUAGE_CODE = CONFIG.get("LANGUAGE_CODE") + +TIME_ZONE = CONFIG.get_time_zone() + +USE_I18N = True + +USE_TZ = True + +# 文件上传配置 +DATA_UPLOAD_MAX_NUMBER_FILES = 1000 + +# 支持的语言 +LANGUAGES = [ + ('en', 'English'), + ('zh', '中文简体'), + ('zh-hant', '中文繁体') +] +# 翻译文件路径 +LOCALE_PATHS = [ + os.path.join(BASE_DIR.parent, 'locales') +] + +# Static files (CSS, JavaScript, Images) +# https://docs.djangoproject.com/en/4.2/howto/static-files/ + +STATIC_URL = 'static/' + +# Default primary key field type +# https://docs.djangoproject.com/en/4.2/ref/settings/#default-auto-field + +DEFAULT_AUTO_FIELD = 'django.db.models.BigAutoField' + +edition = 'CE' + +if os.environ.get('MAXKB_REDIS_SENTINEL_SENTINELS') is not None: + DJANGO_REDIS_CONNECTION_FACTORY = "django_redis.pool.SentinelConnectionFactory" diff --git a/apps/maxkb/settings/base.py b/apps/maxkb/settings/base/web.py similarity index 93% rename from apps/maxkb/settings/base.py rename to apps/maxkb/settings/base/web.py index 780f1f32e..04f6ae5a3 100644 --- a/apps/maxkb/settings/base.py +++ b/apps/maxkb/settings/base/web.py @@ -1,22 +1,18 @@ +# coding=utf-8 """ -Django settings for maxkb project. - -Generated by 'django-admin startproject' using Django 4.2.4. - -For more information on this file, see -https://docs.djangoproject.com/en/4.2/topics/settings/ - -For the full list of settings and their values, see -https://docs.djangoproject.com/en/4.2/ref/settings/ + @project: MaxKB + @Author:虎虎 + @file: web.py + @date:2025/11/5 14:53 + @desc: """ - from pathlib import Path -from ..const import CONFIG, PROJECT_DIR +from ...const import CONFIG, PROJECT_DIR import os from django.utils.translation import gettext_lazy as _ # Build paths inside the project like this: BASE_DIR / 'subdir'. -BASE_DIR = Path(__file__).resolve().parent.parent +BASE_DIR = Path(__file__).resolve().parent.parent.parent # Quick-start development settings - unsuitable for production # See https://docs.djangoproject.com/en/4.2/howto/deployment/checklist/ diff --git a/apps/maxkb/urls/__init__.py b/apps/maxkb/urls/__init__.py new file mode 100644 index 000000000..f788f029e --- /dev/null +++ b/apps/maxkb/urls/__init__.py @@ -0,0 +1,14 @@ +# coding=utf-8 +""" + @project: MaxKB-xpack + @Author:虎虎 + @file: __init__.py.py + @date:2025/11/5 14:45 + @desc: +""" +import os + +if os.environ.get('SERVER_NAME', 'web') == 'local_model': + from .model import * +else: + from .web import * diff --git a/apps/maxkb/urls/model.py b/apps/maxkb/urls/model.py new file mode 100644 index 000000000..6e3a11dd0 --- /dev/null +++ b/apps/maxkb/urls/model.py @@ -0,0 +1,28 @@ +""" +URL configuration for maxkb project. + +The `urlpatterns` list routes URLs to views. For more information please see: + https://docs.djangoproject.com/en/4.2/topics/http/urls/ +Examples: +Function views + 1. Add an import: from my_app import views + 2. Add a URL to urlpatterns: path('', views.home, name='home') +Class-based views + 1. Add an import: from other_app.views import Home + 2. Add a URL to urlpatterns: path('', Home.as_view(), name='home') +Including another URLconf + 1. Import the include() function: from django.urls import include, path + 2. Add a URL to urlpatterns: path('blog/', include('blog.urls')) +""" + +from django.urls import path, include + +from maxkb.const import CONFIG + +admin_api_prefix = CONFIG.get_admin_path()[1:] + '/api/' +admin_ui_prefix = CONFIG.get_admin_path() +chat_api_prefix = CONFIG.get_chat_path()[1:] + '/api/' +chat_ui_prefix = CONFIG.get_chat_path() +urlpatterns = [ + path(admin_api_prefix, include("local_model.urls")), +] diff --git a/apps/maxkb/urls.py b/apps/maxkb/urls/web.py similarity index 100% rename from apps/maxkb/urls.py rename to apps/maxkb/urls/web.py diff --git a/apps/maxkb/wsgi/__init__.py b/apps/maxkb/wsgi/__init__.py new file mode 100644 index 000000000..58649348f --- /dev/null +++ b/apps/maxkb/wsgi/__init__.py @@ -0,0 +1,14 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎虎 + @file: __init__.py.py + @date:2025/11/5 15:14 + @desc: +""" +import os + +if os.environ.get('SERVER_NAME', 'web') == 'local_model': + from .model import * +else: + from .web import * diff --git a/apps/maxkb/wsgi/model.py b/apps/maxkb/wsgi/model.py new file mode 100644 index 000000000..1d70874be --- /dev/null +++ b/apps/maxkb/wsgi/model.py @@ -0,0 +1,15 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎虎 + @file: model.py + @date:2025/11/5 15:14 + @desc: +""" +import os + +from django.core.wsgi import get_wsgi_application + +os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'maxkb.settings') + +application = get_wsgi_application() \ No newline at end of file diff --git a/apps/maxkb/wsgi.py b/apps/maxkb/wsgi/web.py similarity index 71% rename from apps/maxkb/wsgi.py rename to apps/maxkb/wsgi/web.py index fc271a7c3..e95125e42 100644 --- a/apps/maxkb/wsgi.py +++ b/apps/maxkb/wsgi/web.py @@ -1,12 +1,11 @@ +# coding=utf-8 """ -WSGI config for maxkb project. - -It exposes the WSGI callable as a module-level variable named ``application``. - -For more information on this file, see -https://docs.djangoproject.com/en/4.2/howto/deployment/wsgi/ + @project: MaxKB + @Author:虎虎 + @file: web.py + @date:2025/11/5 15:14 + @desc: """ - import os from django.core.wsgi import get_wsgi_application @@ -35,4 +34,4 @@ post_handler() # 仅在scheduler中启动定时任务,dev local_model celery 不需要 if os.environ.get('ENABLE_SCHEDULER') == '1': - post_scheduler_handler() + post_scheduler_handler() \ No newline at end of file diff --git a/apps/models_provider/impl/local_model_provider/credential/reranker.py b/apps/models_provider/impl/local_model_provider/credential/reranker.py index 94341d52f..46f8ebca2 100644 --- a/apps/models_provider/impl/local_model_provider/credential/reranker.py +++ b/apps/models_provider/impl/local_model_provider/credential/reranker.py @@ -15,7 +15,7 @@ from common import forms from common.exception.app_exception import AppApiException from common.forms import BaseForm from models_provider.base_model_provider import BaseModelCredential, ValidCode -from models_provider.impl.local_model_provider.model.reranker import LocalBaseReranker +from models_provider.impl.local_model_provider.model.reranker import LocalReranker from django.utils.translation import gettext_lazy as _, gettext @@ -33,7 +33,7 @@ class LocalRerankerCredential(BaseForm, BaseModelCredential): else: return False try: - model: LocalBaseReranker = provider.get_model(model_type, model_name, model_credential) + model: LocalReranker = provider.get_model(model_type, model_name, model_credential) model.compress_documents([Document(page_content=gettext('Hello'))], gettext('Hello')) except Exception as e: traceback.print_exc() diff --git a/apps/models_provider/impl/local_model_provider/local_model_provider.py b/apps/models_provider/impl/local_model_provider/local_model_provider.py index 14603962a..342f585f4 100644 --- a/apps/models_provider/impl/local_model_provider/local_model_provider.py +++ b/apps/models_provider/impl/local_model_provider/local_model_provider.py @@ -8,15 +8,16 @@ """ import os +from django.utils.translation import gettext as _ + from common.utils.common import get_file_content +from maxkb.conf import PROJECT_DIR from models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, ModelInfo, IModelProvider, \ ModelInfoManage from models_provider.impl.local_model_provider.credential.embedding import LocalEmbeddingCredential from models_provider.impl.local_model_provider.credential.reranker import LocalRerankerCredential from models_provider.impl.local_model_provider.model.embedding import LocalEmbedding from models_provider.impl.local_model_provider.model.reranker import LocalReranker -from maxkb.conf import PROJECT_DIR -from django.utils.translation import gettext as _ embedding_text2vec_base_chinese = ModelInfo('shibing624/text2vec-base-chinese', '', ModelTypeConst.EMBEDDING, LocalEmbeddingCredential(), LocalEmbedding) diff --git a/apps/models_provider/impl/local_model_provider/model/embedding.py b/apps/models_provider/impl/local_model_provider/model/embedding.py deleted file mode 100644 index 94ea652af..000000000 --- a/apps/models_provider/impl/local_model_provider/model/embedding.py +++ /dev/null @@ -1,64 +0,0 @@ -# coding=utf-8 -""" - @project: MaxKB - @Author:虎 - @file: embedding.py - @date:2024/7/11 14:06 - @desc: -""" -from typing import Dict, List - -import requests -from langchain_core.embeddings import Embeddings -from pydantic import BaseModel -from langchain_huggingface import HuggingFaceEmbeddings - -from models_provider.base_model_provider import MaxKBBaseModel -from maxkb.const import CONFIG - - -class WebLocalEmbedding(MaxKBBaseModel, BaseModel, Embeddings): - @staticmethod - def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): - pass - - model_id: str = None - - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.model_id = kwargs.get('model_id', None) - - def embed_query(self, text: str) -> List[float]: - bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}' - prefix = CONFIG.get_admin_path() - res = requests.post(f'{CONFIG.get("LOCAL_MODEL_PROTOCOL")}://{bind}{prefix}/api/model/{self.model_id}/embed_query', - {'text': text}) - result = res.json() - if result.get('code', 500) == 200: - return result.get('data') - raise Exception(result.get('message')) - - def embed_documents(self, texts: List[str]) -> List[List[float]]: - bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}' - prefix = CONFIG.get_admin_path() - res = requests.post(f'{CONFIG.get("LOCAL_MODEL_PROTOCOL")}://{bind}/{prefix}/api/model/{self.model_id}/embed_documents', - {'texts': texts}) - result = res.json() - if result.get('code', 500) == 200: - return result.get('data') - raise Exception(result.get('message')) - - -class LocalEmbedding(MaxKBBaseModel, HuggingFaceEmbeddings): - - @staticmethod - def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): - if model_kwargs.get('use_local', True): - return LocalEmbedding(model_name=model_name, cache_folder=model_credential.get('cache_folder'), - model_kwargs={'device': model_credential.get('device')}, - encode_kwargs={'normalize_embeddings': True} - ) - return WebLocalEmbedding(model_name=model_name, cache_folder=model_credential.get('cache_folder'), - model_kwargs={'device': model_credential.get('device')}, - encode_kwargs={'normalize_embeddings': True}, - **model_kwargs) diff --git a/apps/models_provider/impl/local_model_provider/model/embedding/__init__.py b/apps/models_provider/impl/local_model_provider/model/embedding/__init__.py new file mode 100644 index 000000000..840afa5af --- /dev/null +++ b/apps/models_provider/impl/local_model_provider/model/embedding/__init__.py @@ -0,0 +1,14 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎虎 + @file: __init__.py + @date:2025/11/5 15:24 + @desc: +""" +import os + +if os.environ.get('SERVER_NAME', 'web') == 'local_model': + from .model import * +else: + from .web import * diff --git a/apps/models_provider/impl/local_model_provider/model/embedding/model.py b/apps/models_provider/impl/local_model_provider/model/embedding/model.py new file mode 100644 index 000000000..7ebc41cb1 --- /dev/null +++ b/apps/models_provider/impl/local_model_provider/model/embedding/model.py @@ -0,0 +1,26 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎虎 + @file: model.py + @date:2025/11/5 15:26 + @desc: +""" +from typing import Dict + +from langchain_huggingface import HuggingFaceEmbeddings + +from models_provider.base_model_provider import MaxKBBaseModel + + +class LocalEmbedding(MaxKBBaseModel, HuggingFaceEmbeddings): + @staticmethod + def is_cache_model(): + return True + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + return LocalEmbedding(model_name=model_name, cache_folder=model_credential.get('cache_folder'), + model_kwargs={'device': model_credential.get('device')}, + encode_kwargs={'normalize_embeddings': True} + ) diff --git a/apps/models_provider/impl/local_model_provider/model/embedding/web.py b/apps/models_provider/impl/local_model_provider/model/embedding/web.py new file mode 100644 index 000000000..bfc22bc9b --- /dev/null +++ b/apps/models_provider/impl/local_model_provider/model/embedding/web.py @@ -0,0 +1,54 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎虎 + @file: web.py + @date:2025/11/5 15:24 + @desc: +""" + +from typing import Dict, List + +import requests +from anthropic import BaseModel +from langchain_core.embeddings import Embeddings + +from maxkb.const import CONFIG +from models_provider.base_model_provider import MaxKBBaseModel + + +class LocalEmbedding(MaxKBBaseModel, BaseModel, Embeddings): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.model_id = kwargs.get('model_id', None) + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + return LocalEmbedding(model_name=model_name, cache_folder=model_credential.get('cache_folder'), + model_kwargs={'device': model_credential.get('device')}, + encode_kwargs={'normalize_embeddings': True}, + **model_kwargs) + + model_id: str = None + + def embed_query(self, text: str) -> List[float]: + bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}' + prefix = CONFIG.get_admin_path() + res = requests.post( + f'{CONFIG.get("LOCAL_MODEL_PROTOCOL")}://{bind}{prefix}/api/model/{self.model_id}/embed_query', + {'text': text}) + result = res.json() + if result.get('code', 500) == 200: + return result.get('data') + raise Exception(result.get('message')) + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}' + prefix = CONFIG.get_admin_path() + res = requests.post( + f'{CONFIG.get("LOCAL_MODEL_PROTOCOL")}://{bind}/{prefix}/api/model/{self.model_id}/embed_documents', + {'texts': texts}) + result = res.json() + if result.get('code', 500) == 200: + return result.get('data') + raise Exception(result.get('message')) diff --git a/apps/models_provider/impl/local_model_provider/model/reranker.py b/apps/models_provider/impl/local_model_provider/model/reranker.py deleted file mode 100644 index 5ac2a4ca8..000000000 --- a/apps/models_provider/impl/local_model_provider/model/reranker.py +++ /dev/null @@ -1,102 +0,0 @@ -# coding=utf-8 -""" - @project: MaxKB - @Author:虎 - @file: reranker.py.py - @date:2024/9/2 16:42 - @desc: -""" -from typing import Sequence, Optional, Dict, Any, ClassVar - -import requests -import torch -from langchain_core.callbacks import Callbacks -from langchain_core.documents import BaseDocumentCompressor, Document -from transformers import AutoModelForSequenceClassification, AutoTokenizer - -from models_provider.base_model_provider import MaxKBBaseModel -from maxkb.const import CONFIG - - -class LocalReranker(MaxKBBaseModel): - def __init__(self, model_name, top_n=3, cache_dir=None): - super().__init__() - self.model_name = model_name - self.cache_dir = cache_dir - self.top_n = top_n - - @staticmethod - def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): - if model_kwargs.get('use_local', True): - return LocalBaseReranker(model_name=model_name, cache_dir=model_credential.get('cache_dir'), - model_kwargs={'device': model_credential.get('device', 'cpu')} - - ) - return WebLocalBaseReranker(model_name=model_name, cache_dir=model_credential.get('cache_dir'), - model_kwargs={'device': model_credential.get('device')}, - **model_kwargs) - - -class WebLocalBaseReranker(MaxKBBaseModel, BaseDocumentCompressor): - @staticmethod - def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): - pass - - model_id: str = None - - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.model_id = kwargs.get('model_id', None) - - def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \ - Sequence[Document]: - if documents is None or len(documents) == 0: - return [] - prefix = CONFIG.get_admin_path() - bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}' - res = requests.post( - f'{CONFIG.get("LOCAL_MODEL_PROTOCOL")}://{bind}{prefix}/api/model/{self.model_id}/compress_documents', - json={'documents': [{'page_content': document.page_content, 'metadata': document.metadata} for document in - documents], 'query': query}, headers={'Content-Type': 'application/json'}) - result = res.json() - if result.get('code', 500) == 200: - return [Document(page_content=document.get('page_content'), metadata=document.get('metadata')) for document - in result.get('data')] - raise Exception(result.get('message')) - - -class LocalBaseReranker(MaxKBBaseModel, BaseDocumentCompressor): - client: Any = None - tokenizer: Any = None - model: Optional[str] = None - cache_dir: Optional[str] = None - model_kwargs: Any = {} - - def __init__(self, model_name, cache_dir=None, **model_kwargs): - super().__init__() - self.model = model_name - self.cache_dir = cache_dir - self.model_kwargs = model_kwargs - self.client = AutoModelForSequenceClassification.from_pretrained(self.model, cache_dir=self.cache_dir) - self.tokenizer = AutoTokenizer.from_pretrained(self.model, cache_dir=self.cache_dir) - self.client = self.client.to(self.model_kwargs.get('device', 'cpu')) - self.client.eval() - - @staticmethod - def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): - return LocalBaseReranker(model_name, cache_dir=model_credential.get('cache_dir'), **model_kwargs) - - def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \ - Sequence[Document]: - if documents is None or len(documents) == 0: - return [] - with torch.no_grad(): - inputs = self.tokenizer([[query, document.page_content] for document in documents], padding=True, - truncation=True, return_tensors='pt', max_length=512) - scores = [torch.sigmoid(s).float().item() for s in - self.client(**inputs, return_dict=True).logits.view(-1, ).float()] - result = [Document(page_content=documents[index].page_content, metadata={'relevance_score': scores[index]}) - for index - in range(len(documents))] - result.sort(key=lambda row: row.metadata.get('relevance_score'), reverse=True) - return result diff --git a/apps/models_provider/impl/local_model_provider/model/reranker/__init__.py b/apps/models_provider/impl/local_model_provider/model/reranker/__init__.py new file mode 100644 index 000000000..10d5b1bb6 --- /dev/null +++ b/apps/models_provider/impl/local_model_provider/model/reranker/__init__.py @@ -0,0 +1,14 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎虎 + @file: __init__.py.py + @date:2025/11/5 15:30 + @desc: +""" +import os + +if os.environ.get('SERVER_NAME', 'web') == 'local_model': + from .model import * +else: + from .web import * diff --git a/apps/models_provider/impl/local_model_provider/model/reranker/model.py b/apps/models_provider/impl/local_model_provider/model/reranker/model.py new file mode 100644 index 000000000..c5f1c67e0 --- /dev/null +++ b/apps/models_provider/impl/local_model_provider/model/reranker/model.py @@ -0,0 +1,58 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎虎 + @file: model.py + @date:2025/11/5 15:30 + @desc: +""" + +from typing import Sequence, Optional, Dict, Any + +from langchain_core.callbacks import Callbacks +from langchain_core.documents import Document, BaseDocumentCompressor +from transformers import AutoModelForSequenceClassification, AutoTokenizer + +from models_provider.base_model_provider import MaxKBBaseModel + + +class LocalReranker(MaxKBBaseModel, BaseDocumentCompressor): + client: Any = None + tokenizer: Any = None + model: Optional[str] = None + cache_dir: Optional[str] = None + model_kwargs: Any = {} + + def __init__(self, model_name, cache_dir=None, **model_kwargs): + super().__init__() + self.model = model_name + self.cache_dir = cache_dir + self.model_kwargs = model_kwargs + self.client = AutoModelForSequenceClassification.from_pretrained(self.model, cache_dir=self.cache_dir) + self.tokenizer = AutoTokenizer.from_pretrained(self.model, cache_dir=self.cache_dir) + self.client = self.client.to(self.model_kwargs.get('device', 'cpu')) + self.client.eval() + + @staticmethod + def is_cache_model(): + return False + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + return LocalReranker(model_name, cache_dir=model_credential.get('cache_dir')) + + def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \ + Sequence[Document]: + if documents is None or len(documents) == 0: + return [] + import torch + with torch.no_grad(): + inputs = self.tokenizer([[query, document.page_content] for document in documents], padding=True, + truncation=True, return_tensors='pt', max_length=512) + scores = [torch.sigmoid(s).float().item() for s in + self.client(**inputs, return_dict=True).logits.view(-1, ).float()] + result = [Document(page_content=documents[index].page_content, metadata={'relevance_score': scores[index]}) + for index + in range(len(documents))] + result.sort(key=lambda row: row.metadata.get('relevance_score'), reverse=True) + return result diff --git a/apps/models_provider/impl/local_model_provider/model/reranker/web.py b/apps/models_provider/impl/local_model_provider/model/reranker/web.py new file mode 100644 index 000000000..45ab6978a --- /dev/null +++ b/apps/models_provider/impl/local_model_provider/model/reranker/web.py @@ -0,0 +1,52 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎虎 + @file: web.py + @date:2025/11/5 15:30 + @desc: +""" +from typing import Sequence, Optional, Dict + +import requests +from anthropic import BaseModel +from langchain_core.callbacks import Callbacks +from langchain_core.documents import Document, BaseDocumentCompressor + +from maxkb.const import CONFIG +from models_provider.base_model_provider import MaxKBBaseModel + + +class LocalReranker(MaxKBBaseModel, BaseModel, BaseDocumentCompressor): + + @staticmethod + def is_cache_model(): + return False + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + return LocalReranker(model_type=model_type, model_name=model_name, model_credential=model_credential, + **model_kwargs) + + model_id: str = None + + def __init__(self, **kwargs): + super().__init__(**kwargs) + print('ssss', kwargs.get('model_id', None)) + self.model_id = kwargs.get('model_id', None) + + def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \ + Sequence[Document]: + if documents is None or len(documents) == 0: + return [] + prefix = CONFIG.get_admin_path() + bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}' + res = requests.post( + f'{CONFIG.get("LOCAL_MODEL_PROTOCOL")}://{bind}{prefix}/api/model/{self.model_id}/compress_documents', + json={'documents': [{'page_content': document.page_content, 'metadata': document.metadata} for document in + documents], 'query': query}, headers={'Content-Type': 'application/json'}) + result = res.json() + if result.get('code', 500) == 200: + return [Document(page_content=document.get('page_content'), metadata=document.get('metadata')) for document + in result.get('data')] + raise Exception(result.get('message')) diff --git a/apps/models_provider/impl/ollama_model_provider/model/reranker.py b/apps/models_provider/impl/ollama_model_provider/model/reranker.py index 22d6e63d2..57435961b 100644 --- a/apps/models_provider/impl/ollama_model_provider/model/reranker.py +++ b/apps/models_provider/impl/ollama_model_provider/model/reranker.py @@ -1,12 +1,12 @@ -from typing import Sequence, Optional, Any, Dict +from typing import Sequence, Optional, Dict from langchain_community.embeddings import OllamaEmbeddings from langchain_core.callbacks import Callbacks from langchain_core.documents import Document -from models_provider.base_model_provider import MaxKBBaseModel -from sklearn.metrics.pairwise import cosine_similarity from pydantic import BaseModel, Field +from models_provider.base_model_provider import MaxKBBaseModel + class OllamaReranker(MaxKBBaseModel, OllamaEmbeddings, BaseModel): top_n: Optional[int] = Field(3, description="Number of top documents to return") @@ -22,6 +22,7 @@ class OllamaReranker(MaxKBBaseModel, OllamaEmbeddings, BaseModel): def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \ Sequence[Document]: + from sklearn.metrics.pairwise import cosine_similarity """Rank documents based on their similarity to the query. Args: @@ -37,7 +38,7 @@ class OllamaReranker(MaxKBBaseModel, OllamaEmbeddings, BaseModel): document_embeddings = self.embed_documents(documents) # 计算相似度 similarities = cosine_similarity([query_embedding], document_embeddings)[0] - ranked_docs = [(doc,_) for _, doc in sorted(zip(similarities, documents), reverse=True)][:self.top_n] + ranked_docs = [(doc, _) for _, doc in sorted(zip(similarities, documents), reverse=True)][:self.top_n] return [ Document( page_content=doc, # 第一个值是文档内容 @@ -45,5 +46,3 @@ class OllamaReranker(MaxKBBaseModel, OllamaEmbeddings, BaseModel): ) for doc, score in ranked_docs ] - - diff --git a/main.py b/main.py index 0a300e9cb..5e4387c69 100644 --- a/main.py +++ b/main.py @@ -13,7 +13,6 @@ APP_DIR = os.path.join(BASE_DIR, 'apps') os.chdir(BASE_DIR) sys.path.insert(0, APP_DIR) os.environ.setdefault("DJANGO_SETTINGS_MODULE", "maxkb.settings") -django.setup() def collect_static(): @@ -74,7 +73,6 @@ def dev(): elif services.__contains__('celery'): management.call_command('celery', 'celery') elif services.__contains__('local_model'): - os.environ.setdefault('SERVER_NAME', 'local_model') from maxkb.const import CONFIG bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}' management.call_command('runserver', bind) @@ -108,6 +106,12 @@ if __name__ == '__main__': parser.add_argument('-f', '--force', nargs="?", const=True) args = parser.parse_args() action = args.action + services = args.services if isinstance(args.services, list) else args.services + if services.__contains__('web'): + os.environ.setdefault('SERVER_NAME', 'web') + elif services.__contains__('local_model'): + os.environ.setdefault('SERVER_NAME', 'local_model') + django.setup() if action == "upgrade_db": perform_db_migrate() elif action == "collect_static": @@ -120,4 +124,3 @@ if __name__ == '__main__': collect_static() perform_db_migrate() start_services() -