mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
chore: refactor embedding model parameter handling
This commit is contained in:
parent
18543a0703
commit
25d4e3d502
|
|
@ -1,6 +1,5 @@
|
|||
# coding=utf-8
|
||||
|
||||
import logging
|
||||
import traceback
|
||||
from typing import List
|
||||
|
||||
|
|
@ -14,12 +13,11 @@ from common.event import ListenerManagement, UpdateProblemArgs, UpdateEmbeddingK
|
|||
from common.utils.logger import maxkb_logger
|
||||
from knowledge.models import Document, TaskType, State
|
||||
from knowledge.serializers.common import drop_knowledge_index
|
||||
from models_provider.tools import get_model
|
||||
from models_provider.models import Model
|
||||
from models_provider.tools import get_model
|
||||
from ops import celery_app
|
||||
|
||||
|
||||
|
||||
def get_embedding_model(model_id, exception_handler=lambda e: maxkb_logger.error(
|
||||
_('Failed to obtain vector model: {error} {traceback}').format(
|
||||
error=str(e),
|
||||
|
|
@ -28,7 +26,20 @@ def get_embedding_model(model_id, exception_handler=lambda e: maxkb_logger.error
|
|||
try:
|
||||
model = QuerySet(Model).filter(id=model_id).first()
|
||||
|
||||
s = {p.get('field'): p.get('default_value') for p in model.model_params_form if p.get('default_value') is not None}
|
||||
def convert_to_int(value):
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
return value
|
||||
return value
|
||||
|
||||
s = {
|
||||
p.get('field'): convert_to_int(p.get('default_value'))
|
||||
for p in model.model_params_form
|
||||
if p.get('default_value') is not None
|
||||
}
|
||||
|
||||
embedding_model = ModelManage.get_model(model_id, lambda _id: get_model(model, **{**s}))
|
||||
except Exception as e:
|
||||
exception_handler(e)
|
||||
|
|
|
|||
Loading…
Reference in New Issue