diff --git a/apps/setting/views/model.py b/apps/setting/views/model.py index b4b7c4aaf..b5abf9196 100644 --- a/apps/setting/views/model.py +++ b/apps/setting/views/model.py @@ -162,6 +162,14 @@ class Provide(APIView): , tags=["模型"]) @has_permissions(PermissionConstants.MODEL_READ) def get(self, request: Request): + model_type = request.query_params.get('model_type') + if model_type: + providers = [] + for key in ModelProvideConstants.__members__: + if len([item for item in ModelProvideConstants[key].value.get_model_type_list() if + item['value'] == model_type]) > 0: + providers.append(ModelProvideConstants[key].value.get_model_provide_info().to_dict()) + return result.success(providers) return result.success( [ModelProvideConstants[key].value.get_model_provide_info().to_dict() for key in ModelProvideConstants.__members__]) diff --git a/ui/src/api/model.ts b/ui/src/api/model.ts index 59ca0f82d..6519f1bc0 100644 --- a/ui/src/api/model.ts +++ b/ui/src/api/model.ts @@ -34,6 +34,13 @@ const getProvider: (loading?: Ref) => Promise>> return get(`${prefix_provider}`, {}, loading) } +/** + * 获得供应商列表 + */ +const getProviderByModelType: (model_type: string, loading?: Ref) => Promise>> = (model_type, loading) => { + return get(`${prefix_provider}`, {model_type}, loading) +} + /** * 获取模型创建表单 * @param provider @@ -187,5 +194,6 @@ export default { getModelMetaById, pauseDownload, getModelParamsForm, - updateModelParamsForm + updateModelParamsForm, + getProviderByModelType } diff --git a/ui/src/views/template/component/SelectProviderDialog.vue b/ui/src/views/template/component/SelectProviderDialog.vue index be5cf53b7..65fb8b4fb 100644 --- a/ui/src/views/template/component/SelectProviderDialog.vue +++ b/ui/src/views/template/component/SelectProviderDialog.vue @@ -6,9 +6,29 @@ :close-on-press-escape="false" :destroy-on-close="true" :before-close="close" - title="选择供应商" append-to-body > + @@ -25,9 +45,20 @@ import { ref } from 'vue' import ModelApi from '@/api/model' import type { Provider } from '@/api/type/model' + const loading = ref(false) const dialogVisible = ref(false) const list_provider = ref>([]) +const currentModelType = ref('') + +const modelTypeOptions = ref([ + { text: '全部模型', value:''}, + { text: '大语言模型', value:'LLM'}, + { text: '向量模型', value:'EMBEDDING'}, + { text: '重排模型', value:'RERANKER'}, + { text: '语音识别', value:'STT'}, + { text: '语音合成', value:'TTS'} +]) const open = () => { dialogVisible.value = true @@ -39,6 +70,14 @@ const open = () => { const close = () => { dialogVisible.value = false } + +const checkModelType = (model_type: string) => { + currentModelType.value = modelTypeOptions.value.filter((item) => item.value === model_type)[0].text + ModelApi.getProviderByModelType(model_type, loading).then((ok) => { + list_provider.value = ok.data + }) +} + const emit = defineEmits(['change']) const go_create = (provider: Provider) => { close()