diff --git a/apps/setting/models/model_management.py b/apps/setting/models/model_management.py index 42136a48e..4c47eadfc 100644 --- a/apps/setting/models/model_management.py +++ b/apps/setting/models/model_management.py @@ -22,6 +22,8 @@ class Status(models.TextChoices): DOWNLOAD = "DOWNLOAD", '下载中' + PAUSE_DOWNLOAD = "PAUSE_DOWNLOAD", '暂停下载' + class PermissionType(models.TextChoices): PUBLIC = "PUBLIC", '公开' diff --git a/apps/setting/serializers/provider_serializers.py b/apps/setting/serializers/provider_serializers.py index 2a4a147f2..a0f28e326 100644 --- a/apps/setting/serializers/provider_serializers.py +++ b/apps/setting/serializers/provider_serializers.py @@ -38,6 +38,9 @@ class ModelPullManage: for chunk in response: down_model_chunk[chunk.digest] = chunk.to_dict() if time.time() - timestamp > 5: + model_new = QuerySet(Model).filter(id=model.id).first() + if model_new.status == Status.PAUSE_DOWNLOAD: + return QuerySet(Model).filter(id=model.id).update( meta={"down_model_chunk": list(down_model_chunk.values())}) timestamp = time.time() @@ -238,6 +241,12 @@ class ModelSerializer(serializers.Serializer): QuerySet(Model).filter(id=self.data.get('id')).delete() return True + def pause_download(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + QuerySet(Model).filter(id=self.data.get('id')).update(status=Status.PAUSE_DOWNLOAD) + return True + def edit(self, instance: Dict, user_id: str, with_valid=True): if with_valid: self.is_valid(raise_exception=True) diff --git a/apps/setting/urls.py b/apps/setting/urls.py index 2a7cdd68a..650e2a2b3 100644 --- a/apps/setting/urls.py +++ b/apps/setting/urls.py @@ -16,6 +16,7 @@ urlpatterns = [ name="provider/model_form"), path('model', views.Model.as_view(), name='model'), path('model/', views.Model.Operate.as_view(), name='model/operate'), + path('model//pause_download', views.Model.PauseDownload.as_view(), name='model/operate'), path('model//meta', views.Model.ModelMeta.as_view(), name='model/operate/meta'), path('email_setting', views.SystemSetting.Email.as_view(), name='email_setting'), path('valid//', views.Valid.as_view()) diff --git a/apps/setting/views/model.py b/apps/setting/views/model.py index 7ba0304fc..9108aa15a 100644 --- a/apps/setting/views/model.py +++ b/apps/setting/views/model.py @@ -69,6 +69,18 @@ class Model(APIView): return result.success( ModelSerializer.Operate(data={'id': model_id, 'user_id': request.user.id}).one_meta(with_valid=True)) + class PauseDownload(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['PUT'], detail=False) + @swagger_auto_schema(operation_summary="暂停模型下载", + operation_id="暂停模型下载", + tags=["模型"]) + @has_permissions(PermissionConstants.MODEL_CREATE) + def put(self, request: Request, model_id: str): + return result.success( + ModelSerializer.Operate(data={'id': model_id, 'user_id': request.user.id}).pause_download()) + class Operate(APIView): authentication_classes = [TokenAuth] diff --git a/ui/src/api/model.ts b/ui/src/api/model.ts index bb98984f8..15f6517ff 100644 --- a/ui/src/api/model.ts +++ b/ui/src/api/model.ts @@ -130,7 +130,18 @@ const getModelMetaById: (model_id: string, loading?: Ref) => Promise { return get(`${prefix}/${model_id}/meta`, {}, loading) } - +/** + * 暂停下载 + * @param model_id 模型id + * @param loading 加载器 + * @returns + */ +const pauseDownload: (model_id: string, loading?: Ref) => Promise> = ( + model_id, + loading +) => { + return put(`${prefix}/${model_id}/pause_download`, undefined, {}, loading) +} const deleteModel: (model_id: string, loading?: Ref) => Promise> = ( model_id, loading @@ -147,5 +158,6 @@ export default { updateModel, deleteModel, getModelById, - getModelMetaById + getModelMetaById, + pauseDownload } diff --git a/ui/src/api/type/model.ts b/ui/src/api/type/model.ts index 487cd439a..fcca438d5 100644 --- a/ui/src/api/type/model.ts +++ b/ui/src/api/type/model.ts @@ -69,7 +69,7 @@ interface Model { /** * 状态 */ - status: 'SUCCESS' | 'DOWNLOAD' | 'ERROR' + status: 'SUCCESS' | 'DOWNLOAD' | 'ERROR' | 'PAUSE_DOWNLOAD' /** * 元数据 */ diff --git a/ui/src/views/template/component/ModelCard.vue b/ui/src/views/template/component/ModelCard.vue index 380443e5f..96e0956ff 100644 --- a/ui/src/views/template/component/ModelCard.vue +++ b/ui/src/views/template/component/ModelCard.vue @@ -21,6 +21,12 @@ +
+ 暂停下载 + + + +
@@ -39,17 +45,6 @@
-
正在下载中 @@ -64,7 +59,13 @@ - + @@ -111,27 +112,6 @@ const errMessage = computed(() => { } return '' }) -// const progress = computed(() => { -// if (currentModel.value) { -// const down_model_chunk = currentModel.value.meta['down_model_chunk'] -// if (down_model_chunk) { -// const maxObj = down_model_chunk -// .filter((chunk: any) => chunk.index > 1) -// .reduce( -// (prev: any, current: any) => { -// return (prev.index || 0) > (current.index || 0) ? prev : current -// }, -// { progress: 0 } -// ) -// if (maxObj) { -// return parseFloat(maxObj.progress?.toFixed(1)) -// } -// return 0 -// } -// return 0 -// } -// return 0 -// }) const emit = defineEmits(['change', 'update:model']) const eidtModelRef = ref>() let interval: any @@ -148,7 +128,12 @@ const deleteModel = () => { .catch(() => {}) } -const cancelDownload = () => {} +const cancelDownload = () => { + ModelApi.pauseDownload(props.model.id).then(() => { + downModel.value = undefined + emit('change') + }) +} const openEditModel = () => { const provider = props.provider_list.find((p) => p.provider === props.model.provider) if (provider) {