mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
feat: add function to retrieve default parameters for embedding models
--bug=1063177 --user=刘瑞斌 【知识库】-知识库使用的模型更换维度参数值并重新向量化后,命中测试、检索报错 https://www.tapd.cn/62980211/s/1792117
This commit is contained in:
parent
2de6bd2018
commit
ed19db07d1
|
|
@ -112,6 +112,21 @@ class ProblemParagraphManage:
|
|||
], problem_paragraph_mapping_list
|
||||
return result
|
||||
|
||||
def get_embedding_model_default_params(model):
|
||||
def convert_to_int(value):
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
return value
|
||||
return value
|
||||
|
||||
return {
|
||||
p.get('field'): convert_to_int(p.get('default_value'))
|
||||
for p in model.model_params_form
|
||||
if p.get('default_value') is not None
|
||||
}
|
||||
|
||||
|
||||
def get_embedding_model_by_knowledge_id_list(knowledge_id_list: List):
|
||||
knowledge_list = QuerySet(Knowledge).filter(id__in=knowledge_id_list)
|
||||
|
|
@ -119,17 +134,29 @@ def get_embedding_model_by_knowledge_id_list(knowledge_id_list: List):
|
|||
raise Exception(_('The knowledge base is inconsistent with the vector model'))
|
||||
if len(knowledge_list) == 0:
|
||||
raise Exception(_('Knowledge base setting error, please reset the knowledge base'))
|
||||
return ModelManage.get_model(str(knowledge_list[0].embedding_model_id),
|
||||
lambda _id: get_model(knowledge_list[0].embedding_model))
|
||||
|
||||
default_params = get_embedding_model_default_params(knowledge_list[0].embedding_model)
|
||||
|
||||
return ModelManage.get_model(
|
||||
str(knowledge_list[0].embedding_model_id),
|
||||
lambda _id: get_model(knowledge_list[0].embedding_model, **{**default_params})
|
||||
)
|
||||
|
||||
|
||||
def get_embedding_model_by_knowledge_id(knowledge_id: str):
|
||||
knowledge = QuerySet(Knowledge).select_related('embedding_model').filter(id=knowledge_id).first()
|
||||
return ModelManage.get_model(str(knowledge.embedding_model_id), lambda _id: get_model(knowledge.embedding_model))
|
||||
|
||||
default_params = get_embedding_model_default_params(knowledge.embedding_model)
|
||||
|
||||
return ModelManage.get_model(str(knowledge.embedding_model_id),
|
||||
lambda _id: get_model(knowledge.embedding_model, **{**default_params}))
|
||||
|
||||
|
||||
def get_embedding_model_by_knowledge(knowledge):
|
||||
return ModelManage.get_model(str(knowledge.embedding_model_id), lambda _id: get_model(knowledge.embedding_model))
|
||||
default_params = get_embedding_model_default_params(knowledge.embedding_model)
|
||||
|
||||
return ModelManage.get_model(str(knowledge.embedding_model_id),
|
||||
lambda _id: get_model(knowledge.embedding_model, **{**default_params}))
|
||||
|
||||
|
||||
def get_embedding_model_id_by_knowledge_id(knowledge_id):
|
||||
|
|
@ -241,7 +268,7 @@ def create_knowledge_index(knowledge_id=None, document_id=None):
|
|||
result = sql_execute(sql, [])
|
||||
if len(result) == 0:
|
||||
return
|
||||
dims = result[0]['dims']
|
||||
dims = result[0]['dims']
|
||||
sql = f"""CREATE INDEX "embedding_hnsw_idx_{k_id}" ON embedding USING hnsw ((embedding::vector({dims})) vector_cosine_ops) WHERE knowledge_id = '{k_id}'"""
|
||||
update_execute(sql, [])
|
||||
maxkb_logger.info(f'Created index for knowledge ID: {k_id}')
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from common.event import ListenerManagement, UpdateProblemArgs, UpdateEmbeddingK
|
|||
UpdateEmbeddingDocumentIdArgs
|
||||
from common.utils.logger import maxkb_logger
|
||||
from knowledge.models import Document, TaskType, State
|
||||
from knowledge.serializers.common import drop_knowledge_index
|
||||
from knowledge.serializers.common import drop_knowledge_index, get_embedding_model_default_params
|
||||
from models_provider.models import Model
|
||||
from models_provider.tools import get_model
|
||||
from ops import celery_app
|
||||
|
|
@ -26,21 +26,9 @@ def get_embedding_model(model_id, exception_handler=lambda e: maxkb_logger.error
|
|||
try:
|
||||
model = QuerySet(Model).filter(id=model_id).first()
|
||||
|
||||
def convert_to_int(value):
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
return value
|
||||
return value
|
||||
default_params = get_embedding_model_default_params(model)
|
||||
|
||||
s = {
|
||||
p.get('field'): convert_to_int(p.get('default_value'))
|
||||
for p in model.model_params_form
|
||||
if p.get('default_value') is not None
|
||||
}
|
||||
|
||||
embedding_model = ModelManage.get_model(model_id, lambda _id: get_model(model, **{**s}))
|
||||
embedding_model = ModelManage.get_model(model_id, lambda _id: get_model(model, **{**default_params}))
|
||||
except Exception as e:
|
||||
exception_handler(e)
|
||||
raise e
|
||||
|
|
|
|||
Loading…
Reference in New Issue