diff --git a/pkg/modules/image.go b/pkg/modules/image.go index a4b1d306..5776ad9e 100644 --- a/pkg/modules/image.go +++ b/pkg/modules/image.go @@ -22,6 +22,7 @@ import ( "crypto/tls" "encoding/json" "fmt" + "github.com/kubesphere/kubekey/v4/pkg/utils" "io" "io/fs" "net/http" @@ -188,37 +189,73 @@ type manifestInfo struct { // pull retrieves images from a remote registry and stores them locally func (i imagePullArgs) pull(ctx context.Context, platform []string) error { - for _, img := range i.manifests { - img = normalizeImageNameSimple(img) - src, err := remote.NewRepository(img) - if err != nil { - return errors.Wrapf(err, "failed to get remote image %s", img) - } - src.Client = &auth.Client{ - Client: &http.Client{ - Transport: &http.Transport{ - TLSClientConfig: &tls.Config{ - InsecureSkipVerify: skipTLSVerifyFunc(img, i.auths, *i.skipTLSVerify), - }, - }, - }, - Cache: auth.NewCache(), - Credential: authFunc(i.auths), - } - dst, err := newLocalRepository(filepath.Join(src.Reference.Registry, src.Reference.Repository)+":"+src.Reference.Reference, i.imagesDir) - if err != nil { - return err - } - src.PlainHTTP = plainHTTPFunc(img, i.auths, false) - err = imageSrcToDst(ctx, src, dst, img, platform) - if err != nil { - return err - } + manifests := i.manifests + if len(manifests) == 0 { + return nil + } + + maxWorkers := 10 + + // 创建任务队列 + tasks := make(chan string, len(manifests)) + + worker := utils.Worker[string]{ + MaxWorkerCount: maxWorkers, + TaskChan: tasks, + ExecFunc: func(img string) error { + return i.downloadSingleImage(ctx, img, platform) + }, + } + + worker.Do(ctx) + + // 发送任务 + for _, img := range manifests { + tasks <- img + } + close(tasks) + + // 等待所有 worker 完成 + go func() { + worker.Wait() + }() + + // 收集结果 + var collectedErrors = worker.CollectedErrors() + + if len(collectedErrors) > 0 { + return fmt.Errorf("download errors: %v", strings.Join(collectedErrors, "; ")) } return nil } +func (i imagePullArgs) downloadSingleImage(ctx context.Context, img string, platform []string) error { + img = normalizeImageNameSimple(img) + src, err := remote.NewRepository(img) + if err != nil { + return errors.Wrapf(err, "failed to get remote image %s", img) + } + src.Client = &auth.Client{ + Client: &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: skipTLSVerifyFunc(img, i.auths, *i.skipTLSVerify), + }, + }, + }, + Cache: auth.NewCache(), + Credential: authFunc(i.auths), + } + dst, err := newLocalRepository(filepath.Join(src.Reference.Registry, src.Reference.Repository)+":"+src.Reference.Reference, i.imagesDir) + if err != nil { + return err + } + src.PlainHTTP = plainHTTPFunc(img, i.auths, false) + + return imageSrcToDst(ctx, src, dst, img, platform) +} + func imageSrcToDst(ctx context.Context, src, dst *remote.Repository, img string, platform []string) error { var err error if len(platform) == 0 || (len(platform) == 1 && strings.TrimSpace(platform[0]) == "*") { @@ -1003,28 +1040,28 @@ func (i imageTransport) put(request *http.Request) *http.Response { filename := filepath.Join(i.baseDir, "blobs", request.URL.Query().Get("digest")) if err := os.MkdirAll(filepath.Dir(filename), os.ModePerm); err != nil { - klog.V(4).ErrorS(err, "failed to create dir", "dir", filepath.Dir(filename)) + fmt.Println(err, "failed to create dir", "dir", filepath.Dir(filename)) return responseServerError } file, err := os.OpenFile(filename, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, os.ModePerm) if err != nil { - klog.V(4).ErrorS(err, "failed to create file", "filename", filename) + fmt.Println(err, "failed to create file", "filename", filename) return responseServerError } defer func() { if err = file.Sync(); err != nil { - klog.V(4).ErrorS(err, "failed to sync file", "filename", filename) + fmt.Println(err, "failed to sync file", "filename", filename) } if err = file.Close(); err != nil { - klog.V(4).ErrorS(err, "failed to close file", "filename", filename) + fmt.Println(err, "failed to close file", "filename", filename) } }() if _, err = io.Copy(file, request.Body); err != nil { - klog.V(4).ErrorS(err, "failed to write file", "filename", filename) + fmt.Println(err, "failed to write file", "filename", filename) return responseServerError } @@ -1033,7 +1070,7 @@ func (i imageTransport) put(request *http.Request) *http.Response { } else if strings.HasSuffix(filepath.Dir(request.URL.Path), "/manifests") { // manifests body, err := io.ReadAll(request.Body) if err != nil { - klog.V(4).ErrorS(err, "failed to read request") + fmt.Println(err, "failed to read request") return responseServerError } @@ -1041,13 +1078,13 @@ func (i imageTransport) put(request *http.Request) *http.Response { filename := filepath.Join(i.baseDir, request.Host, strings.TrimPrefix(request.URL.Path, apiPrefix)) if err := os.MkdirAll(filepath.Dir(filename), os.ModePerm); err != nil { - klog.V(4).ErrorS(err, "failed to create dir", "dir", filepath.Dir(filename)) + fmt.Println(err, "failed to create dir", "dir", filepath.Dir(filename)) return responseServerError } if err := os.WriteFile(filename, body, os.ModePerm); err != nil { - klog.V(4).ErrorS(err, "failed to write file", "filename", filename) + fmt.Println(err, "failed to write file", "filename", filename) return responseServerError } diff --git a/pkg/utils/worker_pool.go b/pkg/utils/worker_pool.go new file mode 100644 index 00000000..a55308d6 --- /dev/null +++ b/pkg/utils/worker_pool.go @@ -0,0 +1,48 @@ +package utils + +import ( + "context" + "sync" +) + +type Worker[T any] struct { + MaxWorkerCount int + TaskChan chan T + ExecFunc func(T) error + wg sync.WaitGroup + results chan error +} + +func (worker *Worker[T]) Do(ctx context.Context) { + worker.results = make(chan error) + for w := 0; w < worker.MaxWorkerCount; w++ { + worker.wg.Add(1) + go func(workerID int) { + defer worker.wg.Done() + for thing := range worker.TaskChan { + select { + case <-ctx.Done(): + worker.results <- ctx.Err() + return + default: + worker.results <- worker.ExecFunc(thing) + } + } + }(w) + } +} + +func (worker *Worker[T]) Wait() { + worker.wg.Wait() + close(worker.results) +} + +func (worker *Worker[T]) CollectedErrors() []string { + var collectedErrors []string + for err := range worker.results { + if err != nil { + collectedErrors = append(collectedErrors, err.Error()) + } + } + return collectedErrors +}