diff --git a/ui/src/api/application.ts b/ui/src/api/application.ts index c6020ddd2..1a3631d36 100644 --- a/ui/src/api/application.ts +++ b/ui/src/api/application.ts @@ -216,6 +216,20 @@ const getApplicationHitTest: ( return get(`${prefix}/${application_id}/hit_test`, data, loading) } +/** + * 获取当前用户可使用的模型列表 + * @param application_id + * @param loading + * @query { query_text: string, top_number: number, similarity: number } + * @returns + */ +const getApplicationModel: ( + application_id: string, + loading?: Ref +) => Promise>> = (application_id, loading) => { + return get(`${prefix}/${application_id}/model`, loading) +} + export default { getAllAppilcation, getApplication, @@ -232,5 +246,6 @@ export default { postAppAuthentication, getProfile, putChatVote, - getApplicationHitTest + getApplicationHitTest, + getApplicationModel } diff --git a/ui/src/views/application/CreateAndSetting.vue b/ui/src/views/application/CreateAndSetting.vue index a99d05771..db13b2ca5 100644 --- a/ui/src/views/application/CreateAndSetting.vue +++ b/ui/src/views/application/CreateAndSetting.vue @@ -464,15 +464,27 @@ function getDataset() { function getModel() { loading.value = true - model - .asyncGetModel() - .then((res: any) => { - modelOptions.value = groupBy(res?.data, 'provider') - loading.value = false - }) - .catch(() => { - loading.value = false - }) + if (id) { + applicationApi + .getApplicationModel(id) + .then((res: any) => { + modelOptions.value = groupBy(res?.data, 'provider') + loading.value = false + }) + .catch(() => { + loading.value = false + }) + } else { + model + .asyncGetModel() + .then((res: any) => { + modelOptions.value = groupBy(res?.data, 'provider') + loading.value = false + }) + .catch(() => { + loading.value = false + }) + } } function getProvider() {