diff --git a/apps/knowledge/serializers/common.py b/apps/knowledge/serializers/common.py index a720f0aa9..2951afce0 100644 --- a/apps/knowledge/serializers/common.py +++ b/apps/knowledge/serializers/common.py @@ -112,6 +112,21 @@ class ProblemParagraphManage: ], problem_paragraph_mapping_list return result +def get_embedding_model_default_params(model): + def convert_to_int(value): + if isinstance(value, str): + try: + return int(value) + except ValueError: + return value + return value + + return { + p.get('field'): convert_to_int(p.get('default_value')) + for p in model.model_params_form + if p.get('default_value') is not None + } + def get_embedding_model_by_knowledge_id_list(knowledge_id_list: List): knowledge_list = QuerySet(Knowledge).filter(id__in=knowledge_id_list) @@ -119,17 +134,29 @@ def get_embedding_model_by_knowledge_id_list(knowledge_id_list: List): raise Exception(_('The knowledge base is inconsistent with the vector model')) if len(knowledge_list) == 0: raise Exception(_('Knowledge base setting error, please reset the knowledge base')) - return ModelManage.get_model(str(knowledge_list[0].embedding_model_id), - lambda _id: get_model(knowledge_list[0].embedding_model)) + + default_params = get_embedding_model_default_params(knowledge_list[0].embedding_model) + + return ModelManage.get_model( + str(knowledge_list[0].embedding_model_id), + lambda _id: get_model(knowledge_list[0].embedding_model, **{**default_params}) + ) def get_embedding_model_by_knowledge_id(knowledge_id: str): knowledge = QuerySet(Knowledge).select_related('embedding_model').filter(id=knowledge_id).first() - return ModelManage.get_model(str(knowledge.embedding_model_id), lambda _id: get_model(knowledge.embedding_model)) + + default_params = get_embedding_model_default_params(knowledge.embedding_model) + + return ModelManage.get_model(str(knowledge.embedding_model_id), + lambda _id: get_model(knowledge.embedding_model, **{**default_params})) def get_embedding_model_by_knowledge(knowledge): - return ModelManage.get_model(str(knowledge.embedding_model_id), lambda _id: get_model(knowledge.embedding_model)) + default_params = get_embedding_model_default_params(knowledge.embedding_model) + + return ModelManage.get_model(str(knowledge.embedding_model_id), + lambda _id: get_model(knowledge.embedding_model, **{**default_params})) def get_embedding_model_id_by_knowledge_id(knowledge_id): @@ -241,7 +268,7 @@ def create_knowledge_index(knowledge_id=None, document_id=None): result = sql_execute(sql, []) if len(result) == 0: return - dims = result[0]['dims'] + dims = result[0]['dims'] sql = f"""CREATE INDEX "embedding_hnsw_idx_{k_id}" ON embedding USING hnsw ((embedding::vector({dims})) vector_cosine_ops) WHERE knowledge_id = '{k_id}'""" update_execute(sql, []) maxkb_logger.info(f'Created index for knowledge ID: {k_id}') diff --git a/apps/knowledge/task/embedding.py b/apps/knowledge/task/embedding.py index fd6e68a0e..f1315242a 100644 --- a/apps/knowledge/task/embedding.py +++ b/apps/knowledge/task/embedding.py @@ -12,7 +12,7 @@ from common.event import ListenerManagement, UpdateProblemArgs, UpdateEmbeddingK UpdateEmbeddingDocumentIdArgs from common.utils.logger import maxkb_logger from knowledge.models import Document, TaskType, State -from knowledge.serializers.common import drop_knowledge_index +from knowledge.serializers.common import drop_knowledge_index, get_embedding_model_default_params from models_provider.models import Model from models_provider.tools import get_model from ops import celery_app @@ -26,21 +26,9 @@ def get_embedding_model(model_id, exception_handler=lambda e: maxkb_logger.error try: model = QuerySet(Model).filter(id=model_id).first() - def convert_to_int(value): - if isinstance(value, str): - try: - return int(value) - except ValueError: - return value - return value + default_params = get_embedding_model_default_params(model) - s = { - p.get('field'): convert_to_int(p.get('default_value')) - for p in model.model_params_form - if p.get('default_value') is not None - } - - embedding_model = ModelManage.get_model(model_id, lambda _id: get_model(model, **{**s})) + embedding_model = ModelManage.get_model(model_id, lambda _id: get_model(model, **{**default_params})) except Exception as e: exception_handler(e) raise e