mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
feat: Local model validation using local_madel (#4330)
This commit is contained in:
parent
3f6453eb3a
commit
f457588cd5
|
|
@ -74,7 +74,6 @@ class ModelManage:
|
|||
|
||||
|
||||
def get_local_model(model, **kwargs):
|
||||
# system_setting = QuerySet(SystemSetting).filter(type=1).first()
|
||||
return LocalModelProvider().get_model(model.model_type, model.model_name,
|
||||
json.loads(
|
||||
rsa_long_decrypt(model.credential)),
|
||||
|
|
@ -111,6 +110,21 @@ class CompressDocuments(serializers.Serializer):
|
|||
query = serializers.CharField(required=True, label=_('query'))
|
||||
|
||||
|
||||
class ValidateModelSerializers(serializers.Serializer):
|
||||
model_name = serializers.CharField(required=True, label=_('model_name'))
|
||||
|
||||
model_type = serializers.CharField(required=True, label=_('model_type'))
|
||||
|
||||
model_credential = serializers.DictField(required=True, label="credential")
|
||||
|
||||
def validate_model(self, with_valid=True):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
LocalModelProvider().is_valid_credential(self.data.get('model_type'), self.data.get('model_name'),
|
||||
self.data.get('model_credential'), model_params={},
|
||||
raise_exception=True)
|
||||
|
||||
|
||||
class ModelApplySerializers(serializers.Serializer):
|
||||
model_id = serializers.UUIDField(required=True, label=_('model id'))
|
||||
|
||||
|
|
@ -138,3 +152,9 @@ class ModelApplySerializers(serializers.Serializer):
|
|||
return [{'page_content': d.page_content, 'metadata': d.metadata} for d in model.compress_documents(
|
||||
[Document(page_content=document.get('page_content'), metadata=document.get('metadata')) for document in
|
||||
instance.get('documents')], instance.get('query'))]
|
||||
|
||||
def unload(self, with_valid=True):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
ModelManage.delete_key(self.data.get('model_id'))
|
||||
return True
|
||||
|
|
|
|||
|
|
@ -7,7 +7,9 @@ from . import views
|
|||
app_name = "local_model"
|
||||
# @formatter:off
|
||||
urlpatterns = [
|
||||
path('model/validate', views.LocalModelApply.Validate.as_view()),
|
||||
path('model/<str:model_id>/embed_documents', views.LocalModelApply.EmbedDocuments.as_view()),
|
||||
path('model/<str:model_id>/embed_query', views.LocalModelApply.EmbedQuery.as_view()),
|
||||
path('model/<str:model_id>/compress_documents', views.LocalModelApply.CompressDocuments.as_view()),
|
||||
path('model/<str:model_id>/unload', views.LocalModelApply.Unload.as_view()),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from urllib.request import Request
|
|||
from rest_framework.views import APIView
|
||||
|
||||
from common.result import result
|
||||
from local_model.serializers.model_apply_serializers import ModelApplySerializers
|
||||
from local_model.serializers.model_apply_serializers import ModelApplySerializers, ValidateModelSerializers
|
||||
|
||||
|
||||
class LocalModelApply(APIView):
|
||||
|
|
@ -32,3 +32,12 @@ class LocalModelApply(APIView):
|
|||
def post(self, request: Request, model_id):
|
||||
return result.success(
|
||||
ModelApplySerializers(data={'model_id': model_id}).compress_documents(request.data))
|
||||
|
||||
class Unload(APIView):
|
||||
def post(self, request: Request, model_id):
|
||||
return result.success(
|
||||
ModelApplySerializers(data={'model_id': model_id}).compress_documents(request.data))
|
||||
|
||||
class Validate(APIView):
|
||||
def post(self, request: Request):
|
||||
return result.success(ValidateModelSerializers(data=request.data).validate_model())
|
||||
|
|
|
|||
|
|
@ -0,0 +1,14 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎虎
|
||||
@file: __init__.py.py
|
||||
@date:2025/11/7 14:02
|
||||
@desc:
|
||||
"""
|
||||
import os
|
||||
|
||||
if os.environ.get('SERVER_NAME', 'web') == 'local_model':
|
||||
from .model import *
|
||||
else:
|
||||
from .web import *
|
||||
|
|
@ -1,9 +1,9 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: embedding.py
|
||||
@date:2024/7/11 11:06
|
||||
@Author:虎虎
|
||||
@file: model.py.py
|
||||
@date:2025/11/7 14:02
|
||||
@desc:
|
||||
"""
|
||||
import traceback
|
||||
|
|
@ -0,0 +1,37 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎虎
|
||||
@file: web.py
|
||||
@date:2025/11/7 14:03
|
||||
@desc:
|
||||
"""
|
||||
from typing import Dict
|
||||
|
||||
import requests
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from common import forms
|
||||
from common.forms import BaseForm
|
||||
from maxkb.const import CONFIG
|
||||
from models_provider.base_model_provider import BaseModelCredential
|
||||
|
||||
|
||||
class LocalEmbeddingCredential(BaseForm, BaseModelCredential):
|
||||
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||
raise_exception=False):
|
||||
bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}'
|
||||
prefix = CONFIG.get_admin_path()
|
||||
res = requests.post(
|
||||
f'{CONFIG.get("LOCAL_MODEL_PROTOCOL")}://{bind}{prefix}/api/model/validate',
|
||||
json={'model_name': model_name, 'model_type': model_type, 'model_credential': model_credential})
|
||||
result = res.json()
|
||||
if result.get('code', 500) == 200:
|
||||
return result.get('data')
|
||||
raise Exception(result.get('message'))
|
||||
|
||||
def encryption_dict(self, model: Dict[str, object]):
|
||||
return model
|
||||
|
||||
cache_folder = forms.TextInputField(_('Model catalog'), required=True)
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎虎
|
||||
@file: __init__.py.py
|
||||
@date:2025/11/7 14:22
|
||||
@desc:
|
||||
"""
|
||||
import os
|
||||
|
||||
if os.environ.get('SERVER_NAME', 'web') == 'local_model':
|
||||
from .model import *
|
||||
else:
|
||||
from .web import *
|
||||
|
|
@ -1,9 +1,9 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: reranker.py
|
||||
@date:2024/9/3 14:33
|
||||
@Author:虎虎
|
||||
@file: model.py
|
||||
@date:2025/11/7 14:23
|
||||
@desc:
|
||||
"""
|
||||
import traceback
|
||||
|
|
@ -0,0 +1,37 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎虎
|
||||
@file: web.py
|
||||
@date:2025/11/7 14:23
|
||||
@desc:
|
||||
"""
|
||||
from typing import Dict
|
||||
|
||||
import requests
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from common import forms
|
||||
from common.forms import BaseForm
|
||||
from maxkb.const import CONFIG
|
||||
from models_provider.base_model_provider import BaseModelCredential
|
||||
|
||||
|
||||
class LocalRerankerCredential(BaseForm, BaseModelCredential):
|
||||
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||
raise_exception=False):
|
||||
bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}'
|
||||
prefix = CONFIG.get_admin_path()
|
||||
res = requests.post(
|
||||
f'{CONFIG.get("LOCAL_MODEL_PROTOCOL")}://{bind}{prefix}/api/model/validate',
|
||||
json={'model_name': model_name, 'model_type': model_type, 'model_credential': model_credential})
|
||||
result = res.json()
|
||||
if result.get('code', 500) == 200:
|
||||
return result.get('data')
|
||||
raise Exception(result.get('message'))
|
||||
|
||||
def encryption_dict(self, model: Dict[str, object]):
|
||||
return model
|
||||
|
||||
cache_folder = forms.TextInputField(_('Model catalog'), required=True)
|
||||
Loading…
Reference in New Issue