diff --git a/apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py b/apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py index 96f4dda69..040a2be86 100644 --- a/apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py +++ b/apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py @@ -28,12 +28,12 @@ from models_provider.tools import get_model def get_model_by_id(_id, workspace_id): - model = QuerySet(Model).filter(id=_id, model_type="EMBEDDING").first() + model = QuerySet(Model).filter(id=_id, model_type="EMBEDDING") + get_authorized_model = DatabaseModelManage.get_model("get_authorized_model") + if get_authorized_model is not None: + model = get_authorized_model(model, workspace_id) if model is None: raise Exception(_("Model does not exist")) - if model.workspace_id is not None: - if model.workspace_id != workspace_id: - raise Exception(_("Model does not exist")) return model diff --git a/apps/models_provider/base_model_provider.py b/apps/models_provider/base_model_provider.py index 1d4e33ede..74f9205df 100644 --- a/apps/models_provider/base_model_provider.py +++ b/apps/models_provider/base_model_provider.py @@ -99,7 +99,7 @@ class MaxKBBaseModel(ABC): def filter_optional_params(model_kwargs): optional_params = {} for key, value in model_kwargs.items(): - if key not in ['model_id', 'use_local', 'streaming', 'show_ref_label']: + if key not in ['model_id', 'use_local', 'streaming', 'show_ref_label', 'stream']: if key == 'extra_body' and isinstance(value, dict): optional_params = {**optional_params, **value} else: diff --git a/apps/models_provider/impl/wenxin_model_provider/wenxin_model_provider.py b/apps/models_provider/impl/wenxin_model_provider/wenxin_model_provider.py index a4de65254..ab84934d4 100644 --- a/apps/models_provider/impl/wenxin_model_provider/wenxin_model_provider.py +++ b/apps/models_provider/impl/wenxin_model_provider/wenxin_model_provider.py @@ -29,21 +29,10 @@ model_info_list = [ModelInfo('ERNIE-Bot-4', ModelInfo('ERNIE-Bot-turbo', _('ERNIE-Bot-turbo is a large language model independently developed by Baidu. It covers massive Chinese data, has stronger capabilities in dialogue Q&A, content creation and generation, and has a faster response speed.'), ModelTypeConst.LLM, win_xin_llm_model_credential, QianfanChatModel), - ModelInfo('BLOOMZ-7B', - _('BLOOMZ-7B is a well-known large language model in the industry. It was developed and open sourced by BigScience and can output text in 46 languages and 13 programming languages.'), - ModelTypeConst.LLM, win_xin_llm_model_credential, QianfanChatModel), - ModelInfo('Llama-2-7b-chat', - 'Llama-2-7b-chat was developed by Meta AI and is open source. It performs well in scenarios such as coding, reasoning and knowledge application. Llama-2-7b-chat is a high-performance native open source version suitable for conversation scenarios.', - ModelTypeConst.LLM, win_xin_llm_model_credential, QianfanChatModel), - ModelInfo('Llama-2-13b-chat', - _('Llama-2-13b-chat was developed by Meta AI and is open source. It performs well in scenarios such as coding, reasoning and knowledge application. Llama-2-13b-chat is a native open source version with balanced performance and effect, suitable for conversation scenarios.'), - ModelTypeConst.LLM, win_xin_llm_model_credential, QianfanChatModel), - ModelInfo('Llama-2-70b-chat', - _('Llama-2-70b-chat was developed by Meta AI and is open source. It performs well in scenarios such as coding, reasoning, and knowledge application. Llama-2-70b-chat is a native open source version with high-precision effects.'), - ModelTypeConst.LLM, win_xin_llm_model_credential, QianfanChatModel), - ModelInfo('Qianfan-Chinese-Llama-2-7B', - _('The Chinese enhanced version developed by the Qianfan team based on Llama-2-7b has performed well on Chinese knowledge bases such as CMMLU and C-EVAL.'), + ModelInfo('qianfan-chinese-llama-2-13b', + '', ModelTypeConst.LLM, win_xin_llm_model_credential, QianfanChatModel) + ] embedding_model_info = ModelInfo('Embedding-V1', _('Embedding-V1 is a text representation model based on Baidu Wenxin large model technology. It can convert text into a vector form represented by numerical values and can be used in text retrieval, information recommendation, knowledge mining and other scenarios. Embedding-V1 provides the Embeddings interface, which can generate corresponding vector representations based on input content. You can call this interface to input text into the model and obtain the corresponding vector representation for subsequent text processing and analysis.'), @@ -63,6 +52,8 @@ class WenxinModelProvider(IModelProvider): return model_info_manage def get_model_provide_info(self): - return ModelProvideInfo(provider='model_wenxin_provider', name=_('Thousand sails large model'), icon=get_file_content( - os.path.join(PROJECT_DIR, "apps", 'models_provider', 'impl', 'wenxin_model_provider', 'icon', - 'azure_icon_svg'))) + return ModelProvideInfo(provider='model_wenxin_provider', name=_('Thousand sails large model'), + icon=get_file_content( + os.path.join(PROJECT_DIR, "apps", 'models_provider', 'impl', + 'wenxin_model_provider', 'icon', + 'azure_icon_svg'))) diff --git a/apps/models_provider/tools.py b/apps/models_provider/tools.py index 10ec4a292..0eaf24ace 100644 --- a/apps/models_provider/tools.py +++ b/apps/models_provider/tools.py @@ -10,6 +10,7 @@ from django.db import connection from django.db.models import QuerySet from common.config.embedding_config import ModelManage +from common.database_model_manage.database_model_manage import DatabaseModelManage from models_provider.models import Model from django.utils.translation import gettext_lazy as _ @@ -104,13 +105,12 @@ def is_valid_credential(provider, model_type, model_name, model_credential: Dict def get_model_by_id(_id, workspace_id): - model = QuerySet(Model).filter(id=_id).first() + model = QuerySet(Model).filter(id=_id) + get_authorized_model = DatabaseModelManage.get_model("get_authorized_model") + if get_authorized_model is not None: + model = get_authorized_model(model, workspace_id) if model is None: - raise Exception(_('Model does not exist')) - if model.workspace_id: - if model.workspace_id != workspace_id: - raise Exception(_('Model does not exist')) - + raise Exception(_("Model does not exist")) return model