From 25d4e3d5022a0c9a995ff3fecbb9a752934fd92a Mon Sep 17 00:00:00 2001 From: CaptainB Date: Mon, 27 Oct 2025 12:50:54 +0800 Subject: [PATCH] chore: refactor embedding model parameter handling --- apps/knowledge/task/embedding.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/apps/knowledge/task/embedding.py b/apps/knowledge/task/embedding.py index 79ce9ed01..fd6e68a0e 100644 --- a/apps/knowledge/task/embedding.py +++ b/apps/knowledge/task/embedding.py @@ -1,6 +1,5 @@ # coding=utf-8 -import logging import traceback from typing import List @@ -14,12 +13,11 @@ from common.event import ListenerManagement, UpdateProblemArgs, UpdateEmbeddingK from common.utils.logger import maxkb_logger from knowledge.models import Document, TaskType, State from knowledge.serializers.common import drop_knowledge_index -from models_provider.tools import get_model from models_provider.models import Model +from models_provider.tools import get_model from ops import celery_app - def get_embedding_model(model_id, exception_handler=lambda e: maxkb_logger.error( _('Failed to obtain vector model: {error} {traceback}').format( error=str(e), @@ -28,7 +26,20 @@ def get_embedding_model(model_id, exception_handler=lambda e: maxkb_logger.error try: model = QuerySet(Model).filter(id=model_id).first() - s = {p.get('field'): p.get('default_value') for p in model.model_params_form if p.get('default_value') is not None} + def convert_to_int(value): + if isinstance(value, str): + try: + return int(value) + except ValueError: + return value + return value + + s = { + p.get('field'): convert_to_int(p.get('default_value')) + for p in model.model_params_form + if p.get('default_value') is not None + } + embedding_model = ModelManage.get_model(model_id, lambda _id: get_model(model, **{**s})) except Exception as e: exception_handler(e)