feat: Local model validation using local_madel (#4330)

This commit is contained in:
shaohuzhang1 2025-11-07 14:43:40 +08:00 committed by CaptainB
parent bb58bbbf46
commit f6ebaa7cac
9 changed files with 141 additions and 8 deletions

View File

@ -74,7 +74,6 @@ class ModelManage:
def get_local_model(model, **kwargs):
# system_setting = QuerySet(SystemSetting).filter(type=1).first()
return LocalModelProvider().get_model(model.model_type, model.model_name,
json.loads(
rsa_long_decrypt(model.credential)),
@ -111,6 +110,21 @@ class CompressDocuments(serializers.Serializer):
query = serializers.CharField(required=True, label=_('query'))
class ValidateModelSerializers(serializers.Serializer):
model_name = serializers.CharField(required=True, label=_('model_name'))
model_type = serializers.CharField(required=True, label=_('model_type'))
model_credential = serializers.DictField(required=True, label="credential")
def validate_model(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
LocalModelProvider().is_valid_credential(self.data.get('model_type'), self.data.get('model_name'),
self.data.get('model_credential'), model_params={},
raise_exception=True)
class ModelApplySerializers(serializers.Serializer):
model_id = serializers.UUIDField(required=True, label=_('model id'))
@ -138,3 +152,9 @@ class ModelApplySerializers(serializers.Serializer):
return [{'page_content': d.page_content, 'metadata': d.metadata} for d in model.compress_documents(
[Document(page_content=document.get('page_content'), metadata=document.get('metadata')) for document in
instance.get('documents')], instance.get('query'))]
def unload(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
ModelManage.delete_key(self.data.get('model_id'))
return True

View File

@ -7,7 +7,9 @@ from . import views
app_name = "local_model"
# @formatter:off
urlpatterns = [
path('model/validate', views.LocalModelApply.Validate.as_view()),
path('model/<str:model_id>/embed_documents', views.LocalModelApply.EmbedDocuments.as_view()),
path('model/<str:model_id>/embed_query', views.LocalModelApply.EmbedQuery.as_view()),
path('model/<str:model_id>/compress_documents', views.LocalModelApply.CompressDocuments.as_view()),
path('model/<str:model_id>/unload', views.LocalModelApply.Unload.as_view()),
]

View File

@ -11,7 +11,7 @@ from urllib.request import Request
from rest_framework.views import APIView
from common.result import result
from local_model.serializers.model_apply_serializers import ModelApplySerializers
from local_model.serializers.model_apply_serializers import ModelApplySerializers, ValidateModelSerializers
class LocalModelApply(APIView):
@ -32,3 +32,12 @@ class LocalModelApply(APIView):
def post(self, request: Request, model_id):
return result.success(
ModelApplySerializers(data={'model_id': model_id}).compress_documents(request.data))
class Unload(APIView):
def post(self, request: Request, model_id):
return result.success(
ModelApplySerializers(data={'model_id': model_id}).compress_documents(request.data))
class Validate(APIView):
def post(self, request: Request):
return result.success(ValidateModelSerializers(data=request.data).validate_model())

View File

@ -0,0 +1,14 @@
# coding=utf-8
"""
@project: MaxKB
@Author虎虎
@file __init__.py.py
@date2025/11/7 14:02
@desc:
"""
import os
if os.environ.get('SERVER_NAME', 'web') == 'local_model':
from .model import *
else:
from .web import *

View File

@ -1,9 +1,9 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file embedding.py
@date2024/7/11 11:06
@Author
@file model.py.py
@date2025/11/7 14:02
@desc:
"""
import traceback

View File

@ -0,0 +1,37 @@
# coding=utf-8
"""
@project: MaxKB
@Author虎虎
@file web.py
@date2025/11/7 14:03
@desc:
"""
from typing import Dict
import requests
from django.utils.translation import gettext_lazy as _
from common import forms
from common.forms import BaseForm
from maxkb.const import CONFIG
from models_provider.base_model_provider import BaseModelCredential
class LocalEmbeddingCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
raise_exception=False):
bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}'
prefix = CONFIG.get_admin_path()
res = requests.post(
f'{CONFIG.get("LOCAL_MODEL_PROTOCOL")}://{bind}{prefix}/api/model/validate',
json={'model_name': model_name, 'model_type': model_type, 'model_credential': model_credential})
result = res.json()
if result.get('code', 500) == 200:
return result.get('data')
raise Exception(result.get('message'))
def encryption_dict(self, model: Dict[str, object]):
return model
cache_folder = forms.TextInputField(_('Model catalog'), required=True)

View File

@ -0,0 +1,14 @@
# coding=utf-8
"""
@project: MaxKB
@Author虎虎
@file __init__.py.py
@date2025/11/7 14:22
@desc:
"""
import os
if os.environ.get('SERVER_NAME', 'web') == 'local_model':
from .model import *
else:
from .web import *

View File

@ -1,9 +1,9 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file reranker.py
@date2024/9/3 14:33
@Author
@file model.py
@date2025/11/7 14:23
@desc:
"""
import traceback

View File

@ -0,0 +1,37 @@
# coding=utf-8
"""
@project: MaxKB
@Author虎虎
@file web.py
@date2025/11/7 14:23
@desc:
"""
from typing import Dict
import requests
from django.utils.translation import gettext_lazy as _
from common import forms
from common.forms import BaseForm
from maxkb.const import CONFIG
from models_provider.base_model_provider import BaseModelCredential
class LocalRerankerCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
raise_exception=False):
bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}'
prefix = CONFIG.get_admin_path()
res = requests.post(
f'{CONFIG.get("LOCAL_MODEL_PROTOCOL")}://{bind}{prefix}/api/model/validate',
json={'model_name': model_name, 'model_type': model_type, 'model_credential': model_credential})
result = res.json()
if result.get('code', 500) == 200:
return result.get('data')
raise Exception(result.get('message'))
def encryption_dict(self, model: Dict[str, object]):
return model
cache_folder = forms.TextInputField(_('Model catalog'), required=True)