diff --git a/apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py b/apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py index 7264b544a..9ecfbd44f 100644 --- a/apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py +++ b/apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py @@ -22,10 +22,9 @@ from common.db.search import native_search from common.utils.common import get_file_content from knowledge.models import Paragraph, Knowledge from knowledge.models import SearchMode -from knowledge.serializers.common import get_embedding_model_default_params from maxkb.conf import PROJECT_DIR from models_provider.models import Model -from models_provider.tools import get_model, get_model_by_id +from models_provider.tools import get_model, get_model_by_id, get_model_default_params def reset_meta(meta): @@ -65,7 +64,7 @@ class BaseSearchDatasetStep(ISearchDatasetStep): if model.model_type != "EMBEDDING": raise Exception(_("Model does not exist")) self.context['model_name'] = model.name - default_params = get_embedding_model_default_params(model) + default_params = get_model_default_params(model) embedding_model = ModelManage.get_model(model_id, lambda _id: get_model(model, **{**default_params})) embedding_value = embedding_model.embed_query(exec_problem_text) vector = VectorStore.get_embedding_vector() diff --git a/apps/knowledge/serializers/common.py b/apps/knowledge/serializers/common.py index 2951afce0..a76769cba 100644 --- a/apps/knowledge/serializers/common.py +++ b/apps/knowledge/serializers/common.py @@ -26,7 +26,7 @@ from common.utils.logger import maxkb_logger from knowledge.models import Document from knowledge.models import Paragraph, Problem, ProblemParagraphMapping, Knowledge, File from maxkb.conf import PROJECT_DIR -from models_provider.tools import get_model +from models_provider.tools import get_model, get_model_default_params class MetaSerializer(serializers.Serializer): @@ -112,21 +112,6 @@ class ProblemParagraphManage: ], problem_paragraph_mapping_list return result -def get_embedding_model_default_params(model): - def convert_to_int(value): - if isinstance(value, str): - try: - return int(value) - except ValueError: - return value - return value - - return { - p.get('field'): convert_to_int(p.get('default_value')) - for p in model.model_params_form - if p.get('default_value') is not None - } - def get_embedding_model_by_knowledge_id_list(knowledge_id_list: List): knowledge_list = QuerySet(Knowledge).filter(id__in=knowledge_id_list) @@ -135,7 +120,7 @@ def get_embedding_model_by_knowledge_id_list(knowledge_id_list: List): if len(knowledge_list) == 0: raise Exception(_('Knowledge base setting error, please reset the knowledge base')) - default_params = get_embedding_model_default_params(knowledge_list[0].embedding_model) + default_params = get_model_default_params(knowledge_list[0].embedding_model) return ModelManage.get_model( str(knowledge_list[0].embedding_model_id), @@ -146,14 +131,14 @@ def get_embedding_model_by_knowledge_id_list(knowledge_id_list: List): def get_embedding_model_by_knowledge_id(knowledge_id: str): knowledge = QuerySet(Knowledge).select_related('embedding_model').filter(id=knowledge_id).first() - default_params = get_embedding_model_default_params(knowledge.embedding_model) + default_params = get_model_default_params(knowledge.embedding_model) return ModelManage.get_model(str(knowledge.embedding_model_id), lambda _id: get_model(knowledge.embedding_model, **{**default_params})) def get_embedding_model_by_knowledge(knowledge): - default_params = get_embedding_model_default_params(knowledge.embedding_model) + default_params = get_model_default_params(knowledge.embedding_model) return ModelManage.get_model(str(knowledge.embedding_model_id), lambda _id: get_model(knowledge.embedding_model, **{**default_params})) diff --git a/apps/knowledge/task/embedding.py b/apps/knowledge/task/embedding.py index f1315242a..8410118d0 100644 --- a/apps/knowledge/task/embedding.py +++ b/apps/knowledge/task/embedding.py @@ -12,9 +12,9 @@ from common.event import ListenerManagement, UpdateProblemArgs, UpdateEmbeddingK UpdateEmbeddingDocumentIdArgs from common.utils.logger import maxkb_logger from knowledge.models import Document, TaskType, State -from knowledge.serializers.common import drop_knowledge_index, get_embedding_model_default_params +from knowledge.serializers.common import drop_knowledge_index from models_provider.models import Model -from models_provider.tools import get_model +from models_provider.tools import get_model, get_model_default_params from ops import celery_app @@ -26,7 +26,7 @@ def get_embedding_model(model_id, exception_handler=lambda e: maxkb_logger.error try: model = QuerySet(Model).filter(id=model_id).first() - default_params = get_embedding_model_default_params(model) + default_params = get_model_default_params(model) embedding_model = ModelManage.get_model(model_id, lambda _id: get_model(model, **{**default_params})) except Exception as e: diff --git a/apps/models_provider/tools.py b/apps/models_provider/tools.py index 219baab81..9b94d7800 100644 --- a/apps/models_provider/tools.py +++ b/apps/models_provider/tools.py @@ -115,6 +115,21 @@ def get_model_by_id(_id, workspace_id): raise Exception(_("Model does not exist")) return model +def get_model_default_params(model): + def convert_to_int(value): + if isinstance(value, str): + try: + return int(value) + except ValueError: + return value + return value + + return { + p.get('field'): convert_to_int(p.get('default_value')) + for p in model.model_params_form + if p.get('default_value') is not None + } + def get_model_instance_by_model_workspace_id(model_id, workspace_id, **kwargs): """ @@ -124,5 +139,5 @@ def get_model_instance_by_model_workspace_id(model_id, workspace_id, **kwargs): @return: 模型实例 """ model = get_model_by_id(model_id, workspace_id) - s = {p.get('field'): p.get('default_value') for p in model.model_params_form if p.get('default_value') is not None} + s = get_model_default_params(model) return ModelManage.get_model(model_id, lambda _id: get_model(model, **{**s, **kwargs}))