From f6ebaa7cac607188da2b81c2294df0bc42fbe23c Mon Sep 17 00:00:00 2001 From: shaohuzhang1 <80892890+shaohuzhang1@users.noreply.github.com> Date: Fri, 7 Nov 2025 14:43:40 +0800 Subject: [PATCH] feat: Local model validation using local_madel (#4330) --- .../serializers/model_apply_serializers.py | 22 ++++++++++- apps/local_model/urls.py | 2 + apps/local_model/views/model_apply.py | 11 +++++- .../credential/embedding/__init__.py | 14 +++++++ .../{embedding.py => embedding/model.py} | 6 +-- .../credential/embedding/web.py | 37 +++++++++++++++++++ .../credential/reranker/__init__.py | 14 +++++++ .../{reranker.py => reranker/model.py} | 6 +-- .../credential/reranker/web.py | 37 +++++++++++++++++++ 9 files changed, 141 insertions(+), 8 deletions(-) create mode 100644 apps/models_provider/impl/local_model_provider/credential/embedding/__init__.py rename apps/models_provider/impl/local_model_provider/credential/{embedding.py => embedding/model.py} (96%) create mode 100644 apps/models_provider/impl/local_model_provider/credential/embedding/web.py create mode 100644 apps/models_provider/impl/local_model_provider/credential/reranker/__init__.py rename apps/models_provider/impl/local_model_provider/credential/{reranker.py => reranker/model.py} (96%) create mode 100644 apps/models_provider/impl/local_model_provider/credential/reranker/web.py diff --git a/apps/local_model/serializers/model_apply_serializers.py b/apps/local_model/serializers/model_apply_serializers.py index cf57c2fef..5ecb2260c 100644 --- a/apps/local_model/serializers/model_apply_serializers.py +++ b/apps/local_model/serializers/model_apply_serializers.py @@ -74,7 +74,6 @@ class ModelManage: def get_local_model(model, **kwargs): - # system_setting = QuerySet(SystemSetting).filter(type=1).first() return LocalModelProvider().get_model(model.model_type, model.model_name, json.loads( rsa_long_decrypt(model.credential)), @@ -111,6 +110,21 @@ class CompressDocuments(serializers.Serializer): query = serializers.CharField(required=True, label=_('query')) +class ValidateModelSerializers(serializers.Serializer): + model_name = serializers.CharField(required=True, label=_('model_name')) + + model_type = serializers.CharField(required=True, label=_('model_type')) + + model_credential = serializers.DictField(required=True, label="credential") + + def validate_model(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + LocalModelProvider().is_valid_credential(self.data.get('model_type'), self.data.get('model_name'), + self.data.get('model_credential'), model_params={}, + raise_exception=True) + + class ModelApplySerializers(serializers.Serializer): model_id = serializers.UUIDField(required=True, label=_('model id')) @@ -138,3 +152,9 @@ class ModelApplySerializers(serializers.Serializer): return [{'page_content': d.page_content, 'metadata': d.metadata} for d in model.compress_documents( [Document(page_content=document.get('page_content'), metadata=document.get('metadata')) for document in instance.get('documents')], instance.get('query'))] + + def unload(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + ModelManage.delete_key(self.data.get('model_id')) + return True diff --git a/apps/local_model/urls.py b/apps/local_model/urls.py index cc47946e0..a9c060254 100644 --- a/apps/local_model/urls.py +++ b/apps/local_model/urls.py @@ -7,7 +7,9 @@ from . import views app_name = "local_model" # @formatter:off urlpatterns = [ + path('model/validate', views.LocalModelApply.Validate.as_view()), path('model//embed_documents', views.LocalModelApply.EmbedDocuments.as_view()), path('model//embed_query', views.LocalModelApply.EmbedQuery.as_view()), path('model//compress_documents', views.LocalModelApply.CompressDocuments.as_view()), + path('model//unload', views.LocalModelApply.Unload.as_view()), ] diff --git a/apps/local_model/views/model_apply.py b/apps/local_model/views/model_apply.py index 218d4f091..98c07dd74 100644 --- a/apps/local_model/views/model_apply.py +++ b/apps/local_model/views/model_apply.py @@ -11,7 +11,7 @@ from urllib.request import Request from rest_framework.views import APIView from common.result import result -from local_model.serializers.model_apply_serializers import ModelApplySerializers +from local_model.serializers.model_apply_serializers import ModelApplySerializers, ValidateModelSerializers class LocalModelApply(APIView): @@ -32,3 +32,12 @@ class LocalModelApply(APIView): def post(self, request: Request, model_id): return result.success( ModelApplySerializers(data={'model_id': model_id}).compress_documents(request.data)) + + class Unload(APIView): + def post(self, request: Request, model_id): + return result.success( + ModelApplySerializers(data={'model_id': model_id}).compress_documents(request.data)) + + class Validate(APIView): + def post(self, request: Request): + return result.success(ValidateModelSerializers(data=request.data).validate_model()) diff --git a/apps/models_provider/impl/local_model_provider/credential/embedding/__init__.py b/apps/models_provider/impl/local_model_provider/credential/embedding/__init__.py new file mode 100644 index 000000000..29828bb74 --- /dev/null +++ b/apps/models_provider/impl/local_model_provider/credential/embedding/__init__.py @@ -0,0 +1,14 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎虎 + @file: __init__.py.py + @date:2025/11/7 14:02 + @desc: +""" +import os + +if os.environ.get('SERVER_NAME', 'web') == 'local_model': + from .model import * +else: + from .web import * diff --git a/apps/models_provider/impl/local_model_provider/credential/embedding.py b/apps/models_provider/impl/local_model_provider/credential/embedding/model.py similarity index 96% rename from apps/models_provider/impl/local_model_provider/credential/embedding.py rename to apps/models_provider/impl/local_model_provider/credential/embedding/model.py index 9d656ad98..d9ec4c3da 100644 --- a/apps/models_provider/impl/local_model_provider/credential/embedding.py +++ b/apps/models_provider/impl/local_model_provider/credential/embedding/model.py @@ -1,9 +1,9 @@ # coding=utf-8 """ @project: MaxKB - @Author:虎 - @file: embedding.py - @date:2024/7/11 11:06 + @Author:虎虎 + @file: model.py.py + @date:2025/11/7 14:02 @desc: """ import traceback diff --git a/apps/models_provider/impl/local_model_provider/credential/embedding/web.py b/apps/models_provider/impl/local_model_provider/credential/embedding/web.py new file mode 100644 index 000000000..4695d141c --- /dev/null +++ b/apps/models_provider/impl/local_model_provider/credential/embedding/web.py @@ -0,0 +1,37 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎虎 + @file: web.py + @date:2025/11/7 14:03 + @desc: +""" +from typing import Dict + +import requests +from django.utils.translation import gettext_lazy as _ + +from common import forms +from common.forms import BaseForm +from maxkb.const import CONFIG +from models_provider.base_model_provider import BaseModelCredential + + +class LocalEmbeddingCredential(BaseForm, BaseModelCredential): + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, + raise_exception=False): + bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}' + prefix = CONFIG.get_admin_path() + res = requests.post( + f'{CONFIG.get("LOCAL_MODEL_PROTOCOL")}://{bind}{prefix}/api/model/validate', + json={'model_name': model_name, 'model_type': model_type, 'model_credential': model_credential}) + result = res.json() + if result.get('code', 500) == 200: + return result.get('data') + raise Exception(result.get('message')) + + def encryption_dict(self, model: Dict[str, object]): + return model + + cache_folder = forms.TextInputField(_('Model catalog'), required=True) diff --git a/apps/models_provider/impl/local_model_provider/credential/reranker/__init__.py b/apps/models_provider/impl/local_model_provider/credential/reranker/__init__.py new file mode 100644 index 000000000..f9ec12bc5 --- /dev/null +++ b/apps/models_provider/impl/local_model_provider/credential/reranker/__init__.py @@ -0,0 +1,14 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎虎 + @file: __init__.py.py + @date:2025/11/7 14:22 + @desc: +""" +import os + +if os.environ.get('SERVER_NAME', 'web') == 'local_model': + from .model import * +else: + from .web import * diff --git a/apps/models_provider/impl/local_model_provider/credential/reranker.py b/apps/models_provider/impl/local_model_provider/credential/reranker/model.py similarity index 96% rename from apps/models_provider/impl/local_model_provider/credential/reranker.py rename to apps/models_provider/impl/local_model_provider/credential/reranker/model.py index 46f8ebca2..85b2abce9 100644 --- a/apps/models_provider/impl/local_model_provider/credential/reranker.py +++ b/apps/models_provider/impl/local_model_provider/credential/reranker/model.py @@ -1,9 +1,9 @@ # coding=utf-8 """ @project: MaxKB - @Author:虎 - @file: reranker.py - @date:2024/9/3 14:33 + @Author:虎虎 + @file: model.py + @date:2025/11/7 14:23 @desc: """ import traceback diff --git a/apps/models_provider/impl/local_model_provider/credential/reranker/web.py b/apps/models_provider/impl/local_model_provider/credential/reranker/web.py new file mode 100644 index 000000000..bc86982bf --- /dev/null +++ b/apps/models_provider/impl/local_model_provider/credential/reranker/web.py @@ -0,0 +1,37 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎虎 + @file: web.py + @date:2025/11/7 14:23 + @desc: +""" +from typing import Dict + +import requests +from django.utils.translation import gettext_lazy as _ + +from common import forms +from common.forms import BaseForm +from maxkb.const import CONFIG +from models_provider.base_model_provider import BaseModelCredential + + +class LocalRerankerCredential(BaseForm, BaseModelCredential): + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, + raise_exception=False): + bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}' + prefix = CONFIG.get_admin_path() + res = requests.post( + f'{CONFIG.get("LOCAL_MODEL_PROTOCOL")}://{bind}{prefix}/api/model/validate', + json={'model_name': model_name, 'model_type': model_type, 'model_credential': model_credential}) + result = res.json() + if result.get('code', 500) == 200: + return result.get('data') + raise Exception(result.get('message')) + + def encryption_dict(self, model: Dict[str, object]): + return model + + cache_folder = forms.TextInputField(_('Model catalog'), required=True)