diff --git a/cmd/push.go b/cmd/push.go index da26cb18..9e4271f6 100644 --- a/cmd/push.go +++ b/cmd/push.go @@ -20,11 +20,11 @@ import ( "context" "fmt" - "github.com/modelpack/modctl/pkg/backend" - "github.com/modelpack/modctl/pkg/config" - "github.com/spf13/cobra" "github.com/spf13/viper" + + "github.com/modelpack/modctl/pkg/backend" + "github.com/modelpack/modctl/pkg/config" ) var pushConfig = config.NewPush() @@ -48,10 +48,25 @@ var pushCmd = &cobra.Command{ // init initializes push command. func init() { flags := pushCmd.Flags() - flags.IntVar(&pushConfig.Concurrency, "concurrency", pushConfig.Concurrency, "specify the number of concurrent push operations") + flags.IntVar( + &pushConfig.Concurrency, + "concurrency", + pushConfig.Concurrency, + "specify the number of concurrent push operations", + ) flags.BoolVar(&pushConfig.PlainHTTP, "plain-http", false, "use plain HTTP instead of HTTPS") - flags.BoolVar(&pushConfig.Insecure, "insecure", false, "turning on this flag will disable TLS verification") - flags.BoolVar(&pushConfig.Nydusify, "nydusify", false, "[EXPERIMENTAL] nydusify the model artifact") + flags.BoolVar( + &pushConfig.Insecure, + "insecure", + false, + "turning on this flag will disable TLS verification", + ) + flags.BoolVar( + &pushConfig.Nydusify, + "nydusify", + false, + "[EXPERIMENTAL] nydusify the model artifact", + ) flags.MarkHidden("nydusify") if err := viper.BindPFlags(flags); err != nil { diff --git a/internal/pb/pb.go b/internal/pb/pb.go index 8303a968..3070eb36 100644 --- a/internal/pb/pb.go +++ b/internal/pb/pb.go @@ -136,6 +136,28 @@ func (p *ProgressBar) Add(prompt, name string, size int64, reader io.Reader) io. return reader } +// Placeholder creates or resets a progress bar entry without a reader. +// It is used during retry backoff to keep a visible bar for the item. +func (p *ProgressBar) Placeholder(name string, prompt string, size int64) { + if disableProgress { + return + } + + p.mu.RLock() + existing := p.bars[name] + p.mu.RUnlock() + + // If the bar already exists, just reset its message. + if existing != nil { + existing.msg = fmt.Sprintf("%s %s", prompt, name) + existing.Bar.SetCurrent(0) + return + } + + // Create a new placeholder bar. + p.Add(prompt, name, size, nil) +} + // Get returns the progress bar. func (p *ProgressBar) Get(name string) *progressBar { p.mu.RLock() diff --git a/pkg/backend/retry.go b/pkg/backend/annotation.go similarity index 50% rename from pkg/backend/retry.go rename to pkg/backend/annotation.go index c7494250..60f3d710 100644 --- a/pkg/backend/retry.go +++ b/pkg/backend/annotation.go @@ -17,14 +17,20 @@ package backend import ( - "time" - - retry "github.com/avast/retry-go/v4" + legacymodelspec "github.com/dragonflyoss/model-spec/specs-go/v1" + modelspec "github.com/modelpack/model-spec/specs-go/v1" ) -var defaultRetryOpts = []retry.Option{ - retry.Attempts(6), - retry.DelayType(retry.BackOffDelay), - retry.Delay(5 * time.Second), - retry.MaxDelay(60 * time.Second), +// getAnnotationFilepath returns the filepath stored on a descriptor's +// annotations, preferring the modelpack key and falling back to the legacy +// dragonflyoss key so older artifacts remain readable. Returns empty string +// when neither key is present. +func getAnnotationFilepath(annotations map[string]string) string { + if annotations == nil { + return "" + } + if path := annotations[modelspec.AnnotationFilepath]; path != "" { + return path + } + return annotations[legacymodelspec.AnnotationFilepath] } diff --git a/pkg/backend/build.go b/pkg/backend/build.go index 10f9156c..3241446d 100644 --- a/pkg/backend/build.go +++ b/pkg/backend/build.go @@ -23,7 +23,6 @@ import ( "os" "path/filepath" - retry "github.com/avast/retry-go/v4" modelspec "github.com/modelpack/model-spec/specs-go/v1" ocispec "github.com/opencontainers/image-spec/specs-go/v1" "github.com/sirupsen/logrus" @@ -35,6 +34,7 @@ import ( "github.com/modelpack/modctl/pkg/backend/processor" "github.com/modelpack/modctl/pkg/config" "github.com/modelpack/modctl/pkg/modelfile" + "github.com/modelpack/modctl/pkg/retrypolicy" "github.com/modelpack/modctl/pkg/source" ) @@ -44,7 +44,11 @@ const ( ) // Build builds the user materials into the model artifact which follows the Model Spec. -func (b *backend) Build(ctx context.Context, modelfilePath, workDir, target string, cfg *config.Build) error { +func (b *backend) Build( + ctx context.Context, + modelfilePath, workDir, target string, + cfg *config.Build, +) error { logrus.Infof("build: building artifact %s", target) // parse the repo name and tag name from target. ref, err := ParseReference(target) @@ -123,8 +127,8 @@ func (b *backend) Build(ctx context.Context, modelfilePath, workDir, target stri var configDesc ocispec.Descriptor // Build the model config. - if err := retry.Do(func() error { - configDesc, err = builder.BuildConfig(ctx, config, hooks.NewHooks( + if err := retrypolicy.Do(ctx, func(rctx context.Context) error { + configDesc, err = builder.BuildConfig(rctx, config, hooks.NewHooks( hooks.WithOnStart(func(name string, size int64, reader io.Reader) io.Reader { return pb.Add(internalpb.NormalizePrompt("Building config"), name, size, reader) }), @@ -136,13 +140,16 @@ func (b *backend) Build(ctx context.Context, modelfilePath, workDir, target stri }), )) return err - }, append(defaultRetryOpts, retry.Context(ctx))...); err != nil { + }, retrypolicy.DoOpts{ + FileSize: 0, // config is small + FileName: "config", + }); err != nil { return fmt.Errorf("failed to build model config: %w", err) } // Build the model manifest. - if err := retry.Do(func() error { - _, err = builder.BuildManifest(ctx, layers, configDesc, manifestAnnotation(modelfile), hooks.NewHooks( + if err := retrypolicy.Do(ctx, func(rctx context.Context) error { + _, err = builder.BuildManifest(rctx, layers, configDesc, manifestAnnotation(modelfile), hooks.NewHooks( hooks.WithOnStart(func(name string, size int64, reader io.Reader) io.Reader { return pb.Add(internalpb.NormalizePrompt("Building manifest"), name, size, reader) }), @@ -154,7 +161,10 @@ func (b *backend) Build(ctx context.Context, modelfilePath, workDir, target stri }), )) return err - }, append(defaultRetryOpts, retry.Context(ctx))...); err != nil { + }, retrypolicy.DoOpts{ + FileSize: 0, // manifest is small + FileName: "manifest", + }); err != nil { return fmt.Errorf("failed to build model manifest: %w", err) } @@ -162,7 +172,10 @@ func (b *backend) Build(ctx context.Context, modelfilePath, workDir, target stri return nil } -func (b *backend) getProcessors(modelfile modelfile.Modelfile, cfg *config.Build) []processor.Processor { +func (b *backend) getProcessors( + modelfile modelfile.Modelfile, + cfg *config.Build, +) []processor.Processor { processors := []processor.Processor{} if configs := modelfile.GetConfigs(); len(configs) > 0 { @@ -170,7 +183,10 @@ func (b *backend) getProcessors(modelfile modelfile.Modelfile, cfg *config.Build if cfg.Raw { mediaType = modelspec.MediaTypeModelWeightConfigRaw } - processors = append(processors, processor.NewModelConfigProcessor(b.store, mediaType, configs, "")) + processors = append( + processors, + processor.NewModelConfigProcessor(b.store, mediaType, configs, ""), + ) } if models := modelfile.GetModels(); len(models) > 0 { @@ -201,10 +217,23 @@ func (b *backend) getProcessors(modelfile modelfile.Modelfile, cfg *config.Build } // process walks the user work directory and process the identified files. -func (b *backend) process(ctx context.Context, builder build.Builder, workDir string, pb *internalpb.ProgressBar, cfg *config.Build, processors ...processor.Processor) ([]ocispec.Descriptor, error) { +func (b *backend) process( + ctx context.Context, + builder build.Builder, + workDir string, + pb *internalpb.ProgressBar, + cfg *config.Build, + processors ...processor.Processor, +) ([]ocispec.Descriptor, error) { descriptors := []ocispec.Descriptor{} for _, p := range processors { - descs, err := p.Process(ctx, builder, workDir, processor.WithConcurrency(cfg.Concurrency), processor.WithProgressTracker(pb)) + descs, err := p.Process( + ctx, + builder, + workDir, + processor.WithConcurrency(cfg.Concurrency), + processor.WithProgressTracker(pb), + ) if err != nil { return nil, err } diff --git a/pkg/backend/fetch.go b/pkg/backend/fetch.go index 990d1afa..c6ef90ac 100644 --- a/pkg/backend/fetch.go +++ b/pkg/backend/fetch.go @@ -19,11 +19,12 @@ package backend import ( "context" "encoding/json" + "errors" "fmt" + "sync" + "time" "github.com/bmatcuk/doublestar/v4" - legacymodelspec "github.com/dragonflyoss/model-spec/specs-go/v1" - modelspec "github.com/modelpack/model-spec/specs-go/v1" ocispec "github.com/opencontainers/image-spec/specs-go/v1" "github.com/sirupsen/logrus" "golang.org/x/sync/errgroup" @@ -31,6 +32,7 @@ import ( internalpb "github.com/modelpack/modctl/internal/pb" "github.com/modelpack/modctl/pkg/backend/remote" "github.com/modelpack/modctl/pkg/config" + "github.com/modelpack/modctl/pkg/retrypolicy" ) // Fetch fetches partial files to the output. @@ -74,10 +76,7 @@ func (b *backend) Fetch(ctx context.Context, target string, cfg *config.Fetch) e for _, layer := range manifest.Layers { for _, pattern := range cfg.Patterns { if anno := layer.Annotations; anno != nil { - path := anno[modelspec.AnnotationFilepath] - if path == "" { - path = anno[legacymodelspec.AnnotationFilepath] - } + path := getAnnotationFilepath(anno) // Use doublestar.PathMatch for pattern matching to support ** recursive matching // PathMatch uses the system's native path separator (like filepath.Match) while // also supporting recursive patterns like **/*.json @@ -101,9 +100,12 @@ func (b *backend) Fetch(ctx context.Context, target string, cfg *config.Fetch) e pb.Start() defer pb.Stop() - g, ctx := errgroup.WithContext(ctx) + g := new(errgroup.Group) g.SetLimit(cfg.Concurrency) + var mu sync.Mutex + var errs []error + logrus.Infof("fetch: fetching %d matched layers", len(layers)) for _, layer := range layers { g.Go(func() error { @@ -113,17 +115,37 @@ func (b *backend) Fetch(ctx context.Context, target string, cfg *config.Fetch) e default: } + annoFilepath := getAnnotationFilepath(layer.Annotations) + logrus.Debugf("fetch: processing layer %s", layer.Digest) - if err := pullAndExtractFromRemote(ctx, pb, internalpb.NormalizePrompt("Fetching blob"), client, cfg.Output, layer); err != nil { - return err + if err := retrypolicy.Do(ctx, func(rctx context.Context) error { + return pullAndExtractFromRemote(rctx, pb, internalpb.NormalizePrompt("Fetching blob"), client, cfg.Output, layer) + }, retrypolicy.DoOpts{ + FileSize: layer.Size, + FileName: annoFilepath, + OnRetry: func(attempt uint, reason string, backoff time.Duration) { + if bar := pb.Get(layer.Digest.String()); bar != nil { + bar.SetRefill(bar.Current()) + bar.SetCurrent(0) + bar.EwmaSetCurrent(0, time.Second) + } + }, + }); err != nil { + mu.Lock() + errs = append(errs, err) + mu.Unlock() + } else { + logrus.Debugf("fetch: successfully processed layer %s", layer.Digest) } - - logrus.Debugf("fetch: successfully processed layer %s", layer.Digest) return nil }) } - if err := g.Wait(); err != nil { + _ = g.Wait() + if ctx.Err() != nil { + return fmt.Errorf("fetch cancelled: %w", ctx.Err()) + } + if err := errors.Join(errs...); err != nil { return err } diff --git a/pkg/backend/fetch_by_d7y.go b/pkg/backend/fetch_by_d7y.go index 098a3a53..a9ff7467 100644 --- a/pkg/backend/fetch_by_d7y.go +++ b/pkg/backend/fetch_by_d7y.go @@ -19,18 +19,18 @@ package backend import ( "context" "encoding/json" + "errors" "fmt" "io" "os" "path/filepath" "strings" + "sync" + "time" common "d7y.io/api/v2/pkg/apis/common/v2" dfdaemon "d7y.io/api/v2/pkg/apis/dfdaemon/v2" - "github.com/avast/retry-go/v4" "github.com/bmatcuk/doublestar/v4" - legacymodelspec "github.com/dragonflyoss/model-spec/specs-go/v1" - modelspec "github.com/modelpack/model-spec/specs-go/v1" ocispec "github.com/opencontainers/image-spec/specs-go/v1" "github.com/sirupsen/logrus" "golang.org/x/sync/errgroup" @@ -41,6 +41,7 @@ import ( "github.com/modelpack/modctl/pkg/archiver" "github.com/modelpack/modctl/pkg/backend/remote" "github.com/modelpack/modctl/pkg/config" + "github.com/modelpack/modctl/pkg/retrypolicy" ) // fetchByDragonfly fetches partial files via Dragonfly gRPC service based on pattern matching. @@ -78,10 +79,7 @@ func (b *backend) fetchByDragonfly(ctx context.Context, target string, cfg *conf for _, layer := range manifest.Layers { for _, pattern := range cfg.Patterns { if anno := layer.Annotations; anno != nil { - path := anno[modelspec.AnnotationFilepath] - if path == "" { - path = anno[legacymodelspec.AnnotationFilepath] - } + path := getAnnotationFilepath(anno) // Use doublestar.PathMatch for pattern matching to support ** recursive matching // PathMatch uses the system's native path separator (like filepath.Match) while // also supporting recursive patterns like **/*.json @@ -124,9 +122,12 @@ func (b *backend) fetchByDragonfly(ctx context.Context, target string, cfg *conf defer pb.Stop() // Process layers concurrently. - g, ctx := errgroup.WithContext(ctx) + g := new(errgroup.Group) g.SetLimit(cfg.Concurrency) + var mu sync.Mutex + var errs []error + logrus.Infof("fetch: fetching %d matched layers via dragonfly", len(layers)) for _, layer := range layers { g.Go(func() error { @@ -138,14 +139,21 @@ func (b *backend) fetchByDragonfly(ctx context.Context, target string, cfg *conf logrus.Debugf("fetch: processing layer %s via dragonfly", layer.Digest) if err := fetchLayerByDragonfly(ctx, pb, dfdaemon.NewDfdaemonDownloadClient(conn), ref, manifest, layer, authToken, cfg); err != nil { - return err + mu.Lock() + errs = append(errs, err) + mu.Unlock() + } else { + logrus.Debugf("fetch: successfully processed layer %s via dragonfly", layer.Digest) } - logrus.Debugf("fetch: successfully processed layer %s via dragonfly", layer.Digest) return nil }) } - if err := g.Wait(); err != nil { + _ = g.Wait() + if ctx.Err() != nil { + return fmt.Errorf("fetch cancelled: %w", ctx.Err()) + } + if err := errors.Join(errs...); err != nil { return err } @@ -155,10 +163,12 @@ func (b *backend) fetchByDragonfly(ctx context.Context, target string, cfg *conf // fetchLayerByDragonfly handles downloading and extracting a single layer via Dragonfly. func fetchLayerByDragonfly(ctx context.Context, pb *internalpb.ProgressBar, client dfdaemon.DfdaemonDownloadClient, ref Referencer, manifest ocispec.Manifest, desc ocispec.Descriptor, authToken string, cfg *config.Fetch) error { - err := retry.Do(func() error { + annoFilepath := getAnnotationFilepath(desc.Annotations) + + err := retrypolicy.Do(ctx, func(rctx context.Context) error { logrus.Debugf("fetch: processing layer %s", desc.Digest) cfg.Hooks.BeforePullLayer(desc, manifest) // Call before hook - err := downloadAndExtractFetchLayer(ctx, pb, client, ref, desc, authToken, cfg) + err := downloadAndExtractFetchLayer(rctx, pb, client, ref, desc, authToken, cfg) cfg.Hooks.AfterPullLayer(desc, err) // Call after hook if err != nil { err = fmt.Errorf("pull: failed to download and extract layer %s: %w", desc.Digest, err) @@ -166,7 +176,17 @@ func fetchLayerByDragonfly(ctx context.Context, pb *internalpb.ProgressBar, clie } return err - }, append(defaultRetryOpts, retry.Context(ctx))...) + }, retrypolicy.DoOpts{ + FileSize: desc.Size, + FileName: annoFilepath, + OnRetry: func(attempt uint, reason string, backoff time.Duration) { + if bar := pb.Get(desc.Digest.String()); bar != nil { + bar.SetRefill(bar.Current()) + bar.SetCurrent(0) + bar.EwmaSetCurrent(0, time.Second) + } + }, + }) if err != nil { err = fmt.Errorf("fetch: failed to download and extract layer %s: %w", desc.Digest, err) @@ -184,14 +204,7 @@ func downloadAndExtractFetchLayer(ctx context.Context, pb *internalpb.ProgressBa return fmt.Errorf("failed to resolve output dir: %w", err) } - var annoFilepath string - if desc.Annotations != nil { - if desc.Annotations[modelspec.AnnotationFilepath] != "" { - annoFilepath = desc.Annotations[modelspec.AnnotationFilepath] - } else { - annoFilepath = desc.Annotations[legacymodelspec.AnnotationFilepath] - } - } + annoFilepath := getAnnotationFilepath(desc.Annotations) if annoFilepath == "" { return fmt.Errorf("missing annotation filepath") diff --git a/pkg/backend/processor/base.go b/pkg/backend/processor/base.go index 579e61f2..4e002724 100644 --- a/pkg/backend/processor/base.go +++ b/pkg/backend/processor/base.go @@ -18,6 +18,7 @@ package processor import ( "context" + "errors" "fmt" "io" "os" @@ -26,7 +27,6 @@ import ( "strings" "sync" - "github.com/avast/retry-go/v4" legacymodelspec "github.com/dragonflyoss/model-spec/specs-go/v1" modelspec "github.com/modelpack/model-spec/specs-go/v1" ocispec "github.com/opencontainers/image-spec/specs-go/v1" @@ -36,6 +36,7 @@ import ( internalpb "github.com/modelpack/modctl/internal/pb" "github.com/modelpack/modctl/pkg/backend/build" "github.com/modelpack/modctl/pkg/backend/build/hooks" + "github.com/modelpack/modctl/pkg/retrypolicy" "github.com/modelpack/modctl/pkg/storage" ) @@ -105,14 +106,11 @@ func (b *base) Process(ctx context.Context, builder build.Builder, workDir strin var ( mu sync.Mutex - eg *errgroup.Group descriptors []ocispec.Descriptor + errs []error ) - // Initialize errgroup with a context can be canceled. - ctx, cancel := context.WithCancel(ctx) - defer cancel() - eg, ctx = errgroup.WithContext(ctx) + eg := new(errgroup.Group) // Set default concurrency limit to 1 if not specified. if processOpts.concurrency > 0 { @@ -141,7 +139,13 @@ func (b *base) Process(ctx context.Context, builder build.Builder, workDir strin default: } - if err := retry.Do(func() error { + // Get file size for dynamic retry parameters. + var fileSize int64 + if fi, err := os.Stat(path); err == nil { + fileSize = fi.Size() + } + + if err := retrypolicy.Do(ctx, func(rctx context.Context) error { logrus.Debugf("processor: processing %s file %s", b.name, path) var destPath string @@ -149,7 +153,7 @@ func (b *base) Process(ctx context.Context, builder build.Builder, workDir strin destPath = filepath.Join(b.destDir, filepath.Base(path)) } - desc, err := builder.BuildLayer(ctx, b.mediaType, workDir, path, destPath, hooks.NewHooks( + desc, err := builder.BuildLayer(rctx, b.mediaType, workDir, path, destPath, hooks.NewHooks( hooks.WithOnStart(func(name string, size int64, reader io.Reader) io.Reader { return tracker.Add(internalpb.NormalizePrompt("Building layer"), name, size, reader) }), @@ -170,19 +174,33 @@ func (b *base) Process(ctx context.Context, builder build.Builder, workDir strin mu.Unlock() return nil - }, append(defaultRetryOpts, retry.Context(ctx))...); err != nil { + }, retrypolicy.DoOpts{ + FileSize: fileSize, + FileName: filepath.Base(path), + }); err != nil { logrus.Error(err) - // Cancel manually to abort other tasks because if one fails, - // we should abort all to avoid useless waiting. - cancel() - return err + mu.Lock() + errs = append(errs, err) + mu.Unlock() } return nil }) } - if err := eg.Wait(); err != nil { + if werr := eg.Wait(); werr != nil { + // Surface cancellation from skipped workers so a partially-built + // artifact cannot be emitted as if it were complete. + mu.Lock() + errs = append(errs, werr) + mu.Unlock() + } + + if ctx.Err() != nil { + return nil, fmt.Errorf("processing cancelled: %w", ctx.Err()) + } + + if err := errors.Join(errs...); err != nil { return nil, err } diff --git a/pkg/backend/processor/options.go b/pkg/backend/processor/options.go index 0126410e..4558f514 100644 --- a/pkg/backend/processor/options.go +++ b/pkg/backend/processor/options.go @@ -17,10 +17,6 @@ package processor import ( - "time" - - retry "github.com/avast/retry-go/v4" - "github.com/modelpack/modctl/internal/pb" ) @@ -44,10 +40,3 @@ func WithProgressTracker(tracker *pb.ProgressBar) ProcessOption { o.progressTracker = tracker } } - -var defaultRetryOpts = []retry.Option{ - retry.Attempts(6), - retry.DelayType(retry.BackOffDelay), - retry.Delay(5 * time.Second), - retry.MaxDelay(60 * time.Second), -} diff --git a/pkg/backend/pull.go b/pkg/backend/pull.go index 3cb6a1b0..768e732b 100644 --- a/pkg/backend/pull.go +++ b/pkg/backend/pull.go @@ -22,8 +22,9 @@ import ( "errors" "fmt" "io" + "sync" + "time" - retry "github.com/avast/retry-go/v4" sha256 "github.com/minio/sha256-simd" ocispec "github.com/opencontainers/image-spec/specs-go/v1" "github.com/sirupsen/logrus" @@ -33,6 +34,7 @@ import ( "github.com/modelpack/modctl/pkg/backend/remote" "github.com/modelpack/modctl/pkg/codec" "github.com/modelpack/modctl/pkg/config" + "github.com/modelpack/modctl/pkg/retrypolicy" "github.com/modelpack/modctl/pkg/storage" ) @@ -90,17 +92,20 @@ func (b *backend) Pull(ctx context.Context, target string, cfg *config.Pull) err // copy the layers. dst := b.store - g, gctx := errgroup.WithContext(ctx) + g := new(errgroup.Group) g.SetLimit(cfg.Concurrency) - var fn func(desc ocispec.Descriptor) error + var mu sync.Mutex + var errs []error + + var fn func(ctx context.Context, desc ocispec.Descriptor) error if cfg.ExtractFromRemote { - fn = func(desc ocispec.Descriptor) error { - return pullAndExtractFromRemote(gctx, pb, internalpb.NormalizePrompt("Pulling blob"), src, cfg.ExtractDir, desc) + fn = func(ctx context.Context, desc ocispec.Descriptor) error { + return pullAndExtractFromRemote(ctx, pb, internalpb.NormalizePrompt("Pulling blob"), src, cfg.ExtractDir, desc) } } else { - fn = func(desc ocispec.Descriptor) error { - return pullIfNotExist(gctx, pb, internalpb.NormalizePrompt("Pulling blob"), src, dst, desc, repo, tag) + fn = func(ctx context.Context, desc ocispec.Descriptor) error { + return pullIfNotExist(ctx, pb, internalpb.NormalizePrompt("Pulling blob"), src, dst, desc, repo, tag) } } @@ -108,16 +113,16 @@ func (b *backend) Pull(ctx context.Context, target string, cfg *config.Pull) err for _, layer := range manifest.Layers { g.Go(func() error { select { - case <-gctx.Done(): - return gctx.Err() + case <-ctx.Done(): + return ctx.Err() default: } - return retry.Do(func() error { + retryErr := retrypolicy.Do(ctx, func(retryCtx context.Context) error { logrus.Debugf("pull: processing layer %s", layer.Digest) // call the before hook. cfg.Hooks.BeforePullLayer(layer, manifest) - err := fn(layer) + err := fn(retryCtx, layer) // call the after hook. cfg.Hooks.AfterPullLayer(layer, err) if err != nil { @@ -126,12 +131,35 @@ func (b *backend) Pull(ctx context.Context, target string, cfg *config.Pull) err } return err - }, append(defaultRetryOpts, retry.Context(gctx))...) + }, retrypolicy.DoOpts{ + FileSize: layer.Size, + FileName: layer.Digest.String(), + OnRetry: func(attempt uint, reason string, backoff time.Duration) { + pb.Placeholder(layer.Digest.String(), internalpb.NormalizePrompt("Pulling blob"), layer.Size) + }, + }) + if retryErr != nil { + mu.Lock() + errs = append(errs, retryErr) + mu.Unlock() + } + + return nil }) } - if err := g.Wait(); err != nil { - return fmt.Errorf("failed to pull blob to local: %w", err) + if werr := g.Wait(); werr != nil { + // Surface cancellation from worker goroutines so a cancelled batch + // never slips through as an apparently successful pull. + mu.Lock() + errs = append(errs, werr) + mu.Unlock() + } + if ctx.Err() != nil { + return fmt.Errorf("pull cancelled: %w", ctx.Err()) + } + if len(errs) > 0 { + return fmt.Errorf("failed to pull blob to local: %w", errors.Join(errs...)) } logrus.Infof("pull: layers pulled [count: %d]", len(manifest.Layers)) @@ -143,16 +171,28 @@ func (b *backend) Pull(ctx context.Context, target string, cfg *config.Pull) err } // copy the config. - if err := retry.Do(func() error { - return pullIfNotExist(ctx, pb, internalpb.NormalizePrompt("Pulling config"), src, dst, manifest.Config, repo, tag) - }, append(defaultRetryOpts, retry.Context(ctx))...); err != nil { + if err := retrypolicy.Do(ctx, func(retryCtx context.Context) error { + return pullIfNotExist(retryCtx, pb, internalpb.NormalizePrompt("Pulling config"), src, dst, manifest.Config, repo, tag) + }, retrypolicy.DoOpts{ + FileSize: manifest.Config.Size, + FileName: "config", + OnRetry: func(attempt uint, reason string, backoff time.Duration) { + pb.Placeholder(manifest.Config.Digest.String(), internalpb.NormalizePrompt("Pulling config"), manifest.Config.Size) + }, + }); err != nil { return fmt.Errorf("failed to pull config to local: %w", err) } // copy the manifest. - if err := retry.Do(func() error { - return pullIfNotExist(ctx, pb, internalpb.NormalizePrompt("Pulling manifest"), src, dst, manifestDesc, repo, tag) - }, append(defaultRetryOpts, retry.Context(ctx))...); err != nil { + if err := retrypolicy.Do(ctx, func(retryCtx context.Context) error { + return pullIfNotExist(retryCtx, pb, internalpb.NormalizePrompt("Pulling manifest"), src, dst, manifestDesc, repo, tag) + }, retrypolicy.DoOpts{ + FileSize: manifestDesc.Size, + FileName: "manifest", + OnRetry: func(attempt uint, reason string, backoff time.Duration) { + pb.Placeholder(manifestDesc.Digest.String(), internalpb.NormalizePrompt("Pulling manifest"), manifestDesc.Size) + }, + }); err != nil { return fmt.Errorf("failed to pull manifest to local: %w", err) } diff --git a/pkg/backend/pull_by_d7y.go b/pkg/backend/pull_by_d7y.go index a7ca3b47..170a60c6 100644 --- a/pkg/backend/pull_by_d7y.go +++ b/pkg/backend/pull_by_d7y.go @@ -19,17 +19,17 @@ package backend import ( "context" "encoding/json" + "errors" "fmt" "io" "os" "path/filepath" "strings" + "sync" + "time" common "d7y.io/api/v2/pkg/apis/common/v2" dfdaemon "d7y.io/api/v2/pkg/apis/dfdaemon/v2" - "github.com/avast/retry-go/v4" - legacymodelspec "github.com/dragonflyoss/model-spec/specs-go/v1" - modelspec "github.com/modelpack/model-spec/specs-go/v1" ocispec "github.com/opencontainers/image-spec/specs-go/v1" "github.com/sirupsen/logrus" "golang.org/x/sync/errgroup" @@ -41,6 +41,7 @@ import ( "github.com/modelpack/modctl/pkg/archiver" "github.com/modelpack/modctl/pkg/backend/remote" "github.com/modelpack/modctl/pkg/config" + "github.com/modelpack/modctl/pkg/retrypolicy" ) const ( @@ -100,9 +101,12 @@ func (b *backend) pullByDragonfly(ctx context.Context, target string, cfg *confi defer pb.Stop() // Process layers concurrently. - g, ctx := errgroup.WithContext(ctx) + g := new(errgroup.Group) g.SetLimit(cfg.Concurrency) + var mu sync.Mutex + var errs []error + logrus.Infof("pull: pulling %d layers via dragonfly", len(manifest.Layers)) for _, layer := range manifest.Layers { g.Go(func() error { @@ -114,14 +118,21 @@ func (b *backend) pullByDragonfly(ctx context.Context, target string, cfg *confi logrus.Debugf("pull: processing layer %s via dragonfly", layer.Digest) if err := processLayer(ctx, pb, dfdaemon.NewDfdaemonDownloadClient(conn), ref, manifest, layer, authToken, cfg); err != nil { - return err + mu.Lock() + errs = append(errs, err) + mu.Unlock() + } else { + logrus.Debugf("pull: successfully processed layer %s via dragonfly", layer.Digest) } - logrus.Debugf("pull: successfully processed layer %s via dragonfly", layer.Digest) return nil }) } - if err := g.Wait(); err != nil { + _ = g.Wait() + if ctx.Err() != nil { + return fmt.Errorf("pull cancelled: %w", ctx.Err()) + } + if err := errors.Join(errs...); err != nil { return err } @@ -179,10 +190,12 @@ func buildBlobURL(ref Referencer, plainHTTP bool, digest string) string { // processLayer handles downloading and extracting a single layer. func processLayer(ctx context.Context, pb *internalpb.ProgressBar, client dfdaemon.DfdaemonDownloadClient, ref Referencer, manifest ocispec.Manifest, desc ocispec.Descriptor, authToken string, cfg *config.Pull) error { - err := retry.Do(func() error { + annoFilepath := getAnnotationFilepath(desc.Annotations) + + err := retrypolicy.Do(ctx, func(rctx context.Context) error { logrus.Debugf("pull: processing layer %s", desc.Digest) cfg.Hooks.BeforePullLayer(desc, manifest) // Call before hook - err := downloadAndExtractLayer(ctx, pb, client, ref, desc, authToken, cfg) + err := downloadAndExtractLayer(rctx, pb, client, ref, desc, authToken, cfg) cfg.Hooks.AfterPullLayer(desc, err) // Call after hook if err != nil { err = fmt.Errorf("pull: failed to download and extract layer %s: %w", desc.Digest, err) @@ -190,7 +203,17 @@ func processLayer(ctx context.Context, pb *internalpb.ProgressBar, client dfdaem } return err - }, append(defaultRetryOpts, retry.Context(ctx))...) + }, retrypolicy.DoOpts{ + FileSize: desc.Size, + FileName: annoFilepath, + OnRetry: func(attempt uint, reason string, backoff time.Duration) { + if bar := pb.Get(desc.Digest.String()); bar != nil { + bar.SetRefill(bar.Current()) + bar.SetCurrent(0) + bar.EwmaSetCurrent(0, time.Second) + } + }, + }) return err } @@ -203,14 +226,7 @@ func downloadAndExtractLayer(ctx context.Context, pb *internalpb.ProgressBar, cl return fmt.Errorf("failed to resolve extract dir: %w", err) } - var annoFilepath string - if desc.Annotations != nil { - if desc.Annotations[modelspec.AnnotationFilepath] != "" { - annoFilepath = desc.Annotations[modelspec.AnnotationFilepath] - } else { - annoFilepath = desc.Annotations[legacymodelspec.AnnotationFilepath] - } - } + annoFilepath := getAnnotationFilepath(desc.Annotations) if annoFilepath == "" { return fmt.Errorf("missing annotation filepath") diff --git a/pkg/backend/push.go b/pkg/backend/push.go index 2674cee2..c7d1c2fa 100644 --- a/pkg/backend/push.go +++ b/pkg/backend/push.go @@ -20,10 +20,12 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "io" + "sync" + "time" - retry "github.com/avast/retry-go/v4" godigest "github.com/opencontainers/go-digest" ocispec "github.com/opencontainers/image-spec/specs-go/v1" "github.com/sirupsen/logrus" @@ -32,6 +34,7 @@ import ( internalpb "github.com/modelpack/modctl/internal/pb" "github.com/modelpack/modctl/pkg/backend/remote" "github.com/modelpack/modctl/pkg/config" + "github.com/modelpack/modctl/pkg/retrypolicy" "github.com/modelpack/modctl/pkg/storage" ) @@ -77,49 +80,92 @@ func (b *backend) Push(ctx context.Context, target string, cfg *config.Push) err // note: the order is important, manifest should be pushed at last. // copy the layers. - g, gctx := errgroup.WithContext(ctx) + g := new(errgroup.Group) g.SetLimit(cfg.Concurrency) + var mu sync.Mutex + var errs []error logrus.Infof("push: pushing %d layers for %s", len(manifest.Layers), target) for _, layer := range manifest.Layers { g.Go(func() error { select { - case <-gctx.Done(): - return gctx.Err() + case <-ctx.Done(): + return ctx.Err() default: } - return retry.Do(func() error { + if err := retrypolicy.Do(ctx, func(rctx context.Context) error { logrus.Debugf("push: processing layer %s", layer.Digest) - if err := pushIfNotExist(gctx, pb, internalpb.NormalizePrompt("Copying blob"), src, dst, layer, repo, tag); err != nil { + if err := pushIfNotExist(rctx, pb, internalpb.NormalizePrompt("Copying blob"), src, dst, layer, repo, tag); err != nil { return err } logrus.Debugf("push: successfully processed layer %s", layer.Digest) return nil - }, append(defaultRetryOpts, retry.Context(gctx))...) + }, retrypolicy.DoOpts{ + FileSize: layer.Size, + FileName: layer.Digest.String(), + OnRetry: func(attempt uint, reason string, backoff time.Duration) { + prompt := fmt.Sprintf("%s (retry %d, %s, waiting %s)", + internalpb.NormalizePrompt("Copying blob"), attempt, reason, backoff.Truncate(time.Second)) + pb.Placeholder(layer.Digest.String(), prompt, layer.Size) + }, + }); err != nil { + mu.Lock() + errs = append(errs, err) + mu.Unlock() + } + return nil // never return error to errgroup }) } - if err := g.Wait(); err != nil { - return fmt.Errorf("failed to push blob to remote: %w", err) + if werr := g.Wait(); werr != nil { + // Surface cancellation returned from worker goroutines so we never + // fall through to the config/manifest push with an incomplete set + // of layer uploads. + mu.Lock() + errs = append(errs, werr) + mu.Unlock() + } + if ctx.Err() != nil { + return fmt.Errorf("push cancelled: %w", ctx.Err()) + } + if len(errs) > 0 { + return fmt.Errorf("failed to push blob to remote: %w", errors.Join(errs...)) } // copy the config. - if err := retry.Do(func() error { - return pushIfNotExist(ctx, pb, internalpb.NormalizePrompt("Copying config"), src, dst, manifest.Config, repo, tag) - }, append(defaultRetryOpts, retry.Context(ctx))...); err != nil { + if err := retrypolicy.Do(ctx, func(rctx context.Context) error { + return pushIfNotExist(rctx, pb, internalpb.NormalizePrompt("Copying config"), src, dst, manifest.Config, repo, tag) + }, retrypolicy.DoOpts{ + FileSize: manifest.Config.Size, + FileName: "config", + OnRetry: func(attempt uint, reason string, backoff time.Duration) { + prompt := fmt.Sprintf("%s (retry %d, %s, waiting %s)", + internalpb.NormalizePrompt("Copying config"), attempt, reason, backoff.Truncate(time.Second)) + pb.Placeholder(manifest.Config.Digest.String(), prompt, manifest.Config.Size) + }, + }); err != nil { return fmt.Errorf("failed to push config to remote: %w", err) } // copy the manifest. - if err := retry.Do(func() error { - return pushIfNotExist(ctx, pb, internalpb.NormalizePrompt("Copying manifest"), src, dst, ocispec.Descriptor{ - MediaType: manifest.MediaType, - Size: int64(len(manifestRaw)), - Digest: godigest.FromBytes(manifestRaw), - Data: manifestRaw, - }, repo, tag) - }, append(defaultRetryOpts, retry.Context(ctx))...); err != nil { + manifestDesc := ocispec.Descriptor{ + MediaType: manifest.MediaType, + Size: int64(len(manifestRaw)), + Digest: godigest.FromBytes(manifestRaw), + Data: manifestRaw, + } + if err := retrypolicy.Do(ctx, func(rctx context.Context) error { + return pushIfNotExist(rctx, pb, internalpb.NormalizePrompt("Copying manifest"), src, dst, manifestDesc, repo, tag) + }, retrypolicy.DoOpts{ + FileSize: manifestDesc.Size, + FileName: "manifest", + OnRetry: func(attempt uint, reason string, backoff time.Duration) { + prompt := fmt.Sprintf("%s (retry %d, %s, waiting %s)", + internalpb.NormalizePrompt("Copying manifest"), attempt, reason, backoff.Truncate(time.Second)) + pb.Placeholder(manifestDesc.Digest.String(), prompt, manifestDesc.Size) + }, + }); err != nil { return fmt.Errorf("failed to push manifest to remote: %w", err) } diff --git a/pkg/backend/retry_test.go b/pkg/backend/retry_test.go deleted file mode 100644 index f4069ac1..00000000 --- a/pkg/backend/retry_test.go +++ /dev/null @@ -1,96 +0,0 @@ -/* - * Copyright 2025 The CNAI Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package backend - -import ( - "context" - "errors" - "testing" - - retry "github.com/avast/retry-go/v4" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// testRetryOpts creates retry options with zero delay so tests run fast and deterministically. -func testRetryOpts(ctx context.Context, maxAttempts uint) []retry.Option { - return []retry.Option{ - retry.Attempts(maxAttempts), - retry.Delay(0), - retry.MaxDelay(0), - retry.Context(ctx), - } -} - -func TestRetrySuccessAfterFailures(t *testing.T) { - ctx := context.Background() - attempts := 0 - - err := retry.Do(func() error { - attempts++ - if attempts < 3 { - return errors.New("temporary failure") - } - return nil - }, testRetryOpts(ctx, 6)...) - - require.NoError(t, err) - assert.Equal(t, 3, attempts) -} - -func TestRetryMaxAttemptsExceeded(t *testing.T) { - ctx := context.Background() - attempts := 0 - maxAttempts := uint(4) - - err := retry.Do(func() error { - attempts++ - return errors.New("persistent failure") - }, testRetryOpts(ctx, maxAttempts)...) - - assert.Error(t, err) - assert.Equal(t, int(maxAttempts), attempts) -} - -func TestRetryStopsOnContextCancel(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - attempts := 0 - - err := retry.Do(func() error { - attempts++ - if attempts == 2 { - cancel() - } - return errors.New("temporary failure") - }, testRetryOpts(ctx, 10)...) - - assert.ErrorIs(t, err, context.Canceled) - assert.Equal(t, 2, attempts) -} - -func TestRetrySucceedsOnFirstAttempt(t *testing.T) { - ctx := context.Background() - attempts := 0 - - err := retry.Do(func() error { - attempts++ - return nil - }, testRetryOpts(ctx, 6)...) - - require.NoError(t, err) - assert.Equal(t, 1, attempts) -} diff --git a/pkg/config/build.go b/pkg/config/build.go index 27dd23d5..d3ec828c 100644 --- a/pkg/config/build.go +++ b/pkg/config/build.go @@ -16,7 +16,9 @@ package config -import "fmt" +import ( + "fmt" +) const ( // defaultBuildConcurrency is the default number of concurrent builds. diff --git a/pkg/config/push.go b/pkg/config/push.go index c5fa8e8f..dc75b199 100644 --- a/pkg/config/push.go +++ b/pkg/config/push.go @@ -16,7 +16,9 @@ package config -import "fmt" +import ( + "fmt" +) const ( // defaultPushConcurrency is the default number of concurrent push operations. diff --git a/pkg/retrypolicy/retrypolicy.go b/pkg/retrypolicy/retrypolicy.go new file mode 100644 index 00000000..9908795d --- /dev/null +++ b/pkg/retrypolicy/retrypolicy.go @@ -0,0 +1,386 @@ +/* + * Copyright 2024 The ModelPack Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package retrypolicy provides retry behavior for blob transfer operations. +// +// The package decouples two timing concerns that are commonly conflated: +// +// - Per-attempt timeout: how long a single transfer attempt may take. +// This scales with file size, since larger files need longer to transfer. +// +// - Retry policy: how many attempts to make and how long to wait between +// them. This is bounded by attempt count and per-sleep cap; it does not +// scale with file size, because transient-failure recovery time is +// independent of payload size. +// +// Earlier designs used a single wall-clock budget covering both transfer +// time and retry waits, which made retries scarce precisely when networks +// were slow. The current design fixes that by giving each concern its own +// knob. +package retrypolicy + +import ( + "context" + "errors" + "fmt" + "math" + "regexp" + "strings" + "time" + + retry "github.com/avast/retry-go/v4" + log "github.com/sirupsen/logrus" +) + +const ( + // DefaultMaxAttempts is the total number of attempts (initial + retries). + DefaultMaxAttempts = 6 + + // DefaultInitialDelay is the first sleep between attempts. + DefaultInitialDelay = 5 * time.Second + + // DefaultMaxBackoff caps a single sleep between attempts. It does not + // scale with file size: transient outages have a payload-independent + // duration distribution. + DefaultMaxBackoff = 2 * time.Minute + + // DefaultMaxJitter is the upper bound on randomized jitter added to each + // sleep. + DefaultMaxJitter = 5 * time.Second + + // minThroughput is the assumed worst-case usable throughput when + // computing per-attempt timeouts. Networks slower than this are out of + // scope; users on such links should set Config.PerAttemptTimeout + // explicitly. + minThroughput = 10 * (1 << 20) // 10 MiB/s + + // safetyFactor multiplies the ideal transfer time when sizing + // per-attempt timeouts, leaving headroom for protocol overhead and + // short-lived speed dips. + safetyFactor = 2 + + // minPerAttemptTimeout is the floor for derived per-attempt timeouts — + // small blobs (manifests, configs) still need enough time for TLS, + // auth, and slow-start. + minPerAttemptTimeout = 5 * time.Minute + + // maxPerAttemptTimeout is the ceiling. Very large files (e.g. 100GB+ + // LLM shards) will still hit this cap, after which the user should + // override via Config.PerAttemptTimeout. + maxPerAttemptTimeout = 8 * time.Hour +) + +// Config holds user-configurable retry parameters from CLI flags. +// +// The zero value is valid and yields production defaults: 6 attempts, with +// per-attempt timeout derived from file size and exponential backoff up to +// DefaultMaxBackoff. +type Config struct { + // MaxAttempts is the total number of attempts (initial + retries). + // 0 means "use DefaultMaxAttempts". Set to 1 to disable retries + // (single attempt, no retry on failure). + MaxAttempts int + + // PerAttemptTimeout is the maximum duration for a single attempt. + // 0 → derive from file size (see ComputePerAttemptTimeout). + // <0 → no per-attempt timeout (caller fully controls deadlines). + // >0 → use this value verbatim. + PerAttemptTimeout time.Duration + + // InitialDelay overrides the first inter-attempt sleep. 0 = default. + // Primarily for tests. + InitialDelay time.Duration + + // MaxBackoff overrides the per-sleep cap. 0 = default. + MaxBackoff time.Duration + + // MaxJitter: -1 = no jitter, 0 = default, >0 = override. For tests. + MaxJitter time.Duration +} + +// DoOpts configures a single Do call. +type DoOpts struct { + // FileSize sizes the per-attempt timeout when Config.PerAttemptTimeout + // is unset. May be 0 for non-blob operations (manifest, config); the + // timeout will then clamp to minPerAttemptTimeout. + FileSize int64 + + // FileName is logged on each retry. + FileName string + + // Config is the user-supplied policy. nil means defaults. + Config *Config + + // OnRetry is invoked before each sleep (after a failed attempt) so + // callers can update progress UI. attempt is 1-based; reason is a + // short label from ShortReason. + OnRetry func(attempt uint, reason string, backoff time.Duration) +} + +// Do executes fn with retry. Each attempt runs under its own deadline +// derived from PerAttemptTimeout (or file size); retries are bounded by +// MaxAttempts and exponential backoff capped at MaxBackoff. The parent ctx +// is honored for user-initiated cancellation only — its expiry is not +// coupled to per-attempt transfer time. +func Do(ctx context.Context, fn func(ctx context.Context) error, opts DoOpts) error { + cfg := opts.Config + if cfg == nil { + cfg = &Config{} + } + + perAttemptTimeout := cfg.PerAttemptTimeout + switch { + case perAttemptTimeout == 0: + perAttemptTimeout = ComputePerAttemptTimeout(opts.FileSize) + case perAttemptTimeout < 0: + perAttemptTimeout = 0 // disabled + } + + // runAttempt applies the per-attempt deadline. retry-go calls this + // for each attempt; if MaxAttempts == 1 the loop exits after one + // invocation (equivalent to "no retry"). + runAttempt := func() error { + attemptCtx := ctx + if perAttemptTimeout > 0 { + var cancel context.CancelFunc + attemptCtx, cancel = context.WithTimeout(ctx, perAttemptTimeout) + defer cancel() + } + return fn(attemptCtx) + } + + maxAttempts := cfg.MaxAttempts + if maxAttempts <= 0 { + maxAttempts = DefaultMaxAttempts + } + + initialDelay := cfg.InitialDelay + if initialDelay <= 0 { + initialDelay = DefaultInitialDelay + } + + maxBackoff := cfg.MaxBackoff + if maxBackoff <= 0 { + maxBackoff = DefaultMaxBackoff + } + + jitter := DefaultMaxJitter + if cfg.MaxJitter < 0 { + jitter = 0 + } else if cfg.MaxJitter > 0 { + jitter = cfg.MaxJitter + } + + sizeStr := humanizeBytes(opts.FileSize) + startTime := time.Now() + + return retry.Do( + runAttempt, + retry.Attempts(uint(maxAttempts)), + retry.Context(ctx), + retry.DelayType(retry.BackOffDelay), + retry.Delay(initialDelay), + retry.MaxDelay(maxBackoff), + retry.MaxJitter(jitter), + retry.LastErrorOnly(true), + retry.RetryIf(func(err error) bool { + // Per-attempt timeout fired but parent ctx is alive: this is a + // transient transfer timeout, not a user cancellation. Retry. + if errors.Is(err, context.DeadlineExceeded) && ctx.Err() == nil { + return true + } + retryable := IsRetryable(err) + if !retryable { + log.WithFields(log.Fields{ + "file": opts.FileName, + "size": sizeStr, + "error": err.Error(), + }).Error("[RETRY] non-retryable error, not retrying") + } + return retryable + }), + retry.OnRetry(func(n uint, err error) { + // retry-go calls OnRetry with n = 0-based retry index. Convert + // to 1-based for both the log and the user callback. + attempt := n + 1 + backoff := computeBackoff(attempt, initialDelay, maxBackoff) + elapsed := time.Since(startTime) + + log.WithFields(log.Fields{ + "file": opts.FileName, + "size": sizeStr, + "error": err.Error(), + "max_attempts": maxAttempts, + "max_backoff": maxBackoff.String(), + "per_attempt_to": perAttemptTimeout.String(), + "next_retry_in": backoff.Truncate(time.Second).String(), + "elapsed": elapsed.Truncate(time.Second).String(), + }).Warnf("[RETRY] attempt %d/%d for %q (%s)", attempt, maxAttempts, opts.FileName, sizeStr) + + if opts.OnRetry != nil { + reason := ShortReason(err) + opts.OnRetry(attempt, reason, backoff) + } + }), + ) +} + +// ComputePerAttemptTimeout estimates a single-attempt transfer deadline from +// file size, assuming minThroughput as a worst-case-but-usable rate and +// applying safetyFactor for headroom. The result is clamped to +// [minPerAttemptTimeout, maxPerAttemptTimeout]. +// +// Examples (rounded): +// +// 1 GB → 5 min (floor) +// 10 GB → 34 min +// 70 GB → ~4 h +// 140 GB → ~8 h (ceiling) +// +// fileSize <= 0 returns minPerAttemptTimeout. +func ComputePerAttemptTimeout(fileSize int64) time.Duration { + if fileSize <= 0 { + return minPerAttemptTimeout + } + secs := float64(fileSize) / float64(minThroughput) * safetyFactor + t := time.Duration(secs * float64(time.Second)) + if t < minPerAttemptTimeout { + return minPerAttemptTimeout + } + if t > maxPerAttemptTimeout { + return maxPerAttemptTimeout + } + return t +} + +// computeBackoff estimates the backoff duration for display purposes. +// It mirrors retry-go's exponential schedule (without jitter). +// attempt is 1-based: the first sleep (after attempt 1 fails) is +// initialDelay, the second is 2*initialDelay, capped at maxDelay. +func computeBackoff(attempt uint, initial, maxDelay time.Duration) time.Duration { + if attempt == 0 { + return initial + } + backoff := time.Duration(float64(initial) * math.Pow(2, float64(attempt-1))) + if backoff > maxDelay { + backoff = maxDelay + } + return backoff +} + +// httpStatusPattern matches ORAS-style error messages that embed HTTP status codes. +var httpStatusPattern = regexp.MustCompile(`response status code (\d{3})`) + +// IsRetryable returns true for transient errors that warrant a retry. +// +// context.Canceled and bare context.DeadlineExceeded are not retryable here: +// the Do loop independently re-classifies a per-attempt timeout while the +// parent context is still alive as retryable, so this function only sees +// genuine cancellation. +func IsRetryable(err error) bool { + if err == nil { + return false + } + + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return false + } + + errMsg := err.Error() + + if matches := httpStatusPattern.FindStringSubmatch(errMsg); len(matches) == 2 { + code := matches[1] + if code[0] == '5' { + return true + } + if code == "408" || code == "429" { + return true + } + return false + } + + if strings.Contains(errMsg, "i/o timeout") || + strings.Contains(errMsg, "connection reset by peer") || + strings.Contains(errMsg, "connection refused") || + strings.Contains(errMsg, "broken pipe") || + strings.Contains(errMsg, "EOF") { + return true + } + + if strings.Contains(errMsg, "permission denied") || + strings.Contains(errMsg, "no space left on device") || + strings.Contains(errMsg, "file exists") || + strings.Contains(errMsg, "not a directory") || + strings.Contains(errMsg, "is a directory") || + strings.Contains(errMsg, "no such file or directory") || + strings.Contains(errMsg, "invalid argument") { + return false + } + + log.WithField("error", errMsg).Warn("[RETRY] unknown error treated as retryable") + return true +} + +// ShortReason extracts a brief human-readable label from an error for +// progress bar display. +func ShortReason(err error) string { + if err == nil { + return "" + } + + errMsg := err.Error() + + if matches := httpStatusPattern.FindStringSubmatch(errMsg); len(matches) == 2 { + return "HTTP " + matches[1] + } + + switch { + case strings.Contains(errMsg, "i/o timeout"): + return "i/o timeout" + case strings.Contains(errMsg, "connection reset by peer"): + return "conn reset" + case strings.Contains(errMsg, "connection refused"): + return "conn refused" + case strings.Contains(errMsg, "broken pipe"): + return "broken pipe" + case strings.Contains(errMsg, "EOF"): + return "EOF" + case errors.Is(err, context.DeadlineExceeded): + return "attempt timeout" + } + + return "unknown error" +} + +// humanizeBytes converts a byte count to a human-readable string. +func humanizeBytes(b int64) string { + const ( + kb = 1024 + mb = 1024 * kb + gb = 1024 * mb + ) + + switch { + case b >= gb: + return fmt.Sprintf("%.1f GB", float64(b)/float64(gb)) + case b >= mb: + return fmt.Sprintf("%.1f MB", float64(b)/float64(mb)) + case b >= kb: + return fmt.Sprintf("%.1f KB", float64(b)/float64(kb)) + default: + return fmt.Sprintf("%d B", b) + } +} diff --git a/pkg/retrypolicy/retrypolicy_test.go b/pkg/retrypolicy/retrypolicy_test.go new file mode 100644 index 00000000..b2a4ba7b --- /dev/null +++ b/pkg/retrypolicy/retrypolicy_test.go @@ -0,0 +1,525 @@ +/* + * Copyright 2024 The ModelPack Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package retrypolicy + +import ( + "context" + "errors" + "fmt" + "sync/atomic" + "testing" + "time" +) + +// --- ComputePerAttemptTimeout --- + +func TestComputePerAttemptTimeout(t *testing.T) { + const ( + oneMB = int64(1) << 20 + oneGB = int64(1) << 30 + tenGB = int64(10) << 30 + hundGB = int64(100) << 30 + ) + tests := []struct { + name string + size int64 + want time.Duration + }{ + {"zero size clamps to floor", 0, minPerAttemptTimeout}, + {"100 MB clamps to floor", 100 * oneMB, minPerAttemptTimeout}, + {"1 GB clamps to floor", oneGB, minPerAttemptTimeout}, + // 5 GB / 10 MB/s * 2 = 1024s ≈ 17 min, above the 5min floor + { + "5 GB scales above floor", + 5 * oneGB, + time.Duration(5*oneGB/minThroughput*safetyFactor) * time.Second, + }, + // 10 GB / 10 MB/s * 2 = 2048s ≈ 34 min + { + "10 GB scales linearly", + tenGB, + time.Duration(tenGB/minThroughput*safetyFactor) * time.Second, + }, + // 100 GB / 10 MB/s * 2 = 20480s ≈ 5.7h, still under the 8h ceiling + { + "100 GB still under ceiling", + hundGB, + time.Duration(hundGB/minThroughput*safetyFactor) * time.Second, + }, + // 200 GB hits ceiling + {"200 GB clamps to ceiling", 2 * hundGB, maxPerAttemptTimeout}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ComputePerAttemptTimeout(tt.size) + if got != tt.want { + t.Errorf("ComputePerAttemptTimeout(%d) = %v, want %v", tt.size, got, tt.want) + } + }) + } +} + +// --- IsRetryable --- + +func TestIsRetryable(t *testing.T) { + tests := []struct { + name string + err error + want bool + }{ + {"nil", nil, false}, + {"context.Canceled", context.Canceled, false}, + {"context.DeadlineExceeded", context.DeadlineExceeded, false}, + {"5xx server error", errors.New("response status code 500"), true}, + { + "503 service unavailable", + errors.New("response status code 503: Service Unavailable"), + true, + }, + {"408 request timeout", errors.New("response status code 408"), true}, + {"429 too many requests", errors.New("response status code 429"), true}, + {"401 unauthorized", errors.New("response status code 401"), false}, + {"403 forbidden", errors.New("response status code 403"), false}, + {"404 not found", errors.New("response status code 404"), false}, + {"i/o timeout", errors.New("dial tcp: i/o timeout"), true}, + {"connection reset", errors.New("read tcp: connection reset by peer"), true}, + {"connection refused", errors.New("dial tcp: connection refused"), true}, + {"broken pipe", errors.New("write tcp: broken pipe"), true}, + {"EOF", errors.New("unexpected EOF"), true}, + {"permission denied", errors.New("open /etc/foo: permission denied"), false}, + {"no space left", errors.New("write /tmp/x: no space left on device"), false}, + {"file exists", errors.New("link /a /b: file exists"), false}, + {"no such file", errors.New("open /no/such: no such file or directory"), false}, + {"unknown defaults retryable", errors.New("some weird error"), true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := IsRetryable(tt.err); got != tt.want { + t.Errorf("IsRetryable(%v) = %v, want %v", tt.err, got, tt.want) + } + }) + } +} + +// --- ShortReason --- + +func TestShortReason(t *testing.T) { + tests := []struct { + name string + err error + want string + }{ + {"nil", nil, ""}, + {"5xx", errors.New("response status code 503"), "HTTP 503"}, + {"i/o timeout", errors.New("dial tcp: i/o timeout"), "i/o timeout"}, + {"conn reset", errors.New("read tcp: connection reset by peer"), "conn reset"}, + {"conn refused", errors.New("dial: connection refused"), "conn refused"}, + {"broken pipe", errors.New("write: broken pipe"), "broken pipe"}, + {"EOF", errors.New("unexpected EOF"), "EOF"}, + {"DeadlineExceeded", context.DeadlineExceeded, "attempt timeout"}, + {"unknown", errors.New("totally unrelated"), "unknown error"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := ShortReason(tt.err); got != tt.want { + t.Errorf("ShortReason(%v) = %q, want %q", tt.err, got, tt.want) + } + }) + } +} + +// --- computeBackoff --- + +func TestComputeBackoff(t *testing.T) { + const initial = 1 * time.Second + const cap_ = 10 * time.Second + tests := []struct { + attempt uint + want time.Duration + }{ + {1, 1 * time.Second}, // first sleep + {2, 2 * time.Second}, // doubled + {3, 4 * time.Second}, + {4, 8 * time.Second}, + {5, 10 * time.Second}, // capped + {20, 10 * time.Second}, + } + for _, tt := range tests { + t.Run(fmt.Sprintf("attempt=%d", tt.attempt), func(t *testing.T) { + got := computeBackoff(tt.attempt, initial, cap_) + if got != tt.want { + t.Errorf("computeBackoff(%d) = %v, want %v", tt.attempt, got, tt.want) + } + }) + } +} + +// --- Do: defaults & success path --- + +func TestDo_SuccessFirstAttempt(t *testing.T) { + calls := 0 + err := Do(context.Background(), func(ctx context.Context) error { + calls++ + return nil + }, DoOpts{FileName: "ok"}) + if err != nil { + t.Fatalf("Do returned %v", err) + } + if calls != 1 { + t.Errorf("calls = %d, want 1", calls) + } +} + +// --- Do: MaxAttempts=1 is the canonical "no retry" knob --- + +func TestDo_MaxAttempts1(t *testing.T) { + calls := 0 + transient := errors.New("response status code 503") + err := Do(context.Background(), func(ctx context.Context) error { + calls++ + return transient + }, DoOpts{ + FileName: "single-attempt", + Config: &Config{MaxAttempts: 1}, + }) + if err == nil { + t.Fatal("Do returned nil, want transient error returned verbatim") + } + if calls != 1 { + t.Errorf("calls = %d, want 1 (MaxAttempts=1)", calls) + } +} + +// --- Do: single attempt still honors per-attempt timeout --- +// +// Even when retries are disabled (MaxAttempts=1), a hung transfer must still +// terminate; otherwise users get a stalled CLI with no failure signal. +func TestDo_SingleAttemptHonorsPerAttemptTimeout(t *testing.T) { + var calls int32 + start := time.Now() + err := Do(context.Background(), func(ctx context.Context) error { + atomic.AddInt32(&calls, 1) + <-ctx.Done() + return ctx.Err() + }, DoOpts{ + FileName: "single-but-bounded", + Config: &Config{ + MaxAttempts: 1, + PerAttemptTimeout: 30 * time.Millisecond, + }, + }) + elapsed := time.Since(start) + if err == nil { + t.Fatal("Do returned nil, want context deadline error") + } + if got := atomic.LoadInt32(&calls); got != 1 { + t.Errorf("calls = %d, want 1", got) + } + if elapsed > 500*time.Millisecond { + t.Errorf("Do hung for %v, want quick exit via per-attempt timeout", elapsed) + } +} + +// --- Do: MaxAttempts caps the number of tries --- + +func TestDo_MaxAttempts(t *testing.T) { + var calls int32 + err := Do(context.Background(), func(ctx context.Context) error { + atomic.AddInt32(&calls, 1) + return errors.New("response status code 503") + }, DoOpts{ + FileName: "always-fails", + Config: &Config{ + MaxAttempts: 3, + InitialDelay: 1 * time.Millisecond, + MaxBackoff: 1 * time.Millisecond, + MaxJitter: -1, + }, + }) + if err == nil { + t.Fatal("Do returned nil, want error after attempts exhausted") + } + if got := atomic.LoadInt32(&calls); got != 3 { + t.Errorf("calls = %d, want 3", got) + } +} + +// --- Do: non-retryable error stops immediately --- + +func TestDo_NonRetryableStopsImmediately(t *testing.T) { + var calls int32 + permErr := errors.New("permission denied") + err := Do(context.Background(), func(ctx context.Context) error { + atomic.AddInt32(&calls, 1) + return permErr + }, DoOpts{ + FileName: "perm", + Config: &Config{ + MaxAttempts: 5, + InitialDelay: 1 * time.Millisecond, + MaxJitter: -1, + }, + }) + if err == nil { + t.Fatal("Do returned nil, want non-retryable error") + } + if got := atomic.LoadInt32(&calls); got != 1 { + t.Errorf("calls = %d, want 1 (non-retryable)", got) + } +} + +// --- Do: per-attempt timeout fires & retries continue --- + +func TestDo_PerAttemptTimeoutTriggersRetry(t *testing.T) { + var calls int32 + const succeededOn int32 = 3 + err := Do(context.Background(), func(ctx context.Context) error { + n := atomic.AddInt32(&calls, 1) + if n < succeededOn { + <-ctx.Done() + return ctx.Err() + } + return nil + }, DoOpts{ + FileName: "slow", + Config: &Config{ + MaxAttempts: 5, + PerAttemptTimeout: 30 * time.Millisecond, + InitialDelay: 1 * time.Millisecond, + MaxBackoff: 1 * time.Millisecond, + MaxJitter: -1, + }, + }) + if err != nil { + t.Fatalf("Do returned %v, want success after retries", err) + } + if got := atomic.LoadInt32(&calls); got != succeededOn { + t.Errorf("calls = %d, want %d", got, succeededOn) + } +} + +// --- Do: per-attempt timeout exhausts after MaxAttempts --- + +func TestDo_PerAttemptTimeoutExhausts(t *testing.T) { + var calls int32 + err := Do(context.Background(), func(ctx context.Context) error { + atomic.AddInt32(&calls, 1) + <-ctx.Done() + return ctx.Err() + }, DoOpts{ + FileName: "always-times-out", + Config: &Config{ + MaxAttempts: 3, + PerAttemptTimeout: 20 * time.Millisecond, + InitialDelay: 1 * time.Millisecond, + MaxBackoff: 1 * time.Millisecond, + MaxJitter: -1, + }, + }) + if err == nil { + t.Fatal("Do returned nil, want error after exhausting attempts") + } + if got := atomic.LoadInt32(&calls); got != 3 { + t.Errorf("calls = %d, want 3", got) + } +} + +// --- Do: parent context cancellation aborts retries --- + +func TestDo_ParentContextCancel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + var calls int32 + go func() { + time.Sleep(20 * time.Millisecond) + cancel() + }() + + err := Do(ctx, func(ctx context.Context) error { + atomic.AddInt32(&calls, 1) + return errors.New("response status code 503") + }, DoOpts{ + FileName: "user-cancels", + Config: &Config{ + MaxAttempts: 100, + InitialDelay: 5 * time.Millisecond, + MaxBackoff: 5 * time.Millisecond, + MaxJitter: -1, + }, + }) + if err == nil { + t.Fatal("Do returned nil, want context cancellation error") + } + if got := atomic.LoadInt32(&calls); got > 50 { + t.Errorf( + "calls = %d, want significantly fewer than MaxAttempts (parent ctx cancelled)", + got, + ) + } +} + +// --- Do: OnRetry callback invoked with 1-based attempt --- + +func TestDo_OnRetryCallback(t *testing.T) { + var attempts []uint + var reasons []string + var calls int32 + err := Do(context.Background(), func(ctx context.Context) error { + n := atomic.AddInt32(&calls, 1) + if n < 3 { + return errors.New("response status code 500") + } + return nil + }, DoOpts{ + FileName: "cb", + Config: &Config{ + MaxAttempts: 5, + InitialDelay: 1 * time.Millisecond, + MaxBackoff: 1 * time.Millisecond, + MaxJitter: -1, + }, + OnRetry: func(attempt uint, reason string, backoff time.Duration) { + attempts = append(attempts, attempt) + reasons = append(reasons, reason) + }, + }) + if err != nil { + t.Fatalf("Do returned %v", err) + } + want := []uint{1, 2} + if fmt.Sprintf("%v", attempts) != fmt.Sprintf("%v", want) { + t.Errorf("attempts = %v, want %v", attempts, want) + } + for _, r := range reasons { + if r != "HTTP 500" { + t.Errorf("reason = %q, want \"HTTP 500\"", r) + } + } +} + +// --- Do: retry budget invariance — same total time regardless of file size --- +// +// The whole point of the redesign: with PerAttemptTimeout disabled and a +// fixed backoff schedule, total wall-clock for retries does not depend on +// FileSize. (The old API's MaxRetryTime scaled with size and changed this.) +func TestDo_RetryBudgetIndependentOfFileSize(t *testing.T) { + measure := func(fileSize int64) time.Duration { + var calls int32 + start := time.Now() + _ = Do(context.Background(), func(ctx context.Context) error { + atomic.AddInt32(&calls, 1) + return errors.New("response status code 503") + }, DoOpts{ + FileName: "size-invariance", + FileSize: fileSize, + Config: &Config{ + MaxAttempts: 4, + PerAttemptTimeout: -1, + InitialDelay: 10 * time.Millisecond, + MaxBackoff: 10 * time.Millisecond, + MaxJitter: -1, + }, + }) + return time.Since(start) + } + + small := measure(1 << 20) // 1 MB + huge := measure(int64(1) << 40) // 1 TB + diff := small - huge + if diff < 0 { + diff = -diff + } + if diff > 100*time.Millisecond { + t.Errorf( + "retry budget varied with file size: small=%v huge=%v (diff=%v)", + small, + huge, + diff, + ) + } +} + +// --- Do: PerAttemptTimeout < 0 disables the deadline --- + +func TestDo_PerAttemptTimeoutDisabled(t *testing.T) { + var seenDeadline bool + err := Do(context.Background(), func(ctx context.Context) error { + _, ok := ctx.Deadline() + seenDeadline = ok + return nil + }, DoOpts{ + FileName: "no-deadline", + FileSize: 1 << 30, + Config: &Config{ + PerAttemptTimeout: -1, + }, + }) + if err != nil { + t.Fatalf("Do returned %v", err) + } + if seenDeadline { + t.Error("attempt context had a deadline; expected none when PerAttemptTimeout < 0") + } +} + +// --- Do: PerAttemptTimeout = 0 derives from file size --- + +func TestDo_PerAttemptTimeoutDerivedFromSize(t *testing.T) { + var seenTimeout time.Duration + const size = int64(50) << 30 // 50 GB → > floor, < ceiling + err := Do(context.Background(), func(ctx context.Context) error { + dl, ok := ctx.Deadline() + if !ok { + t.Fatal("expected deadline") + } + seenTimeout = time.Until(dl) + return nil + }, DoOpts{ + FileName: "derived", + FileSize: size, + }) + if err != nil { + t.Fatalf("Do returned %v", err) + } + want := ComputePerAttemptTimeout(size) + if seenTimeout > want || seenTimeout < want-time.Second { + t.Errorf("derived timeout = %v, want ~%v", seenTimeout, want) + } +} + +// --- humanizeBytes --- + +func TestHumanizeBytes(t *testing.T) { + tests := []struct { + size int64 + want string + }{ + {0, "0 B"}, + {500, "500 B"}, + {2048, "2.0 KB"}, + {int64(5) << 20, "5.0 MB"}, + {int64(7) << 30, "7.0 GB"}, + } + for _, tt := range tests { + t.Run(tt.want, func(t *testing.T) { + if got := humanizeBytes(tt.size); got != tt.want { + t.Errorf("humanizeBytes(%d) = %q, want %q", tt.size, got, tt.want) + } + }) + } +}