mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
refactor: replace get_embedding_model_default_params with get_model_default_params
This commit is contained in:
parent
1b08643b98
commit
d0cd6d657a
|
|
@ -22,10 +22,9 @@ from common.db.search import native_search
|
|||
from common.utils.common import get_file_content
|
||||
from knowledge.models import Paragraph, Knowledge
|
||||
from knowledge.models import SearchMode
|
||||
from knowledge.serializers.common import get_embedding_model_default_params
|
||||
from maxkb.conf import PROJECT_DIR
|
||||
from models_provider.models import Model
|
||||
from models_provider.tools import get_model, get_model_by_id
|
||||
from models_provider.tools import get_model, get_model_by_id, get_model_default_params
|
||||
|
||||
|
||||
def reset_meta(meta):
|
||||
|
|
@ -65,7 +64,7 @@ class BaseSearchDatasetStep(ISearchDatasetStep):
|
|||
if model.model_type != "EMBEDDING":
|
||||
raise Exception(_("Model does not exist"))
|
||||
self.context['model_name'] = model.name
|
||||
default_params = get_embedding_model_default_params(model)
|
||||
default_params = get_model_default_params(model)
|
||||
embedding_model = ModelManage.get_model(model_id, lambda _id: get_model(model, **{**default_params}))
|
||||
embedding_value = embedding_model.embed_query(exec_problem_text)
|
||||
vector = VectorStore.get_embedding_vector()
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ from common.utils.logger import maxkb_logger
|
|||
from knowledge.models import Document
|
||||
from knowledge.models import Paragraph, Problem, ProblemParagraphMapping, Knowledge, File
|
||||
from maxkb.conf import PROJECT_DIR
|
||||
from models_provider.tools import get_model
|
||||
from models_provider.tools import get_model, get_model_default_params
|
||||
|
||||
|
||||
class MetaSerializer(serializers.Serializer):
|
||||
|
|
@ -112,21 +112,6 @@ class ProblemParagraphManage:
|
|||
], problem_paragraph_mapping_list
|
||||
return result
|
||||
|
||||
def get_embedding_model_default_params(model):
|
||||
def convert_to_int(value):
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
return value
|
||||
return value
|
||||
|
||||
return {
|
||||
p.get('field'): convert_to_int(p.get('default_value'))
|
||||
for p in model.model_params_form
|
||||
if p.get('default_value') is not None
|
||||
}
|
||||
|
||||
|
||||
def get_embedding_model_by_knowledge_id_list(knowledge_id_list: List):
|
||||
knowledge_list = QuerySet(Knowledge).filter(id__in=knowledge_id_list)
|
||||
|
|
@ -135,7 +120,7 @@ def get_embedding_model_by_knowledge_id_list(knowledge_id_list: List):
|
|||
if len(knowledge_list) == 0:
|
||||
raise Exception(_('Knowledge base setting error, please reset the knowledge base'))
|
||||
|
||||
default_params = get_embedding_model_default_params(knowledge_list[0].embedding_model)
|
||||
default_params = get_model_default_params(knowledge_list[0].embedding_model)
|
||||
|
||||
return ModelManage.get_model(
|
||||
str(knowledge_list[0].embedding_model_id),
|
||||
|
|
@ -146,14 +131,14 @@ def get_embedding_model_by_knowledge_id_list(knowledge_id_list: List):
|
|||
def get_embedding_model_by_knowledge_id(knowledge_id: str):
|
||||
knowledge = QuerySet(Knowledge).select_related('embedding_model').filter(id=knowledge_id).first()
|
||||
|
||||
default_params = get_embedding_model_default_params(knowledge.embedding_model)
|
||||
default_params = get_model_default_params(knowledge.embedding_model)
|
||||
|
||||
return ModelManage.get_model(str(knowledge.embedding_model_id),
|
||||
lambda _id: get_model(knowledge.embedding_model, **{**default_params}))
|
||||
|
||||
|
||||
def get_embedding_model_by_knowledge(knowledge):
|
||||
default_params = get_embedding_model_default_params(knowledge.embedding_model)
|
||||
default_params = get_model_default_params(knowledge.embedding_model)
|
||||
|
||||
return ModelManage.get_model(str(knowledge.embedding_model_id),
|
||||
lambda _id: get_model(knowledge.embedding_model, **{**default_params}))
|
||||
|
|
|
|||
|
|
@ -12,9 +12,9 @@ from common.event import ListenerManagement, UpdateProblemArgs, UpdateEmbeddingK
|
|||
UpdateEmbeddingDocumentIdArgs
|
||||
from common.utils.logger import maxkb_logger
|
||||
from knowledge.models import Document, TaskType, State
|
||||
from knowledge.serializers.common import drop_knowledge_index, get_embedding_model_default_params
|
||||
from knowledge.serializers.common import drop_knowledge_index
|
||||
from models_provider.models import Model
|
||||
from models_provider.tools import get_model
|
||||
from models_provider.tools import get_model, get_model_default_params
|
||||
from ops import celery_app
|
||||
|
||||
|
||||
|
|
@ -26,7 +26,7 @@ def get_embedding_model(model_id, exception_handler=lambda e: maxkb_logger.error
|
|||
try:
|
||||
model = QuerySet(Model).filter(id=model_id).first()
|
||||
|
||||
default_params = get_embedding_model_default_params(model)
|
||||
default_params = get_model_default_params(model)
|
||||
|
||||
embedding_model = ModelManage.get_model(model_id, lambda _id: get_model(model, **{**default_params}))
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -115,6 +115,21 @@ def get_model_by_id(_id, workspace_id):
|
|||
raise Exception(_("Model does not exist"))
|
||||
return model
|
||||
|
||||
def get_model_default_params(model):
|
||||
def convert_to_int(value):
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
return value
|
||||
return value
|
||||
|
||||
return {
|
||||
p.get('field'): convert_to_int(p.get('default_value'))
|
||||
for p in model.model_params_form
|
||||
if p.get('default_value') is not None
|
||||
}
|
||||
|
||||
|
||||
def get_model_instance_by_model_workspace_id(model_id, workspace_id, **kwargs):
|
||||
"""
|
||||
|
|
@ -124,5 +139,5 @@ def get_model_instance_by_model_workspace_id(model_id, workspace_id, **kwargs):
|
|||
@return: 模型实例
|
||||
"""
|
||||
model = get_model_by_id(model_id, workspace_id)
|
||||
s = {p.get('field'): p.get('default_value') for p in model.model_params_form if p.get('default_value') is not None}
|
||||
s = get_model_default_params(model)
|
||||
return ModelManage.get_model(model_id, lambda _id: get_model(model, **{**s, **kwargs}))
|
||||
|
|
|
|||
Loading…
Reference in New Issue