feat: add function to retrieve default parameters for embedding models

--bug=1063177 --user=刘瑞斌 【知识库】-知识库使用的模型更换维度参数值并重新向量化后,命中测试、检索报错 https://www.tapd.cn/62980211/s/1792117
This commit is contained in:
CaptainB 2025-10-30 17:00:38 +08:00
parent 2de6bd2018
commit ed19db07d1
2 changed files with 35 additions and 20 deletions

View File

@ -112,6 +112,21 @@ class ProblemParagraphManage:
], problem_paragraph_mapping_list
return result
def get_embedding_model_default_params(model):
def convert_to_int(value):
if isinstance(value, str):
try:
return int(value)
except ValueError:
return value
return value
return {
p.get('field'): convert_to_int(p.get('default_value'))
for p in model.model_params_form
if p.get('default_value') is not None
}
def get_embedding_model_by_knowledge_id_list(knowledge_id_list: List):
knowledge_list = QuerySet(Knowledge).filter(id__in=knowledge_id_list)
@ -119,17 +134,29 @@ def get_embedding_model_by_knowledge_id_list(knowledge_id_list: List):
raise Exception(_('The knowledge base is inconsistent with the vector model'))
if len(knowledge_list) == 0:
raise Exception(_('Knowledge base setting error, please reset the knowledge base'))
return ModelManage.get_model(str(knowledge_list[0].embedding_model_id),
lambda _id: get_model(knowledge_list[0].embedding_model))
default_params = get_embedding_model_default_params(knowledge_list[0].embedding_model)
return ModelManage.get_model(
str(knowledge_list[0].embedding_model_id),
lambda _id: get_model(knowledge_list[0].embedding_model, **{**default_params})
)
def get_embedding_model_by_knowledge_id(knowledge_id: str):
knowledge = QuerySet(Knowledge).select_related('embedding_model').filter(id=knowledge_id).first()
return ModelManage.get_model(str(knowledge.embedding_model_id), lambda _id: get_model(knowledge.embedding_model))
default_params = get_embedding_model_default_params(knowledge.embedding_model)
return ModelManage.get_model(str(knowledge.embedding_model_id),
lambda _id: get_model(knowledge.embedding_model, **{**default_params}))
def get_embedding_model_by_knowledge(knowledge):
return ModelManage.get_model(str(knowledge.embedding_model_id), lambda _id: get_model(knowledge.embedding_model))
default_params = get_embedding_model_default_params(knowledge.embedding_model)
return ModelManage.get_model(str(knowledge.embedding_model_id),
lambda _id: get_model(knowledge.embedding_model, **{**default_params}))
def get_embedding_model_id_by_knowledge_id(knowledge_id):
@ -241,7 +268,7 @@ def create_knowledge_index(knowledge_id=None, document_id=None):
result = sql_execute(sql, [])
if len(result) == 0:
return
dims = result[0]['dims']
dims = result[0]['dims']
sql = f"""CREATE INDEX "embedding_hnsw_idx_{k_id}" ON embedding USING hnsw ((embedding::vector({dims})) vector_cosine_ops) WHERE knowledge_id = '{k_id}'"""
update_execute(sql, [])
maxkb_logger.info(f'Created index for knowledge ID: {k_id}')

View File

@ -12,7 +12,7 @@ from common.event import ListenerManagement, UpdateProblemArgs, UpdateEmbeddingK
UpdateEmbeddingDocumentIdArgs
from common.utils.logger import maxkb_logger
from knowledge.models import Document, TaskType, State
from knowledge.serializers.common import drop_knowledge_index
from knowledge.serializers.common import drop_knowledge_index, get_embedding_model_default_params
from models_provider.models import Model
from models_provider.tools import get_model
from ops import celery_app
@ -26,21 +26,9 @@ def get_embedding_model(model_id, exception_handler=lambda e: maxkb_logger.error
try:
model = QuerySet(Model).filter(id=model_id).first()
def convert_to_int(value):
if isinstance(value, str):
try:
return int(value)
except ValueError:
return value
return value
default_params = get_embedding_model_default_params(model)
s = {
p.get('field'): convert_to_int(p.get('default_value'))
for p in model.model_params_form
if p.get('default_value') is not None
}
embedding_model = ModelManage.get_model(model_id, lambda _id: get_model(model, **{**s}))
embedding_model = ModelManage.get_model(model_id, lambda _id: get_model(model, **{**default_params}))
except Exception as e:
exception_handler(e)
raise e