mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
perf: Memory optimization (#4318)
This commit is contained in:
parent
3445e05aea
commit
fa1aee6c3d
|
|
@ -224,11 +224,12 @@ class LoopWorkFlowPostHandler(WorkFlowPostHandler):
|
|||
|
||||
class BaseLoopNode(ILoopNode):
|
||||
def save_context(self, details, workflow_manage):
|
||||
self.context['result'] = details.get('result')
|
||||
self.context['loop_context_data'] = details.get('loop_context_data')
|
||||
self.context['loop_answer_data'] = details.get('loop_answer_data')
|
||||
for key, value in details['context'].items():
|
||||
if key not in self.context:
|
||||
self.context[key] = value
|
||||
self.answer_text = str(details.get('result'))
|
||||
self.answer_text = ""
|
||||
|
||||
def get_answer_list(self) -> List[Answer] | None:
|
||||
result = []
|
||||
|
|
|
|||
|
|
@ -15,7 +15,6 @@ from urllib.parse import urljoin
|
|||
|
||||
import uuid_utils.compat as uuid
|
||||
from charset_normalizer import detect
|
||||
from django.db.models import QuerySet
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from common.handle.base_split_handle import BaseSplitHandle
|
||||
|
|
@ -39,7 +38,6 @@ class FileBufferHandle:
|
|||
return self.buffer
|
||||
|
||||
|
||||
|
||||
default_split_handle = TextSplitHandle()
|
||||
split_handles = [
|
||||
HTMLSplitHandle(),
|
||||
|
|
|
|||
|
|
@ -0,0 +1,3 @@
|
|||
from django.contrib import admin
|
||||
|
||||
# Register your models here.
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
from django.apps import AppConfig
|
||||
|
||||
|
||||
class LocalModelConfig(AppConfig):
|
||||
default_auto_field = 'django.db.models.BigAutoField'
|
||||
name = 'local_model'
|
||||
|
|
@ -0,0 +1,10 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: __init__.py
|
||||
@date:2023/9/25 15:04
|
||||
@desc:
|
||||
"""
|
||||
|
||||
from .model_management import *
|
||||
|
|
@ -0,0 +1,49 @@
|
|||
# coding=utf-8
|
||||
import uuid_utils.compat as uuid
|
||||
|
||||
from django.db import models
|
||||
|
||||
from common.mixins.app_model_mixin import AppModelMixin
|
||||
from local_model.models.user import User
|
||||
|
||||
|
||||
class Status(models.TextChoices):
|
||||
"""系统设置类型"""
|
||||
SUCCESS = "SUCCESS", '成功'
|
||||
|
||||
ERROR = "ERROR", "失败"
|
||||
|
||||
DOWNLOAD = "DOWNLOAD", '下载中'
|
||||
|
||||
PAUSE_DOWNLOAD = "PAUSE_DOWNLOAD", '暂停下载'
|
||||
|
||||
|
||||
class Model(AppModelMixin):
|
||||
"""
|
||||
模型数据
|
||||
"""
|
||||
id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid7, editable=False, verbose_name="主键id")
|
||||
|
||||
name = models.CharField(max_length=128, verbose_name="名称", db_index=True)
|
||||
|
||||
status = models.CharField(max_length=20, verbose_name='设置类型', choices=Status.choices,
|
||||
default=Status.SUCCESS, db_index=True)
|
||||
|
||||
model_type = models.CharField(max_length=128, verbose_name="模型类型", db_index=True)
|
||||
|
||||
model_name = models.CharField(max_length=128, verbose_name="模型名称", db_index=True)
|
||||
|
||||
user = models.ForeignKey(User, on_delete=models.SET_NULL, db_constraint=False, blank=True, null=True)
|
||||
|
||||
provider = models.CharField(max_length=128, verbose_name='供应商', db_index=True)
|
||||
|
||||
credential = models.CharField(max_length=102400, verbose_name="模型认证信息")
|
||||
|
||||
meta = models.JSONField(verbose_name="模型元数据,用于存储下载,或者错误信息", default=dict)
|
||||
|
||||
model_params_form = models.JSONField(verbose_name="模型参数配置", default=list)
|
||||
workspace_id = models.CharField(max_length=64, verbose_name="工作空间id", default="default", db_index=True)
|
||||
|
||||
class Meta:
|
||||
db_table = "model"
|
||||
unique_together = ['name', 'workspace_id']
|
||||
|
|
@ -0,0 +1,34 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: system_management.py
|
||||
@date:2024/3/19 13:47
|
||||
@desc: 邮箱管理
|
||||
"""
|
||||
|
||||
from django.db import models
|
||||
|
||||
from common.mixins.app_model_mixin import AppModelMixin
|
||||
|
||||
|
||||
class SettingType(models.IntegerChoices):
|
||||
"""系统设置类型"""
|
||||
EMAIL = 0, '邮箱'
|
||||
|
||||
RSA = 1, "私钥秘钥"
|
||||
|
||||
LOG = 2, "日志清理时间"
|
||||
|
||||
|
||||
class SystemSetting(AppModelMixin):
|
||||
"""
|
||||
系统设置
|
||||
"""
|
||||
type = models.IntegerField(primary_key=True, verbose_name='设置类型', choices=SettingType.choices,
|
||||
default=SettingType.EMAIL)
|
||||
|
||||
meta = models.JSONField(verbose_name="配置数据", default=dict)
|
||||
|
||||
class Meta:
|
||||
db_table = "system_setting"
|
||||
|
|
@ -0,0 +1,38 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎虎
|
||||
@file: user.py
|
||||
@date:2025/4/14 10:20
|
||||
@desc:
|
||||
"""
|
||||
import uuid_utils.compat as uuid
|
||||
|
||||
from django.db import models
|
||||
|
||||
from common.utils.common import password_encrypt
|
||||
|
||||
|
||||
class User(models.Model):
|
||||
id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid7, editable=False, verbose_name="主键id")
|
||||
email = models.EmailField(unique=True, null=True, blank=True, verbose_name="邮箱", db_index=True)
|
||||
phone = models.CharField(max_length=20, verbose_name="电话", default="", db_index=True)
|
||||
nick_name = models.CharField(max_length=150, verbose_name="昵称", unique=True, db_index=True)
|
||||
username = models.CharField(max_length=150, unique=True, verbose_name="用户名", db_index=True)
|
||||
password = models.CharField(max_length=150, verbose_name="密码")
|
||||
role = models.CharField(max_length=150, verbose_name="角色")
|
||||
source = models.CharField(max_length=10, verbose_name="来源", default="LOCAL", db_index=True)
|
||||
is_active = models.BooleanField(default=True, db_index=True)
|
||||
language = models.CharField(max_length=10, verbose_name="语言", null=True, default=None)
|
||||
create_time = models.DateTimeField(verbose_name="创建时间", auto_now_add=True, null=True, db_index=True)
|
||||
update_time = models.DateTimeField(verbose_name="修改时间", auto_now=True, null=True, db_index=True)
|
||||
|
||||
USERNAME_FIELD = 'username'
|
||||
REQUIRED_FIELDS = []
|
||||
|
||||
class Meta:
|
||||
db_table = "user"
|
||||
|
||||
def set_password(self, row_password):
|
||||
self.password = password_encrypt(row_password)
|
||||
self._password = row_password
|
||||
|
|
@ -0,0 +1 @@
|
|||
# coding=utf-8
|
||||
|
|
@ -0,0 +1,140 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: model_apply_serializers.py
|
||||
@date:2024/8/20 20:39
|
||||
@desc:
|
||||
"""
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
|
||||
from django.db import connection
|
||||
from django.db.models import QuerySet
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from langchain_core.documents import Document
|
||||
from rest_framework import serializers
|
||||
|
||||
from local_model.models import Model
|
||||
from local_model.serializers.rsa_util import rsa_long_decrypt
|
||||
from models_provider.impl.local_model_provider.local_model_provider import LocalModelProvider
|
||||
|
||||
from common.cache.mem_cache import MemCache
|
||||
|
||||
_lock = threading.Lock()
|
||||
locks = {}
|
||||
|
||||
|
||||
class ModelManage:
|
||||
cache = MemCache('model', {})
|
||||
up_clear_time = time.time()
|
||||
|
||||
@staticmethod
|
||||
def _get_lock(_id):
|
||||
lock = locks.get(_id)
|
||||
if lock is None:
|
||||
with _lock:
|
||||
lock = locks.get(_id)
|
||||
if lock is None:
|
||||
lock = threading.Lock()
|
||||
locks[_id] = lock
|
||||
|
||||
return lock
|
||||
|
||||
@staticmethod
|
||||
def get_model(_id, get_model):
|
||||
model_instance = ModelManage.cache.get(_id)
|
||||
if model_instance is None:
|
||||
lock = ModelManage._get_lock(_id)
|
||||
with lock:
|
||||
model_instance = ModelManage.cache.get(_id)
|
||||
if model_instance is None:
|
||||
model_instance = get_model(_id)
|
||||
ModelManage.cache.set(_id, model_instance, timeout=60 * 60 * 8)
|
||||
else:
|
||||
if model_instance.is_cache_model():
|
||||
ModelManage.cache.touch(_id, timeout=60 * 60 * 8)
|
||||
else:
|
||||
model_instance = get_model(_id)
|
||||
ModelManage.cache.set(_id, model_instance, timeout=60 * 60 * 8)
|
||||
ModelManage.clear_timeout_cache()
|
||||
return model_instance
|
||||
|
||||
@staticmethod
|
||||
def clear_timeout_cache():
|
||||
if time.time() - ModelManage.up_clear_time > 60 * 60:
|
||||
threading.Thread(target=lambda: ModelManage.cache.clear_timeout_data()).start()
|
||||
ModelManage.up_clear_time = time.time()
|
||||
|
||||
@staticmethod
|
||||
def delete_key(_id):
|
||||
if ModelManage.cache.has_key(_id):
|
||||
ModelManage.cache.delete(_id)
|
||||
|
||||
|
||||
def get_local_model(model, **kwargs):
|
||||
# system_setting = QuerySet(SystemSetting).filter(type=1).first()
|
||||
return LocalModelProvider().get_model(model.model_type, model.model_name,
|
||||
json.loads(
|
||||
rsa_long_decrypt(model.credential)),
|
||||
model_id=model.id,
|
||||
streaming=True, **kwargs)
|
||||
|
||||
|
||||
def get_embedding_model(model_id):
|
||||
model = QuerySet(Model).filter(id=model_id).first()
|
||||
# 手动关闭数据库连接
|
||||
connection.close()
|
||||
embedding_model = ModelManage.get_model(model_id,
|
||||
lambda _id: get_local_model(model, use_local=True))
|
||||
return embedding_model
|
||||
|
||||
|
||||
class EmbedDocuments(serializers.Serializer):
|
||||
texts = serializers.ListField(required=True, child=serializers.CharField(required=True,
|
||||
label=_('vector text')),
|
||||
label=_('vector text list')),
|
||||
|
||||
|
||||
class EmbedQuery(serializers.Serializer):
|
||||
text = serializers.CharField(required=True, label=_('vector text'))
|
||||
|
||||
|
||||
class CompressDocument(serializers.Serializer):
|
||||
page_content = serializers.CharField(required=True, label=_('text'))
|
||||
metadata = serializers.DictField(required=False, label=_('metadata'))
|
||||
|
||||
|
||||
class CompressDocuments(serializers.Serializer):
|
||||
documents = CompressDocument(required=True, many=True)
|
||||
query = serializers.CharField(required=True, label=_('query'))
|
||||
|
||||
|
||||
class ModelApplySerializers(serializers.Serializer):
|
||||
model_id = serializers.UUIDField(required=True, label=_('model id'))
|
||||
|
||||
def embed_documents(self, instance, with_valid=True):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
EmbedDocuments(data=instance).is_valid(raise_exception=True)
|
||||
|
||||
model = get_embedding_model(self.data.get('model_id'))
|
||||
return model.embed_documents(instance.getlist('texts'))
|
||||
|
||||
def embed_query(self, instance, with_valid=True):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
EmbedQuery(data=instance).is_valid(raise_exception=True)
|
||||
|
||||
model = get_embedding_model(self.data.get('model_id'))
|
||||
return model.embed_query(instance.get('text'))
|
||||
|
||||
def compress_documents(self, instance, with_valid=True):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
CompressDocuments(data=instance).is_valid(raise_exception=True)
|
||||
model = get_embedding_model(self.data.get('model_id'))
|
||||
return [{'page_content': d.page_content, 'metadata': d.metadata} for d in model.compress_documents(
|
||||
[Document(page_content=document.get('page_content'), metadata=document.get('metadata')) for document in
|
||||
instance.get('documents')], instance.get('query'))]
|
||||
|
|
@ -0,0 +1,139 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: rsa_util.py
|
||||
@date:2023/11/3 11:13
|
||||
@desc:
|
||||
"""
|
||||
import base64
|
||||
import threading
|
||||
|
||||
from Crypto.Cipher import PKCS1_v1_5 as PKCS1_cipher
|
||||
from Crypto.PublicKey import RSA
|
||||
from django.core import cache
|
||||
from django.db.models import QuerySet
|
||||
|
||||
from common.constants.cache_version import Cache_Version
|
||||
from local_model.models.system_setting import SystemSetting, SettingType
|
||||
|
||||
lock = threading.Lock()
|
||||
rsa_cache = cache.cache
|
||||
cache_key = "rsa_key"
|
||||
# 对密钥加密的密码
|
||||
secret_code = "mac_kb_password"
|
||||
|
||||
|
||||
def generate():
|
||||
"""
|
||||
生成 私钥秘钥对
|
||||
:return:{key:'公钥',value:'私钥'}
|
||||
"""
|
||||
# 生成一个 2048 位的密钥
|
||||
key = RSA.generate(2048)
|
||||
|
||||
# 获取私钥
|
||||
encrypted_key = key.export_key(passphrase=secret_code, pkcs=8,
|
||||
protection="scryptAndAES128-CBC")
|
||||
return {'key': key.publickey().export_key(), 'value': encrypted_key}
|
||||
|
||||
|
||||
def get_key_pair():
|
||||
rsa_value = rsa_cache.get(cache_key)
|
||||
if rsa_value is None:
|
||||
with lock:
|
||||
rsa_value = rsa_cache.get(cache_key)
|
||||
if rsa_value is not None:
|
||||
return rsa_value
|
||||
rsa_value = get_key_pair_by_sql()
|
||||
version, get_key = Cache_Version.SYSTEM.value
|
||||
rsa_cache.set(get_key(key='rsa_key'), rsa_value, timeout=None, version=version)
|
||||
return rsa_value
|
||||
|
||||
|
||||
def get_key_pair_by_sql():
|
||||
system_setting = QuerySet(SystemSetting).filter(type=SettingType.RSA.value).first()
|
||||
if system_setting is None:
|
||||
kv = generate()
|
||||
system_setting = SystemSetting(type=SettingType.RSA.value,
|
||||
meta={'key': kv.get('key').decode(), 'value': kv.get('value').decode()})
|
||||
system_setting.save()
|
||||
return system_setting.meta
|
||||
|
||||
|
||||
def encrypt(msg, public_key: str | None = None):
|
||||
"""
|
||||
加密
|
||||
:param msg: 加密数据
|
||||
:param public_key: 公钥
|
||||
:return: 加密后的数据
|
||||
"""
|
||||
if public_key is None:
|
||||
public_key = get_key_pair().get('key')
|
||||
cipher = PKCS1_cipher.new(RSA.importKey(public_key))
|
||||
encrypt_msg = cipher.encrypt(msg.encode("utf-8"))
|
||||
return base64.b64encode(encrypt_msg).decode()
|
||||
|
||||
|
||||
def decrypt(msg, pri_key: str | None = None):
|
||||
"""
|
||||
解密
|
||||
:param msg: 需要解密的数据
|
||||
:param pri_key: 私钥
|
||||
:return: 解密后数据
|
||||
"""
|
||||
if pri_key is None:
|
||||
pri_key = get_key_pair().get('value')
|
||||
cipher = PKCS1_cipher.new(RSA.importKey(pri_key, passphrase=secret_code))
|
||||
decrypt_data = cipher.decrypt(base64.b64decode(msg), 0)
|
||||
return decrypt_data.decode("utf-8")
|
||||
|
||||
|
||||
def rsa_long_encrypt(message, public_key: str | None = None, length=200):
|
||||
"""
|
||||
超长文本加密
|
||||
|
||||
:param message: 需要加密的字符串
|
||||
:param public_key 公钥
|
||||
:param length: 1024bit的证书用100, 2048bit的证书用 200
|
||||
:return: 加密后的数据
|
||||
"""
|
||||
# 读取公钥
|
||||
if public_key is None:
|
||||
public_key = get_key_pair().get('key')
|
||||
cipher = PKCS1_cipher.new(RSA.importKey(extern_key=public_key,
|
||||
passphrase=secret_code))
|
||||
# 处理:Plaintext is too long. 分段加密
|
||||
if len(message) <= length:
|
||||
# 对编码的数据进行加密,并通过base64进行编码
|
||||
result = base64.b64encode(cipher.encrypt(message.encode('utf-8')))
|
||||
else:
|
||||
rsa_text = []
|
||||
# 对编码后的数据进行切片,原因:加密长度不能过长
|
||||
for i in range(0, len(message), length):
|
||||
cont = message[i:i + length]
|
||||
# 对切片后的数据进行加密,并新增到text后面
|
||||
rsa_text.append(cipher.encrypt(cont.encode('utf-8')))
|
||||
# 加密完进行拼接
|
||||
cipher_text = b''.join(rsa_text)
|
||||
# base64进行编码
|
||||
result = base64.b64encode(cipher_text)
|
||||
return result.decode()
|
||||
|
||||
|
||||
def rsa_long_decrypt(message, pri_key: str | None = None, length=256):
|
||||
"""
|
||||
超长文本解密,默认不加密
|
||||
:param message: 需要解密的数据
|
||||
:param pri_key: 秘钥
|
||||
:param length : 1024bit的证书用128,2048bit证书用256位
|
||||
:return: 解密后的数据
|
||||
"""
|
||||
if pri_key is None:
|
||||
pri_key = get_key_pair().get('value')
|
||||
cipher = PKCS1_cipher.new(RSA.importKey(pri_key, passphrase=secret_code))
|
||||
base64_de = base64.b64decode(message)
|
||||
res = []
|
||||
for i in range(0, len(base64_de), length):
|
||||
res.append(cipher.decrypt(base64_de[i:i + length], 0))
|
||||
return b"".join(res).decode()
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
from django.test import TestCase
|
||||
|
||||
# Create your tests here.
|
||||
|
|
@ -0,0 +1,13 @@
|
|||
import os
|
||||
|
||||
from django.urls import path
|
||||
|
||||
from . import views
|
||||
|
||||
app_name = "local_model"
|
||||
# @formatter:off
|
||||
urlpatterns = [
|
||||
path('model/<str:model_id>/embed_documents', views.LocalModelApply.EmbedDocuments.as_view()),
|
||||
path('model/<str:model_id>/embed_query', views.LocalModelApply.EmbedQuery.as_view()),
|
||||
path('model/<str:model_id>/compress_documents', views.LocalModelApply.CompressDocuments.as_view()),
|
||||
]
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
# coding=utf-8
|
||||
from .model_apply import *
|
||||
|
|
@ -0,0 +1,34 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: model_apply.py
|
||||
@date:2024/8/20 20:38
|
||||
@desc:
|
||||
"""
|
||||
from urllib.request import Request
|
||||
|
||||
from rest_framework.views import APIView
|
||||
|
||||
from common.result import result
|
||||
from local_model.serializers.model_apply_serializers import ModelApplySerializers
|
||||
|
||||
|
||||
class LocalModelApply(APIView):
|
||||
class EmbedDocuments(APIView):
|
||||
|
||||
def post(self, request: Request, model_id):
|
||||
return result.success(
|
||||
ModelApplySerializers(data={'model_id': model_id}).embed_documents(request.data))
|
||||
|
||||
class EmbedQuery(APIView):
|
||||
|
||||
def post(self, request: Request, model_id):
|
||||
return result.success(
|
||||
ModelApplySerializers(data={'model_id': model_id}).embed_query(request.data))
|
||||
|
||||
class CompressDocuments(APIView):
|
||||
|
||||
def post(self, request: Request, model_id):
|
||||
return result.success(
|
||||
ModelApplySerializers(data={'model_id': model_id}).compress_documents(request.data))
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎虎
|
||||
@file: __init__.py
|
||||
@date:2025/11/5 14:50
|
||||
@desc:
|
||||
"""
|
||||
import os
|
||||
|
||||
if os.environ.get('SERVER_NAME', 'web') == 'local_model':
|
||||
from .model import *
|
||||
else:
|
||||
from .web import *
|
||||
|
|
@ -0,0 +1,11 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎虎
|
||||
@file: auth.py
|
||||
@date:2024/7/9 18:47
|
||||
@desc:
|
||||
"""
|
||||
|
||||
AUTH_HANDLES = [
|
||||
]
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎虎
|
||||
@file: __init__.py.py
|
||||
@date:2025/11/5 14:53
|
||||
@desc:
|
||||
"""
|
||||
import os
|
||||
|
||||
if os.environ.get('SERVER_NAME', 'web') == 'local_model':
|
||||
from .model import *
|
||||
else:
|
||||
from .web import *
|
||||
|
|
@ -0,0 +1,179 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎虎
|
||||
@file: model.py
|
||||
@date:2025/11/5 14:53
|
||||
@desc:
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from ...const import CONFIG, PROJECT_DIR
|
||||
import os
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
# Build paths inside the project like this: BASE_DIR / 'subdir'.
|
||||
BASE_DIR = Path(__file__).resolve().parent.parent.parent
|
||||
|
||||
# Quick-start development settings - unsuitable for production
|
||||
# See https://docs.djangoproject.com/en/4.2/howto/deployment/checklist/
|
||||
|
||||
# SECURITY WARNING: keep the secret key used in production secret!
|
||||
SECRET_KEY = CONFIG.get("SECRET_KEY") or 'django-insecure-zm^1_^i5)3gp^&0io6zg72&z!a*d=9kf9o2%uft+27l)+t(#3e'
|
||||
|
||||
# SECURITY WARNING: don't run with debug turned on in production!
|
||||
DEBUG = CONFIG.get_debug()
|
||||
|
||||
ALLOWED_HOSTS = ['*']
|
||||
|
||||
# Application definition
|
||||
|
||||
INSTALLED_APPS = [
|
||||
'django.contrib.contenttypes',
|
||||
'django.contrib.messages',
|
||||
'django.contrib.staticfiles',
|
||||
'rest_framework',
|
||||
'local_model',
|
||||
]
|
||||
|
||||
MIDDLEWARE = [
|
||||
'django.middleware.locale.LocaleMiddleware',
|
||||
'django.middleware.security.SecurityMiddleware',
|
||||
'django.contrib.sessions.middleware.SessionMiddleware',
|
||||
|
||||
]
|
||||
|
||||
REST_FRAMEWORK = {
|
||||
'EXCEPTION_HANDLER': 'common.exception.handle_exception.handle_exception',
|
||||
'DEFAULT_SCHEMA_CLASS': 'drf_spectacular.openapi.AutoSchema',
|
||||
'DEFAULT_AUTHENTICATION_CLASSES': ['common.auth.authenticate.AnonymousAuthentication']
|
||||
}
|
||||
STATICFILES_DIRS = [(os.path.join(PROJECT_DIR, 'ui', 'dist'))]
|
||||
STATIC_ROOT = os.path.join(BASE_DIR.parent, 'static')
|
||||
ROOT_URLCONF = 'maxkb.urls'
|
||||
APPS_DIR = os.path.join(PROJECT_DIR, 'apps')
|
||||
|
||||
TEMPLATES = [
|
||||
{
|
||||
'BACKEND': 'django.template.backends.django.DjangoTemplates',
|
||||
'DIRS': ["apps/static/admin"],
|
||||
'APP_DIRS': True,
|
||||
'OPTIONS': {
|
||||
'context_processors': [
|
||||
'django.template.context_processors.debug',
|
||||
'django.template.context_processors.request',
|
||||
'django.contrib.auth.context_processors.auth',
|
||||
'django.contrib.messages.context_processors.messages',
|
||||
],
|
||||
},
|
||||
},
|
||||
{"NAME": "CHAT",
|
||||
'BACKEND': 'django.template.backends.django.DjangoTemplates',
|
||||
'DIRS': ["apps/static/chat"],
|
||||
'APP_DIRS': True,
|
||||
'OPTIONS': {
|
||||
'context_processors': [
|
||||
'django.template.context_processors.debug',
|
||||
'django.template.context_processors.request',
|
||||
'django.contrib.auth.context_processors.auth',
|
||||
'django.contrib.messages.context_processors.messages',
|
||||
],
|
||||
},
|
||||
},
|
||||
{"NAME": "DOC",
|
||||
'BACKEND': 'django.template.backends.django.DjangoTemplates',
|
||||
'DIRS': ["apps/static/drf_spectacular_sidecar"],
|
||||
'APP_DIRS': True,
|
||||
'OPTIONS': {
|
||||
'context_processors': [
|
||||
'django.template.context_processors.debug',
|
||||
'django.template.context_processors.request',
|
||||
'django.contrib.auth.context_processors.auth',
|
||||
'django.contrib.messages.context_processors.messages',
|
||||
],
|
||||
},
|
||||
},
|
||||
]
|
||||
SPECTACULAR_SETTINGS = {
|
||||
'TITLE': 'MaxKB API',
|
||||
'DESCRIPTION': _('Intelligent customer service platform'),
|
||||
'VERSION': 'v2',
|
||||
'SERVE_INCLUDE_SCHEMA': False,
|
||||
# OTHER SETTINGS
|
||||
'SWAGGER_UI_DIST': f'{CONFIG.get_admin_path()}/api-doc/swagger-ui-dist', # shorthand to use the sidecar instead
|
||||
'SWAGGER_UI_FAVICON_HREF': f'{CONFIG.get_admin_path()}/api-doc/swagger-ui-dist/favicon-32x32.png',
|
||||
'REDOC_DIST': f'{CONFIG.get_admin_path()}/api-doc/redoc',
|
||||
'SECURITY_DEFINITIONS': {
|
||||
'Bearer': {
|
||||
'type': 'apiKey',
|
||||
'name': 'AUTHORIZATION',
|
||||
'in': 'header',
|
||||
}
|
||||
}
|
||||
}
|
||||
WSGI_APPLICATION = 'maxkb.wsgi.application'
|
||||
|
||||
# Database
|
||||
# https://docs.djangoproject.com/en/4.2/ref/settings/#databases
|
||||
|
||||
DATABASES = {'default': CONFIG.get_db_setting()}
|
||||
|
||||
CACHES = CONFIG.get_cache_setting()
|
||||
|
||||
# Password validation
|
||||
# https://docs.djangoproject.com/en/4.2/ref/settings/#auth-password-validators
|
||||
|
||||
AUTH_PASSWORD_VALIDATORS = [
|
||||
{
|
||||
'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator',
|
||||
},
|
||||
{
|
||||
'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator',
|
||||
},
|
||||
{
|
||||
'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator',
|
||||
},
|
||||
{
|
||||
'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator',
|
||||
},
|
||||
]
|
||||
|
||||
# Internationalization
|
||||
# https://docs.djangoproject.com/en/4.2/topics/i18n/
|
||||
|
||||
LANGUAGE_CODE = CONFIG.get("LANGUAGE_CODE")
|
||||
|
||||
TIME_ZONE = CONFIG.get_time_zone()
|
||||
|
||||
USE_I18N = True
|
||||
|
||||
USE_TZ = True
|
||||
|
||||
# 文件上传配置
|
||||
DATA_UPLOAD_MAX_NUMBER_FILES = 1000
|
||||
|
||||
# 支持的语言
|
||||
LANGUAGES = [
|
||||
('en', 'English'),
|
||||
('zh', '中文简体'),
|
||||
('zh-hant', '中文繁体')
|
||||
]
|
||||
# 翻译文件路径
|
||||
LOCALE_PATHS = [
|
||||
os.path.join(BASE_DIR.parent, 'locales')
|
||||
]
|
||||
|
||||
# Static files (CSS, JavaScript, Images)
|
||||
# https://docs.djangoproject.com/en/4.2/howto/static-files/
|
||||
|
||||
STATIC_URL = 'static/'
|
||||
|
||||
# Default primary key field type
|
||||
# https://docs.djangoproject.com/en/4.2/ref/settings/#default-auto-field
|
||||
|
||||
DEFAULT_AUTO_FIELD = 'django.db.models.BigAutoField'
|
||||
|
||||
edition = 'CE'
|
||||
|
||||
if os.environ.get('MAXKB_REDIS_SENTINEL_SENTINELS') is not None:
|
||||
DJANGO_REDIS_CONNECTION_FACTORY = "django_redis.pool.SentinelConnectionFactory"
|
||||
|
|
@ -1,22 +1,18 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
Django settings for maxkb project.
|
||||
|
||||
Generated by 'django-admin startproject' using Django 4.2.4.
|
||||
|
||||
For more information on this file, see
|
||||
https://docs.djangoproject.com/en/4.2/topics/settings/
|
||||
|
||||
For the full list of settings and their values, see
|
||||
https://docs.djangoproject.com/en/4.2/ref/settings/
|
||||
@project: MaxKB
|
||||
@Author:虎虎
|
||||
@file: web.py
|
||||
@date:2025/11/5 14:53
|
||||
@desc:
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from ..const import CONFIG, PROJECT_DIR
|
||||
from ...const import CONFIG, PROJECT_DIR
|
||||
import os
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
# Build paths inside the project like this: BASE_DIR / 'subdir'.
|
||||
BASE_DIR = Path(__file__).resolve().parent.parent
|
||||
BASE_DIR = Path(__file__).resolve().parent.parent.parent
|
||||
|
||||
# Quick-start development settings - unsuitable for production
|
||||
# See https://docs.djangoproject.com/en/4.2/howto/deployment/checklist/
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB-xpack
|
||||
@Author:虎虎
|
||||
@file: __init__.py.py
|
||||
@date:2025/11/5 14:45
|
||||
@desc:
|
||||
"""
|
||||
import os
|
||||
|
||||
if os.environ.get('SERVER_NAME', 'web') == 'local_model':
|
||||
from .model import *
|
||||
else:
|
||||
from .web import *
|
||||
|
|
@ -0,0 +1,28 @@
|
|||
"""
|
||||
URL configuration for maxkb project.
|
||||
|
||||
The `urlpatterns` list routes URLs to views. For more information please see:
|
||||
https://docs.djangoproject.com/en/4.2/topics/http/urls/
|
||||
Examples:
|
||||
Function views
|
||||
1. Add an import: from my_app import views
|
||||
2. Add a URL to urlpatterns: path('', views.home, name='home')
|
||||
Class-based views
|
||||
1. Add an import: from other_app.views import Home
|
||||
2. Add a URL to urlpatterns: path('', Home.as_view(), name='home')
|
||||
Including another URLconf
|
||||
1. Import the include() function: from django.urls import include, path
|
||||
2. Add a URL to urlpatterns: path('blog/', include('blog.urls'))
|
||||
"""
|
||||
|
||||
from django.urls import path, include
|
||||
|
||||
from maxkb.const import CONFIG
|
||||
|
||||
admin_api_prefix = CONFIG.get_admin_path()[1:] + '/api/'
|
||||
admin_ui_prefix = CONFIG.get_admin_path()
|
||||
chat_api_prefix = CONFIG.get_chat_path()[1:] + '/api/'
|
||||
chat_ui_prefix = CONFIG.get_chat_path()
|
||||
urlpatterns = [
|
||||
path(admin_api_prefix, include("local_model.urls")),
|
||||
]
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎虎
|
||||
@file: __init__.py.py
|
||||
@date:2025/11/5 15:14
|
||||
@desc:
|
||||
"""
|
||||
import os
|
||||
|
||||
if os.environ.get('SERVER_NAME', 'web') == 'local_model':
|
||||
from .model import *
|
||||
else:
|
||||
from .web import *
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎虎
|
||||
@file: model.py
|
||||
@date:2025/11/5 15:14
|
||||
@desc:
|
||||
"""
|
||||
import os
|
||||
|
||||
from django.core.wsgi import get_wsgi_application
|
||||
|
||||
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'maxkb.settings')
|
||||
|
||||
application = get_wsgi_application()
|
||||
|
|
@ -1,12 +1,11 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
WSGI config for maxkb project.
|
||||
|
||||
It exposes the WSGI callable as a module-level variable named ``application``.
|
||||
|
||||
For more information on this file, see
|
||||
https://docs.djangoproject.com/en/4.2/howto/deployment/wsgi/
|
||||
@project: MaxKB
|
||||
@Author:虎虎
|
||||
@file: web.py
|
||||
@date:2025/11/5 15:14
|
||||
@desc:
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from django.core.wsgi import get_wsgi_application
|
||||
|
|
@ -35,4 +34,4 @@ post_handler()
|
|||
|
||||
# 仅在scheduler中启动定时任务,dev local_model celery 不需要
|
||||
if os.environ.get('ENABLE_SCHEDULER') == '1':
|
||||
post_scheduler_handler()
|
||||
post_scheduler_handler()
|
||||
|
|
@ -15,7 +15,7 @@ from common import forms
|
|||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
from models_provider.impl.local_model_provider.model.reranker import LocalBaseReranker
|
||||
from models_provider.impl.local_model_provider.model.reranker import LocalReranker
|
||||
from django.utils.translation import gettext_lazy as _, gettext
|
||||
|
||||
|
||||
|
|
@ -33,7 +33,7 @@ class LocalRerankerCredential(BaseForm, BaseModelCredential):
|
|||
else:
|
||||
return False
|
||||
try:
|
||||
model: LocalBaseReranker = provider.get_model(model_type, model_name, model_credential)
|
||||
model: LocalReranker = provider.get_model(model_type, model_name, model_credential)
|
||||
model.compress_documents([Document(page_content=gettext('Hello'))], gettext('Hello'))
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
|
|
|
|||
|
|
@ -8,15 +8,16 @@
|
|||
"""
|
||||
import os
|
||||
|
||||
from django.utils.translation import gettext as _
|
||||
|
||||
from common.utils.common import get_file_content
|
||||
from maxkb.conf import PROJECT_DIR
|
||||
from models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, ModelInfo, IModelProvider, \
|
||||
ModelInfoManage
|
||||
from models_provider.impl.local_model_provider.credential.embedding import LocalEmbeddingCredential
|
||||
from models_provider.impl.local_model_provider.credential.reranker import LocalRerankerCredential
|
||||
from models_provider.impl.local_model_provider.model.embedding import LocalEmbedding
|
||||
from models_provider.impl.local_model_provider.model.reranker import LocalReranker
|
||||
from maxkb.conf import PROJECT_DIR
|
||||
from django.utils.translation import gettext as _
|
||||
|
||||
embedding_text2vec_base_chinese = ModelInfo('shibing624/text2vec-base-chinese', '', ModelTypeConst.EMBEDDING,
|
||||
LocalEmbeddingCredential(), LocalEmbedding)
|
||||
|
|
|
|||
|
|
@ -1,64 +0,0 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: embedding.py
|
||||
@date:2024/7/11 14:06
|
||||
@desc:
|
||||
"""
|
||||
from typing import Dict, List
|
||||
|
||||
import requests
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from pydantic import BaseModel
|
||||
from langchain_huggingface import HuggingFaceEmbeddings
|
||||
|
||||
from models_provider.base_model_provider import MaxKBBaseModel
|
||||
from maxkb.const import CONFIG
|
||||
|
||||
|
||||
class WebLocalEmbedding(MaxKBBaseModel, BaseModel, Embeddings):
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
pass
|
||||
|
||||
model_id: str = None
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.model_id = kwargs.get('model_id', None)
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}'
|
||||
prefix = CONFIG.get_admin_path()
|
||||
res = requests.post(f'{CONFIG.get("LOCAL_MODEL_PROTOCOL")}://{bind}{prefix}/api/model/{self.model_id}/embed_query',
|
||||
{'text': text})
|
||||
result = res.json()
|
||||
if result.get('code', 500) == 200:
|
||||
return result.get('data')
|
||||
raise Exception(result.get('message'))
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}'
|
||||
prefix = CONFIG.get_admin_path()
|
||||
res = requests.post(f'{CONFIG.get("LOCAL_MODEL_PROTOCOL")}://{bind}/{prefix}/api/model/{self.model_id}/embed_documents',
|
||||
{'texts': texts})
|
||||
result = res.json()
|
||||
if result.get('code', 500) == 200:
|
||||
return result.get('data')
|
||||
raise Exception(result.get('message'))
|
||||
|
||||
|
||||
class LocalEmbedding(MaxKBBaseModel, HuggingFaceEmbeddings):
|
||||
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
if model_kwargs.get('use_local', True):
|
||||
return LocalEmbedding(model_name=model_name, cache_folder=model_credential.get('cache_folder'),
|
||||
model_kwargs={'device': model_credential.get('device')},
|
||||
encode_kwargs={'normalize_embeddings': True}
|
||||
)
|
||||
return WebLocalEmbedding(model_name=model_name, cache_folder=model_credential.get('cache_folder'),
|
||||
model_kwargs={'device': model_credential.get('device')},
|
||||
encode_kwargs={'normalize_embeddings': True},
|
||||
**model_kwargs)
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎虎
|
||||
@file: __init__.py
|
||||
@date:2025/11/5 15:24
|
||||
@desc:
|
||||
"""
|
||||
import os
|
||||
|
||||
if os.environ.get('SERVER_NAME', 'web') == 'local_model':
|
||||
from .model import *
|
||||
else:
|
||||
from .web import *
|
||||
|
|
@ -0,0 +1,26 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎虎
|
||||
@file: model.py
|
||||
@date:2025/11/5 15:26
|
||||
@desc:
|
||||
"""
|
||||
from typing import Dict
|
||||
|
||||
from langchain_huggingface import HuggingFaceEmbeddings
|
||||
|
||||
from models_provider.base_model_provider import MaxKBBaseModel
|
||||
|
||||
|
||||
class LocalEmbedding(MaxKBBaseModel, HuggingFaceEmbeddings):
|
||||
@staticmethod
|
||||
def is_cache_model():
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
return LocalEmbedding(model_name=model_name, cache_folder=model_credential.get('cache_folder'),
|
||||
model_kwargs={'device': model_credential.get('device')},
|
||||
encode_kwargs={'normalize_embeddings': True}
|
||||
)
|
||||
|
|
@ -0,0 +1,54 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎虎
|
||||
@file: web.py
|
||||
@date:2025/11/5 15:24
|
||||
@desc:
|
||||
"""
|
||||
|
||||
from typing import Dict, List
|
||||
|
||||
import requests
|
||||
from anthropic import BaseModel
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
from maxkb.const import CONFIG
|
||||
from models_provider.base_model_provider import MaxKBBaseModel
|
||||
|
||||
|
||||
class LocalEmbedding(MaxKBBaseModel, BaseModel, Embeddings):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.model_id = kwargs.get('model_id', None)
|
||||
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
return LocalEmbedding(model_name=model_name, cache_folder=model_credential.get('cache_folder'),
|
||||
model_kwargs={'device': model_credential.get('device')},
|
||||
encode_kwargs={'normalize_embeddings': True},
|
||||
**model_kwargs)
|
||||
|
||||
model_id: str = None
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}'
|
||||
prefix = CONFIG.get_admin_path()
|
||||
res = requests.post(
|
||||
f'{CONFIG.get("LOCAL_MODEL_PROTOCOL")}://{bind}{prefix}/api/model/{self.model_id}/embed_query',
|
||||
{'text': text})
|
||||
result = res.json()
|
||||
if result.get('code', 500) == 200:
|
||||
return result.get('data')
|
||||
raise Exception(result.get('message'))
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}'
|
||||
prefix = CONFIG.get_admin_path()
|
||||
res = requests.post(
|
||||
f'{CONFIG.get("LOCAL_MODEL_PROTOCOL")}://{bind}/{prefix}/api/model/{self.model_id}/embed_documents',
|
||||
{'texts': texts})
|
||||
result = res.json()
|
||||
if result.get('code', 500) == 200:
|
||||
return result.get('data')
|
||||
raise Exception(result.get('message'))
|
||||
|
|
@ -1,102 +0,0 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: reranker.py.py
|
||||
@date:2024/9/2 16:42
|
||||
@desc:
|
||||
"""
|
||||
from typing import Sequence, Optional, Dict, Any, ClassVar
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from langchain_core.callbacks import Callbacks
|
||||
from langchain_core.documents import BaseDocumentCompressor, Document
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||
|
||||
from models_provider.base_model_provider import MaxKBBaseModel
|
||||
from maxkb.const import CONFIG
|
||||
|
||||
|
||||
class LocalReranker(MaxKBBaseModel):
|
||||
def __init__(self, model_name, top_n=3, cache_dir=None):
|
||||
super().__init__()
|
||||
self.model_name = model_name
|
||||
self.cache_dir = cache_dir
|
||||
self.top_n = top_n
|
||||
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
if model_kwargs.get('use_local', True):
|
||||
return LocalBaseReranker(model_name=model_name, cache_dir=model_credential.get('cache_dir'),
|
||||
model_kwargs={'device': model_credential.get('device', 'cpu')}
|
||||
|
||||
)
|
||||
return WebLocalBaseReranker(model_name=model_name, cache_dir=model_credential.get('cache_dir'),
|
||||
model_kwargs={'device': model_credential.get('device')},
|
||||
**model_kwargs)
|
||||
|
||||
|
||||
class WebLocalBaseReranker(MaxKBBaseModel, BaseDocumentCompressor):
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
pass
|
||||
|
||||
model_id: str = None
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.model_id = kwargs.get('model_id', None)
|
||||
|
||||
def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \
|
||||
Sequence[Document]:
|
||||
if documents is None or len(documents) == 0:
|
||||
return []
|
||||
prefix = CONFIG.get_admin_path()
|
||||
bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}'
|
||||
res = requests.post(
|
||||
f'{CONFIG.get("LOCAL_MODEL_PROTOCOL")}://{bind}{prefix}/api/model/{self.model_id}/compress_documents',
|
||||
json={'documents': [{'page_content': document.page_content, 'metadata': document.metadata} for document in
|
||||
documents], 'query': query}, headers={'Content-Type': 'application/json'})
|
||||
result = res.json()
|
||||
if result.get('code', 500) == 200:
|
||||
return [Document(page_content=document.get('page_content'), metadata=document.get('metadata')) for document
|
||||
in result.get('data')]
|
||||
raise Exception(result.get('message'))
|
||||
|
||||
|
||||
class LocalBaseReranker(MaxKBBaseModel, BaseDocumentCompressor):
|
||||
client: Any = None
|
||||
tokenizer: Any = None
|
||||
model: Optional[str] = None
|
||||
cache_dir: Optional[str] = None
|
||||
model_kwargs: Any = {}
|
||||
|
||||
def __init__(self, model_name, cache_dir=None, **model_kwargs):
|
||||
super().__init__()
|
||||
self.model = model_name
|
||||
self.cache_dir = cache_dir
|
||||
self.model_kwargs = model_kwargs
|
||||
self.client = AutoModelForSequenceClassification.from_pretrained(self.model, cache_dir=self.cache_dir)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model, cache_dir=self.cache_dir)
|
||||
self.client = self.client.to(self.model_kwargs.get('device', 'cpu'))
|
||||
self.client.eval()
|
||||
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
return LocalBaseReranker(model_name, cache_dir=model_credential.get('cache_dir'), **model_kwargs)
|
||||
|
||||
def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \
|
||||
Sequence[Document]:
|
||||
if documents is None or len(documents) == 0:
|
||||
return []
|
||||
with torch.no_grad():
|
||||
inputs = self.tokenizer([[query, document.page_content] for document in documents], padding=True,
|
||||
truncation=True, return_tensors='pt', max_length=512)
|
||||
scores = [torch.sigmoid(s).float().item() for s in
|
||||
self.client(**inputs, return_dict=True).logits.view(-1, ).float()]
|
||||
result = [Document(page_content=documents[index].page_content, metadata={'relevance_score': scores[index]})
|
||||
for index
|
||||
in range(len(documents))]
|
||||
result.sort(key=lambda row: row.metadata.get('relevance_score'), reverse=True)
|
||||
return result
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎虎
|
||||
@file: __init__.py.py
|
||||
@date:2025/11/5 15:30
|
||||
@desc:
|
||||
"""
|
||||
import os
|
||||
|
||||
if os.environ.get('SERVER_NAME', 'web') == 'local_model':
|
||||
from .model import *
|
||||
else:
|
||||
from .web import *
|
||||
|
|
@ -0,0 +1,58 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎虎
|
||||
@file: model.py
|
||||
@date:2025/11/5 15:30
|
||||
@desc:
|
||||
"""
|
||||
|
||||
from typing import Sequence, Optional, Dict, Any
|
||||
|
||||
from langchain_core.callbacks import Callbacks
|
||||
from langchain_core.documents import Document, BaseDocumentCompressor
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||
|
||||
from models_provider.base_model_provider import MaxKBBaseModel
|
||||
|
||||
|
||||
class LocalReranker(MaxKBBaseModel, BaseDocumentCompressor):
|
||||
client: Any = None
|
||||
tokenizer: Any = None
|
||||
model: Optional[str] = None
|
||||
cache_dir: Optional[str] = None
|
||||
model_kwargs: Any = {}
|
||||
|
||||
def __init__(self, model_name, cache_dir=None, **model_kwargs):
|
||||
super().__init__()
|
||||
self.model = model_name
|
||||
self.cache_dir = cache_dir
|
||||
self.model_kwargs = model_kwargs
|
||||
self.client = AutoModelForSequenceClassification.from_pretrained(self.model, cache_dir=self.cache_dir)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model, cache_dir=self.cache_dir)
|
||||
self.client = self.client.to(self.model_kwargs.get('device', 'cpu'))
|
||||
self.client.eval()
|
||||
|
||||
@staticmethod
|
||||
def is_cache_model():
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
return LocalReranker(model_name, cache_dir=model_credential.get('cache_dir'))
|
||||
|
||||
def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \
|
||||
Sequence[Document]:
|
||||
if documents is None or len(documents) == 0:
|
||||
return []
|
||||
import torch
|
||||
with torch.no_grad():
|
||||
inputs = self.tokenizer([[query, document.page_content] for document in documents], padding=True,
|
||||
truncation=True, return_tensors='pt', max_length=512)
|
||||
scores = [torch.sigmoid(s).float().item() for s in
|
||||
self.client(**inputs, return_dict=True).logits.view(-1, ).float()]
|
||||
result = [Document(page_content=documents[index].page_content, metadata={'relevance_score': scores[index]})
|
||||
for index
|
||||
in range(len(documents))]
|
||||
result.sort(key=lambda row: row.metadata.get('relevance_score'), reverse=True)
|
||||
return result
|
||||
|
|
@ -0,0 +1,52 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎虎
|
||||
@file: web.py
|
||||
@date:2025/11/5 15:30
|
||||
@desc:
|
||||
"""
|
||||
from typing import Sequence, Optional, Dict
|
||||
|
||||
import requests
|
||||
from anthropic import BaseModel
|
||||
from langchain_core.callbacks import Callbacks
|
||||
from langchain_core.documents import Document, BaseDocumentCompressor
|
||||
|
||||
from maxkb.const import CONFIG
|
||||
from models_provider.base_model_provider import MaxKBBaseModel
|
||||
|
||||
|
||||
class LocalReranker(MaxKBBaseModel, BaseModel, BaseDocumentCompressor):
|
||||
|
||||
@staticmethod
|
||||
def is_cache_model():
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
return LocalReranker(model_type=model_type, model_name=model_name, model_credential=model_credential,
|
||||
**model_kwargs)
|
||||
|
||||
model_id: str = None
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
print('ssss', kwargs.get('model_id', None))
|
||||
self.model_id = kwargs.get('model_id', None)
|
||||
|
||||
def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \
|
||||
Sequence[Document]:
|
||||
if documents is None or len(documents) == 0:
|
||||
return []
|
||||
prefix = CONFIG.get_admin_path()
|
||||
bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}'
|
||||
res = requests.post(
|
||||
f'{CONFIG.get("LOCAL_MODEL_PROTOCOL")}://{bind}{prefix}/api/model/{self.model_id}/compress_documents',
|
||||
json={'documents': [{'page_content': document.page_content, 'metadata': document.metadata} for document in
|
||||
documents], 'query': query}, headers={'Content-Type': 'application/json'})
|
||||
result = res.json()
|
||||
if result.get('code', 500) == 200:
|
||||
return [Document(page_content=document.get('page_content'), metadata=document.get('metadata')) for document
|
||||
in result.get('data')]
|
||||
raise Exception(result.get('message'))
|
||||
|
|
@ -1,12 +1,12 @@
|
|||
from typing import Sequence, Optional, Any, Dict
|
||||
from typing import Sequence, Optional, Dict
|
||||
|
||||
from langchain_community.embeddings import OllamaEmbeddings
|
||||
from langchain_core.callbacks import Callbacks
|
||||
from langchain_core.documents import Document
|
||||
from models_provider.base_model_provider import MaxKBBaseModel
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from models_provider.base_model_provider import MaxKBBaseModel
|
||||
|
||||
|
||||
class OllamaReranker(MaxKBBaseModel, OllamaEmbeddings, BaseModel):
|
||||
top_n: Optional[int] = Field(3, description="Number of top documents to return")
|
||||
|
|
@ -22,6 +22,7 @@ class OllamaReranker(MaxKBBaseModel, OllamaEmbeddings, BaseModel):
|
|||
|
||||
def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \
|
||||
Sequence[Document]:
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
"""Rank documents based on their similarity to the query.
|
||||
|
||||
Args:
|
||||
|
|
@ -37,7 +38,7 @@ class OllamaReranker(MaxKBBaseModel, OllamaEmbeddings, BaseModel):
|
|||
document_embeddings = self.embed_documents(documents)
|
||||
# 计算相似度
|
||||
similarities = cosine_similarity([query_embedding], document_embeddings)[0]
|
||||
ranked_docs = [(doc,_) for _, doc in sorted(zip(similarities, documents), reverse=True)][:self.top_n]
|
||||
ranked_docs = [(doc, _) for _, doc in sorted(zip(similarities, documents), reverse=True)][:self.top_n]
|
||||
return [
|
||||
Document(
|
||||
page_content=doc, # 第一个值是文档内容
|
||||
|
|
@ -45,5 +46,3 @@ class OllamaReranker(MaxKBBaseModel, OllamaEmbeddings, BaseModel):
|
|||
)
|
||||
for doc, score in ranked_docs
|
||||
]
|
||||
|
||||
|
||||
|
|
|
|||
9
main.py
9
main.py
|
|
@ -13,7 +13,6 @@ APP_DIR = os.path.join(BASE_DIR, 'apps')
|
|||
os.chdir(BASE_DIR)
|
||||
sys.path.insert(0, APP_DIR)
|
||||
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "maxkb.settings")
|
||||
django.setup()
|
||||
|
||||
|
||||
def collect_static():
|
||||
|
|
@ -74,7 +73,6 @@ def dev():
|
|||
elif services.__contains__('celery'):
|
||||
management.call_command('celery', 'celery')
|
||||
elif services.__contains__('local_model'):
|
||||
os.environ.setdefault('SERVER_NAME', 'local_model')
|
||||
from maxkb.const import CONFIG
|
||||
bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}'
|
||||
management.call_command('runserver', bind)
|
||||
|
|
@ -108,6 +106,12 @@ if __name__ == '__main__':
|
|||
parser.add_argument('-f', '--force', nargs="?", const=True)
|
||||
args = parser.parse_args()
|
||||
action = args.action
|
||||
services = args.services if isinstance(args.services, list) else args.services
|
||||
if services.__contains__('web'):
|
||||
os.environ.setdefault('SERVER_NAME', 'web')
|
||||
elif services.__contains__('local_model'):
|
||||
os.environ.setdefault('SERVER_NAME', 'local_model')
|
||||
django.setup()
|
||||
if action == "upgrade_db":
|
||||
perform_db_migrate()
|
||||
elif action == "collect_static":
|
||||
|
|
@ -120,4 +124,3 @@ if __name__ == '__main__':
|
|||
collect_static()
|
||||
perform_db_migrate()
|
||||
start_services()
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue