refactor: replace get_embedding_model_default_params with get_model_default_params

This commit is contained in:
CaptainB 2025-10-30 17:12:30 +08:00
parent 1b08643b98
commit d0cd6d657a
4 changed files with 25 additions and 26 deletions

View File

@ -22,10 +22,9 @@ from common.db.search import native_search
from common.utils.common import get_file_content
from knowledge.models import Paragraph, Knowledge
from knowledge.models import SearchMode
from knowledge.serializers.common import get_embedding_model_default_params
from maxkb.conf import PROJECT_DIR
from models_provider.models import Model
from models_provider.tools import get_model, get_model_by_id
from models_provider.tools import get_model, get_model_by_id, get_model_default_params
def reset_meta(meta):
@ -65,7 +64,7 @@ class BaseSearchDatasetStep(ISearchDatasetStep):
if model.model_type != "EMBEDDING":
raise Exception(_("Model does not exist"))
self.context['model_name'] = model.name
default_params = get_embedding_model_default_params(model)
default_params = get_model_default_params(model)
embedding_model = ModelManage.get_model(model_id, lambda _id: get_model(model, **{**default_params}))
embedding_value = embedding_model.embed_query(exec_problem_text)
vector = VectorStore.get_embedding_vector()

View File

@ -26,7 +26,7 @@ from common.utils.logger import maxkb_logger
from knowledge.models import Document
from knowledge.models import Paragraph, Problem, ProblemParagraphMapping, Knowledge, File
from maxkb.conf import PROJECT_DIR
from models_provider.tools import get_model
from models_provider.tools import get_model, get_model_default_params
class MetaSerializer(serializers.Serializer):
@ -112,21 +112,6 @@ class ProblemParagraphManage:
], problem_paragraph_mapping_list
return result
def get_embedding_model_default_params(model):
def convert_to_int(value):
if isinstance(value, str):
try:
return int(value)
except ValueError:
return value
return value
return {
p.get('field'): convert_to_int(p.get('default_value'))
for p in model.model_params_form
if p.get('default_value') is not None
}
def get_embedding_model_by_knowledge_id_list(knowledge_id_list: List):
knowledge_list = QuerySet(Knowledge).filter(id__in=knowledge_id_list)
@ -135,7 +120,7 @@ def get_embedding_model_by_knowledge_id_list(knowledge_id_list: List):
if len(knowledge_list) == 0:
raise Exception(_('Knowledge base setting error, please reset the knowledge base'))
default_params = get_embedding_model_default_params(knowledge_list[0].embedding_model)
default_params = get_model_default_params(knowledge_list[0].embedding_model)
return ModelManage.get_model(
str(knowledge_list[0].embedding_model_id),
@ -146,14 +131,14 @@ def get_embedding_model_by_knowledge_id_list(knowledge_id_list: List):
def get_embedding_model_by_knowledge_id(knowledge_id: str):
knowledge = QuerySet(Knowledge).select_related('embedding_model').filter(id=knowledge_id).first()
default_params = get_embedding_model_default_params(knowledge.embedding_model)
default_params = get_model_default_params(knowledge.embedding_model)
return ModelManage.get_model(str(knowledge.embedding_model_id),
lambda _id: get_model(knowledge.embedding_model, **{**default_params}))
def get_embedding_model_by_knowledge(knowledge):
default_params = get_embedding_model_default_params(knowledge.embedding_model)
default_params = get_model_default_params(knowledge.embedding_model)
return ModelManage.get_model(str(knowledge.embedding_model_id),
lambda _id: get_model(knowledge.embedding_model, **{**default_params}))

View File

@ -12,9 +12,9 @@ from common.event import ListenerManagement, UpdateProblemArgs, UpdateEmbeddingK
UpdateEmbeddingDocumentIdArgs
from common.utils.logger import maxkb_logger
from knowledge.models import Document, TaskType, State
from knowledge.serializers.common import drop_knowledge_index, get_embedding_model_default_params
from knowledge.serializers.common import drop_knowledge_index
from models_provider.models import Model
from models_provider.tools import get_model
from models_provider.tools import get_model, get_model_default_params
from ops import celery_app
@ -26,7 +26,7 @@ def get_embedding_model(model_id, exception_handler=lambda e: maxkb_logger.error
try:
model = QuerySet(Model).filter(id=model_id).first()
default_params = get_embedding_model_default_params(model)
default_params = get_model_default_params(model)
embedding_model = ModelManage.get_model(model_id, lambda _id: get_model(model, **{**default_params}))
except Exception as e:

View File

@ -115,6 +115,21 @@ def get_model_by_id(_id, workspace_id):
raise Exception(_("Model does not exist"))
return model
def get_model_default_params(model):
def convert_to_int(value):
if isinstance(value, str):
try:
return int(value)
except ValueError:
return value
return value
return {
p.get('field'): convert_to_int(p.get('default_value'))
for p in model.model_params_form
if p.get('default_value') is not None
}
def get_model_instance_by_model_workspace_id(model_id, workspace_id, **kwargs):
"""
@ -124,5 +139,5 @@ def get_model_instance_by_model_workspace_id(model_id, workspace_id, **kwargs):
@return: 模型实例
"""
model = get_model_by_id(model_id, workspace_id)
s = {p.get('field'): p.get('default_value') for p in model.model_params_form if p.get('default_value') is not None}
s = get_model_default_params(model)
return ModelManage.get_model(model_id, lambda _id: get_model(model, **{**s, **kwargs}))