From 4e658584ed1ada9d67452a665df75fe237a1e1ca Mon Sep 17 00:00:00 2001 From: Zhao Chen Date: Mon, 23 Mar 2026 18:15:17 +0800 Subject: [PATCH 01/16] feat(retrypolicy): add dynamic retry package with file-size-based parameters Signed-off-by: Zhao Chen --- pkg/retrypolicy/retrypolicy.go | 276 +++++++++++++++ pkg/retrypolicy/retrypolicy_test.go | 520 ++++++++++++++++++++++++++++ 2 files changed, 796 insertions(+) create mode 100644 pkg/retrypolicy/retrypolicy.go create mode 100644 pkg/retrypolicy/retrypolicy_test.go diff --git a/pkg/retrypolicy/retrypolicy.go b/pkg/retrypolicy/retrypolicy.go new file mode 100644 index 00000000..b7761712 --- /dev/null +++ b/pkg/retrypolicy/retrypolicy.go @@ -0,0 +1,276 @@ +/* + * Copyright 2024 The CNAI Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package retrypolicy + +import ( + "context" + "errors" + "fmt" + "math" + "regexp" + "strings" + "time" + + retry "github.com/avast/retry-go/v4" + log "github.com/sirupsen/logrus" +) + +const ( + oneGB = 1 << 30 // 1 GiB in bytes + tenGB = 10 << 30 + nineGB = tenGB - oneGB + + minMaxRetryTime = 10 * time.Minute + maxMaxRetryTime = 60 * time.Minute + + minMaxBackoff = 1 * time.Minute + maxMaxBackoff = 10 * time.Minute + + initialDelay = 5 * time.Second + maxJitter = 5 * time.Second + + // maxBackoff cap when derived from user-specified MaxRetryTime + absoluteMaxBackoff = 10 * time.Minute +) + +// Config holds user-configurable retry parameters from CLI flags. +type Config struct { + MaxRetryTime time.Duration // 0 = dynamic based on file size + NoRetry bool // disable retry entirely +} + +// DoOpts configures a single Do call. +type DoOpts struct { + FileSize int64 // for dynamic parameter calculation + FileName string // for logging + Config *Config + OnRetry func(attempt uint, reason string, backoff time.Duration) +} + +// Do executes fn with retry. It computes dynamic retry parameters from fileSize, +// creates an internal deadline context (and defers its cancel to prevent leak), +// sets up retry logging, and calls retry.Do. +// The parent ctx is only used for user-initiated cancellation. +func Do(ctx context.Context, fn func(ctx context.Context) error, opts DoOpts) error { + cfg := opts.Config + if cfg == nil { + cfg = &Config{} + } + + // NoRetry: call fn once with the parent context, return the result. + if cfg.NoRetry { + return fn(ctx) + } + + maxRetryTime, maxBackoff := computeDynamicParams(opts.FileSize) + + // Override with user-specified MaxRetryTime if set. + if cfg.MaxRetryTime > 0 { + maxRetryTime = cfg.MaxRetryTime + maxBackoff = cfg.MaxRetryTime / 6 + if maxBackoff > absoluteMaxBackoff { + maxBackoff = absoluteMaxBackoff + } + } + + startTime := time.Now() + deadlineCtx, deadlineCancel := context.WithDeadline(ctx, startTime.Add(maxRetryTime)) + defer deadlineCancel() + + sizeStr := humanizeBytes(opts.FileSize) + + return retry.Do( + func() error { + return fn(deadlineCtx) + }, + retry.Attempts(0), + retry.Context(deadlineCtx), + retry.DelayType(retry.BackOffDelay), + retry.Delay(initialDelay), + retry.MaxDelay(maxBackoff), + retry.MaxJitter(maxJitter), + retry.LastErrorOnly(true), + retry.WrapContextErrorWithLastError(true), + retry.RetryIf(func(err error) bool { + retryable := IsRetryable(err) + if !retryable { + log.WithFields(log.Fields{ + "file": opts.FileName, + "size": sizeStr, + "error": err.Error(), + }).Error("[RETRY] non-retryable error, not retrying") + } + return retryable + }), + retry.OnRetry(func(n uint, err error) { + backoff := computeBackoff(n+1, initialDelay, maxBackoff) + elapsed := time.Since(startTime) + + log.WithFields(log.Fields{ + "file": opts.FileName, + "size": sizeStr, + "error": err.Error(), + "max_retry_time": maxRetryTime.String(), + "max_backoff": maxBackoff.String(), + "next_retry_in": backoff.Truncate(time.Second).String(), + "elapsed": fmt.Sprintf("%s / %s", elapsed.Truncate(time.Second), maxRetryTime), + }).Warnf("[RETRY] attempt %d for %q (%s)", n+1, opts.FileName, sizeStr) + + if opts.OnRetry != nil { + reason := ShortReason(err) + opts.OnRetry(n+1, reason, backoff) + } + }), + ) +} + +// computeDynamicParams calculates maxRetryTime and maxBackoff based on file size. +// +// For files <= 1 GB: maxRetryTime=10min, maxBackoff=1min +// For files >= 10 GB: maxRetryTime=60min, maxBackoff=10min +// Linear interpolation between. +func computeDynamicParams(fileSize int64) (time.Duration, time.Duration) { + ratio := float64(fileSize-oneGB) / float64(nineGB) + if ratio < 0 { + ratio = 0 + } + if ratio > 1 { + ratio = 1 + } + + maxRetryTime := minMaxRetryTime + time.Duration(ratio*float64(maxMaxRetryTime-minMaxRetryTime)) + maxBackoff := minMaxBackoff + time.Duration(ratio*float64(maxMaxBackoff-minMaxBackoff)) + + return maxRetryTime, maxBackoff +} + +// computeBackoff estimates the backoff duration for display purposes. +// It mirrors the exponential backoff calculation without jitter. +func computeBackoff(attempt uint, initial, maxDelay time.Duration) time.Duration { + if attempt == 0 { + return initial + } + backoff := time.Duration(float64(initial) * math.Pow(2, float64(attempt-1))) + if backoff > maxDelay { + backoff = maxDelay + } + return backoff +} + +// httpStatusPattern matches ORAS-style error messages that embed HTTP status codes. +var httpStatusPattern = regexp.MustCompile(`response status code (\d{3})`) + +// IsRetryable returns true for transient errors that warrant a retry. +func IsRetryable(err error) bool { + if err == nil { + return false + } + + // context.Canceled is never retryable — it means user/system cancellation. + if errors.Is(err, context.Canceled) { + return false + } + + errMsg := err.Error() + + // Check for HTTP status codes embedded in error messages (ORAS style). + if matches := httpStatusPattern.FindStringSubmatch(errMsg); len(matches) == 2 { + code := matches[1] + // 5xx server errors are retryable. + if code[0] == '5' { + return true + } + // 408 (Request Timeout) and 429 (Too Many Requests) are retryable. + if code == "408" || code == "429" { + return true + } + // Other 4xx are not retryable (401, 403, 404, etc.) + return false + } + + // Network-level transient errors. + if strings.Contains(errMsg, "i/o timeout") { + return true + } + if strings.Contains(errMsg, "connection reset by peer") { + return true + } + if strings.Contains(errMsg, "connection refused") { + return true + } + if strings.Contains(errMsg, "broken pipe") { + return true + } + if strings.Contains(errMsg, "EOF") { + return true + } + + // Unknown errors default to retryable. + return true +} + +// ShortReason extracts a brief human-readable label from an error for progress bar display. +func ShortReason(err error) string { + if err == nil { + return "" + } + + errMsg := err.Error() + + // Check for HTTP status codes. + if matches := httpStatusPattern.FindStringSubmatch(errMsg); len(matches) == 2 { + return "HTTP " + matches[1] + } + + if strings.Contains(errMsg, "i/o timeout") { + return "i/o timeout" + } + if strings.Contains(errMsg, "connection reset by peer") { + return "conn reset" + } + if strings.Contains(errMsg, "connection refused") { + return "conn refused" + } + if strings.Contains(errMsg, "broken pipe") { + return "broken pipe" + } + if strings.Contains(errMsg, "EOF") { + return "EOF" + } + + return "unknown error" +} + +// humanizeBytes converts a byte count to a human-readable string. +func humanizeBytes(b int64) string { + const ( + kb = 1024 + mb = 1024 * kb + gb = 1024 * mb + ) + + switch { + case b >= gb: + return fmt.Sprintf("%.1f GB", float64(b)/float64(gb)) + case b >= mb: + return fmt.Sprintf("%.1f MB", float64(b)/float64(mb)) + case b >= kb: + return fmt.Sprintf("%.1f KB", float64(b)/float64(kb)) + default: + return fmt.Sprintf("%d B", b) + } +} diff --git a/pkg/retrypolicy/retrypolicy_test.go b/pkg/retrypolicy/retrypolicy_test.go new file mode 100644 index 00000000..3e29226f --- /dev/null +++ b/pkg/retrypolicy/retrypolicy_test.go @@ -0,0 +1,520 @@ +/* + * Copyright 2024 The CNAI Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package retrypolicy + +import ( + "context" + "errors" + "fmt" + "net" + "sync/atomic" + "testing" + "time" +) + +// --- helpers for tests --- + +// timeoutError implements net.Error with Timeout() returning true. +type timeoutError struct { + msg string +} + +func (e *timeoutError) Error() string { return e.msg } +func (e *timeoutError) Timeout() bool { return true } +func (e *timeoutError) Temporary() bool { return true } + +// --- computeDynamicParams tests --- + +func TestComputeDynamicParams(t *testing.T) { + tests := []struct { + name string + fileSize int64 + wantRetryTime time.Duration + wantMaxBackoff time.Duration + }{ + { + name: "zero bytes - clamped to minimum", + fileSize: 0, + wantRetryTime: 10 * time.Minute, + wantMaxBackoff: 1 * time.Minute, + }, + { + name: "500 MB - below 1 GB, clamped to minimum", + fileSize: 500 * 1024 * 1024, + wantRetryTime: 10 * time.Minute, + wantMaxBackoff: 1 * time.Minute, + }, + { + name: "1 GB - boundary, ratio=0, minimum values", + fileSize: 1 << 30, + wantRetryTime: 10 * time.Minute, + wantMaxBackoff: 1 * time.Minute, + }, + { + name: "5.5 GB - midpoint, interpolated values", + fileSize: int64(5.5 * float64(1<<30)), + wantRetryTime: 35 * time.Minute, + wantMaxBackoff: 5*time.Minute + 30*time.Second, + }, + { + name: "10 GB - boundary, ratio=1, maximum values", + fileSize: 10 << 30, + wantRetryTime: 60 * time.Minute, + wantMaxBackoff: 10 * time.Minute, + }, + { + name: "20 GB - above 10 GB, clamped to maximum", + fileSize: 20 << 30, + wantRetryTime: 60 * time.Minute, + wantMaxBackoff: 10 * time.Minute, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotRetryTime, gotMaxBackoff := computeDynamicParams(tt.fileSize) + + // Allow 1 second tolerance for floating point interpolation. + retryTimeDiff := absDuration(gotRetryTime - tt.wantRetryTime) + if retryTimeDiff > time.Second { + t.Errorf("maxRetryTime = %v, want %v (diff %v)", gotRetryTime, tt.wantRetryTime, retryTimeDiff) + } + + backoffDiff := absDuration(gotMaxBackoff - tt.wantMaxBackoff) + if backoffDiff > time.Second { + t.Errorf("maxBackoff = %v, want %v (diff %v)", gotMaxBackoff, tt.wantMaxBackoff, backoffDiff) + } + }) + } +} + +func absDuration(d time.Duration) time.Duration { + if d < 0 { + return -d + } + return d +} + +// --- IsRetryable tests --- + +func TestIsRetryable(t *testing.T) { + tests := []struct { + name string + err error + want bool + }{ + { + name: "nil error", + err: nil, + want: false, + }, + { + name: "context.Canceled", + err: context.Canceled, + want: false, + }, + { + name: "wrapped context.Canceled", + err: fmt.Errorf("operation failed: %w", context.Canceled), + want: false, + }, + { + name: "HTTP 500 server error", + err: fmt.Errorf("PUT /blobs/uploads: response status code 500: internal server error"), + want: true, + }, + { + name: "HTTP 502 bad gateway", + err: fmt.Errorf("response status code 502: bad gateway"), + want: true, + }, + { + name: "HTTP 503 service unavailable", + err: fmt.Errorf("response status code 503: service unavailable"), + want: true, + }, + { + name: "HTTP 408 request timeout", + err: fmt.Errorf("response status code 408: request timeout"), + want: true, + }, + { + name: "HTTP 429 too many requests", + err: fmt.Errorf("response status code 429: too many requests"), + want: true, + }, + { + name: "HTTP 401 unauthorized - not retryable", + err: fmt.Errorf("response status code 401: unauthorized"), + want: false, + }, + { + name: "HTTP 403 forbidden - not retryable", + err: fmt.Errorf("response status code 403: access denied"), + want: false, + }, + { + name: "HTTP 404 not found - not retryable", + err: fmt.Errorf("response status code 404: not found"), + want: false, + }, + { + name: "i/o timeout", + err: &net.OpError{ + Op: "read", + Net: "tcp", + Err: &timeoutError{msg: "i/o timeout"}, + }, + want: true, + }, + { + name: "i/o timeout in wrapped error message", + err: fmt.Errorf("read tcp 10.0.0.1:1234->10.0.0.2:443: i/o timeout"), + want: true, + }, + { + name: "connection reset by peer", + err: fmt.Errorf("read tcp: connection reset by peer"), + want: true, + }, + { + name: "connection refused", + err: fmt.Errorf("dial tcp 10.0.0.1:443: connection refused"), + want: true, + }, + { + name: "broken pipe", + err: fmt.Errorf("write tcp: broken pipe"), + want: true, + }, + { + name: "EOF", + err: fmt.Errorf("unexpected EOF"), + want: true, + }, + { + name: "unknown error - defaults to retryable", + err: errors.New("something totally unexpected happened"), + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := IsRetryable(tt.err) + if got != tt.want { + t.Errorf("IsRetryable(%v) = %v, want %v", tt.err, got, tt.want) + } + }) + } +} + +// --- ShortReason tests --- + +func TestShortReason(t *testing.T) { + tests := []struct { + name string + err error + want string + }{ + { + name: "nil error", + err: nil, + want: "", + }, + { + name: "HTTP 500", + err: fmt.Errorf("response status code 500: internal server error"), + want: "HTTP 500", + }, + { + name: "HTTP 502", + err: fmt.Errorf("response status code 502: bad gateway"), + want: "HTTP 502", + }, + { + name: "HTTP 429", + err: fmt.Errorf("response status code 429: too many requests"), + want: "HTTP 429", + }, + { + name: "HTTP 408", + err: fmt.Errorf("response status code 408: request timeout"), + want: "HTTP 408", + }, + { + name: "i/o timeout", + err: fmt.Errorf("read tcp: i/o timeout"), + want: "i/o timeout", + }, + { + name: "connection reset", + err: fmt.Errorf("read tcp: connection reset by peer"), + want: "conn reset", + }, + { + name: "connection refused", + err: fmt.Errorf("dial tcp: connection refused"), + want: "conn refused", + }, + { + name: "broken pipe", + err: fmt.Errorf("write tcp: broken pipe"), + want: "broken pipe", + }, + { + name: "EOF", + err: fmt.Errorf("unexpected EOF"), + want: "EOF", + }, + { + name: "unknown error", + err: errors.New("some weird error"), + want: "unknown error", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ShortReason(tt.err) + if got != tt.want { + t.Errorf("ShortReason(%v) = %q, want %q", tt.err, got, tt.want) + } + }) + } +} + +// --- Do tests --- + +func TestDo_SuccessFirstAttempt(t *testing.T) { + callCount := int32(0) + err := Do(context.Background(), func(ctx context.Context) error { + atomic.AddInt32(&callCount, 1) + return nil + }, DoOpts{ + FileSize: 100, + FileName: "test.bin", + }) + + if err != nil { + t.Fatalf("expected nil error, got %v", err) + } + if atomic.LoadInt32(&callCount) != 1 { + t.Fatalf("expected fn to be called once, got %d", callCount) + } +} + +func TestDo_RetryOnTransientError(t *testing.T) { + callCount := int32(0) + err := Do(context.Background(), func(ctx context.Context) error { + n := atomic.AddInt32(&callCount, 1) + if n < 3 { + return fmt.Errorf("response status code 500: internal server error") + } + return nil + }, DoOpts{ + FileSize: 100, + FileName: "test.bin", + }) + + if err != nil { + t.Fatalf("expected nil error after retries, got %v", err) + } + if atomic.LoadInt32(&callCount) != 3 { + t.Fatalf("expected fn to be called 3 times, got %d", callCount) + } +} + +func TestDo_NoRetryOnPermanentError(t *testing.T) { + callCount := int32(0) + err := Do(context.Background(), func(ctx context.Context) error { + atomic.AddInt32(&callCount, 1) + return fmt.Errorf("response status code 404: not found") + }, DoOpts{ + FileSize: 100, + FileName: "test.bin", + }) + + if err == nil { + t.Fatal("expected error, got nil") + } + if atomic.LoadInt32(&callCount) != 1 { + t.Fatalf("expected fn to be called once for non-retryable error, got %d", callCount) + } +} + +func TestDo_NoRetryConfig(t *testing.T) { + callCount := int32(0) + err := Do(context.Background(), func(ctx context.Context) error { + atomic.AddInt32(&callCount, 1) + return fmt.Errorf("response status code 500: internal server error") + }, DoOpts{ + FileSize: 100, + FileName: "test.bin", + Config: &Config{NoRetry: true}, + }) + + if err == nil { + t.Fatal("expected error, got nil") + } + if atomic.LoadInt32(&callCount) != 1 { + t.Fatalf("expected fn to be called once with NoRetry, got %d", callCount) + } +} + +func TestDo_ParentContextCancel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + callCount := int32(0) + err := Do(ctx, func(ctx context.Context) error { + n := atomic.AddInt32(&callCount, 1) + if n == 1 { + // Cancel the parent context after the first attempt. + cancel() + return fmt.Errorf("response status code 500: internal server error") + } + return nil + }, DoOpts{ + FileSize: 100, + FileName: "test.bin", + }) + + if err == nil { + t.Fatal("expected error from cancelled context, got nil") + } +} + +func TestDo_OnRetryCallback(t *testing.T) { + var retryAttempts []uint + var retryReasons []string + + callCount := int32(0) + err := Do(context.Background(), func(ctx context.Context) error { + n := atomic.AddInt32(&callCount, 1) + if n < 3 { + return fmt.Errorf("response status code 500: internal server error") + } + return nil + }, DoOpts{ + FileSize: 100, + FileName: "test.bin", + OnRetry: func(attempt uint, reason string, backoff time.Duration) { + retryAttempts = append(retryAttempts, attempt) + retryReasons = append(retryReasons, reason) + }, + }) + + if err != nil { + t.Fatalf("expected nil error, got %v", err) + } + if len(retryAttempts) != 2 { + t.Fatalf("expected 2 OnRetry calls, got %d", len(retryAttempts)) + } + if retryAttempts[0] != 1 || retryAttempts[1] != 2 { + t.Errorf("expected attempts [1, 2], got %v", retryAttempts) + } + if retryReasons[0] != "HTTP 500" || retryReasons[1] != "HTTP 500" { + t.Errorf("expected reasons [HTTP 500, HTTP 500], got %v", retryReasons) + } +} + +func TestDo_ConfigMaxRetryTimeOverride(t *testing.T) { + // Use a very short MaxRetryTime to ensure the retry loop terminates quickly. + callCount := int32(0) + start := time.Now() + err := Do(context.Background(), func(ctx context.Context) error { + atomic.AddInt32(&callCount, 1) + return fmt.Errorf("response status code 500: internal server error") + }, DoOpts{ + FileSize: 100, + FileName: "test.bin", + Config: &Config{ + MaxRetryTime: 8 * time.Second, + }, + }) + + elapsed := time.Since(start) + + if err == nil { + t.Fatal("expected error after retry timeout, got nil") + } + // Should have run for approximately MaxRetryTime (8s), not the dynamic default (10min). + if elapsed > 30*time.Second { + t.Errorf("expected retry to terminate within ~8s, but elapsed %v", elapsed) + } + if atomic.LoadInt32(&callCount) < 2 { + t.Errorf("expected at least 2 attempts, got %d", callCount) + } +} + +// --- humanizeBytes tests --- + +func TestHumanizeBytes(t *testing.T) { + tests := []struct { + input int64 + want string + }{ + {0, "0 B"}, + {512, "512 B"}, + {1024, "1.0 KB"}, + {1536, "1.5 KB"}, + {1048576, "1.0 MB"}, + {1073741824, "1.0 GB"}, + {int64(5.5 * float64(1<<30)), "5.5 GB"}, + } + + for _, tt := range tests { + t.Run(tt.want, func(t *testing.T) { + got := humanizeBytes(tt.input) + if got != tt.want { + t.Errorf("humanizeBytes(%d) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +// --- computeBackoff tests --- + +func TestComputeBackoff(t *testing.T) { + initial := 5 * time.Second + maxDelay := 1 * time.Minute + + // attempt 0 => initial + if got := computeBackoff(0, initial, maxDelay); got != initial { + t.Errorf("attempt 0: got %v, want %v", got, initial) + } + + // attempt 1 => 5s * 2^0 = 5s + if got := computeBackoff(1, initial, maxDelay); got != 5*time.Second { + t.Errorf("attempt 1: got %v, want %v", got, 5*time.Second) + } + + // attempt 2 => 5s * 2^1 = 10s + if got := computeBackoff(2, initial, maxDelay); got != 10*time.Second { + t.Errorf("attempt 2: got %v, want %v", got, 10*time.Second) + } + + // attempt 4 => 5s * 2^3 = 40s + if got := computeBackoff(4, initial, maxDelay); got != 40*time.Second { + t.Errorf("attempt 4: got %v, want %v", got, 40*time.Second) + } + + // attempt 5 => 5s * 2^4 = 80s, but capped at 60s + if got := computeBackoff(5, initial, maxDelay); got != maxDelay { + t.Errorf("attempt 5: got %v, want %v (capped)", got, maxDelay) + } +} From 92c9ca3abcba98f88ad88004482303745b0fdfde Mon Sep 17 00:00:00 2001 From: Zhao Chen Date: Mon, 23 Mar 2026 18:21:21 +0800 Subject: [PATCH 02/16] feat(cli): add --no-retry and --retry-max-time flags to push/pull/build/fetch Signed-off-by: Zhao Chen --- cmd/build.go | 2 ++ cmd/fetch.go | 2 ++ cmd/pull.go | 2 ++ cmd/push.go | 2 ++ pkg/config/build.go | 7 ++++++- pkg/config/fetch.go | 3 +++ pkg/config/pull.go | 3 +++ pkg/config/push.go | 7 ++++++- 8 files changed, 26 insertions(+), 2 deletions(-) diff --git a/cmd/build.go b/cmd/build.go index c41c0628..e4284fee 100644 --- a/cmd/build.go +++ b/cmd/build.go @@ -61,6 +61,8 @@ func init() { flags.BoolVar(&buildConfig.Raw, "raw", true, "turning on this flag will build model artifact layers in raw format") flags.BoolVar(&buildConfig.Reasoning, "reasoning", false, "turning on this flag will mark this model as reasoning model in the config") flags.BoolVar(&buildConfig.NoCreationTime, "no-creation-time", false, "turning on this flag will not set createdAt in the config, which will be helpful for repeated builds") + flags.BoolVar(&buildConfig.RetryConfig.NoRetry, "no-retry", false, "Disable retry on transient errors") + flags.DurationVar(&buildConfig.RetryConfig.MaxRetryTime, "retry-max-time", 0, "Max total retry time per file (0 = dynamic based on file size)") if err := viper.BindPFlags(flags); err != nil { panic(fmt.Errorf("bind build flags to viper: %w", err)) diff --git a/cmd/fetch.go b/cmd/fetch.go index e13de379..3a2cf661 100644 --- a/cmd/fetch.go +++ b/cmd/fetch.go @@ -55,6 +55,8 @@ func init() { flags.StringVar(&fetchConfig.Output, "output", "", "specify the directory for fetching the model artifact") flags.StringSliceVar(&fetchConfig.Patterns, "patterns", []string{}, "specify the patterns for fetching the model artifact") flags.StringVar(&fetchConfig.DragonflyEndpoint, "dragonfly-endpoint", "", "specify the dragonfly endpoint for the pull operation, which will download and hardlink the blob by dragonfly GRPC service.") + flags.BoolVar(&fetchConfig.RetryConfig.NoRetry, "no-retry", false, "Disable retry on transient errors") + flags.DurationVar(&fetchConfig.RetryConfig.MaxRetryTime, "retry-max-time", 0, "Max total retry time per file (0 = dynamic based on file size)") if err := viper.BindPFlags(flags); err != nil { panic(fmt.Errorf("bind fetch flags to viper: %w", err)) diff --git a/cmd/pull.go b/cmd/pull.go index f4df1fa9..f6912c9d 100644 --- a/cmd/pull.go +++ b/cmd/pull.go @@ -55,6 +55,8 @@ func init() { flags.StringVar(&pullConfig.ExtractDir, "extract-dir", "", "specify the extract dir for extracting the model artifact") flags.BoolVar(&pullConfig.ExtractFromRemote, "extract-from-remote", false, "turning on this flag will pull and extract the data from remote registry and no longer store model artifact locally, so user must specify extract-dir as the output directory") flags.StringVar(&pullConfig.DragonflyEndpoint, "dragonfly-endpoint", "", "specify the dragonfly endpoint for the pull operation, which will download and hardlink the blob by dragonfly GRPC service, this mode requires extract-from-remote must be true") + flags.BoolVar(&pullConfig.RetryConfig.NoRetry, "no-retry", false, "Disable retry on transient errors") + flags.DurationVar(&pullConfig.RetryConfig.MaxRetryTime, "retry-max-time", 0, "Max total retry time per file (0 = dynamic based on file size)") if err := viper.BindPFlags(flags); err != nil { panic(fmt.Errorf("bind pull flags to viper: %w", err)) diff --git a/cmd/push.go b/cmd/push.go index da26cb18..4f5f5961 100644 --- a/cmd/push.go +++ b/cmd/push.go @@ -53,6 +53,8 @@ func init() { flags.BoolVar(&pushConfig.Insecure, "insecure", false, "turning on this flag will disable TLS verification") flags.BoolVar(&pushConfig.Nydusify, "nydusify", false, "[EXPERIMENTAL] nydusify the model artifact") flags.MarkHidden("nydusify") + flags.BoolVar(&pushConfig.RetryConfig.NoRetry, "no-retry", false, "Disable retry on transient errors") + flags.DurationVar(&pushConfig.RetryConfig.MaxRetryTime, "retry-max-time", 0, "Max total retry time per file (0 = dynamic based on file size)") if err := viper.BindPFlags(flags); err != nil { panic(fmt.Errorf("bind push flags to viper: %w", err)) diff --git a/pkg/config/build.go b/pkg/config/build.go index 27dd23d5..861cc002 100644 --- a/pkg/config/build.go +++ b/pkg/config/build.go @@ -16,7 +16,11 @@ package config -import "fmt" +import ( + "fmt" + + "github.com/modelpack/modctl/pkg/retrypolicy" +) const ( // defaultBuildConcurrency is the default number of concurrent builds. @@ -36,6 +40,7 @@ type Build struct { Raw bool Reasoning bool NoCreationTime bool + RetryConfig retrypolicy.Config } func NewBuild() *Build { diff --git a/pkg/config/fetch.go b/pkg/config/fetch.go index 472bb6a7..7fb8f899 100644 --- a/pkg/config/fetch.go +++ b/pkg/config/fetch.go @@ -20,6 +20,8 @@ import ( "fmt" "io" "os" + + "github.com/modelpack/modctl/pkg/retrypolicy" ) const ( @@ -38,6 +40,7 @@ type Fetch struct { ProgressWriter io.Writer DisableProgress bool Hooks PullHooks + RetryConfig retrypolicy.Config } func NewFetch() *Fetch { diff --git a/pkg/config/pull.go b/pkg/config/pull.go index 6d33715a..3c290e2c 100644 --- a/pkg/config/pull.go +++ b/pkg/config/pull.go @@ -22,6 +22,8 @@ import ( "os" ocispec "github.com/opencontainers/image-spec/specs-go/v1" + + "github.com/modelpack/modctl/pkg/retrypolicy" ) const ( @@ -40,6 +42,7 @@ type Pull struct { ProgressWriter io.Writer DisableProgress bool DragonflyEndpoint string + RetryConfig retrypolicy.Config } func NewPull() *Pull { diff --git a/pkg/config/push.go b/pkg/config/push.go index c5fa8e8f..9ba596bd 100644 --- a/pkg/config/push.go +++ b/pkg/config/push.go @@ -16,7 +16,11 @@ package config -import "fmt" +import ( + "fmt" + + "github.com/modelpack/modctl/pkg/retrypolicy" +) const ( // defaultPushConcurrency is the default number of concurrent push operations. @@ -28,6 +32,7 @@ type Push struct { PlainHTTP bool Insecure bool Nydusify bool + RetryConfig retrypolicy.Config } func NewPush() *Push { From cf1bcf24c316530796c791dcffca3b66139be141 Mon Sep 17 00:00:00 2001 From: Zhao Chen Date: Mon, 23 Mar 2026 18:27:02 +0800 Subject: [PATCH 03/16] feat(backend): use retrypolicy.Do with independent retry, remove cascading cancellation Signed-off-by: Zhao Chen --- internal/pb/pb.go | 22 +++++++++++ pkg/backend/build.go | 24 ++++++++---- pkg/backend/fetch.go | 44 ++++++++++++++++++--- pkg/backend/fetch_by_d7y.go | 44 +++++++++++++++++---- pkg/backend/processor/base.go | 39 +++++++++++-------- pkg/backend/processor/options.go | 16 ++++---- pkg/backend/pull.go | 66 ++++++++++++++++++++++++-------- pkg/backend/pull_by_d7y.go | 44 +++++++++++++++++---- pkg/backend/push.go | 58 ++++++++++++++++++++-------- 9 files changed, 274 insertions(+), 83 deletions(-) diff --git a/internal/pb/pb.go b/internal/pb/pb.go index 8303a968..3070eb36 100644 --- a/internal/pb/pb.go +++ b/internal/pb/pb.go @@ -136,6 +136,28 @@ func (p *ProgressBar) Add(prompt, name string, size int64, reader io.Reader) io. return reader } +// Placeholder creates or resets a progress bar entry without a reader. +// It is used during retry backoff to keep a visible bar for the item. +func (p *ProgressBar) Placeholder(name string, prompt string, size int64) { + if disableProgress { + return + } + + p.mu.RLock() + existing := p.bars[name] + p.mu.RUnlock() + + // If the bar already exists, just reset its message. + if existing != nil { + existing.msg = fmt.Sprintf("%s %s", prompt, name) + existing.Bar.SetCurrent(0) + return + } + + // Create a new placeholder bar. + p.Add(prompt, name, size, nil) +} + // Get returns the progress bar. func (p *ProgressBar) Get(name string) *progressBar { p.mu.RLock() diff --git a/pkg/backend/build.go b/pkg/backend/build.go index 10f9156c..346c44e4 100644 --- a/pkg/backend/build.go +++ b/pkg/backend/build.go @@ -23,7 +23,6 @@ import ( "os" "path/filepath" - retry "github.com/avast/retry-go/v4" modelspec "github.com/modelpack/model-spec/specs-go/v1" ocispec "github.com/opencontainers/image-spec/specs-go/v1" "github.com/sirupsen/logrus" @@ -35,6 +34,7 @@ import ( "github.com/modelpack/modctl/pkg/backend/processor" "github.com/modelpack/modctl/pkg/config" "github.com/modelpack/modctl/pkg/modelfile" + "github.com/modelpack/modctl/pkg/retrypolicy" "github.com/modelpack/modctl/pkg/source" ) @@ -123,8 +123,8 @@ func (b *backend) Build(ctx context.Context, modelfilePath, workDir, target stri var configDesc ocispec.Descriptor // Build the model config. - if err := retry.Do(func() error { - configDesc, err = builder.BuildConfig(ctx, config, hooks.NewHooks( + if err := retrypolicy.Do(ctx, func(rctx context.Context) error { + configDesc, err = builder.BuildConfig(rctx, config, hooks.NewHooks( hooks.WithOnStart(func(name string, size int64, reader io.Reader) io.Reader { return pb.Add(internalpb.NormalizePrompt("Building config"), name, size, reader) }), @@ -136,13 +136,17 @@ func (b *backend) Build(ctx context.Context, modelfilePath, workDir, target stri }), )) return err - }, append(defaultRetryOpts, retry.Context(ctx))...); err != nil { + }, retrypolicy.DoOpts{ + FileSize: 0, // config is small + FileName: "config", + Config: &cfg.RetryConfig, + }); err != nil { return fmt.Errorf("failed to build model config: %w", err) } // Build the model manifest. - if err := retry.Do(func() error { - _, err = builder.BuildManifest(ctx, layers, configDesc, manifestAnnotation(modelfile), hooks.NewHooks( + if err := retrypolicy.Do(ctx, func(rctx context.Context) error { + _, err = builder.BuildManifest(rctx, layers, configDesc, manifestAnnotation(modelfile), hooks.NewHooks( hooks.WithOnStart(func(name string, size int64, reader io.Reader) io.Reader { return pb.Add(internalpb.NormalizePrompt("Building manifest"), name, size, reader) }), @@ -154,7 +158,11 @@ func (b *backend) Build(ctx context.Context, modelfilePath, workDir, target stri }), )) return err - }, append(defaultRetryOpts, retry.Context(ctx))...); err != nil { + }, retrypolicy.DoOpts{ + FileSize: 0, // manifest is small + FileName: "manifest", + Config: &cfg.RetryConfig, + }); err != nil { return fmt.Errorf("failed to build model manifest: %w", err) } @@ -204,7 +212,7 @@ func (b *backend) getProcessors(modelfile modelfile.Modelfile, cfg *config.Build func (b *backend) process(ctx context.Context, builder build.Builder, workDir string, pb *internalpb.ProgressBar, cfg *config.Build, processors ...processor.Processor) ([]ocispec.Descriptor, error) { descriptors := []ocispec.Descriptor{} for _, p := range processors { - descs, err := p.Process(ctx, builder, workDir, processor.WithConcurrency(cfg.Concurrency), processor.WithProgressTracker(pb)) + descs, err := p.Process(ctx, builder, workDir, processor.WithConcurrency(cfg.Concurrency), processor.WithProgressTracker(pb), processor.WithRetryConfig(cfg.RetryConfig)) if err != nil { return nil, err } diff --git a/pkg/backend/fetch.go b/pkg/backend/fetch.go index 990d1afa..37b88678 100644 --- a/pkg/backend/fetch.go +++ b/pkg/backend/fetch.go @@ -19,7 +19,10 @@ package backend import ( "context" "encoding/json" + "errors" "fmt" + "sync" + "time" "github.com/bmatcuk/doublestar/v4" legacymodelspec "github.com/dragonflyoss/model-spec/specs-go/v1" @@ -31,6 +34,7 @@ import ( internalpb "github.com/modelpack/modctl/internal/pb" "github.com/modelpack/modctl/pkg/backend/remote" "github.com/modelpack/modctl/pkg/config" + "github.com/modelpack/modctl/pkg/retrypolicy" ) // Fetch fetches partial files to the output. @@ -101,9 +105,12 @@ func (b *backend) Fetch(ctx context.Context, target string, cfg *config.Fetch) e pb.Start() defer pb.Stop() - g, ctx := errgroup.WithContext(ctx) + g := new(errgroup.Group) g.SetLimit(cfg.Concurrency) + var mu sync.Mutex + var errs []error + logrus.Infof("fetch: fetching %d matched layers", len(layers)) for _, layer := range layers { g.Go(func() error { @@ -113,17 +120,42 @@ func (b *backend) Fetch(ctx context.Context, target string, cfg *config.Fetch) e default: } - logrus.Debugf("fetch: processing layer %s", layer.Digest) - if err := pullAndExtractFromRemote(ctx, pb, internalpb.NormalizePrompt("Fetching blob"), client, cfg.Output, layer); err != nil { - return err + var annoFilepath string + if layer.Annotations != nil { + if layer.Annotations[modelspec.AnnotationFilepath] != "" { + annoFilepath = layer.Annotations[modelspec.AnnotationFilepath] + } else { + annoFilepath = layer.Annotations[legacymodelspec.AnnotationFilepath] + } } - logrus.Debugf("fetch: successfully processed layer %s", layer.Digest) + logrus.Debugf("fetch: processing layer %s", layer.Digest) + if err := retrypolicy.Do(ctx, func(ctx context.Context) error { + return pullAndExtractFromRemote(ctx, pb, internalpb.NormalizePrompt("Fetching blob"), client, cfg.Output, layer) + }, retrypolicy.DoOpts{ + FileSize: layer.Size, + FileName: annoFilepath, + Config: &cfg.RetryConfig, + OnRetry: func(attempt uint, reason string, backoff time.Duration) { + if bar := pb.Get(layer.Digest.String()); bar != nil { + bar.SetRefill(bar.Current()) + bar.SetCurrent(0) + bar.EwmaSetCurrent(0, time.Second) + } + }, + }); err != nil { + mu.Lock() + errs = append(errs, err) + mu.Unlock() + } else { + logrus.Debugf("fetch: successfully processed layer %s", layer.Digest) + } return nil }) } - if err := g.Wait(); err != nil { + _ = g.Wait() + if err := errors.Join(errs...); err != nil { return err } diff --git a/pkg/backend/fetch_by_d7y.go b/pkg/backend/fetch_by_d7y.go index 098a3a53..94c24d32 100644 --- a/pkg/backend/fetch_by_d7y.go +++ b/pkg/backend/fetch_by_d7y.go @@ -19,15 +19,17 @@ package backend import ( "context" "encoding/json" + "errors" "fmt" "io" "os" "path/filepath" "strings" + "sync" + "time" common "d7y.io/api/v2/pkg/apis/common/v2" dfdaemon "d7y.io/api/v2/pkg/apis/dfdaemon/v2" - "github.com/avast/retry-go/v4" "github.com/bmatcuk/doublestar/v4" legacymodelspec "github.com/dragonflyoss/model-spec/specs-go/v1" modelspec "github.com/modelpack/model-spec/specs-go/v1" @@ -41,6 +43,7 @@ import ( "github.com/modelpack/modctl/pkg/archiver" "github.com/modelpack/modctl/pkg/backend/remote" "github.com/modelpack/modctl/pkg/config" + "github.com/modelpack/modctl/pkg/retrypolicy" ) // fetchByDragonfly fetches partial files via Dragonfly gRPC service based on pattern matching. @@ -124,9 +127,12 @@ func (b *backend) fetchByDragonfly(ctx context.Context, target string, cfg *conf defer pb.Stop() // Process layers concurrently. - g, ctx := errgroup.WithContext(ctx) + g := new(errgroup.Group) g.SetLimit(cfg.Concurrency) + var mu sync.Mutex + var errs []error + logrus.Infof("fetch: fetching %d matched layers via dragonfly", len(layers)) for _, layer := range layers { g.Go(func() error { @@ -138,14 +144,18 @@ func (b *backend) fetchByDragonfly(ctx context.Context, target string, cfg *conf logrus.Debugf("fetch: processing layer %s via dragonfly", layer.Digest) if err := fetchLayerByDragonfly(ctx, pb, dfdaemon.NewDfdaemonDownloadClient(conn), ref, manifest, layer, authToken, cfg); err != nil { - return err + mu.Lock() + errs = append(errs, err) + mu.Unlock() + } else { + logrus.Debugf("fetch: successfully processed layer %s via dragonfly", layer.Digest) } - logrus.Debugf("fetch: successfully processed layer %s via dragonfly", layer.Digest) return nil }) } - if err := g.Wait(); err != nil { + _ = g.Wait() + if err := errors.Join(errs...); err != nil { return err } @@ -155,7 +165,16 @@ func (b *backend) fetchByDragonfly(ctx context.Context, target string, cfg *conf // fetchLayerByDragonfly handles downloading and extracting a single layer via Dragonfly. func fetchLayerByDragonfly(ctx context.Context, pb *internalpb.ProgressBar, client dfdaemon.DfdaemonDownloadClient, ref Referencer, manifest ocispec.Manifest, desc ocispec.Descriptor, authToken string, cfg *config.Fetch) error { - err := retry.Do(func() error { + var annoFilepath string + if desc.Annotations != nil { + if desc.Annotations[modelspec.AnnotationFilepath] != "" { + annoFilepath = desc.Annotations[modelspec.AnnotationFilepath] + } else { + annoFilepath = desc.Annotations[legacymodelspec.AnnotationFilepath] + } + } + + err := retrypolicy.Do(ctx, func(ctx context.Context) error { logrus.Debugf("fetch: processing layer %s", desc.Digest) cfg.Hooks.BeforePullLayer(desc, manifest) // Call before hook err := downloadAndExtractFetchLayer(ctx, pb, client, ref, desc, authToken, cfg) @@ -166,7 +185,18 @@ func fetchLayerByDragonfly(ctx context.Context, pb *internalpb.ProgressBar, clie } return err - }, append(defaultRetryOpts, retry.Context(ctx))...) + }, retrypolicy.DoOpts{ + FileSize: desc.Size, + FileName: annoFilepath, + Config: &cfg.RetryConfig, + OnRetry: func(attempt uint, reason string, backoff time.Duration) { + if bar := pb.Get(desc.Digest.String()); bar != nil { + bar.SetRefill(bar.Current()) + bar.SetCurrent(0) + bar.EwmaSetCurrent(0, time.Second) + } + }, + }) if err != nil { err = fmt.Errorf("fetch: failed to download and extract layer %s: %w", desc.Digest, err) diff --git a/pkg/backend/processor/base.go b/pkg/backend/processor/base.go index 579e61f2..fb2fa81a 100644 --- a/pkg/backend/processor/base.go +++ b/pkg/backend/processor/base.go @@ -18,6 +18,7 @@ package processor import ( "context" + "errors" "fmt" "io" "os" @@ -26,7 +27,6 @@ import ( "strings" "sync" - "github.com/avast/retry-go/v4" legacymodelspec "github.com/dragonflyoss/model-spec/specs-go/v1" modelspec "github.com/modelpack/model-spec/specs-go/v1" ocispec "github.com/opencontainers/image-spec/specs-go/v1" @@ -36,6 +36,7 @@ import ( internalpb "github.com/modelpack/modctl/internal/pb" "github.com/modelpack/modctl/pkg/backend/build" "github.com/modelpack/modctl/pkg/backend/build/hooks" + "github.com/modelpack/modctl/pkg/retrypolicy" "github.com/modelpack/modctl/pkg/storage" ) @@ -105,14 +106,11 @@ func (b *base) Process(ctx context.Context, builder build.Builder, workDir strin var ( mu sync.Mutex - eg *errgroup.Group descriptors []ocispec.Descriptor + errs []error ) - // Initialize errgroup with a context can be canceled. - ctx, cancel := context.WithCancel(ctx) - defer cancel() - eg, ctx = errgroup.WithContext(ctx) + eg := new(errgroup.Group) // Set default concurrency limit to 1 if not specified. if processOpts.concurrency > 0 { @@ -137,11 +135,17 @@ func (b *base) Process(ctx context.Context, builder build.Builder, workDir strin eg.Go(func() error { select { case <-ctx.Done(): - return ctx.Err() + return nil default: } - if err := retry.Do(func() error { + // Get file size for dynamic retry parameters. + var fileSize int64 + if fi, err := os.Stat(path); err == nil { + fileSize = fi.Size() + } + + if err := retrypolicy.Do(ctx, func(rctx context.Context) error { logrus.Debugf("processor: processing %s file %s", b.name, path) var destPath string @@ -149,7 +153,7 @@ func (b *base) Process(ctx context.Context, builder build.Builder, workDir strin destPath = filepath.Join(b.destDir, filepath.Base(path)) } - desc, err := builder.BuildLayer(ctx, b.mediaType, workDir, path, destPath, hooks.NewHooks( + desc, err := builder.BuildLayer(rctx, b.mediaType, workDir, path, destPath, hooks.NewHooks( hooks.WithOnStart(func(name string, size int64, reader io.Reader) io.Reader { return tracker.Add(internalpb.NormalizePrompt("Building layer"), name, size, reader) }), @@ -170,19 +174,24 @@ func (b *base) Process(ctx context.Context, builder build.Builder, workDir strin mu.Unlock() return nil - }, append(defaultRetryOpts, retry.Context(ctx))...); err != nil { + }, retrypolicy.DoOpts{ + FileSize: fileSize, + FileName: filepath.Base(path), + Config: processOpts.retryConfig, + }); err != nil { logrus.Error(err) - // Cancel manually to abort other tasks because if one fails, - // we should abort all to avoid useless waiting. - cancel() - return err + mu.Lock() + errs = append(errs, err) + mu.Unlock() } return nil }) } - if err := eg.Wait(); err != nil { + _ = eg.Wait() + + if err := errors.Join(errs...); err != nil { return nil, err } diff --git a/pkg/backend/processor/options.go b/pkg/backend/processor/options.go index 0126410e..43d5d82e 100644 --- a/pkg/backend/processor/options.go +++ b/pkg/backend/processor/options.go @@ -17,11 +17,8 @@ package processor import ( - "time" - - retry "github.com/avast/retry-go/v4" - "github.com/modelpack/modctl/internal/pb" + "github.com/modelpack/modctl/pkg/retrypolicy" ) type ProcessOption func(*processOptions) @@ -31,6 +28,8 @@ type processOptions struct { concurrency int // progressTracker is the progress bar to use for tracking progress. progressTracker *pb.ProgressBar + // retryConfig is the retry configuration to use for processing. + retryConfig *retrypolicy.Config } func WithConcurrency(concurrency int) ProcessOption { @@ -45,9 +44,8 @@ func WithProgressTracker(tracker *pb.ProgressBar) ProcessOption { } } -var defaultRetryOpts = []retry.Option{ - retry.Attempts(6), - retry.DelayType(retry.BackOffDelay), - retry.Delay(5 * time.Second), - retry.MaxDelay(60 * time.Second), +func WithRetryConfig(cfg retrypolicy.Config) ProcessOption { + return func(o *processOptions) { + o.retryConfig = &cfg + } } diff --git a/pkg/backend/pull.go b/pkg/backend/pull.go index 3cb6a1b0..6c95de0b 100644 --- a/pkg/backend/pull.go +++ b/pkg/backend/pull.go @@ -22,8 +22,9 @@ import ( "errors" "fmt" "io" + "sync" + "time" - retry "github.com/avast/retry-go/v4" sha256 "github.com/minio/sha256-simd" ocispec "github.com/opencontainers/image-spec/specs-go/v1" "github.com/sirupsen/logrus" @@ -33,6 +34,7 @@ import ( "github.com/modelpack/modctl/pkg/backend/remote" "github.com/modelpack/modctl/pkg/codec" "github.com/modelpack/modctl/pkg/config" + "github.com/modelpack/modctl/pkg/retrypolicy" "github.com/modelpack/modctl/pkg/storage" ) @@ -90,17 +92,20 @@ func (b *backend) Pull(ctx context.Context, target string, cfg *config.Pull) err // copy the layers. dst := b.store - g, gctx := errgroup.WithContext(ctx) + g := new(errgroup.Group) g.SetLimit(cfg.Concurrency) + var mu sync.Mutex + var errs []error + var fn func(desc ocispec.Descriptor) error if cfg.ExtractFromRemote { fn = func(desc ocispec.Descriptor) error { - return pullAndExtractFromRemote(gctx, pb, internalpb.NormalizePrompt("Pulling blob"), src, cfg.ExtractDir, desc) + return pullAndExtractFromRemote(ctx, pb, internalpb.NormalizePrompt("Pulling blob"), src, cfg.ExtractDir, desc) } } else { fn = func(desc ocispec.Descriptor) error { - return pullIfNotExist(gctx, pb, internalpb.NormalizePrompt("Pulling blob"), src, dst, desc, repo, tag) + return pullIfNotExist(ctx, pb, internalpb.NormalizePrompt("Pulling blob"), src, dst, desc, repo, tag) } } @@ -108,12 +113,12 @@ func (b *backend) Pull(ctx context.Context, target string, cfg *config.Pull) err for _, layer := range manifest.Layers { g.Go(func() error { select { - case <-gctx.Done(): - return gctx.Err() + case <-ctx.Done(): + return ctx.Err() default: } - return retry.Do(func() error { + retryErr := retrypolicy.Do(ctx, func(retryCtx context.Context) error { logrus.Debugf("pull: processing layer %s", layer.Digest) // call the before hook. cfg.Hooks.BeforePullLayer(layer, manifest) @@ -126,12 +131,27 @@ func (b *backend) Pull(ctx context.Context, target string, cfg *config.Pull) err } return err - }, append(defaultRetryOpts, retry.Context(gctx))...) + }, retrypolicy.DoOpts{ + FileSize: layer.Size, + FileName: layer.Digest.String(), + Config: &cfg.RetryConfig, + OnRetry: func(attempt uint, reason string, backoff time.Duration) { + pb.Placeholder(layer.Digest.String(), internalpb.NormalizePrompt("Pulling blob"), layer.Size) + }, + }) + if retryErr != nil { + mu.Lock() + errs = append(errs, retryErr) + mu.Unlock() + } + + return nil }) } - if err := g.Wait(); err != nil { - return fmt.Errorf("failed to pull blob to local: %w", err) + _ = g.Wait() + if len(errs) > 0 { + return fmt.Errorf("failed to pull blob to local: %w", errors.Join(errs...)) } logrus.Infof("pull: layers pulled [count: %d]", len(manifest.Layers)) @@ -143,16 +163,30 @@ func (b *backend) Pull(ctx context.Context, target string, cfg *config.Pull) err } // copy the config. - if err := retry.Do(func() error { - return pullIfNotExist(ctx, pb, internalpb.NormalizePrompt("Pulling config"), src, dst, manifest.Config, repo, tag) - }, append(defaultRetryOpts, retry.Context(ctx))...); err != nil { + if err := retrypolicy.Do(ctx, func(retryCtx context.Context) error { + return pullIfNotExist(retryCtx, pb, internalpb.NormalizePrompt("Pulling config"), src, dst, manifest.Config, repo, tag) + }, retrypolicy.DoOpts{ + FileSize: manifest.Config.Size, + FileName: "config", + Config: &cfg.RetryConfig, + OnRetry: func(attempt uint, reason string, backoff time.Duration) { + pb.Placeholder(manifest.Config.Digest.String(), internalpb.NormalizePrompt("Pulling config"), manifest.Config.Size) + }, + }); err != nil { return fmt.Errorf("failed to pull config to local: %w", err) } // copy the manifest. - if err := retry.Do(func() error { - return pullIfNotExist(ctx, pb, internalpb.NormalizePrompt("Pulling manifest"), src, dst, manifestDesc, repo, tag) - }, append(defaultRetryOpts, retry.Context(ctx))...); err != nil { + if err := retrypolicy.Do(ctx, func(retryCtx context.Context) error { + return pullIfNotExist(retryCtx, pb, internalpb.NormalizePrompt("Pulling manifest"), src, dst, manifestDesc, repo, tag) + }, retrypolicy.DoOpts{ + FileSize: manifestDesc.Size, + FileName: "manifest", + Config: &cfg.RetryConfig, + OnRetry: func(attempt uint, reason string, backoff time.Duration) { + pb.Placeholder(manifestDesc.Digest.String(), internalpb.NormalizePrompt("Pulling manifest"), manifestDesc.Size) + }, + }); err != nil { return fmt.Errorf("failed to pull manifest to local: %w", err) } diff --git a/pkg/backend/pull_by_d7y.go b/pkg/backend/pull_by_d7y.go index a7ca3b47..467db254 100644 --- a/pkg/backend/pull_by_d7y.go +++ b/pkg/backend/pull_by_d7y.go @@ -19,15 +19,17 @@ package backend import ( "context" "encoding/json" + "errors" "fmt" "io" "os" "path/filepath" "strings" + "sync" + "time" common "d7y.io/api/v2/pkg/apis/common/v2" dfdaemon "d7y.io/api/v2/pkg/apis/dfdaemon/v2" - "github.com/avast/retry-go/v4" legacymodelspec "github.com/dragonflyoss/model-spec/specs-go/v1" modelspec "github.com/modelpack/model-spec/specs-go/v1" ocispec "github.com/opencontainers/image-spec/specs-go/v1" @@ -41,6 +43,7 @@ import ( "github.com/modelpack/modctl/pkg/archiver" "github.com/modelpack/modctl/pkg/backend/remote" "github.com/modelpack/modctl/pkg/config" + "github.com/modelpack/modctl/pkg/retrypolicy" ) const ( @@ -100,9 +103,12 @@ func (b *backend) pullByDragonfly(ctx context.Context, target string, cfg *confi defer pb.Stop() // Process layers concurrently. - g, ctx := errgroup.WithContext(ctx) + g := new(errgroup.Group) g.SetLimit(cfg.Concurrency) + var mu sync.Mutex + var errs []error + logrus.Infof("pull: pulling %d layers via dragonfly", len(manifest.Layers)) for _, layer := range manifest.Layers { g.Go(func() error { @@ -114,14 +120,18 @@ func (b *backend) pullByDragonfly(ctx context.Context, target string, cfg *confi logrus.Debugf("pull: processing layer %s via dragonfly", layer.Digest) if err := processLayer(ctx, pb, dfdaemon.NewDfdaemonDownloadClient(conn), ref, manifest, layer, authToken, cfg); err != nil { - return err + mu.Lock() + errs = append(errs, err) + mu.Unlock() + } else { + logrus.Debugf("pull: successfully processed layer %s via dragonfly", layer.Digest) } - logrus.Debugf("pull: successfully processed layer %s via dragonfly", layer.Digest) return nil }) } - if err := g.Wait(); err != nil { + _ = g.Wait() + if err := errors.Join(errs...); err != nil { return err } @@ -179,7 +189,16 @@ func buildBlobURL(ref Referencer, plainHTTP bool, digest string) string { // processLayer handles downloading and extracting a single layer. func processLayer(ctx context.Context, pb *internalpb.ProgressBar, client dfdaemon.DfdaemonDownloadClient, ref Referencer, manifest ocispec.Manifest, desc ocispec.Descriptor, authToken string, cfg *config.Pull) error { - err := retry.Do(func() error { + var annoFilepath string + if desc.Annotations != nil { + if desc.Annotations[modelspec.AnnotationFilepath] != "" { + annoFilepath = desc.Annotations[modelspec.AnnotationFilepath] + } else { + annoFilepath = desc.Annotations[legacymodelspec.AnnotationFilepath] + } + } + + err := retrypolicy.Do(ctx, func(ctx context.Context) error { logrus.Debugf("pull: processing layer %s", desc.Digest) cfg.Hooks.BeforePullLayer(desc, manifest) // Call before hook err := downloadAndExtractLayer(ctx, pb, client, ref, desc, authToken, cfg) @@ -190,7 +209,18 @@ func processLayer(ctx context.Context, pb *internalpb.ProgressBar, client dfdaem } return err - }, append(defaultRetryOpts, retry.Context(ctx))...) + }, retrypolicy.DoOpts{ + FileSize: desc.Size, + FileName: annoFilepath, + Config: &cfg.RetryConfig, + OnRetry: func(attempt uint, reason string, backoff time.Duration) { + if bar := pb.Get(desc.Digest.String()); bar != nil { + bar.SetRefill(bar.Current()) + bar.SetCurrent(0) + bar.EwmaSetCurrent(0, time.Second) + } + }, + }) return err } diff --git a/pkg/backend/push.go b/pkg/backend/push.go index 2674cee2..a320f531 100644 --- a/pkg/backend/push.go +++ b/pkg/backend/push.go @@ -20,10 +20,12 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "io" + "sync" + "time" - retry "github.com/avast/retry-go/v4" godigest "github.com/opencontainers/go-digest" ocispec "github.com/opencontainers/image-spec/specs-go/v1" "github.com/sirupsen/logrus" @@ -32,6 +34,7 @@ import ( internalpb "github.com/modelpack/modctl/internal/pb" "github.com/modelpack/modctl/pkg/backend/remote" "github.com/modelpack/modctl/pkg/config" + "github.com/modelpack/modctl/pkg/retrypolicy" "github.com/modelpack/modctl/pkg/storage" ) @@ -77,49 +80,74 @@ func (b *backend) Push(ctx context.Context, target string, cfg *config.Push) err // note: the order is important, manifest should be pushed at last. // copy the layers. - g, gctx := errgroup.WithContext(ctx) + g := new(errgroup.Group) g.SetLimit(cfg.Concurrency) + var mu sync.Mutex + var errs []error logrus.Infof("push: pushing %d layers for %s", len(manifest.Layers), target) for _, layer := range manifest.Layers { g.Go(func() error { select { - case <-gctx.Done(): - return gctx.Err() + case <-ctx.Done(): + return ctx.Err() default: } - return retry.Do(func() error { + if err := retrypolicy.Do(ctx, func(rctx context.Context) error { logrus.Debugf("push: processing layer %s", layer.Digest) - if err := pushIfNotExist(gctx, pb, internalpb.NormalizePrompt("Copying blob"), src, dst, layer, repo, tag); err != nil { + if err := pushIfNotExist(rctx, pb, internalpb.NormalizePrompt("Copying blob"), src, dst, layer, repo, tag); err != nil { return err } logrus.Debugf("push: successfully processed layer %s", layer.Digest) return nil - }, append(defaultRetryOpts, retry.Context(gctx))...) + }, retrypolicy.DoOpts{ + FileSize: layer.Size, + FileName: layer.Digest.String(), + Config: &cfg.RetryConfig, + OnRetry: func(attempt uint, reason string, backoff time.Duration) { + prompt := fmt.Sprintf("%s (retry %d, %s, waiting %s)", + internalpb.NormalizePrompt("Copying blob"), attempt, reason, backoff.Truncate(time.Second)) + pb.Add(prompt, layer.Digest.String(), layer.Size, nil) + }, + }); err != nil { + mu.Lock() + errs = append(errs, err) + mu.Unlock() + } + return nil // never return error to errgroup }) } - if err := g.Wait(); err != nil { - return fmt.Errorf("failed to push blob to remote: %w", err) + g.Wait() + if len(errs) > 0 { + return fmt.Errorf("failed to push blob to remote: %w", errors.Join(errs...)) } // copy the config. - if err := retry.Do(func() error { - return pushIfNotExist(ctx, pb, internalpb.NormalizePrompt("Copying config"), src, dst, manifest.Config, repo, tag) - }, append(defaultRetryOpts, retry.Context(ctx))...); err != nil { + if err := retrypolicy.Do(ctx, func(rctx context.Context) error { + return pushIfNotExist(rctx, pb, internalpb.NormalizePrompt("Copying config"), src, dst, manifest.Config, repo, tag) + }, retrypolicy.DoOpts{ + FileSize: manifest.Config.Size, + FileName: "config", + Config: &cfg.RetryConfig, + }); err != nil { return fmt.Errorf("failed to push config to remote: %w", err) } // copy the manifest. - if err := retry.Do(func() error { - return pushIfNotExist(ctx, pb, internalpb.NormalizePrompt("Copying manifest"), src, dst, ocispec.Descriptor{ + if err := retrypolicy.Do(ctx, func(rctx context.Context) error { + return pushIfNotExist(rctx, pb, internalpb.NormalizePrompt("Copying manifest"), src, dst, ocispec.Descriptor{ MediaType: manifest.MediaType, Size: int64(len(manifestRaw)), Digest: godigest.FromBytes(manifestRaw), Data: manifestRaw, }, repo, tag) - }, append(defaultRetryOpts, retry.Context(ctx))...); err != nil { + }, retrypolicy.DoOpts{ + FileSize: int64(len(manifestRaw)), + FileName: "manifest", + Config: &cfg.RetryConfig, + }); err != nil { return fmt.Errorf("failed to push manifest to remote: %w", err) } From be391fdbc5d116641394e19f7aa682db2590999b Mon Sep 17 00:00:00 2001 From: Zhao Chen Date: Mon, 23 Mar 2026 18:27:07 +0800 Subject: [PATCH 04/16] refactor: remove legacy defaultRetryOpts, delegate to retrypolicy Signed-off-by: Zhao Chen --- pkg/backend/retry.go | 30 ------------ pkg/backend/retry_test.go | 96 --------------------------------------- 2 files changed, 126 deletions(-) delete mode 100644 pkg/backend/retry.go delete mode 100644 pkg/backend/retry_test.go diff --git a/pkg/backend/retry.go b/pkg/backend/retry.go deleted file mode 100644 index c7494250..00000000 --- a/pkg/backend/retry.go +++ /dev/null @@ -1,30 +0,0 @@ -/* - * Copyright 2025 The CNAI Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package backend - -import ( - "time" - - retry "github.com/avast/retry-go/v4" -) - -var defaultRetryOpts = []retry.Option{ - retry.Attempts(6), - retry.DelayType(retry.BackOffDelay), - retry.Delay(5 * time.Second), - retry.MaxDelay(60 * time.Second), -} diff --git a/pkg/backend/retry_test.go b/pkg/backend/retry_test.go deleted file mode 100644 index f4069ac1..00000000 --- a/pkg/backend/retry_test.go +++ /dev/null @@ -1,96 +0,0 @@ -/* - * Copyright 2025 The CNAI Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package backend - -import ( - "context" - "errors" - "testing" - - retry "github.com/avast/retry-go/v4" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// testRetryOpts creates retry options with zero delay so tests run fast and deterministically. -func testRetryOpts(ctx context.Context, maxAttempts uint) []retry.Option { - return []retry.Option{ - retry.Attempts(maxAttempts), - retry.Delay(0), - retry.MaxDelay(0), - retry.Context(ctx), - } -} - -func TestRetrySuccessAfterFailures(t *testing.T) { - ctx := context.Background() - attempts := 0 - - err := retry.Do(func() error { - attempts++ - if attempts < 3 { - return errors.New("temporary failure") - } - return nil - }, testRetryOpts(ctx, 6)...) - - require.NoError(t, err) - assert.Equal(t, 3, attempts) -} - -func TestRetryMaxAttemptsExceeded(t *testing.T) { - ctx := context.Background() - attempts := 0 - maxAttempts := uint(4) - - err := retry.Do(func() error { - attempts++ - return errors.New("persistent failure") - }, testRetryOpts(ctx, maxAttempts)...) - - assert.Error(t, err) - assert.Equal(t, int(maxAttempts), attempts) -} - -func TestRetryStopsOnContextCancel(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - attempts := 0 - - err := retry.Do(func() error { - attempts++ - if attempts == 2 { - cancel() - } - return errors.New("temporary failure") - }, testRetryOpts(ctx, 10)...) - - assert.ErrorIs(t, err, context.Canceled) - assert.Equal(t, 2, attempts) -} - -func TestRetrySucceedsOnFirstAttempt(t *testing.T) { - ctx := context.Background() - attempts := 0 - - err := retry.Do(func() error { - attempts++ - return nil - }, testRetryOpts(ctx, 6)...) - - require.NoError(t, err) - assert.Equal(t, 1, attempts) -} From 2fcdba2cfbac534d37585217e01574623604d31a Mon Sep 17 00:00:00 2001 From: Zhao Chen Date: Mon, 23 Mar 2026 21:13:42 +0800 Subject: [PATCH 05/16] fix(push): add OnRetry handler for config and manifest retry Signed-off-by: Zhao Chen --- pkg/backend/push.go | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/pkg/backend/push.go b/pkg/backend/push.go index a320f531..f2073358 100644 --- a/pkg/backend/push.go +++ b/pkg/backend/push.go @@ -131,22 +131,33 @@ func (b *backend) Push(ctx context.Context, target string, cfg *config.Push) err FileSize: manifest.Config.Size, FileName: "config", Config: &cfg.RetryConfig, + OnRetry: func(attempt uint, reason string, backoff time.Duration) { + prompt := fmt.Sprintf("%s (retry %d, %s, waiting %s)", + internalpb.NormalizePrompt("Copying config"), attempt, reason, backoff.Truncate(time.Second)) + pb.Add(prompt, manifest.Config.Digest.String(), manifest.Config.Size, nil) + }, }); err != nil { return fmt.Errorf("failed to push config to remote: %w", err) } // copy the manifest. + manifestDesc := ocispec.Descriptor{ + MediaType: manifest.MediaType, + Size: int64(len(manifestRaw)), + Digest: godigest.FromBytes(manifestRaw), + Data: manifestRaw, + } if err := retrypolicy.Do(ctx, func(rctx context.Context) error { - return pushIfNotExist(rctx, pb, internalpb.NormalizePrompt("Copying manifest"), src, dst, ocispec.Descriptor{ - MediaType: manifest.MediaType, - Size: int64(len(manifestRaw)), - Digest: godigest.FromBytes(manifestRaw), - Data: manifestRaw, - }, repo, tag) + return pushIfNotExist(rctx, pb, internalpb.NormalizePrompt("Copying manifest"), src, dst, manifestDesc, repo, tag) }, retrypolicy.DoOpts{ - FileSize: int64(len(manifestRaw)), + FileSize: manifestDesc.Size, FileName: "manifest", Config: &cfg.RetryConfig, + OnRetry: func(attempt uint, reason string, backoff time.Duration) { + prompt := fmt.Sprintf("%s (retry %d, %s, waiting %s)", + internalpb.NormalizePrompt("Copying manifest"), attempt, reason, backoff.Truncate(time.Second)) + pb.Add(prompt, manifestDesc.Digest.String(), manifestDesc.Size, nil) + }, }); err != nil { return fmt.Errorf("failed to push manifest to remote: %w", err) } From 4cdd8ae89b22aefd2940c19df3b0202bbd3b43c7 Mon Sep 17 00:00:00 2001 From: Zhao Chen Date: Mon, 23 Mar 2026 21:54:51 +0800 Subject: [PATCH 06/16] fix: add non-retryable local errors and thread retryCtx in pull Signed-off-by: Zhao Chen --- pkg/backend/pull.go | 8 ++++---- pkg/retrypolicy/retrypolicy.go | 14 +++++++++++++- pkg/retrypolicy/retrypolicy_test.go | 15 +++++++++++++++ 3 files changed, 32 insertions(+), 5 deletions(-) diff --git a/pkg/backend/pull.go b/pkg/backend/pull.go index 6c95de0b..6ffa3baf 100644 --- a/pkg/backend/pull.go +++ b/pkg/backend/pull.go @@ -98,13 +98,13 @@ func (b *backend) Pull(ctx context.Context, target string, cfg *config.Pull) err var mu sync.Mutex var errs []error - var fn func(desc ocispec.Descriptor) error + var fn func(ctx context.Context, desc ocispec.Descriptor) error if cfg.ExtractFromRemote { - fn = func(desc ocispec.Descriptor) error { + fn = func(ctx context.Context, desc ocispec.Descriptor) error { return pullAndExtractFromRemote(ctx, pb, internalpb.NormalizePrompt("Pulling blob"), src, cfg.ExtractDir, desc) } } else { - fn = func(desc ocispec.Descriptor) error { + fn = func(ctx context.Context, desc ocispec.Descriptor) error { return pullIfNotExist(ctx, pb, internalpb.NormalizePrompt("Pulling blob"), src, dst, desc, repo, tag) } } @@ -122,7 +122,7 @@ func (b *backend) Pull(ctx context.Context, target string, cfg *config.Pull) err logrus.Debugf("pull: processing layer %s", layer.Digest) // call the before hook. cfg.Hooks.BeforePullLayer(layer, manifest) - err := fn(layer) + err := fn(retryCtx, layer) // call the after hook. cfg.Hooks.AfterPullLayer(layer, err) if err != nil { diff --git a/pkg/retrypolicy/retrypolicy.go b/pkg/retrypolicy/retrypolicy.go index b7761712..c53cf0ea 100644 --- a/pkg/retrypolicy/retrypolicy.go +++ b/pkg/retrypolicy/retrypolicy.go @@ -219,7 +219,19 @@ func IsRetryable(err error) bool { return true } - // Unknown errors default to retryable. + // Local / permanent errors are not retryable. + if strings.Contains(errMsg, "permission denied") || + strings.Contains(errMsg, "no space left on device") || + strings.Contains(errMsg, "file exists") || + strings.Contains(errMsg, "not a directory") || + strings.Contains(errMsg, "is a directory") || + strings.Contains(errMsg, "no such file or directory") || + strings.Contains(errMsg, "invalid argument") { + return false + } + + // Unknown errors default to retryable with a warning. + log.WithField("error", errMsg).Warn("[RETRY] unknown error treated as retryable") return true } diff --git a/pkg/retrypolicy/retrypolicy_test.go b/pkg/retrypolicy/retrypolicy_test.go index 3e29226f..20601963 100644 --- a/pkg/retrypolicy/retrypolicy_test.go +++ b/pkg/retrypolicy/retrypolicy_test.go @@ -206,6 +206,21 @@ func TestIsRetryable(t *testing.T) { err: fmt.Errorf("unexpected EOF"), want: true, }, + { + name: "permission denied - not retryable", + err: fmt.Errorf("open /data/model.bin: permission denied"), + want: false, + }, + { + name: "no space left on device - not retryable", + err: fmt.Errorf("write /data/model.bin: no space left on device"), + want: false, + }, + { + name: "no such file or directory - not retryable", + err: fmt.Errorf("open /data/model.bin: no such file or directory"), + want: false, + }, { name: "unknown error - defaults to retryable", err: errors.New("something totally unexpected happened"), From 5047acc8211f01c468f3ee84de2e31d1fbfbc0ae Mon Sep 17 00:00:00 2001 From: Zhao Chen Date: Mon, 23 Mar 2026 22:15:31 +0800 Subject: [PATCH 07/16] fix: propagate context cancellation after Wait to prevent incomplete artifacts Signed-off-by: Zhao Chen --- pkg/backend/fetch.go | 3 +++ pkg/backend/fetch_by_d7y.go | 3 +++ pkg/backend/processor/base.go | 4 ++++ pkg/backend/pull.go | 3 +++ pkg/backend/pull_by_d7y.go | 3 +++ pkg/backend/push.go | 3 +++ 6 files changed, 19 insertions(+) diff --git a/pkg/backend/fetch.go b/pkg/backend/fetch.go index 37b88678..dbc9d956 100644 --- a/pkg/backend/fetch.go +++ b/pkg/backend/fetch.go @@ -155,6 +155,9 @@ func (b *backend) Fetch(ctx context.Context, target string, cfg *config.Fetch) e } _ = g.Wait() + if ctx.Err() != nil { + return fmt.Errorf("fetch cancelled: %w", ctx.Err()) + } if err := errors.Join(errs...); err != nil { return err } diff --git a/pkg/backend/fetch_by_d7y.go b/pkg/backend/fetch_by_d7y.go index 94c24d32..8d9035f3 100644 --- a/pkg/backend/fetch_by_d7y.go +++ b/pkg/backend/fetch_by_d7y.go @@ -155,6 +155,9 @@ func (b *backend) fetchByDragonfly(ctx context.Context, target string, cfg *conf } _ = g.Wait() + if ctx.Err() != nil { + return fmt.Errorf("fetch cancelled: %w", ctx.Err()) + } if err := errors.Join(errs...); err != nil { return err } diff --git a/pkg/backend/processor/base.go b/pkg/backend/processor/base.go index fb2fa81a..e62863bc 100644 --- a/pkg/backend/processor/base.go +++ b/pkg/backend/processor/base.go @@ -191,6 +191,10 @@ func (b *base) Process(ctx context.Context, builder build.Builder, workDir strin _ = eg.Wait() + if ctx.Err() != nil { + return nil, fmt.Errorf("processing cancelled: %w", ctx.Err()) + } + if err := errors.Join(errs...); err != nil { return nil, err } diff --git a/pkg/backend/pull.go b/pkg/backend/pull.go index 6ffa3baf..2038f99c 100644 --- a/pkg/backend/pull.go +++ b/pkg/backend/pull.go @@ -150,6 +150,9 @@ func (b *backend) Pull(ctx context.Context, target string, cfg *config.Pull) err } _ = g.Wait() + if ctx.Err() != nil { + return fmt.Errorf("pull cancelled: %w", ctx.Err()) + } if len(errs) > 0 { return fmt.Errorf("failed to pull blob to local: %w", errors.Join(errs...)) } diff --git a/pkg/backend/pull_by_d7y.go b/pkg/backend/pull_by_d7y.go index 467db254..fa462596 100644 --- a/pkg/backend/pull_by_d7y.go +++ b/pkg/backend/pull_by_d7y.go @@ -131,6 +131,9 @@ func (b *backend) pullByDragonfly(ctx context.Context, target string, cfg *confi } _ = g.Wait() + if ctx.Err() != nil { + return fmt.Errorf("pull cancelled: %w", ctx.Err()) + } if err := errors.Join(errs...); err != nil { return err } diff --git a/pkg/backend/push.go b/pkg/backend/push.go index f2073358..d227f79c 100644 --- a/pkg/backend/push.go +++ b/pkg/backend/push.go @@ -120,6 +120,9 @@ func (b *backend) Push(ctx context.Context, target string, cfg *config.Push) err } g.Wait() + if ctx.Err() != nil { + return fmt.Errorf("push cancelled: %w", ctx.Err()) + } if len(errs) > 0 { return fmt.Errorf("failed to push blob to remote: %w", errors.Join(errs...)) } From fcdbcb0813c523fe3af5c818209beef682c03518 Mon Sep 17 00:00:00 2001 From: Zhao Chen Date: Mon, 23 Mar 2026 22:40:59 +0800 Subject: [PATCH 08/16] fix: handle DeadlineExceeded in IsRetryable, fix ctx shadow, speed up tests Signed-off-by: Zhao Chen --- pkg/backend/fetch.go | 4 ++-- pkg/backend/fetch_by_d7y.go | 4 ++-- pkg/backend/pull_by_d7y.go | 4 ++-- pkg/retrypolicy/retrypolicy.go | 21 +++++++++++++++++---- pkg/retrypolicy/retrypolicy_test.go | 23 +++++++++++++++++++---- 5 files changed, 42 insertions(+), 14 deletions(-) diff --git a/pkg/backend/fetch.go b/pkg/backend/fetch.go index dbc9d956..aebbe947 100644 --- a/pkg/backend/fetch.go +++ b/pkg/backend/fetch.go @@ -130,8 +130,8 @@ func (b *backend) Fetch(ctx context.Context, target string, cfg *config.Fetch) e } logrus.Debugf("fetch: processing layer %s", layer.Digest) - if err := retrypolicy.Do(ctx, func(ctx context.Context) error { - return pullAndExtractFromRemote(ctx, pb, internalpb.NormalizePrompt("Fetching blob"), client, cfg.Output, layer) + if err := retrypolicy.Do(ctx, func(rctx context.Context) error { + return pullAndExtractFromRemote(rctx, pb, internalpb.NormalizePrompt("Fetching blob"), client, cfg.Output, layer) }, retrypolicy.DoOpts{ FileSize: layer.Size, FileName: annoFilepath, diff --git a/pkg/backend/fetch_by_d7y.go b/pkg/backend/fetch_by_d7y.go index 8d9035f3..39a1248d 100644 --- a/pkg/backend/fetch_by_d7y.go +++ b/pkg/backend/fetch_by_d7y.go @@ -177,10 +177,10 @@ func fetchLayerByDragonfly(ctx context.Context, pb *internalpb.ProgressBar, clie } } - err := retrypolicy.Do(ctx, func(ctx context.Context) error { + err := retrypolicy.Do(ctx, func(rctx context.Context) error { logrus.Debugf("fetch: processing layer %s", desc.Digest) cfg.Hooks.BeforePullLayer(desc, manifest) // Call before hook - err := downloadAndExtractFetchLayer(ctx, pb, client, ref, desc, authToken, cfg) + err := downloadAndExtractFetchLayer(rctx, pb, client, ref, desc, authToken, cfg) cfg.Hooks.AfterPullLayer(desc, err) // Call after hook if err != nil { err = fmt.Errorf("pull: failed to download and extract layer %s: %w", desc.Digest, err) diff --git a/pkg/backend/pull_by_d7y.go b/pkg/backend/pull_by_d7y.go index fa462596..32782597 100644 --- a/pkg/backend/pull_by_d7y.go +++ b/pkg/backend/pull_by_d7y.go @@ -201,10 +201,10 @@ func processLayer(ctx context.Context, pb *internalpb.ProgressBar, client dfdaem } } - err := retrypolicy.Do(ctx, func(ctx context.Context) error { + err := retrypolicy.Do(ctx, func(rctx context.Context) error { logrus.Debugf("pull: processing layer %s", desc.Digest) cfg.Hooks.BeforePullLayer(desc, manifest) // Call before hook - err := downloadAndExtractLayer(ctx, pb, client, ref, desc, authToken, cfg) + err := downloadAndExtractLayer(rctx, pb, client, ref, desc, authToken, cfg) cfg.Hooks.AfterPullLayer(desc, err) // Call after hook if err != nil { err = fmt.Errorf("pull: failed to download and extract layer %s: %w", desc.Digest, err) diff --git a/pkg/retrypolicy/retrypolicy.go b/pkg/retrypolicy/retrypolicy.go index c53cf0ea..f43cce1c 100644 --- a/pkg/retrypolicy/retrypolicy.go +++ b/pkg/retrypolicy/retrypolicy.go @@ -51,6 +51,8 @@ const ( type Config struct { MaxRetryTime time.Duration // 0 = dynamic based on file size NoRetry bool // disable retry entirely + InitialDelay time.Duration // 0 = use default (5s), for testing + MaxJitter time.Duration // -1 = no jitter, 0 = use default (5s), for testing } // DoOpts configures a single Do call. @@ -93,6 +95,17 @@ func Do(ctx context.Context, fn func(ctx context.Context) error, opts DoOpts) er sizeStr := humanizeBytes(opts.FileSize) + delay := initialDelay + jitter := maxJitter + if cfg.InitialDelay > 0 { + delay = cfg.InitialDelay + } + if cfg.MaxJitter < 0 { + jitter = 0 + } else if cfg.MaxJitter > 0 { + jitter = cfg.MaxJitter + } + return retry.Do( func() error { return fn(deadlineCtx) @@ -100,9 +113,9 @@ func Do(ctx context.Context, fn func(ctx context.Context) error, opts DoOpts) er retry.Attempts(0), retry.Context(deadlineCtx), retry.DelayType(retry.BackOffDelay), - retry.Delay(initialDelay), + retry.Delay(delay), retry.MaxDelay(maxBackoff), - retry.MaxJitter(maxJitter), + retry.MaxJitter(jitter), retry.LastErrorOnly(true), retry.WrapContextErrorWithLastError(true), retry.RetryIf(func(err error) bool { @@ -180,8 +193,8 @@ func IsRetryable(err error) bool { return false } - // context.Canceled is never retryable — it means user/system cancellation. - if errors.Is(err, context.Canceled) { + // context.Canceled and context.DeadlineExceeded are never retryable. + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { return false } diff --git a/pkg/retrypolicy/retrypolicy_test.go b/pkg/retrypolicy/retrypolicy_test.go index 20601963..eeabb44f 100644 --- a/pkg/retrypolicy/retrypolicy_test.go +++ b/pkg/retrypolicy/retrypolicy_test.go @@ -132,6 +132,16 @@ func TestIsRetryable(t *testing.T) { err: fmt.Errorf("operation failed: %w", context.Canceled), want: false, }, + { + name: "context.DeadlineExceeded", + err: context.DeadlineExceeded, + want: false, + }, + { + name: "wrapped context.DeadlineExceeded", + err: fmt.Errorf("operation timed out: %w", context.DeadlineExceeded), + want: false, + }, { name: "HTTP 500 server error", err: fmt.Errorf("PUT /blobs/uploads: response status code 500: internal server error"), @@ -344,6 +354,7 @@ func TestDo_RetryOnTransientError(t *testing.T) { }, DoOpts{ FileSize: 100, FileName: "test.bin", + Config: &Config{InitialDelay: 10 * time.Millisecond, MaxJitter: -1}, }) if err != nil { @@ -406,6 +417,7 @@ func TestDo_ParentContextCancel(t *testing.T) { }, DoOpts{ FileSize: 100, FileName: "test.bin", + Config: &Config{InitialDelay: 10 * time.Millisecond, MaxJitter: -1}, }) if err == nil { @@ -427,6 +439,7 @@ func TestDo_OnRetryCallback(t *testing.T) { }, DoOpts{ FileSize: 100, FileName: "test.bin", + Config: &Config{InitialDelay: 10 * time.Millisecond, MaxJitter: -1}, OnRetry: func(attempt uint, reason string, backoff time.Duration) { retryAttempts = append(retryAttempts, attempt) retryReasons = append(retryReasons, reason) @@ -458,7 +471,9 @@ func TestDo_ConfigMaxRetryTimeOverride(t *testing.T) { FileSize: 100, FileName: "test.bin", Config: &Config{ - MaxRetryTime: 8 * time.Second, + MaxRetryTime: 1 * time.Second, + InitialDelay: 50 * time.Millisecond, + MaxJitter: -1, }, }) @@ -467,9 +482,9 @@ func TestDo_ConfigMaxRetryTimeOverride(t *testing.T) { if err == nil { t.Fatal("expected error after retry timeout, got nil") } - // Should have run for approximately MaxRetryTime (8s), not the dynamic default (10min). - if elapsed > 30*time.Second { - t.Errorf("expected retry to terminate within ~8s, but elapsed %v", elapsed) + // Should have run for approximately MaxRetryTime (1s), not the dynamic default (10min). + if elapsed > 5*time.Second { + t.Errorf("expected retry to terminate within ~1s, but elapsed %v", elapsed) } if atomic.LoadInt32(&callCount) < 2 { t.Errorf("expected at least 2 attempts, got %d", callCount) From f155ccb41a37299f5446e3e49ec4367223ed3ffa Mon Sep 17 00:00:00 2001 From: Zhao Chen Date: Thu, 23 Apr 2026 12:50:30 +0800 Subject: [PATCH 09/16] refactor(retrypolicy): fix off-by-one in computeBackoff logging retry-go's OnRetry callback supplies a 1-based retry attempt number, so computeBackoff(n+1, ...) logged a delay that was one doubling ahead of the actual backoff used by the retry loop. Pass n directly. Also drop the unreachable attempt == 0 branch in computeBackoff (retry-go never supplies 0) and skip that case in the test. Addresses review feedback on PR #468 (gemini). Signed-off-by: Zhao Chen --- pkg/retrypolicy/retrypolicy.go | 7 +++---- pkg/retrypolicy/retrypolicy_test.go | 6 ++---- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/pkg/retrypolicy/retrypolicy.go b/pkg/retrypolicy/retrypolicy.go index f43cce1c..54709744 100644 --- a/pkg/retrypolicy/retrypolicy.go +++ b/pkg/retrypolicy/retrypolicy.go @@ -130,7 +130,7 @@ func Do(ctx context.Context, fn func(ctx context.Context) error, opts DoOpts) er return retryable }), retry.OnRetry(func(n uint, err error) { - backoff := computeBackoff(n+1, initialDelay, maxBackoff) + backoff := computeBackoff(n, initialDelay, maxBackoff) elapsed := time.Since(startTime) log.WithFields(log.Fields{ @@ -173,10 +173,9 @@ func computeDynamicParams(fileSize int64) (time.Duration, time.Duration) { // computeBackoff estimates the backoff duration for display purposes. // It mirrors the exponential backoff calculation without jitter. +// The attempt parameter is the 1-based retry number as provided by +// retry-go's OnRetry callback, so it is always >= 1. func computeBackoff(attempt uint, initial, maxDelay time.Duration) time.Duration { - if attempt == 0 { - return initial - } backoff := time.Duration(float64(initial) * math.Pow(2, float64(attempt-1))) if backoff > maxDelay { backoff = maxDelay diff --git a/pkg/retrypolicy/retrypolicy_test.go b/pkg/retrypolicy/retrypolicy_test.go index eeabb44f..8fa9eb5c 100644 --- a/pkg/retrypolicy/retrypolicy_test.go +++ b/pkg/retrypolicy/retrypolicy_test.go @@ -523,10 +523,8 @@ func TestComputeBackoff(t *testing.T) { initial := 5 * time.Second maxDelay := 1 * time.Minute - // attempt 0 => initial - if got := computeBackoff(0, initial, maxDelay); got != initial { - t.Errorf("attempt 0: got %v, want %v", got, initial) - } + // retry-go's OnRetry always supplies a 1-based attempt number, so 0 is + // not a value the function is ever called with in production. Start from 1. // attempt 1 => 5s * 2^0 = 5s if got := computeBackoff(1, initial, maxDelay); got != 5*time.Second { From 514f604a32c831795ae68f39075ae497dda3f63a Mon Sep 17 00:00:00 2001 From: Zhao Chen Date: Thu, 23 Apr 2026 12:50:30 +0800 Subject: [PATCH 10/16] refactor(backend): extract getAnnotationFilepath helper The 'prefer modelspec.AnnotationFilepath, fall back to the legacy dragonflyoss key' pattern was duplicated across fetch.go, fetch_by_d7y.go, and pull_by_d7y.go (six call sites total). Centralize it in backend.getAnnotationFilepath so each caller is a one-liner and future changes to the annotation resolution live in one place. Addresses review feedback on PR #468 (gemini). Signed-off-by: Zhao Chen --- pkg/backend/annotation.go | 36 ++++++++++++++++++++++++++++++++++++ pkg/backend/fetch.go | 16 ++-------------- pkg/backend/fetch_by_d7y.go | 25 +++---------------------- pkg/backend/pull_by_d7y.go | 20 ++------------------ 4 files changed, 43 insertions(+), 54 deletions(-) create mode 100644 pkg/backend/annotation.go diff --git a/pkg/backend/annotation.go b/pkg/backend/annotation.go new file mode 100644 index 00000000..60f3d710 --- /dev/null +++ b/pkg/backend/annotation.go @@ -0,0 +1,36 @@ +/* + * Copyright 2025 The CNAI Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package backend + +import ( + legacymodelspec "github.com/dragonflyoss/model-spec/specs-go/v1" + modelspec "github.com/modelpack/model-spec/specs-go/v1" +) + +// getAnnotationFilepath returns the filepath stored on a descriptor's +// annotations, preferring the modelpack key and falling back to the legacy +// dragonflyoss key so older artifacts remain readable. Returns empty string +// when neither key is present. +func getAnnotationFilepath(annotations map[string]string) string { + if annotations == nil { + return "" + } + if path := annotations[modelspec.AnnotationFilepath]; path != "" { + return path + } + return annotations[legacymodelspec.AnnotationFilepath] +} diff --git a/pkg/backend/fetch.go b/pkg/backend/fetch.go index aebbe947..e62f5c67 100644 --- a/pkg/backend/fetch.go +++ b/pkg/backend/fetch.go @@ -25,8 +25,6 @@ import ( "time" "github.com/bmatcuk/doublestar/v4" - legacymodelspec "github.com/dragonflyoss/model-spec/specs-go/v1" - modelspec "github.com/modelpack/model-spec/specs-go/v1" ocispec "github.com/opencontainers/image-spec/specs-go/v1" "github.com/sirupsen/logrus" "golang.org/x/sync/errgroup" @@ -78,10 +76,7 @@ func (b *backend) Fetch(ctx context.Context, target string, cfg *config.Fetch) e for _, layer := range manifest.Layers { for _, pattern := range cfg.Patterns { if anno := layer.Annotations; anno != nil { - path := anno[modelspec.AnnotationFilepath] - if path == "" { - path = anno[legacymodelspec.AnnotationFilepath] - } + path := getAnnotationFilepath(anno) // Use doublestar.PathMatch for pattern matching to support ** recursive matching // PathMatch uses the system's native path separator (like filepath.Match) while // also supporting recursive patterns like **/*.json @@ -120,14 +115,7 @@ func (b *backend) Fetch(ctx context.Context, target string, cfg *config.Fetch) e default: } - var annoFilepath string - if layer.Annotations != nil { - if layer.Annotations[modelspec.AnnotationFilepath] != "" { - annoFilepath = layer.Annotations[modelspec.AnnotationFilepath] - } else { - annoFilepath = layer.Annotations[legacymodelspec.AnnotationFilepath] - } - } + annoFilepath := getAnnotationFilepath(layer.Annotations) logrus.Debugf("fetch: processing layer %s", layer.Digest) if err := retrypolicy.Do(ctx, func(rctx context.Context) error { diff --git a/pkg/backend/fetch_by_d7y.go b/pkg/backend/fetch_by_d7y.go index 39a1248d..e6ff9433 100644 --- a/pkg/backend/fetch_by_d7y.go +++ b/pkg/backend/fetch_by_d7y.go @@ -31,8 +31,6 @@ import ( common "d7y.io/api/v2/pkg/apis/common/v2" dfdaemon "d7y.io/api/v2/pkg/apis/dfdaemon/v2" "github.com/bmatcuk/doublestar/v4" - legacymodelspec "github.com/dragonflyoss/model-spec/specs-go/v1" - modelspec "github.com/modelpack/model-spec/specs-go/v1" ocispec "github.com/opencontainers/image-spec/specs-go/v1" "github.com/sirupsen/logrus" "golang.org/x/sync/errgroup" @@ -81,10 +79,7 @@ func (b *backend) fetchByDragonfly(ctx context.Context, target string, cfg *conf for _, layer := range manifest.Layers { for _, pattern := range cfg.Patterns { if anno := layer.Annotations; anno != nil { - path := anno[modelspec.AnnotationFilepath] - if path == "" { - path = anno[legacymodelspec.AnnotationFilepath] - } + path := getAnnotationFilepath(anno) // Use doublestar.PathMatch for pattern matching to support ** recursive matching // PathMatch uses the system's native path separator (like filepath.Match) while // also supporting recursive patterns like **/*.json @@ -168,14 +163,7 @@ func (b *backend) fetchByDragonfly(ctx context.Context, target string, cfg *conf // fetchLayerByDragonfly handles downloading and extracting a single layer via Dragonfly. func fetchLayerByDragonfly(ctx context.Context, pb *internalpb.ProgressBar, client dfdaemon.DfdaemonDownloadClient, ref Referencer, manifest ocispec.Manifest, desc ocispec.Descriptor, authToken string, cfg *config.Fetch) error { - var annoFilepath string - if desc.Annotations != nil { - if desc.Annotations[modelspec.AnnotationFilepath] != "" { - annoFilepath = desc.Annotations[modelspec.AnnotationFilepath] - } else { - annoFilepath = desc.Annotations[legacymodelspec.AnnotationFilepath] - } - } + annoFilepath := getAnnotationFilepath(desc.Annotations) err := retrypolicy.Do(ctx, func(rctx context.Context) error { logrus.Debugf("fetch: processing layer %s", desc.Digest) @@ -217,14 +205,7 @@ func downloadAndExtractFetchLayer(ctx context.Context, pb *internalpb.ProgressBa return fmt.Errorf("failed to resolve output dir: %w", err) } - var annoFilepath string - if desc.Annotations != nil { - if desc.Annotations[modelspec.AnnotationFilepath] != "" { - annoFilepath = desc.Annotations[modelspec.AnnotationFilepath] - } else { - annoFilepath = desc.Annotations[legacymodelspec.AnnotationFilepath] - } - } + annoFilepath := getAnnotationFilepath(desc.Annotations) if annoFilepath == "" { return fmt.Errorf("missing annotation filepath") diff --git a/pkg/backend/pull_by_d7y.go b/pkg/backend/pull_by_d7y.go index 32782597..138be9ea 100644 --- a/pkg/backend/pull_by_d7y.go +++ b/pkg/backend/pull_by_d7y.go @@ -30,8 +30,6 @@ import ( common "d7y.io/api/v2/pkg/apis/common/v2" dfdaemon "d7y.io/api/v2/pkg/apis/dfdaemon/v2" - legacymodelspec "github.com/dragonflyoss/model-spec/specs-go/v1" - modelspec "github.com/modelpack/model-spec/specs-go/v1" ocispec "github.com/opencontainers/image-spec/specs-go/v1" "github.com/sirupsen/logrus" "golang.org/x/sync/errgroup" @@ -192,14 +190,7 @@ func buildBlobURL(ref Referencer, plainHTTP bool, digest string) string { // processLayer handles downloading and extracting a single layer. func processLayer(ctx context.Context, pb *internalpb.ProgressBar, client dfdaemon.DfdaemonDownloadClient, ref Referencer, manifest ocispec.Manifest, desc ocispec.Descriptor, authToken string, cfg *config.Pull) error { - var annoFilepath string - if desc.Annotations != nil { - if desc.Annotations[modelspec.AnnotationFilepath] != "" { - annoFilepath = desc.Annotations[modelspec.AnnotationFilepath] - } else { - annoFilepath = desc.Annotations[legacymodelspec.AnnotationFilepath] - } - } + annoFilepath := getAnnotationFilepath(desc.Annotations) err := retrypolicy.Do(ctx, func(rctx context.Context) error { logrus.Debugf("pull: processing layer %s", desc.Digest) @@ -236,14 +227,7 @@ func downloadAndExtractLayer(ctx context.Context, pb *internalpb.ProgressBar, cl return fmt.Errorf("failed to resolve extract dir: %w", err) } - var annoFilepath string - if desc.Annotations != nil { - if desc.Annotations[modelspec.AnnotationFilepath] != "" { - annoFilepath = desc.Annotations[modelspec.AnnotationFilepath] - } else { - annoFilepath = desc.Annotations[legacymodelspec.AnnotationFilepath] - } - } + annoFilepath := getAnnotationFilepath(desc.Annotations) if annoFilepath == "" { return fmt.Errorf("missing annotation filepath") From 973f5084f6f0e6fc96a612bdd33abedb1e7e42b4 Mon Sep 17 00:00:00 2001 From: Zhao Chen Date: Thu, 23 Apr 2026 12:50:30 +0800 Subject: [PATCH 11/16] fix(backend): surface errgroup cancellation and use Placeholder on retry Three call sites (push, pull, processor) used 'g.Wait()' without capturing its return, so a cancelled worker's ctx.Err() could be discarded. In the edge case where all running uploads/layers succeed but queued-but-unstarted workers exit via the ctx.Done() select, errs stayed empty and the operation continued to config/manifest push (or finalized a partial descriptor set), potentially publishing an incomplete artifact. Capture the Wait() result into errs and rely on the existing ctx.Err()/len(errs) checks to return cleanly. Also change processor/base.go's select from 'return nil' to 'return ctx.Err()' so the same propagation path applies there. Additionally, align push.go's OnRetry with pull.go/fetch.go by using pb.Placeholder (designed to reset an existing bar's message and progress) instead of pb.Add for layer, config, and manifest retries. Addresses review feedback on PR #468 (codex P1, gemini). Signed-off-by: Zhao Chen --- pkg/backend/processor/base.go | 10 ++++++++-- pkg/backend/pull.go | 8 +++++++- pkg/backend/push.go | 15 +++++++++++---- 3 files changed, 26 insertions(+), 7 deletions(-) diff --git a/pkg/backend/processor/base.go b/pkg/backend/processor/base.go index e62863bc..87549011 100644 --- a/pkg/backend/processor/base.go +++ b/pkg/backend/processor/base.go @@ -135,7 +135,7 @@ func (b *base) Process(ctx context.Context, builder build.Builder, workDir strin eg.Go(func() error { select { case <-ctx.Done(): - return nil + return ctx.Err() default: } @@ -189,7 +189,13 @@ func (b *base) Process(ctx context.Context, builder build.Builder, workDir strin }) } - _ = eg.Wait() + if werr := eg.Wait(); werr != nil { + // Surface cancellation from skipped workers so a partially-built + // artifact cannot be emitted as if it were complete. + mu.Lock() + errs = append(errs, werr) + mu.Unlock() + } if ctx.Err() != nil { return nil, fmt.Errorf("processing cancelled: %w", ctx.Err()) diff --git a/pkg/backend/pull.go b/pkg/backend/pull.go index 2038f99c..3666cf27 100644 --- a/pkg/backend/pull.go +++ b/pkg/backend/pull.go @@ -149,7 +149,13 @@ func (b *backend) Pull(ctx context.Context, target string, cfg *config.Pull) err }) } - _ = g.Wait() + if werr := g.Wait(); werr != nil { + // Surface cancellation from worker goroutines so a cancelled batch + // never slips through as an apparently successful pull. + mu.Lock() + errs = append(errs, werr) + mu.Unlock() + } if ctx.Err() != nil { return fmt.Errorf("pull cancelled: %w", ctx.Err()) } diff --git a/pkg/backend/push.go b/pkg/backend/push.go index d227f79c..7ea09adb 100644 --- a/pkg/backend/push.go +++ b/pkg/backend/push.go @@ -108,7 +108,7 @@ func (b *backend) Push(ctx context.Context, target string, cfg *config.Push) err OnRetry: func(attempt uint, reason string, backoff time.Duration) { prompt := fmt.Sprintf("%s (retry %d, %s, waiting %s)", internalpb.NormalizePrompt("Copying blob"), attempt, reason, backoff.Truncate(time.Second)) - pb.Add(prompt, layer.Digest.String(), layer.Size, nil) + pb.Placeholder(layer.Digest.String(), prompt, layer.Size) }, }); err != nil { mu.Lock() @@ -119,7 +119,14 @@ func (b *backend) Push(ctx context.Context, target string, cfg *config.Push) err }) } - g.Wait() + if werr := g.Wait(); werr != nil { + // Surface cancellation returned from worker goroutines so we never + // fall through to the config/manifest push with an incomplete set + // of layer uploads. + mu.Lock() + errs = append(errs, werr) + mu.Unlock() + } if ctx.Err() != nil { return fmt.Errorf("push cancelled: %w", ctx.Err()) } @@ -137,7 +144,7 @@ func (b *backend) Push(ctx context.Context, target string, cfg *config.Push) err OnRetry: func(attempt uint, reason string, backoff time.Duration) { prompt := fmt.Sprintf("%s (retry %d, %s, waiting %s)", internalpb.NormalizePrompt("Copying config"), attempt, reason, backoff.Truncate(time.Second)) - pb.Add(prompt, manifest.Config.Digest.String(), manifest.Config.Size, nil) + pb.Placeholder(manifest.Config.Digest.String(), prompt, manifest.Config.Size) }, }); err != nil { return fmt.Errorf("failed to push config to remote: %w", err) @@ -159,7 +166,7 @@ func (b *backend) Push(ctx context.Context, target string, cfg *config.Push) err OnRetry: func(attempt uint, reason string, backoff time.Duration) { prompt := fmt.Sprintf("%s (retry %d, %s, waiting %s)", internalpb.NormalizePrompt("Copying manifest"), attempt, reason, backoff.Truncate(time.Second)) - pb.Add(prompt, manifestDesc.Digest.String(), manifestDesc.Size, nil) + pb.Placeholder(manifestDesc.Digest.String(), prompt, manifestDesc.Size) }, }); err != nil { return fmt.Errorf("failed to push manifest to remote: %w", err) From fc3ac13156c30620e9d16952e4ca88a131680b56 Mon Sep 17 00:00:00 2001 From: Zhao Chen Date: Wed, 6 May 2026 16:32:01 +0800 Subject: [PATCH 12/16] feat(retrypolicy): decouple per-attempt timeout from retry budget MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The previous design used a single MaxRetryTime budget to bound both the in-flight transfer and the inter-attempt sleeps. With a wall-clock that scales with file size, a slow first attempt can consume the whole budget, leaving no room for retries — exactly when retries matter most. Split into two independent knobs: * PerAttemptTimeout: derived from file size assuming a 10 MiB/s minimum throughput with a 2x safety factor, clamped to [5min, 8h]. Each attempt gets its own context.WithTimeout, cancelled per attempt. * MaxAttempts + MaxBackoff: bound retry-only behavior. Defaults are constants (6 attempts, 2min backoff cap) and do not scale with file size — transient-failure recovery time is payload-independent. The retry loop classifies a per-attempt DeadlineExceeded under a live parent context as retryable, so a single transfer timeout no longer short-circuits the loop. User cancellation (parent ctx) still aborts immediately. CLI: --retry-max-time removed (semantics conflated transfer and retry) --retry-attempts new (int, total attempts including initial) --per-attempt-timeout new (Duration; 0 = derive from size, <0 = off) --no-retry unchanged Tests cover the size→timeout mapping with clamps, the per-attempt deadline retry loop, parent-cancel propagation, and a size-invariance test that pins the design's core property: total retry wall-clock does not depend on file size. Signed-off-by: Zhao Chen --- cmd/build.go | 3 +- cmd/fetch.go | 3 +- cmd/pull.go | 3 +- cmd/push.go | 3 +- pkg/retrypolicy/retrypolicy.go | 311 +++++++----- pkg/retrypolicy/retrypolicy_test.go | 715 +++++++++++++--------------- 6 files changed, 527 insertions(+), 511 deletions(-) diff --git a/cmd/build.go b/cmd/build.go index e4284fee..0eb8cbc5 100644 --- a/cmd/build.go +++ b/cmd/build.go @@ -62,7 +62,8 @@ func init() { flags.BoolVar(&buildConfig.Reasoning, "reasoning", false, "turning on this flag will mark this model as reasoning model in the config") flags.BoolVar(&buildConfig.NoCreationTime, "no-creation-time", false, "turning on this flag will not set createdAt in the config, which will be helpful for repeated builds") flags.BoolVar(&buildConfig.RetryConfig.NoRetry, "no-retry", false, "Disable retry on transient errors") - flags.DurationVar(&buildConfig.RetryConfig.MaxRetryTime, "retry-max-time", 0, "Max total retry time per file (0 = dynamic based on file size)") + flags.IntVar(&buildConfig.RetryConfig.MaxAttempts, "retry-attempts", 0, "Max total attempts per file (initial + retries; 0 = use default of 6)") + flags.DurationVar(&buildConfig.RetryConfig.PerAttemptTimeout, "per-attempt-timeout", 0, "Timeout for a single transfer attempt (0 = derive from file size; <0 = disabled)") if err := viper.BindPFlags(flags); err != nil { panic(fmt.Errorf("bind build flags to viper: %w", err)) diff --git a/cmd/fetch.go b/cmd/fetch.go index 3a2cf661..ce589c48 100644 --- a/cmd/fetch.go +++ b/cmd/fetch.go @@ -56,7 +56,8 @@ func init() { flags.StringSliceVar(&fetchConfig.Patterns, "patterns", []string{}, "specify the patterns for fetching the model artifact") flags.StringVar(&fetchConfig.DragonflyEndpoint, "dragonfly-endpoint", "", "specify the dragonfly endpoint for the pull operation, which will download and hardlink the blob by dragonfly GRPC service.") flags.BoolVar(&fetchConfig.RetryConfig.NoRetry, "no-retry", false, "Disable retry on transient errors") - flags.DurationVar(&fetchConfig.RetryConfig.MaxRetryTime, "retry-max-time", 0, "Max total retry time per file (0 = dynamic based on file size)") + flags.IntVar(&fetchConfig.RetryConfig.MaxAttempts, "retry-attempts", 0, "Max total attempts per file (initial + retries; 0 = use default of 6)") + flags.DurationVar(&fetchConfig.RetryConfig.PerAttemptTimeout, "per-attempt-timeout", 0, "Timeout for a single transfer attempt (0 = derive from file size; <0 = disabled)") if err := viper.BindPFlags(flags); err != nil { panic(fmt.Errorf("bind fetch flags to viper: %w", err)) diff --git a/cmd/pull.go b/cmd/pull.go index f6912c9d..0d70e43e 100644 --- a/cmd/pull.go +++ b/cmd/pull.go @@ -56,7 +56,8 @@ func init() { flags.BoolVar(&pullConfig.ExtractFromRemote, "extract-from-remote", false, "turning on this flag will pull and extract the data from remote registry and no longer store model artifact locally, so user must specify extract-dir as the output directory") flags.StringVar(&pullConfig.DragonflyEndpoint, "dragonfly-endpoint", "", "specify the dragonfly endpoint for the pull operation, which will download and hardlink the blob by dragonfly GRPC service, this mode requires extract-from-remote must be true") flags.BoolVar(&pullConfig.RetryConfig.NoRetry, "no-retry", false, "Disable retry on transient errors") - flags.DurationVar(&pullConfig.RetryConfig.MaxRetryTime, "retry-max-time", 0, "Max total retry time per file (0 = dynamic based on file size)") + flags.IntVar(&pullConfig.RetryConfig.MaxAttempts, "retry-attempts", 0, "Max total attempts per file (initial + retries; 0 = use default of 6)") + flags.DurationVar(&pullConfig.RetryConfig.PerAttemptTimeout, "per-attempt-timeout", 0, "Timeout for a single transfer attempt (0 = derive from file size; <0 = disabled)") if err := viper.BindPFlags(flags); err != nil { panic(fmt.Errorf("bind pull flags to viper: %w", err)) diff --git a/cmd/push.go b/cmd/push.go index 4f5f5961..936dd5a4 100644 --- a/cmd/push.go +++ b/cmd/push.go @@ -54,7 +54,8 @@ func init() { flags.BoolVar(&pushConfig.Nydusify, "nydusify", false, "[EXPERIMENTAL] nydusify the model artifact") flags.MarkHidden("nydusify") flags.BoolVar(&pushConfig.RetryConfig.NoRetry, "no-retry", false, "Disable retry on transient errors") - flags.DurationVar(&pushConfig.RetryConfig.MaxRetryTime, "retry-max-time", 0, "Max total retry time per file (0 = dynamic based on file size)") + flags.IntVar(&pushConfig.RetryConfig.MaxAttempts, "retry-attempts", 0, "Max total attempts per file (initial + retries; 0 = use default of 6)") + flags.DurationVar(&pushConfig.RetryConfig.PerAttemptTimeout, "per-attempt-timeout", 0, "Timeout for a single transfer attempt (0 = derive from file size; <0 = disabled)") if err := viper.BindPFlags(flags); err != nil { panic(fmt.Errorf("bind push flags to viper: %w", err)) diff --git a/pkg/retrypolicy/retrypolicy.go b/pkg/retrypolicy/retrypolicy.go index 54709744..ded9b7b1 100644 --- a/pkg/retrypolicy/retrypolicy.go +++ b/pkg/retrypolicy/retrypolicy.go @@ -1,5 +1,5 @@ /* - * Copyright 2024 The CNAI Authors + * Copyright 2024 The ModelPack Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,6 +14,22 @@ * limitations under the License. */ +// Package retrypolicy provides retry behavior for blob transfer operations. +// +// The package decouples two timing concerns that are commonly conflated: +// +// - Per-attempt timeout: how long a single transfer attempt may take. +// This scales with file size, since larger files need longer to transfer. +// +// - Retry policy: how many attempts to make and how long to wait between +// them. This is bounded by attempt count and per-sleep cap; it does not +// scale with file size, because transient-failure recovery time is +// independent of payload size. +// +// Earlier designs used a single wall-clock budget covering both transfer +// time and retry waits, which made retries scarce precisely when networks +// were slow. The current design fixes that by giving each concern its own +// knob. package retrypolicy import ( @@ -30,95 +46,165 @@ import ( ) const ( - oneGB = 1 << 30 // 1 GiB in bytes - tenGB = 10 << 30 - nineGB = tenGB - oneGB - - minMaxRetryTime = 10 * time.Minute - maxMaxRetryTime = 60 * time.Minute - - minMaxBackoff = 1 * time.Minute - maxMaxBackoff = 10 * time.Minute - - initialDelay = 5 * time.Second - maxJitter = 5 * time.Second - - // maxBackoff cap when derived from user-specified MaxRetryTime - absoluteMaxBackoff = 10 * time.Minute + // DefaultMaxAttempts is the total number of attempts (initial + retries). + DefaultMaxAttempts = 6 + + // DefaultInitialDelay is the first sleep between attempts. + DefaultInitialDelay = 5 * time.Second + + // DefaultMaxBackoff caps a single sleep between attempts. It does not + // scale with file size: transient outages have a payload-independent + // duration distribution. + DefaultMaxBackoff = 2 * time.Minute + + // DefaultMaxJitter is the upper bound on randomized jitter added to each + // sleep. + DefaultMaxJitter = 5 * time.Second + + // minThroughput is the assumed worst-case usable throughput when + // computing per-attempt timeouts. Networks slower than this are out of + // scope; users on such links should set Config.PerAttemptTimeout + // explicitly. + minThroughput = 10 * (1 << 20) // 10 MiB/s + + // safetyFactor multiplies the ideal transfer time when sizing + // per-attempt timeouts, leaving headroom for protocol overhead and + // short-lived speed dips. + safetyFactor = 2 + + // minPerAttemptTimeout is the floor for derived per-attempt timeouts — + // small blobs (manifests, configs) still need enough time for TLS, + // auth, and slow-start. + minPerAttemptTimeout = 5 * time.Minute + + // maxPerAttemptTimeout is the ceiling. Very large files (e.g. 100GB+ + // LLM shards) will still hit this cap, after which the user should + // override via Config.PerAttemptTimeout. + maxPerAttemptTimeout = 8 * time.Hour ) // Config holds user-configurable retry parameters from CLI flags. +// +// The zero value is valid and yields production defaults: 6 attempts, with +// per-attempt timeout derived from file size and exponential backoff up to +// DefaultMaxBackoff. type Config struct { - MaxRetryTime time.Duration // 0 = dynamic based on file size - NoRetry bool // disable retry entirely - InitialDelay time.Duration // 0 = use default (5s), for testing - MaxJitter time.Duration // -1 = no jitter, 0 = use default (5s), for testing + // MaxAttempts is the total number of attempts (initial + retries). + // 0 means "use DefaultMaxAttempts". Use NoRetry to call fn exactly + // once. + MaxAttempts int + + // PerAttemptTimeout is the maximum duration for a single attempt. + // 0 → derive from file size (see ComputePerAttemptTimeout). + // <0 → no per-attempt timeout (caller fully controls deadlines). + // >0 → use this value verbatim. + PerAttemptTimeout time.Duration + + // NoRetry disables retry entirely; fn is called once with the parent + // context. + NoRetry bool + + // InitialDelay overrides the first inter-attempt sleep. 0 = default. + // Primarily for tests. + InitialDelay time.Duration + + // MaxBackoff overrides the per-sleep cap. 0 = default. + MaxBackoff time.Duration + + // MaxJitter: -1 = no jitter, 0 = default, >0 = override. For tests. + MaxJitter time.Duration } // DoOpts configures a single Do call. type DoOpts struct { - FileSize int64 // for dynamic parameter calculation - FileName string // for logging - Config *Config - OnRetry func(attempt uint, reason string, backoff time.Duration) + // FileSize sizes the per-attempt timeout when Config.PerAttemptTimeout + // is unset. May be 0 for non-blob operations (manifest, config); the + // timeout will then clamp to minPerAttemptTimeout. + FileSize int64 + + // FileName is logged on each retry. + FileName string + + // Config is the user-supplied policy. nil means defaults. + Config *Config + + // OnRetry is invoked before each sleep (after a failed attempt) so + // callers can update progress UI. attempt is 1-based; reason is a + // short label from ShortReason. + OnRetry func(attempt uint, reason string, backoff time.Duration) } -// Do executes fn with retry. It computes dynamic retry parameters from fileSize, -// creates an internal deadline context (and defers its cancel to prevent leak), -// sets up retry logging, and calls retry.Do. -// The parent ctx is only used for user-initiated cancellation. +// Do executes fn with retry. Each attempt runs under its own deadline +// derived from PerAttemptTimeout (or file size); retries are bounded by +// MaxAttempts and exponential backoff capped at MaxBackoff. The parent ctx +// is honored for user-initiated cancellation only — its expiry is not +// coupled to per-attempt transfer time. func Do(ctx context.Context, fn func(ctx context.Context) error, opts DoOpts) error { cfg := opts.Config if cfg == nil { cfg = &Config{} } - // NoRetry: call fn once with the parent context, return the result. if cfg.NoRetry { return fn(ctx) } - maxRetryTime, maxBackoff := computeDynamicParams(opts.FileSize) - - // Override with user-specified MaxRetryTime if set. - if cfg.MaxRetryTime > 0 { - maxRetryTime = cfg.MaxRetryTime - maxBackoff = cfg.MaxRetryTime / 6 - if maxBackoff > absoluteMaxBackoff { - maxBackoff = absoluteMaxBackoff - } + maxAttempts := cfg.MaxAttempts + if maxAttempts <= 0 { + maxAttempts = DefaultMaxAttempts } - startTime := time.Now() - deadlineCtx, deadlineCancel := context.WithDeadline(ctx, startTime.Add(maxRetryTime)) - defer deadlineCancel() + perAttemptTimeout := cfg.PerAttemptTimeout + switch { + case perAttemptTimeout == 0: + perAttemptTimeout = ComputePerAttemptTimeout(opts.FileSize) + case perAttemptTimeout < 0: + perAttemptTimeout = 0 // disabled + } - sizeStr := humanizeBytes(opts.FileSize) + initialDelay := cfg.InitialDelay + if initialDelay <= 0 { + initialDelay = DefaultInitialDelay + } - delay := initialDelay - jitter := maxJitter - if cfg.InitialDelay > 0 { - delay = cfg.InitialDelay + maxBackoff := cfg.MaxBackoff + if maxBackoff <= 0 { + maxBackoff = DefaultMaxBackoff } + + jitter := DefaultMaxJitter if cfg.MaxJitter < 0 { jitter = 0 } else if cfg.MaxJitter > 0 { jitter = cfg.MaxJitter } + sizeStr := humanizeBytes(opts.FileSize) + startTime := time.Now() + return retry.Do( func() error { - return fn(deadlineCtx) + attemptCtx := ctx + if perAttemptTimeout > 0 { + var cancel context.CancelFunc + attemptCtx, cancel = context.WithTimeout(ctx, perAttemptTimeout) + defer cancel() + } + return fn(attemptCtx) }, - retry.Attempts(0), - retry.Context(deadlineCtx), + retry.Attempts(uint(maxAttempts)), + retry.Context(ctx), retry.DelayType(retry.BackOffDelay), - retry.Delay(delay), + retry.Delay(initialDelay), retry.MaxDelay(maxBackoff), retry.MaxJitter(jitter), retry.LastErrorOnly(true), - retry.WrapContextErrorWithLastError(true), retry.RetryIf(func(err error) bool { + // Per-attempt timeout fired but parent ctx is alive: this is a + // transient transfer timeout, not a user cancellation. Retry. + if errors.Is(err, context.DeadlineExceeded) && ctx.Err() == nil { + return true + } retryable := IsRetryable(err) if !retryable { log.WithFields(log.Fields{ @@ -130,52 +216,67 @@ func Do(ctx context.Context, fn func(ctx context.Context) error, opts DoOpts) er return retryable }), retry.OnRetry(func(n uint, err error) { - backoff := computeBackoff(n, initialDelay, maxBackoff) + // retry-go calls OnRetry with n = 0-based retry index. Convert + // to 1-based for both the log and the user callback. + attempt := n + 1 + backoff := computeBackoff(attempt, initialDelay, maxBackoff) elapsed := time.Since(startTime) log.WithFields(log.Fields{ - "file": opts.FileName, - "size": sizeStr, - "error": err.Error(), - "max_retry_time": maxRetryTime.String(), - "max_backoff": maxBackoff.String(), - "next_retry_in": backoff.Truncate(time.Second).String(), - "elapsed": fmt.Sprintf("%s / %s", elapsed.Truncate(time.Second), maxRetryTime), - }).Warnf("[RETRY] attempt %d for %q (%s)", n+1, opts.FileName, sizeStr) + "file": opts.FileName, + "size": sizeStr, + "error": err.Error(), + "max_attempts": maxAttempts, + "max_backoff": maxBackoff.String(), + "per_attempt_to": perAttemptTimeout.String(), + "next_retry_in": backoff.Truncate(time.Second).String(), + "elapsed": elapsed.Truncate(time.Second).String(), + }).Warnf("[RETRY] attempt %d/%d for %q (%s)", attempt, maxAttempts, opts.FileName, sizeStr) if opts.OnRetry != nil { reason := ShortReason(err) - opts.OnRetry(n+1, reason, backoff) + opts.OnRetry(attempt, reason, backoff) } }), ) } -// computeDynamicParams calculates maxRetryTime and maxBackoff based on file size. +// ComputePerAttemptTimeout estimates a single-attempt transfer deadline from +// file size, assuming minThroughput as a worst-case-but-usable rate and +// applying safetyFactor for headroom. The result is clamped to +// [minPerAttemptTimeout, maxPerAttemptTimeout]. +// +// Examples (rounded): // -// For files <= 1 GB: maxRetryTime=10min, maxBackoff=1min -// For files >= 10 GB: maxRetryTime=60min, maxBackoff=10min -// Linear interpolation between. -func computeDynamicParams(fileSize int64) (time.Duration, time.Duration) { - ratio := float64(fileSize-oneGB) / float64(nineGB) - if ratio < 0 { - ratio = 0 +// 1 GB → 5 min (floor) +// 10 GB → 34 min +// 70 GB → ~4 h +// 140 GB → ~8 h (ceiling) +// +// fileSize <= 0 returns minPerAttemptTimeout. +func ComputePerAttemptTimeout(fileSize int64) time.Duration { + if fileSize <= 0 { + return minPerAttemptTimeout } - if ratio > 1 { - ratio = 1 + secs := float64(fileSize) / float64(minThroughput) * safetyFactor + t := time.Duration(secs * float64(time.Second)) + if t < minPerAttemptTimeout { + return minPerAttemptTimeout } - - maxRetryTime := minMaxRetryTime + time.Duration(ratio*float64(maxMaxRetryTime-minMaxRetryTime)) - maxBackoff := minMaxBackoff + time.Duration(ratio*float64(maxMaxBackoff-minMaxBackoff)) - - return maxRetryTime, maxBackoff + if t > maxPerAttemptTimeout { + return maxPerAttemptTimeout + } + return t } // computeBackoff estimates the backoff duration for display purposes. -// It mirrors the exponential backoff calculation without jitter. -// The attempt parameter is the 1-based retry number as provided by -// retry-go's OnRetry callback, so it is always >= 1. +// It mirrors retry-go's exponential schedule (without jitter). +// attempt is 1-based: the first sleep (after attempt 1 fails) is +// initialDelay, the second is 2*initialDelay, capped at maxDelay. func computeBackoff(attempt uint, initial, maxDelay time.Duration) time.Duration { + if attempt == 0 { + return initial + } backoff := time.Duration(float64(initial) * math.Pow(2, float64(attempt-1))) if backoff > maxDelay { backoff = maxDelay @@ -187,51 +288,41 @@ func computeBackoff(attempt uint, initial, maxDelay time.Duration) time.Duration var httpStatusPattern = regexp.MustCompile(`response status code (\d{3})`) // IsRetryable returns true for transient errors that warrant a retry. +// +// context.Canceled and bare context.DeadlineExceeded are not retryable here: +// the Do loop independently re-classifies a per-attempt timeout while the +// parent context is still alive as retryable, so this function only sees +// genuine cancellation. func IsRetryable(err error) bool { if err == nil { return false } - // context.Canceled and context.DeadlineExceeded are never retryable. if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { return false } errMsg := err.Error() - // Check for HTTP status codes embedded in error messages (ORAS style). if matches := httpStatusPattern.FindStringSubmatch(errMsg); len(matches) == 2 { code := matches[1] - // 5xx server errors are retryable. if code[0] == '5' { return true } - // 408 (Request Timeout) and 429 (Too Many Requests) are retryable. if code == "408" || code == "429" { return true } - // Other 4xx are not retryable (401, 403, 404, etc.) return false } - // Network-level transient errors. - if strings.Contains(errMsg, "i/o timeout") { - return true - } - if strings.Contains(errMsg, "connection reset by peer") { - return true - } - if strings.Contains(errMsg, "connection refused") { - return true - } - if strings.Contains(errMsg, "broken pipe") { - return true - } - if strings.Contains(errMsg, "EOF") { + if strings.Contains(errMsg, "i/o timeout") || + strings.Contains(errMsg, "connection reset by peer") || + strings.Contains(errMsg, "connection refused") || + strings.Contains(errMsg, "broken pipe") || + strings.Contains(errMsg, "EOF") { return true } - // Local / permanent errors are not retryable. if strings.Contains(errMsg, "permission denied") || strings.Contains(errMsg, "no space left on device") || strings.Contains(errMsg, "file exists") || @@ -242,12 +333,12 @@ func IsRetryable(err error) bool { return false } - // Unknown errors default to retryable with a warning. log.WithField("error", errMsg).Warn("[RETRY] unknown error treated as retryable") return true } -// ShortReason extracts a brief human-readable label from an error for progress bar display. +// ShortReason extracts a brief human-readable label from an error for +// progress bar display. func ShortReason(err error) string { if err == nil { return "" @@ -255,25 +346,23 @@ func ShortReason(err error) string { errMsg := err.Error() - // Check for HTTP status codes. if matches := httpStatusPattern.FindStringSubmatch(errMsg); len(matches) == 2 { return "HTTP " + matches[1] } - if strings.Contains(errMsg, "i/o timeout") { + switch { + case strings.Contains(errMsg, "i/o timeout"): return "i/o timeout" - } - if strings.Contains(errMsg, "connection reset by peer") { + case strings.Contains(errMsg, "connection reset by peer"): return "conn reset" - } - if strings.Contains(errMsg, "connection refused") { + case strings.Contains(errMsg, "connection refused"): return "conn refused" - } - if strings.Contains(errMsg, "broken pipe") { + case strings.Contains(errMsg, "broken pipe"): return "broken pipe" - } - if strings.Contains(errMsg, "EOF") { + case strings.Contains(errMsg, "EOF"): return "EOF" + case errors.Is(err, context.DeadlineExceeded): + return "attempt timeout" } return "unknown error" diff --git a/pkg/retrypolicy/retrypolicy_test.go b/pkg/retrypolicy/retrypolicy_test.go index 8fa9eb5c..c940261a 100644 --- a/pkg/retrypolicy/retrypolicy_test.go +++ b/pkg/retrypolicy/retrypolicy_test.go @@ -1,5 +1,5 @@ /* - * Copyright 2024 The CNAI Authors + * Copyright 2024 The ModelPack Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,96 +20,48 @@ import ( "context" "errors" "fmt" - "net" "sync/atomic" "testing" "time" ) -// --- helpers for tests --- +// --- ComputePerAttemptTimeout --- -// timeoutError implements net.Error with Timeout() returning true. -type timeoutError struct { - msg string -} - -func (e *timeoutError) Error() string { return e.msg } -func (e *timeoutError) Timeout() bool { return true } -func (e *timeoutError) Temporary() bool { return true } - -// --- computeDynamicParams tests --- - -func TestComputeDynamicParams(t *testing.T) { +func TestComputePerAttemptTimeout(t *testing.T) { + const ( + oneMB = int64(1) << 20 + oneGB = int64(1) << 30 + tenGB = int64(10) << 30 + hundGB = int64(100) << 30 + ) tests := []struct { - name string - fileSize int64 - wantRetryTime time.Duration - wantMaxBackoff time.Duration + name string + size int64 + want time.Duration }{ - { - name: "zero bytes - clamped to minimum", - fileSize: 0, - wantRetryTime: 10 * time.Minute, - wantMaxBackoff: 1 * time.Minute, - }, - { - name: "500 MB - below 1 GB, clamped to minimum", - fileSize: 500 * 1024 * 1024, - wantRetryTime: 10 * time.Minute, - wantMaxBackoff: 1 * time.Minute, - }, - { - name: "1 GB - boundary, ratio=0, minimum values", - fileSize: 1 << 30, - wantRetryTime: 10 * time.Minute, - wantMaxBackoff: 1 * time.Minute, - }, - { - name: "5.5 GB - midpoint, interpolated values", - fileSize: int64(5.5 * float64(1<<30)), - wantRetryTime: 35 * time.Minute, - wantMaxBackoff: 5*time.Minute + 30*time.Second, - }, - { - name: "10 GB - boundary, ratio=1, maximum values", - fileSize: 10 << 30, - wantRetryTime: 60 * time.Minute, - wantMaxBackoff: 10 * time.Minute, - }, - { - name: "20 GB - above 10 GB, clamped to maximum", - fileSize: 20 << 30, - wantRetryTime: 60 * time.Minute, - wantMaxBackoff: 10 * time.Minute, - }, + {"zero size clamps to floor", 0, minPerAttemptTimeout}, + {"100 MB clamps to floor", 100 * oneMB, minPerAttemptTimeout}, + {"1 GB clamps to floor", oneGB, minPerAttemptTimeout}, + // 5 GB / 10 MB/s * 2 = 1024s ≈ 17 min, above the 5min floor + {"5 GB scales above floor", 5 * oneGB, time.Duration(5*oneGB/minThroughput*safetyFactor) * time.Second}, + // 10 GB / 10 MB/s * 2 = 2048s ≈ 34 min + {"10 GB scales linearly", tenGB, time.Duration(tenGB/minThroughput*safetyFactor) * time.Second}, + // 100 GB / 10 MB/s * 2 = 20480s ≈ 5.7h, still under the 8h ceiling + {"100 GB still under ceiling", hundGB, time.Duration(hundGB/minThroughput*safetyFactor) * time.Second}, + // 200 GB hits ceiling + {"200 GB clamps to ceiling", 2 * hundGB, maxPerAttemptTimeout}, } - for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotRetryTime, gotMaxBackoff := computeDynamicParams(tt.fileSize) - - // Allow 1 second tolerance for floating point interpolation. - retryTimeDiff := absDuration(gotRetryTime - tt.wantRetryTime) - if retryTimeDiff > time.Second { - t.Errorf("maxRetryTime = %v, want %v (diff %v)", gotRetryTime, tt.wantRetryTime, retryTimeDiff) - } - - backoffDiff := absDuration(gotMaxBackoff - tt.wantMaxBackoff) - if backoffDiff > time.Second { - t.Errorf("maxBackoff = %v, want %v (diff %v)", gotMaxBackoff, tt.wantMaxBackoff, backoffDiff) + got := ComputePerAttemptTimeout(tt.size) + if got != tt.want { + t.Errorf("ComputePerAttemptTimeout(%d) = %v, want %v", tt.size, got, tt.want) } }) } } -func absDuration(d time.Duration) time.Duration { - if d < 0 { - return -d - } - return d -} - -// --- IsRetryable tests --- +// --- IsRetryable --- func TestIsRetryable(t *testing.T) { tests := []struct { @@ -117,138 +69,37 @@ func TestIsRetryable(t *testing.T) { err error want bool }{ - { - name: "nil error", - err: nil, - want: false, - }, - { - name: "context.Canceled", - err: context.Canceled, - want: false, - }, - { - name: "wrapped context.Canceled", - err: fmt.Errorf("operation failed: %w", context.Canceled), - want: false, - }, - { - name: "context.DeadlineExceeded", - err: context.DeadlineExceeded, - want: false, - }, - { - name: "wrapped context.DeadlineExceeded", - err: fmt.Errorf("operation timed out: %w", context.DeadlineExceeded), - want: false, - }, - { - name: "HTTP 500 server error", - err: fmt.Errorf("PUT /blobs/uploads: response status code 500: internal server error"), - want: true, - }, - { - name: "HTTP 502 bad gateway", - err: fmt.Errorf("response status code 502: bad gateway"), - want: true, - }, - { - name: "HTTP 503 service unavailable", - err: fmt.Errorf("response status code 503: service unavailable"), - want: true, - }, - { - name: "HTTP 408 request timeout", - err: fmt.Errorf("response status code 408: request timeout"), - want: true, - }, - { - name: "HTTP 429 too many requests", - err: fmt.Errorf("response status code 429: too many requests"), - want: true, - }, - { - name: "HTTP 401 unauthorized - not retryable", - err: fmt.Errorf("response status code 401: unauthorized"), - want: false, - }, - { - name: "HTTP 403 forbidden - not retryable", - err: fmt.Errorf("response status code 403: access denied"), - want: false, - }, - { - name: "HTTP 404 not found - not retryable", - err: fmt.Errorf("response status code 404: not found"), - want: false, - }, - { - name: "i/o timeout", - err: &net.OpError{ - Op: "read", - Net: "tcp", - Err: &timeoutError{msg: "i/o timeout"}, - }, - want: true, - }, - { - name: "i/o timeout in wrapped error message", - err: fmt.Errorf("read tcp 10.0.0.1:1234->10.0.0.2:443: i/o timeout"), - want: true, - }, - { - name: "connection reset by peer", - err: fmt.Errorf("read tcp: connection reset by peer"), - want: true, - }, - { - name: "connection refused", - err: fmt.Errorf("dial tcp 10.0.0.1:443: connection refused"), - want: true, - }, - { - name: "broken pipe", - err: fmt.Errorf("write tcp: broken pipe"), - want: true, - }, - { - name: "EOF", - err: fmt.Errorf("unexpected EOF"), - want: true, - }, - { - name: "permission denied - not retryable", - err: fmt.Errorf("open /data/model.bin: permission denied"), - want: false, - }, - { - name: "no space left on device - not retryable", - err: fmt.Errorf("write /data/model.bin: no space left on device"), - want: false, - }, - { - name: "no such file or directory - not retryable", - err: fmt.Errorf("open /data/model.bin: no such file or directory"), - want: false, - }, - { - name: "unknown error - defaults to retryable", - err: errors.New("something totally unexpected happened"), - want: true, - }, + {"nil", nil, false}, + {"context.Canceled", context.Canceled, false}, + {"context.DeadlineExceeded", context.DeadlineExceeded, false}, + {"5xx server error", errors.New("response status code 500"), true}, + {"503 service unavailable", errors.New("response status code 503: Service Unavailable"), true}, + {"408 request timeout", errors.New("response status code 408"), true}, + {"429 too many requests", errors.New("response status code 429"), true}, + {"401 unauthorized", errors.New("response status code 401"), false}, + {"403 forbidden", errors.New("response status code 403"), false}, + {"404 not found", errors.New("response status code 404"), false}, + {"i/o timeout", errors.New("dial tcp: i/o timeout"), true}, + {"connection reset", errors.New("read tcp: connection reset by peer"), true}, + {"connection refused", errors.New("dial tcp: connection refused"), true}, + {"broken pipe", errors.New("write tcp: broken pipe"), true}, + {"EOF", errors.New("unexpected EOF"), true}, + {"permission denied", errors.New("open /etc/foo: permission denied"), false}, + {"no space left", errors.New("write /tmp/x: no space left on device"), false}, + {"file exists", errors.New("link /a /b: file exists"), false}, + {"no such file", errors.New("open /no/such: no such file or directory"), false}, + {"unknown defaults retryable", errors.New("some weird error"), true}, } - for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := IsRetryable(tt.err) - if got != tt.want { + if got := IsRetryable(tt.err); got != tt.want { t.Errorf("IsRetryable(%v) = %v, want %v", tt.err, got, tt.want) } }) } } -// --- ShortReason tests --- +// --- ShortReason --- func TestShortReason(t *testing.T) { tests := []struct { @@ -256,293 +107,365 @@ func TestShortReason(t *testing.T) { err error want string }{ - { - name: "nil error", - err: nil, - want: "", - }, - { - name: "HTTP 500", - err: fmt.Errorf("response status code 500: internal server error"), - want: "HTTP 500", - }, - { - name: "HTTP 502", - err: fmt.Errorf("response status code 502: bad gateway"), - want: "HTTP 502", - }, - { - name: "HTTP 429", - err: fmt.Errorf("response status code 429: too many requests"), - want: "HTTP 429", - }, - { - name: "HTTP 408", - err: fmt.Errorf("response status code 408: request timeout"), - want: "HTTP 408", - }, - { - name: "i/o timeout", - err: fmt.Errorf("read tcp: i/o timeout"), - want: "i/o timeout", - }, - { - name: "connection reset", - err: fmt.Errorf("read tcp: connection reset by peer"), - want: "conn reset", - }, - { - name: "connection refused", - err: fmt.Errorf("dial tcp: connection refused"), - want: "conn refused", - }, - { - name: "broken pipe", - err: fmt.Errorf("write tcp: broken pipe"), - want: "broken pipe", - }, - { - name: "EOF", - err: fmt.Errorf("unexpected EOF"), - want: "EOF", - }, - { - name: "unknown error", - err: errors.New("some weird error"), - want: "unknown error", - }, + {"nil", nil, ""}, + {"5xx", errors.New("response status code 503"), "HTTP 503"}, + {"i/o timeout", errors.New("dial tcp: i/o timeout"), "i/o timeout"}, + {"conn reset", errors.New("read tcp: connection reset by peer"), "conn reset"}, + {"conn refused", errors.New("dial: connection refused"), "conn refused"}, + {"broken pipe", errors.New("write: broken pipe"), "broken pipe"}, + {"EOF", errors.New("unexpected EOF"), "EOF"}, + {"DeadlineExceeded", context.DeadlineExceeded, "attempt timeout"}, + {"unknown", errors.New("totally unrelated"), "unknown error"}, } - for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := ShortReason(tt.err) - if got != tt.want { + if got := ShortReason(tt.err); got != tt.want { t.Errorf("ShortReason(%v) = %q, want %q", tt.err, got, tt.want) } }) } } -// --- Do tests --- +// --- computeBackoff --- + +func TestComputeBackoff(t *testing.T) { + const initial = 1 * time.Second + const cap_ = 10 * time.Second + tests := []struct { + attempt uint + want time.Duration + }{ + {1, 1 * time.Second}, // first sleep + {2, 2 * time.Second}, // doubled + {3, 4 * time.Second}, + {4, 8 * time.Second}, + {5, 10 * time.Second}, // capped + {20, 10 * time.Second}, + } + for _, tt := range tests { + t.Run(fmt.Sprintf("attempt=%d", tt.attempt), func(t *testing.T) { + got := computeBackoff(tt.attempt, initial, cap_) + if got != tt.want { + t.Errorf("computeBackoff(%d) = %v, want %v", tt.attempt, got, tt.want) + } + }) + } +} + +// --- Do: defaults & success path --- func TestDo_SuccessFirstAttempt(t *testing.T) { - callCount := int32(0) + calls := 0 err := Do(context.Background(), func(ctx context.Context) error { - atomic.AddInt32(&callCount, 1) + calls++ return nil - }, DoOpts{ - FileSize: 100, - FileName: "test.bin", - }) - + }, DoOpts{FileName: "ok"}) if err != nil { - t.Fatalf("expected nil error, got %v", err) + t.Fatalf("Do returned %v", err) } - if atomic.LoadInt32(&callCount) != 1 { - t.Fatalf("expected fn to be called once, got %d", callCount) + if calls != 1 { + t.Errorf("calls = %d, want 1", calls) } } -func TestDo_RetryOnTransientError(t *testing.T) { - callCount := int32(0) +// --- Do: NoRetry --- + +func TestDo_NoRetry(t *testing.T) { + calls := 0 + transient := errors.New("response status code 503") err := Do(context.Background(), func(ctx context.Context) error { - n := atomic.AddInt32(&callCount, 1) - if n < 3 { - return fmt.Errorf("response status code 500: internal server error") - } - return nil + calls++ + return transient }, DoOpts{ - FileSize: 100, - FileName: "test.bin", - Config: &Config{InitialDelay: 10 * time.Millisecond, MaxJitter: -1}, + FileName: "noretry", + Config: &Config{NoRetry: true}, }) - - if err != nil { - t.Fatalf("expected nil error after retries, got %v", err) + if !errors.Is(err, transient) { + t.Errorf("err = %v, want %v", err, transient) } - if atomic.LoadInt32(&callCount) != 3 { - t.Fatalf("expected fn to be called 3 times, got %d", callCount) + if calls != 1 { + t.Errorf("calls = %d, want 1 (NoRetry)", calls) } } -func TestDo_NoRetryOnPermanentError(t *testing.T) { - callCount := int32(0) +// --- Do: MaxAttempts caps the number of tries --- + +func TestDo_MaxAttempts(t *testing.T) { + var calls int32 err := Do(context.Background(), func(ctx context.Context) error { - atomic.AddInt32(&callCount, 1) - return fmt.Errorf("response status code 404: not found") + atomic.AddInt32(&calls, 1) + return errors.New("response status code 503") }, DoOpts{ - FileSize: 100, - FileName: "test.bin", + FileName: "always-fails", + Config: &Config{ + MaxAttempts: 3, + InitialDelay: 1 * time.Millisecond, + MaxBackoff: 1 * time.Millisecond, + MaxJitter: -1, + }, }) + if err == nil { + t.Fatal("Do returned nil, want error after attempts exhausted") + } + if got := atomic.LoadInt32(&calls); got != 3 { + t.Errorf("calls = %d, want 3", got) + } +} + +// --- Do: non-retryable error stops immediately --- +func TestDo_NonRetryableStopsImmediately(t *testing.T) { + var calls int32 + permErr := errors.New("permission denied") + err := Do(context.Background(), func(ctx context.Context) error { + atomic.AddInt32(&calls, 1) + return permErr + }, DoOpts{ + FileName: "perm", + Config: &Config{ + MaxAttempts: 5, + InitialDelay: 1 * time.Millisecond, + MaxJitter: -1, + }, + }) if err == nil { - t.Fatal("expected error, got nil") + t.Fatal("Do returned nil, want non-retryable error") } - if atomic.LoadInt32(&callCount) != 1 { - t.Fatalf("expected fn to be called once for non-retryable error, got %d", callCount) + if got := atomic.LoadInt32(&calls); got != 1 { + t.Errorf("calls = %d, want 1 (non-retryable)", got) } } -func TestDo_NoRetryConfig(t *testing.T) { - callCount := int32(0) +// --- Do: per-attempt timeout fires & retries continue --- + +func TestDo_PerAttemptTimeoutTriggersRetry(t *testing.T) { + var calls int32 + const succeededOn int32 = 3 err := Do(context.Background(), func(ctx context.Context) error { - atomic.AddInt32(&callCount, 1) - return fmt.Errorf("response status code 500: internal server error") + n := atomic.AddInt32(&calls, 1) + if n < succeededOn { + <-ctx.Done() + return ctx.Err() + } + return nil }, DoOpts{ - FileSize: 100, - FileName: "test.bin", - Config: &Config{NoRetry: true}, + FileName: "slow", + Config: &Config{ + MaxAttempts: 5, + PerAttemptTimeout: 30 * time.Millisecond, + InitialDelay: 1 * time.Millisecond, + MaxBackoff: 1 * time.Millisecond, + MaxJitter: -1, + }, }) + if err != nil { + t.Fatalf("Do returned %v, want success after retries", err) + } + if got := atomic.LoadInt32(&calls); got != succeededOn { + t.Errorf("calls = %d, want %d", got, succeededOn) + } +} + +// --- Do: per-attempt timeout exhausts after MaxAttempts --- +func TestDo_PerAttemptTimeoutExhausts(t *testing.T) { + var calls int32 + err := Do(context.Background(), func(ctx context.Context) error { + atomic.AddInt32(&calls, 1) + <-ctx.Done() + return ctx.Err() + }, DoOpts{ + FileName: "always-times-out", + Config: &Config{ + MaxAttempts: 3, + PerAttemptTimeout: 20 * time.Millisecond, + InitialDelay: 1 * time.Millisecond, + MaxBackoff: 1 * time.Millisecond, + MaxJitter: -1, + }, + }) if err == nil { - t.Fatal("expected error, got nil") + t.Fatal("Do returned nil, want error after exhausting attempts") } - if atomic.LoadInt32(&callCount) != 1 { - t.Fatalf("expected fn to be called once with NoRetry, got %d", callCount) + if got := atomic.LoadInt32(&calls); got != 3 { + t.Errorf("calls = %d, want 3", got) } } +// --- Do: parent context cancellation aborts retries --- + func TestDo_ParentContextCancel(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + var calls int32 + go func() { + time.Sleep(20 * time.Millisecond) + cancel() + }() - callCount := int32(0) err := Do(ctx, func(ctx context.Context) error { - n := atomic.AddInt32(&callCount, 1) - if n == 1 { - // Cancel the parent context after the first attempt. - cancel() - return fmt.Errorf("response status code 500: internal server error") - } - return nil + atomic.AddInt32(&calls, 1) + return errors.New("response status code 503") }, DoOpts{ - FileSize: 100, - FileName: "test.bin", - Config: &Config{InitialDelay: 10 * time.Millisecond, MaxJitter: -1}, + FileName: "user-cancels", + Config: &Config{ + MaxAttempts: 100, + InitialDelay: 5 * time.Millisecond, + MaxBackoff: 5 * time.Millisecond, + MaxJitter: -1, + }, }) - if err == nil { - t.Fatal("expected error from cancelled context, got nil") + t.Fatal("Do returned nil, want context cancellation error") + } + if got := atomic.LoadInt32(&calls); got > 50 { + t.Errorf("calls = %d, want significantly fewer than MaxAttempts (parent ctx cancelled)", got) } } -func TestDo_OnRetryCallback(t *testing.T) { - var retryAttempts []uint - var retryReasons []string +// --- Do: OnRetry callback invoked with 1-based attempt --- - callCount := int32(0) +func TestDo_OnRetryCallback(t *testing.T) { + var attempts []uint + var reasons []string + var calls int32 err := Do(context.Background(), func(ctx context.Context) error { - n := atomic.AddInt32(&callCount, 1) + n := atomic.AddInt32(&calls, 1) if n < 3 { - return fmt.Errorf("response status code 500: internal server error") + return errors.New("response status code 500") } return nil }, DoOpts{ - FileSize: 100, - FileName: "test.bin", - Config: &Config{InitialDelay: 10 * time.Millisecond, MaxJitter: -1}, + FileName: "cb", + Config: &Config{ + MaxAttempts: 5, + InitialDelay: 1 * time.Millisecond, + MaxBackoff: 1 * time.Millisecond, + MaxJitter: -1, + }, OnRetry: func(attempt uint, reason string, backoff time.Duration) { - retryAttempts = append(retryAttempts, attempt) - retryReasons = append(retryReasons, reason) + attempts = append(attempts, attempt) + reasons = append(reasons, reason) }, }) - if err != nil { - t.Fatalf("expected nil error, got %v", err) + t.Fatalf("Do returned %v", err) } - if len(retryAttempts) != 2 { - t.Fatalf("expected 2 OnRetry calls, got %d", len(retryAttempts)) + want := []uint{1, 2} + if fmt.Sprintf("%v", attempts) != fmt.Sprintf("%v", want) { + t.Errorf("attempts = %v, want %v", attempts, want) } - if retryAttempts[0] != 1 || retryAttempts[1] != 2 { - t.Errorf("expected attempts [1, 2], got %v", retryAttempts) + for _, r := range reasons { + if r != "HTTP 500" { + t.Errorf("reason = %q, want \"HTTP 500\"", r) + } } - if retryReasons[0] != "HTTP 500" || retryReasons[1] != "HTTP 500" { - t.Errorf("expected reasons [HTTP 500, HTTP 500], got %v", retryReasons) +} + +// --- Do: retry budget invariance — same total time regardless of file size --- +// +// The whole point of the redesign: with PerAttemptTimeout disabled and a +// fixed backoff schedule, total wall-clock for retries does not depend on +// FileSize. (The old API's MaxRetryTime scaled with size and changed this.) +func TestDo_RetryBudgetIndependentOfFileSize(t *testing.T) { + measure := func(fileSize int64) time.Duration { + var calls int32 + start := time.Now() + _ = Do(context.Background(), func(ctx context.Context) error { + atomic.AddInt32(&calls, 1) + return errors.New("response status code 503") + }, DoOpts{ + FileName: "size-invariance", + FileSize: fileSize, + Config: &Config{ + MaxAttempts: 4, + PerAttemptTimeout: -1, + InitialDelay: 10 * time.Millisecond, + MaxBackoff: 10 * time.Millisecond, + MaxJitter: -1, + }, + }) + return time.Since(start) + } + + small := measure(1 << 20) // 1 MB + huge := measure(int64(1) << 40) // 1 TB + diff := small - huge + if diff < 0 { + diff = -diff + } + if diff > 100*time.Millisecond { + t.Errorf("retry budget varied with file size: small=%v huge=%v (diff=%v)", small, huge, diff) } } -func TestDo_ConfigMaxRetryTimeOverride(t *testing.T) { - // Use a very short MaxRetryTime to ensure the retry loop terminates quickly. - callCount := int32(0) - start := time.Now() +// --- Do: PerAttemptTimeout < 0 disables the deadline --- + +func TestDo_PerAttemptTimeoutDisabled(t *testing.T) { + var seenDeadline bool err := Do(context.Background(), func(ctx context.Context) error { - atomic.AddInt32(&callCount, 1) - return fmt.Errorf("response status code 500: internal server error") + _, ok := ctx.Deadline() + seenDeadline = ok + return nil }, DoOpts{ - FileSize: 100, - FileName: "test.bin", + FileName: "no-deadline", + FileSize: 1 << 30, Config: &Config{ - MaxRetryTime: 1 * time.Second, - InitialDelay: 50 * time.Millisecond, - MaxJitter: -1, + PerAttemptTimeout: -1, }, }) + if err != nil { + t.Fatalf("Do returned %v", err) + } + if seenDeadline { + t.Error("attempt context had a deadline; expected none when PerAttemptTimeout < 0") + } +} - elapsed := time.Since(start) +// --- Do: PerAttemptTimeout = 0 derives from file size --- - if err == nil { - t.Fatal("expected error after retry timeout, got nil") - } - // Should have run for approximately MaxRetryTime (1s), not the dynamic default (10min). - if elapsed > 5*time.Second { - t.Errorf("expected retry to terminate within ~1s, but elapsed %v", elapsed) +func TestDo_PerAttemptTimeoutDerivedFromSize(t *testing.T) { + var seenTimeout time.Duration + const size = int64(50) << 30 // 50 GB → > floor, < ceiling + err := Do(context.Background(), func(ctx context.Context) error { + dl, ok := ctx.Deadline() + if !ok { + t.Fatal("expected deadline") + } + seenTimeout = time.Until(dl) + return nil + }, DoOpts{ + FileName: "derived", + FileSize: size, + }) + if err != nil { + t.Fatalf("Do returned %v", err) } - if atomic.LoadInt32(&callCount) < 2 { - t.Errorf("expected at least 2 attempts, got %d", callCount) + want := ComputePerAttemptTimeout(size) + if seenTimeout > want || seenTimeout < want-time.Second { + t.Errorf("derived timeout = %v, want ~%v", seenTimeout, want) } } -// --- humanizeBytes tests --- +// --- humanizeBytes --- func TestHumanizeBytes(t *testing.T) { tests := []struct { - input int64 - want string + size int64 + want string }{ {0, "0 B"}, - {512, "512 B"}, - {1024, "1.0 KB"}, - {1536, "1.5 KB"}, - {1048576, "1.0 MB"}, - {1073741824, "1.0 GB"}, - {int64(5.5 * float64(1<<30)), "5.5 GB"}, + {500, "500 B"}, + {2048, "2.0 KB"}, + {int64(5) << 20, "5.0 MB"}, + {int64(7) << 30, "7.0 GB"}, } - for _, tt := range tests { t.Run(tt.want, func(t *testing.T) { - got := humanizeBytes(tt.input) - if got != tt.want { - t.Errorf("humanizeBytes(%d) = %q, want %q", tt.input, got, tt.want) + if got := humanizeBytes(tt.size); got != tt.want { + t.Errorf("humanizeBytes(%d) = %q, want %q", tt.size, got, tt.want) } }) } } - -// --- computeBackoff tests --- - -func TestComputeBackoff(t *testing.T) { - initial := 5 * time.Second - maxDelay := 1 * time.Minute - - // retry-go's OnRetry always supplies a 1-based attempt number, so 0 is - // not a value the function is ever called with in production. Start from 1. - - // attempt 1 => 5s * 2^0 = 5s - if got := computeBackoff(1, initial, maxDelay); got != 5*time.Second { - t.Errorf("attempt 1: got %v, want %v", got, 5*time.Second) - } - - // attempt 2 => 5s * 2^1 = 10s - if got := computeBackoff(2, initial, maxDelay); got != 10*time.Second { - t.Errorf("attempt 2: got %v, want %v", got, 10*time.Second) - } - - // attempt 4 => 5s * 2^3 = 40s - if got := computeBackoff(4, initial, maxDelay); got != 40*time.Second { - t.Errorf("attempt 4: got %v, want %v", got, 40*time.Second) - } - - // attempt 5 => 5s * 2^4 = 80s, but capped at 60s - if got := computeBackoff(5, initial, maxDelay); got != maxDelay { - t.Errorf("attempt 5: got %v, want %v (capped)", got, maxDelay) - } -} From 4eaa8b43d00d92754b4d125be8d90a2143ecdf6c Mon Sep 17 00:00:00 2001 From: Zhao Chen Date: Wed, 6 May 2026 20:31:59 +0800 Subject: [PATCH 13/16] style(retrypolicy): apply gci and golines formatting No behavioral change. golangci-lint v2.5 (gci, golines with line length 122) flagged whitespace alignment in struct literals and long single-line table-driven test entries; auto-fixed via make lint-fix. Signed-off-by: Zhao Chen --- pkg/retrypolicy/retrypolicy.go | 16 ++++++------ pkg/retrypolicy/retrypolicy_test.go | 40 +++++++++++++++++++++++------ 2 files changed, 40 insertions(+), 16 deletions(-) diff --git a/pkg/retrypolicy/retrypolicy.go b/pkg/retrypolicy/retrypolicy.go index ded9b7b1..2cd0c8c4 100644 --- a/pkg/retrypolicy/retrypolicy.go +++ b/pkg/retrypolicy/retrypolicy.go @@ -223,14 +223,14 @@ func Do(ctx context.Context, fn func(ctx context.Context) error, opts DoOpts) er elapsed := time.Since(startTime) log.WithFields(log.Fields{ - "file": opts.FileName, - "size": sizeStr, - "error": err.Error(), - "max_attempts": maxAttempts, - "max_backoff": maxBackoff.String(), - "per_attempt_to": perAttemptTimeout.String(), - "next_retry_in": backoff.Truncate(time.Second).String(), - "elapsed": elapsed.Truncate(time.Second).String(), + "file": opts.FileName, + "size": sizeStr, + "error": err.Error(), + "max_attempts": maxAttempts, + "max_backoff": maxBackoff.String(), + "per_attempt_to": perAttemptTimeout.String(), + "next_retry_in": backoff.Truncate(time.Second).String(), + "elapsed": elapsed.Truncate(time.Second).String(), }).Warnf("[RETRY] attempt %d/%d for %q (%s)", attempt, maxAttempts, opts.FileName, sizeStr) if opts.OnRetry != nil { diff --git a/pkg/retrypolicy/retrypolicy_test.go b/pkg/retrypolicy/retrypolicy_test.go index c940261a..627fc605 100644 --- a/pkg/retrypolicy/retrypolicy_test.go +++ b/pkg/retrypolicy/retrypolicy_test.go @@ -43,11 +43,23 @@ func TestComputePerAttemptTimeout(t *testing.T) { {"100 MB clamps to floor", 100 * oneMB, minPerAttemptTimeout}, {"1 GB clamps to floor", oneGB, minPerAttemptTimeout}, // 5 GB / 10 MB/s * 2 = 1024s ≈ 17 min, above the 5min floor - {"5 GB scales above floor", 5 * oneGB, time.Duration(5*oneGB/minThroughput*safetyFactor) * time.Second}, + { + "5 GB scales above floor", + 5 * oneGB, + time.Duration(5*oneGB/minThroughput*safetyFactor) * time.Second, + }, // 10 GB / 10 MB/s * 2 = 2048s ≈ 34 min - {"10 GB scales linearly", tenGB, time.Duration(tenGB/minThroughput*safetyFactor) * time.Second}, + { + "10 GB scales linearly", + tenGB, + time.Duration(tenGB/minThroughput*safetyFactor) * time.Second, + }, // 100 GB / 10 MB/s * 2 = 20480s ≈ 5.7h, still under the 8h ceiling - {"100 GB still under ceiling", hundGB, time.Duration(hundGB/minThroughput*safetyFactor) * time.Second}, + { + "100 GB still under ceiling", + hundGB, + time.Duration(hundGB/minThroughput*safetyFactor) * time.Second, + }, // 200 GB hits ceiling {"200 GB clamps to ceiling", 2 * hundGB, maxPerAttemptTimeout}, } @@ -73,7 +85,11 @@ func TestIsRetryable(t *testing.T) { {"context.Canceled", context.Canceled, false}, {"context.DeadlineExceeded", context.DeadlineExceeded, false}, {"5xx server error", errors.New("response status code 500"), true}, - {"503 service unavailable", errors.New("response status code 503: Service Unavailable"), true}, + { + "503 service unavailable", + errors.New("response status code 503: Service Unavailable"), + true, + }, {"408 request timeout", errors.New("response status code 408"), true}, {"429 too many requests", errors.New("response status code 429"), true}, {"401 unauthorized", errors.New("response status code 401"), false}, @@ -135,8 +151,8 @@ func TestComputeBackoff(t *testing.T) { attempt uint want time.Duration }{ - {1, 1 * time.Second}, // first sleep - {2, 2 * time.Second}, // doubled + {1, 1 * time.Second}, // first sleep + {2, 2 * time.Second}, // doubled {3, 4 * time.Second}, {4, 8 * time.Second}, {5, 10 * time.Second}, // capped @@ -320,7 +336,10 @@ func TestDo_ParentContextCancel(t *testing.T) { t.Fatal("Do returned nil, want context cancellation error") } if got := atomic.LoadInt32(&calls); got > 50 { - t.Errorf("calls = %d, want significantly fewer than MaxAttempts (parent ctx cancelled)", got) + t.Errorf( + "calls = %d, want significantly fewer than MaxAttempts (parent ctx cancelled)", + got, + ) } } @@ -396,7 +415,12 @@ func TestDo_RetryBudgetIndependentOfFileSize(t *testing.T) { diff = -diff } if diff > 100*time.Millisecond { - t.Errorf("retry budget varied with file size: small=%v huge=%v (diff=%v)", small, huge, diff) + t.Errorf( + "retry budget varied with file size: small=%v huge=%v (diff=%v)", + small, + huge, + diff, + ) } } From 4cb2b96a1809a73dced676c5761fa4b2d42cfb15 Mon Sep 17 00:00:00 2001 From: Zhao Chen Date: Wed, 6 May 2026 20:49:36 +0800 Subject: [PATCH 14/16] fix(retrypolicy): apply per-attempt timeout when NoRetry is set The previous Do() short-circuited NoRetry by calling fn(ctx) directly, bypassing the per-attempt deadline derivation. As Codex flagged, this made --no-retry silently disable the transfer timeout: a hung connection would never terminate, leaving users with a stalled CLI and no failure signal. Per-attempt timeout is a transfer-bound concern, not a retry concern, so it must apply on the single-attempt path too. Refactor the deadline setup into a small runAttempt closure used by both the NoRetry branch and retry.Do. Add TestDo_NoRetryHonorsPerAttemptTimeout to pin this invariant. Signed-off-by: Zhao Chen --- pkg/retrypolicy/retrypolicy.go | 42 ++++++++++++++++------------- pkg/retrypolicy/retrypolicy_test.go | 31 +++++++++++++++++++++ 2 files changed, 55 insertions(+), 18 deletions(-) diff --git a/pkg/retrypolicy/retrypolicy.go b/pkg/retrypolicy/retrypolicy.go index 2cd0c8c4..97da8104 100644 --- a/pkg/retrypolicy/retrypolicy.go +++ b/pkg/retrypolicy/retrypolicy.go @@ -145,15 +145,6 @@ func Do(ctx context.Context, fn func(ctx context.Context) error, opts DoOpts) er cfg = &Config{} } - if cfg.NoRetry { - return fn(ctx) - } - - maxAttempts := cfg.MaxAttempts - if maxAttempts <= 0 { - maxAttempts = DefaultMaxAttempts - } - perAttemptTimeout := cfg.PerAttemptTimeout switch { case perAttemptTimeout == 0: @@ -162,6 +153,29 @@ func Do(ctx context.Context, fn func(ctx context.Context) error, opts DoOpts) er perAttemptTimeout = 0 // disabled } + // runAttempt applies the per-attempt deadline regardless of retry policy: + // a single hung transfer must still be terminated even when retries are + // disabled, so users of --no-retry get failure visibility instead of a + // stalled CLI. + runAttempt := func() error { + attemptCtx := ctx + if perAttemptTimeout > 0 { + var cancel context.CancelFunc + attemptCtx, cancel = context.WithTimeout(ctx, perAttemptTimeout) + defer cancel() + } + return fn(attemptCtx) + } + + if cfg.NoRetry { + return runAttempt() + } + + maxAttempts := cfg.MaxAttempts + if maxAttempts <= 0 { + maxAttempts = DefaultMaxAttempts + } + initialDelay := cfg.InitialDelay if initialDelay <= 0 { initialDelay = DefaultInitialDelay @@ -183,15 +197,7 @@ func Do(ctx context.Context, fn func(ctx context.Context) error, opts DoOpts) er startTime := time.Now() return retry.Do( - func() error { - attemptCtx := ctx - if perAttemptTimeout > 0 { - var cancel context.CancelFunc - attemptCtx, cancel = context.WithTimeout(ctx, perAttemptTimeout) - defer cancel() - } - return fn(attemptCtx) - }, + runAttempt, retry.Attempts(uint(maxAttempts)), retry.Context(ctx), retry.DelayType(retry.BackOffDelay), diff --git a/pkg/retrypolicy/retrypolicy_test.go b/pkg/retrypolicy/retrypolicy_test.go index 627fc605..c861eec1 100644 --- a/pkg/retrypolicy/retrypolicy_test.go +++ b/pkg/retrypolicy/retrypolicy_test.go @@ -204,6 +204,37 @@ func TestDo_NoRetry(t *testing.T) { } } +// --- Do: NoRetry still honors per-attempt timeout --- +// +// `--no-retry` disables extra attempts, but a single transfer must still +// terminate on a hung connection — otherwise users get a stalled CLI with +// no failure signal. +func TestDo_NoRetryHonorsPerAttemptTimeout(t *testing.T) { + var calls int32 + start := time.Now() + err := Do(context.Background(), func(ctx context.Context) error { + atomic.AddInt32(&calls, 1) + <-ctx.Done() + return ctx.Err() + }, DoOpts{ + FileName: "noretry-but-bounded", + Config: &Config{ + NoRetry: true, + PerAttemptTimeout: 30 * time.Millisecond, + }, + }) + elapsed := time.Since(start) + if err == nil { + t.Fatal("Do returned nil, want context deadline error") + } + if got := atomic.LoadInt32(&calls); got != 1 { + t.Errorf("calls = %d, want 1 (NoRetry)", got) + } + if elapsed > 500*time.Millisecond { + t.Errorf("Do hung for %v, want quick exit via per-attempt timeout", elapsed) + } +} + // --- Do: MaxAttempts caps the number of tries --- func TestDo_MaxAttempts(t *testing.T) { From a31fbb97e75da960689f06287dc8e54b1b944521 Mon Sep 17 00:00:00 2001 From: Zhao Chen Date: Wed, 6 May 2026 20:58:07 +0800 Subject: [PATCH 15/16] feat(cli)!: remove --no-retry flag, use --retry-attempts=1 instead MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The --no-retry flag was redundant once the retry budget was decoupled from the per-attempt timeout: it expressed nothing that --retry-attempts=1 cannot. Keeping both invited the class of bug Codex found earlier (NoRetry short-circuited Do() and bypassed the per-attempt deadline) — having two ways to say "don't retry" makes it easy to wire only one of them through correctly. After this commit there are exactly two retry-related CLI flags: --retry-attempts int (default 0 → 6; set to 1 for fail-fast) --per-attempt-timeout duration (0 → derive from size; <0 → disabled) Each controls one orthogonal concern. "No retry" is a value of one of them, not a separate flag. BREAKING CHANGE: --no-retry is removed. Replace with --retry-attempts=1. The flag was introduced in this same PR and never shipped, so external breakage scope is just the PR's own iteration history. Signed-off-by: Zhao Chen --- cmd/build.go | 3 +-- cmd/fetch.go | 3 +-- cmd/pull.go | 3 +-- cmd/push.go | 42 ++++++++++++++++++++++------- pkg/retrypolicy/retrypolicy.go | 19 ++++--------- pkg/retrypolicy/retrypolicy_test.go | 29 ++++++++++---------- 6 files changed, 55 insertions(+), 44 deletions(-) diff --git a/cmd/build.go b/cmd/build.go index 0eb8cbc5..43ec3464 100644 --- a/cmd/build.go +++ b/cmd/build.go @@ -61,8 +61,7 @@ func init() { flags.BoolVar(&buildConfig.Raw, "raw", true, "turning on this flag will build model artifact layers in raw format") flags.BoolVar(&buildConfig.Reasoning, "reasoning", false, "turning on this flag will mark this model as reasoning model in the config") flags.BoolVar(&buildConfig.NoCreationTime, "no-creation-time", false, "turning on this flag will not set createdAt in the config, which will be helpful for repeated builds") - flags.BoolVar(&buildConfig.RetryConfig.NoRetry, "no-retry", false, "Disable retry on transient errors") - flags.IntVar(&buildConfig.RetryConfig.MaxAttempts, "retry-attempts", 0, "Max total attempts per file (initial + retries; 0 = use default of 6)") + flags.IntVar(&buildConfig.RetryConfig.MaxAttempts, "retry-attempts", 0, "Max total attempts per file (initial + retries; 0 = use default of 6, 1 = no retry)") flags.DurationVar(&buildConfig.RetryConfig.PerAttemptTimeout, "per-attempt-timeout", 0, "Timeout for a single transfer attempt (0 = derive from file size; <0 = disabled)") if err := viper.BindPFlags(flags); err != nil { diff --git a/cmd/fetch.go b/cmd/fetch.go index ce589c48..6448db12 100644 --- a/cmd/fetch.go +++ b/cmd/fetch.go @@ -55,8 +55,7 @@ func init() { flags.StringVar(&fetchConfig.Output, "output", "", "specify the directory for fetching the model artifact") flags.StringSliceVar(&fetchConfig.Patterns, "patterns", []string{}, "specify the patterns for fetching the model artifact") flags.StringVar(&fetchConfig.DragonflyEndpoint, "dragonfly-endpoint", "", "specify the dragonfly endpoint for the pull operation, which will download and hardlink the blob by dragonfly GRPC service.") - flags.BoolVar(&fetchConfig.RetryConfig.NoRetry, "no-retry", false, "Disable retry on transient errors") - flags.IntVar(&fetchConfig.RetryConfig.MaxAttempts, "retry-attempts", 0, "Max total attempts per file (initial + retries; 0 = use default of 6)") + flags.IntVar(&fetchConfig.RetryConfig.MaxAttempts, "retry-attempts", 0, "Max total attempts per file (initial + retries; 0 = use default of 6, 1 = no retry)") flags.DurationVar(&fetchConfig.RetryConfig.PerAttemptTimeout, "per-attempt-timeout", 0, "Timeout for a single transfer attempt (0 = derive from file size; <0 = disabled)") if err := viper.BindPFlags(flags); err != nil { diff --git a/cmd/pull.go b/cmd/pull.go index 0d70e43e..4c5ba376 100644 --- a/cmd/pull.go +++ b/cmd/pull.go @@ -55,8 +55,7 @@ func init() { flags.StringVar(&pullConfig.ExtractDir, "extract-dir", "", "specify the extract dir for extracting the model artifact") flags.BoolVar(&pullConfig.ExtractFromRemote, "extract-from-remote", false, "turning on this flag will pull and extract the data from remote registry and no longer store model artifact locally, so user must specify extract-dir as the output directory") flags.StringVar(&pullConfig.DragonflyEndpoint, "dragonfly-endpoint", "", "specify the dragonfly endpoint for the pull operation, which will download and hardlink the blob by dragonfly GRPC service, this mode requires extract-from-remote must be true") - flags.BoolVar(&pullConfig.RetryConfig.NoRetry, "no-retry", false, "Disable retry on transient errors") - flags.IntVar(&pullConfig.RetryConfig.MaxAttempts, "retry-attempts", 0, "Max total attempts per file (initial + retries; 0 = use default of 6)") + flags.IntVar(&pullConfig.RetryConfig.MaxAttempts, "retry-attempts", 0, "Max total attempts per file (initial + retries; 0 = use default of 6, 1 = no retry)") flags.DurationVar(&pullConfig.RetryConfig.PerAttemptTimeout, "per-attempt-timeout", 0, "Timeout for a single transfer attempt (0 = derive from file size; <0 = disabled)") if err := viper.BindPFlags(flags); err != nil { diff --git a/cmd/push.go b/cmd/push.go index 936dd5a4..7fc479a0 100644 --- a/cmd/push.go +++ b/cmd/push.go @@ -20,11 +20,11 @@ import ( "context" "fmt" - "github.com/modelpack/modctl/pkg/backend" - "github.com/modelpack/modctl/pkg/config" - "github.com/spf13/cobra" "github.com/spf13/viper" + + "github.com/modelpack/modctl/pkg/backend" + "github.com/modelpack/modctl/pkg/config" ) var pushConfig = config.NewPush() @@ -48,14 +48,38 @@ var pushCmd = &cobra.Command{ // init initializes push command. func init() { flags := pushCmd.Flags() - flags.IntVar(&pushConfig.Concurrency, "concurrency", pushConfig.Concurrency, "specify the number of concurrent push operations") + flags.IntVar( + &pushConfig.Concurrency, + "concurrency", + pushConfig.Concurrency, + "specify the number of concurrent push operations", + ) flags.BoolVar(&pushConfig.PlainHTTP, "plain-http", false, "use plain HTTP instead of HTTPS") - flags.BoolVar(&pushConfig.Insecure, "insecure", false, "turning on this flag will disable TLS verification") - flags.BoolVar(&pushConfig.Nydusify, "nydusify", false, "[EXPERIMENTAL] nydusify the model artifact") + flags.BoolVar( + &pushConfig.Insecure, + "insecure", + false, + "turning on this flag will disable TLS verification", + ) + flags.BoolVar( + &pushConfig.Nydusify, + "nydusify", + false, + "[EXPERIMENTAL] nydusify the model artifact", + ) flags.MarkHidden("nydusify") - flags.BoolVar(&pushConfig.RetryConfig.NoRetry, "no-retry", false, "Disable retry on transient errors") - flags.IntVar(&pushConfig.RetryConfig.MaxAttempts, "retry-attempts", 0, "Max total attempts per file (initial + retries; 0 = use default of 6)") - flags.DurationVar(&pushConfig.RetryConfig.PerAttemptTimeout, "per-attempt-timeout", 0, "Timeout for a single transfer attempt (0 = derive from file size; <0 = disabled)") + flags.IntVar( + &pushConfig.RetryConfig.MaxAttempts, + "retry-attempts", + 0, + "Max total attempts per file (initial + retries; 0 = use default of 6, 1 = no retry)", + ) + flags.DurationVar( + &pushConfig.RetryConfig.PerAttemptTimeout, + "per-attempt-timeout", + 0, + "Timeout for a single transfer attempt (0 = derive from file size; <0 = disabled)", + ) if err := viper.BindPFlags(flags); err != nil { panic(fmt.Errorf("bind push flags to viper: %w", err)) diff --git a/pkg/retrypolicy/retrypolicy.go b/pkg/retrypolicy/retrypolicy.go index 97da8104..9908795d 100644 --- a/pkg/retrypolicy/retrypolicy.go +++ b/pkg/retrypolicy/retrypolicy.go @@ -90,8 +90,8 @@ const ( // DefaultMaxBackoff. type Config struct { // MaxAttempts is the total number of attempts (initial + retries). - // 0 means "use DefaultMaxAttempts". Use NoRetry to call fn exactly - // once. + // 0 means "use DefaultMaxAttempts". Set to 1 to disable retries + // (single attempt, no retry on failure). MaxAttempts int // PerAttemptTimeout is the maximum duration for a single attempt. @@ -100,10 +100,6 @@ type Config struct { // >0 → use this value verbatim. PerAttemptTimeout time.Duration - // NoRetry disables retry entirely; fn is called once with the parent - // context. - NoRetry bool - // InitialDelay overrides the first inter-attempt sleep. 0 = default. // Primarily for tests. InitialDelay time.Duration @@ -153,10 +149,9 @@ func Do(ctx context.Context, fn func(ctx context.Context) error, opts DoOpts) er perAttemptTimeout = 0 // disabled } - // runAttempt applies the per-attempt deadline regardless of retry policy: - // a single hung transfer must still be terminated even when retries are - // disabled, so users of --no-retry get failure visibility instead of a - // stalled CLI. + // runAttempt applies the per-attempt deadline. retry-go calls this + // for each attempt; if MaxAttempts == 1 the loop exits after one + // invocation (equivalent to "no retry"). runAttempt := func() error { attemptCtx := ctx if perAttemptTimeout > 0 { @@ -167,10 +162,6 @@ func Do(ctx context.Context, fn func(ctx context.Context) error, opts DoOpts) er return fn(attemptCtx) } - if cfg.NoRetry { - return runAttempt() - } - maxAttempts := cfg.MaxAttempts if maxAttempts <= 0 { maxAttempts = DefaultMaxAttempts diff --git a/pkg/retrypolicy/retrypolicy_test.go b/pkg/retrypolicy/retrypolicy_test.go index c861eec1..b2a4ba7b 100644 --- a/pkg/retrypolicy/retrypolicy_test.go +++ b/pkg/retrypolicy/retrypolicy_test.go @@ -184,32 +184,31 @@ func TestDo_SuccessFirstAttempt(t *testing.T) { } } -// --- Do: NoRetry --- +// --- Do: MaxAttempts=1 is the canonical "no retry" knob --- -func TestDo_NoRetry(t *testing.T) { +func TestDo_MaxAttempts1(t *testing.T) { calls := 0 transient := errors.New("response status code 503") err := Do(context.Background(), func(ctx context.Context) error { calls++ return transient }, DoOpts{ - FileName: "noretry", - Config: &Config{NoRetry: true}, + FileName: "single-attempt", + Config: &Config{MaxAttempts: 1}, }) - if !errors.Is(err, transient) { - t.Errorf("err = %v, want %v", err, transient) + if err == nil { + t.Fatal("Do returned nil, want transient error returned verbatim") } if calls != 1 { - t.Errorf("calls = %d, want 1 (NoRetry)", calls) + t.Errorf("calls = %d, want 1 (MaxAttempts=1)", calls) } } -// --- Do: NoRetry still honors per-attempt timeout --- +// --- Do: single attempt still honors per-attempt timeout --- // -// `--no-retry` disables extra attempts, but a single transfer must still -// terminate on a hung connection — otherwise users get a stalled CLI with -// no failure signal. -func TestDo_NoRetryHonorsPerAttemptTimeout(t *testing.T) { +// Even when retries are disabled (MaxAttempts=1), a hung transfer must still +// terminate; otherwise users get a stalled CLI with no failure signal. +func TestDo_SingleAttemptHonorsPerAttemptTimeout(t *testing.T) { var calls int32 start := time.Now() err := Do(context.Background(), func(ctx context.Context) error { @@ -217,9 +216,9 @@ func TestDo_NoRetryHonorsPerAttemptTimeout(t *testing.T) { <-ctx.Done() return ctx.Err() }, DoOpts{ - FileName: "noretry-but-bounded", + FileName: "single-but-bounded", Config: &Config{ - NoRetry: true, + MaxAttempts: 1, PerAttemptTimeout: 30 * time.Millisecond, }, }) @@ -228,7 +227,7 @@ func TestDo_NoRetryHonorsPerAttemptTimeout(t *testing.T) { t.Fatal("Do returned nil, want context deadline error") } if got := atomic.LoadInt32(&calls); got != 1 { - t.Errorf("calls = %d, want 1 (NoRetry)", got) + t.Errorf("calls = %d, want 1", got) } if elapsed > 500*time.Millisecond { t.Errorf("Do hung for %v, want quick exit via per-attempt timeout", elapsed) From 7e3f3e8d6e800fa57c78864a0e38a187accf1cba Mon Sep 17 00:00:00 2001 From: Zhao Chen Date: Fri, 8 May 2026 00:12:02 +0800 Subject: [PATCH 16/16] feat(cli)!: drop retry CLI flags, defaults only MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Remove --retry-attempts and --per-attempt-timeout. Both controlled operational behavior (retry tolerance, network speed assumption) that practically never varies per invocation, so they didn't earn a slot on the CLI. Keeping them invited misuse and grew help-text noise for no measurable benefit. After this commit, retry behavior is fully determined by package defaults in pkg/retrypolicy: - 6 attempts (DefaultMaxAttempts) - 5s initial delay, exponential backoff capped at 2 min - Per-attempt timeout derived from file size (10 MiB/s minimum throughput, 2x safety, clamped to [5min, 8h]) The retrypolicy.Config struct stays public for programmatic users embedding modctl in another binary; it just isn't wired to any CLI flag. If a real user shows up with a network where the defaults break down, we can add a knob then with their case as the documented example — until then, YAGNI wins. Cleanup: - cmd/{push,pull,build,fetch}.go: drop flag registrations. - pkg/config/{push,pull,build,fetch}.go: drop RetryConfig field and the now-unused retrypolicy import. - pkg/backend/{push,pull,fetch,build,*_by_d7y}.go: drop Config: &cfg.RetryConfig from each retrypolicy.Do call. - pkg/backend/processor/options.go: drop WithRetryConfig and the retryConfig field; processor.Process no longer takes a retry option. - pkg/backend/processor/base.go: stop threading processOpts.retryConfig into retrypolicy.DoOpts. BREAKING CHANGE: --retry-attempts and --per-attempt-timeout are removed. Both were introduced earlier in this same PR and never shipped. Signed-off-by: Zhao Chen --- cmd/build.go | 2 -- cmd/fetch.go | 2 -- cmd/pull.go | 2 -- cmd/push.go | 12 ----------- pkg/backend/build.go | 35 +++++++++++++++++++++++++------- pkg/backend/fetch.go | 1 - pkg/backend/fetch_by_d7y.go | 1 - pkg/backend/processor/base.go | 1 - pkg/backend/processor/options.go | 9 -------- pkg/backend/pull.go | 3 --- pkg/backend/pull_by_d7y.go | 1 - pkg/backend/push.go | 3 --- pkg/config/build.go | 3 --- pkg/config/fetch.go | 3 --- pkg/config/pull.go | 3 --- pkg/config/push.go | 3 --- 16 files changed, 28 insertions(+), 56 deletions(-) diff --git a/cmd/build.go b/cmd/build.go index 43ec3464..c41c0628 100644 --- a/cmd/build.go +++ b/cmd/build.go @@ -61,8 +61,6 @@ func init() { flags.BoolVar(&buildConfig.Raw, "raw", true, "turning on this flag will build model artifact layers in raw format") flags.BoolVar(&buildConfig.Reasoning, "reasoning", false, "turning on this flag will mark this model as reasoning model in the config") flags.BoolVar(&buildConfig.NoCreationTime, "no-creation-time", false, "turning on this flag will not set createdAt in the config, which will be helpful for repeated builds") - flags.IntVar(&buildConfig.RetryConfig.MaxAttempts, "retry-attempts", 0, "Max total attempts per file (initial + retries; 0 = use default of 6, 1 = no retry)") - flags.DurationVar(&buildConfig.RetryConfig.PerAttemptTimeout, "per-attempt-timeout", 0, "Timeout for a single transfer attempt (0 = derive from file size; <0 = disabled)") if err := viper.BindPFlags(flags); err != nil { panic(fmt.Errorf("bind build flags to viper: %w", err)) diff --git a/cmd/fetch.go b/cmd/fetch.go index 6448db12..e13de379 100644 --- a/cmd/fetch.go +++ b/cmd/fetch.go @@ -55,8 +55,6 @@ func init() { flags.StringVar(&fetchConfig.Output, "output", "", "specify the directory for fetching the model artifact") flags.StringSliceVar(&fetchConfig.Patterns, "patterns", []string{}, "specify the patterns for fetching the model artifact") flags.StringVar(&fetchConfig.DragonflyEndpoint, "dragonfly-endpoint", "", "specify the dragonfly endpoint for the pull operation, which will download and hardlink the blob by dragonfly GRPC service.") - flags.IntVar(&fetchConfig.RetryConfig.MaxAttempts, "retry-attempts", 0, "Max total attempts per file (initial + retries; 0 = use default of 6, 1 = no retry)") - flags.DurationVar(&fetchConfig.RetryConfig.PerAttemptTimeout, "per-attempt-timeout", 0, "Timeout for a single transfer attempt (0 = derive from file size; <0 = disabled)") if err := viper.BindPFlags(flags); err != nil { panic(fmt.Errorf("bind fetch flags to viper: %w", err)) diff --git a/cmd/pull.go b/cmd/pull.go index 4c5ba376..f4df1fa9 100644 --- a/cmd/pull.go +++ b/cmd/pull.go @@ -55,8 +55,6 @@ func init() { flags.StringVar(&pullConfig.ExtractDir, "extract-dir", "", "specify the extract dir for extracting the model artifact") flags.BoolVar(&pullConfig.ExtractFromRemote, "extract-from-remote", false, "turning on this flag will pull and extract the data from remote registry and no longer store model artifact locally, so user must specify extract-dir as the output directory") flags.StringVar(&pullConfig.DragonflyEndpoint, "dragonfly-endpoint", "", "specify the dragonfly endpoint for the pull operation, which will download and hardlink the blob by dragonfly GRPC service, this mode requires extract-from-remote must be true") - flags.IntVar(&pullConfig.RetryConfig.MaxAttempts, "retry-attempts", 0, "Max total attempts per file (initial + retries; 0 = use default of 6, 1 = no retry)") - flags.DurationVar(&pullConfig.RetryConfig.PerAttemptTimeout, "per-attempt-timeout", 0, "Timeout for a single transfer attempt (0 = derive from file size; <0 = disabled)") if err := viper.BindPFlags(flags); err != nil { panic(fmt.Errorf("bind pull flags to viper: %w", err)) diff --git a/cmd/push.go b/cmd/push.go index 7fc479a0..9e4271f6 100644 --- a/cmd/push.go +++ b/cmd/push.go @@ -68,18 +68,6 @@ func init() { "[EXPERIMENTAL] nydusify the model artifact", ) flags.MarkHidden("nydusify") - flags.IntVar( - &pushConfig.RetryConfig.MaxAttempts, - "retry-attempts", - 0, - "Max total attempts per file (initial + retries; 0 = use default of 6, 1 = no retry)", - ) - flags.DurationVar( - &pushConfig.RetryConfig.PerAttemptTimeout, - "per-attempt-timeout", - 0, - "Timeout for a single transfer attempt (0 = derive from file size; <0 = disabled)", - ) if err := viper.BindPFlags(flags); err != nil { panic(fmt.Errorf("bind push flags to viper: %w", err)) diff --git a/pkg/backend/build.go b/pkg/backend/build.go index 346c44e4..3241446d 100644 --- a/pkg/backend/build.go +++ b/pkg/backend/build.go @@ -44,7 +44,11 @@ const ( ) // Build builds the user materials into the model artifact which follows the Model Spec. -func (b *backend) Build(ctx context.Context, modelfilePath, workDir, target string, cfg *config.Build) error { +func (b *backend) Build( + ctx context.Context, + modelfilePath, workDir, target string, + cfg *config.Build, +) error { logrus.Infof("build: building artifact %s", target) // parse the repo name and tag name from target. ref, err := ParseReference(target) @@ -139,7 +143,6 @@ func (b *backend) Build(ctx context.Context, modelfilePath, workDir, target stri }, retrypolicy.DoOpts{ FileSize: 0, // config is small FileName: "config", - Config: &cfg.RetryConfig, }); err != nil { return fmt.Errorf("failed to build model config: %w", err) } @@ -161,7 +164,6 @@ func (b *backend) Build(ctx context.Context, modelfilePath, workDir, target stri }, retrypolicy.DoOpts{ FileSize: 0, // manifest is small FileName: "manifest", - Config: &cfg.RetryConfig, }); err != nil { return fmt.Errorf("failed to build model manifest: %w", err) } @@ -170,7 +172,10 @@ func (b *backend) Build(ctx context.Context, modelfilePath, workDir, target stri return nil } -func (b *backend) getProcessors(modelfile modelfile.Modelfile, cfg *config.Build) []processor.Processor { +func (b *backend) getProcessors( + modelfile modelfile.Modelfile, + cfg *config.Build, +) []processor.Processor { processors := []processor.Processor{} if configs := modelfile.GetConfigs(); len(configs) > 0 { @@ -178,7 +183,10 @@ func (b *backend) getProcessors(modelfile modelfile.Modelfile, cfg *config.Build if cfg.Raw { mediaType = modelspec.MediaTypeModelWeightConfigRaw } - processors = append(processors, processor.NewModelConfigProcessor(b.store, mediaType, configs, "")) + processors = append( + processors, + processor.NewModelConfigProcessor(b.store, mediaType, configs, ""), + ) } if models := modelfile.GetModels(); len(models) > 0 { @@ -209,10 +217,23 @@ func (b *backend) getProcessors(modelfile modelfile.Modelfile, cfg *config.Build } // process walks the user work directory and process the identified files. -func (b *backend) process(ctx context.Context, builder build.Builder, workDir string, pb *internalpb.ProgressBar, cfg *config.Build, processors ...processor.Processor) ([]ocispec.Descriptor, error) { +func (b *backend) process( + ctx context.Context, + builder build.Builder, + workDir string, + pb *internalpb.ProgressBar, + cfg *config.Build, + processors ...processor.Processor, +) ([]ocispec.Descriptor, error) { descriptors := []ocispec.Descriptor{} for _, p := range processors { - descs, err := p.Process(ctx, builder, workDir, processor.WithConcurrency(cfg.Concurrency), processor.WithProgressTracker(pb), processor.WithRetryConfig(cfg.RetryConfig)) + descs, err := p.Process( + ctx, + builder, + workDir, + processor.WithConcurrency(cfg.Concurrency), + processor.WithProgressTracker(pb), + ) if err != nil { return nil, err } diff --git a/pkg/backend/fetch.go b/pkg/backend/fetch.go index e62f5c67..c6ef90ac 100644 --- a/pkg/backend/fetch.go +++ b/pkg/backend/fetch.go @@ -123,7 +123,6 @@ func (b *backend) Fetch(ctx context.Context, target string, cfg *config.Fetch) e }, retrypolicy.DoOpts{ FileSize: layer.Size, FileName: annoFilepath, - Config: &cfg.RetryConfig, OnRetry: func(attempt uint, reason string, backoff time.Duration) { if bar := pb.Get(layer.Digest.String()); bar != nil { bar.SetRefill(bar.Current()) diff --git a/pkg/backend/fetch_by_d7y.go b/pkg/backend/fetch_by_d7y.go index e6ff9433..a9ff7467 100644 --- a/pkg/backend/fetch_by_d7y.go +++ b/pkg/backend/fetch_by_d7y.go @@ -179,7 +179,6 @@ func fetchLayerByDragonfly(ctx context.Context, pb *internalpb.ProgressBar, clie }, retrypolicy.DoOpts{ FileSize: desc.Size, FileName: annoFilepath, - Config: &cfg.RetryConfig, OnRetry: func(attempt uint, reason string, backoff time.Duration) { if bar := pb.Get(desc.Digest.String()); bar != nil { bar.SetRefill(bar.Current()) diff --git a/pkg/backend/processor/base.go b/pkg/backend/processor/base.go index 87549011..4e002724 100644 --- a/pkg/backend/processor/base.go +++ b/pkg/backend/processor/base.go @@ -177,7 +177,6 @@ func (b *base) Process(ctx context.Context, builder build.Builder, workDir strin }, retrypolicy.DoOpts{ FileSize: fileSize, FileName: filepath.Base(path), - Config: processOpts.retryConfig, }); err != nil { logrus.Error(err) mu.Lock() diff --git a/pkg/backend/processor/options.go b/pkg/backend/processor/options.go index 43d5d82e..4558f514 100644 --- a/pkg/backend/processor/options.go +++ b/pkg/backend/processor/options.go @@ -18,7 +18,6 @@ package processor import ( "github.com/modelpack/modctl/internal/pb" - "github.com/modelpack/modctl/pkg/retrypolicy" ) type ProcessOption func(*processOptions) @@ -28,8 +27,6 @@ type processOptions struct { concurrency int // progressTracker is the progress bar to use for tracking progress. progressTracker *pb.ProgressBar - // retryConfig is the retry configuration to use for processing. - retryConfig *retrypolicy.Config } func WithConcurrency(concurrency int) ProcessOption { @@ -43,9 +40,3 @@ func WithProgressTracker(tracker *pb.ProgressBar) ProcessOption { o.progressTracker = tracker } } - -func WithRetryConfig(cfg retrypolicy.Config) ProcessOption { - return func(o *processOptions) { - o.retryConfig = &cfg - } -} diff --git a/pkg/backend/pull.go b/pkg/backend/pull.go index 3666cf27..768e732b 100644 --- a/pkg/backend/pull.go +++ b/pkg/backend/pull.go @@ -134,7 +134,6 @@ func (b *backend) Pull(ctx context.Context, target string, cfg *config.Pull) err }, retrypolicy.DoOpts{ FileSize: layer.Size, FileName: layer.Digest.String(), - Config: &cfg.RetryConfig, OnRetry: func(attempt uint, reason string, backoff time.Duration) { pb.Placeholder(layer.Digest.String(), internalpb.NormalizePrompt("Pulling blob"), layer.Size) }, @@ -177,7 +176,6 @@ func (b *backend) Pull(ctx context.Context, target string, cfg *config.Pull) err }, retrypolicy.DoOpts{ FileSize: manifest.Config.Size, FileName: "config", - Config: &cfg.RetryConfig, OnRetry: func(attempt uint, reason string, backoff time.Duration) { pb.Placeholder(manifest.Config.Digest.String(), internalpb.NormalizePrompt("Pulling config"), manifest.Config.Size) }, @@ -191,7 +189,6 @@ func (b *backend) Pull(ctx context.Context, target string, cfg *config.Pull) err }, retrypolicy.DoOpts{ FileSize: manifestDesc.Size, FileName: "manifest", - Config: &cfg.RetryConfig, OnRetry: func(attempt uint, reason string, backoff time.Duration) { pb.Placeholder(manifestDesc.Digest.String(), internalpb.NormalizePrompt("Pulling manifest"), manifestDesc.Size) }, diff --git a/pkg/backend/pull_by_d7y.go b/pkg/backend/pull_by_d7y.go index 138be9ea..170a60c6 100644 --- a/pkg/backend/pull_by_d7y.go +++ b/pkg/backend/pull_by_d7y.go @@ -206,7 +206,6 @@ func processLayer(ctx context.Context, pb *internalpb.ProgressBar, client dfdaem }, retrypolicy.DoOpts{ FileSize: desc.Size, FileName: annoFilepath, - Config: &cfg.RetryConfig, OnRetry: func(attempt uint, reason string, backoff time.Duration) { if bar := pb.Get(desc.Digest.String()); bar != nil { bar.SetRefill(bar.Current()) diff --git a/pkg/backend/push.go b/pkg/backend/push.go index 7ea09adb..c7d1c2fa 100644 --- a/pkg/backend/push.go +++ b/pkg/backend/push.go @@ -104,7 +104,6 @@ func (b *backend) Push(ctx context.Context, target string, cfg *config.Push) err }, retrypolicy.DoOpts{ FileSize: layer.Size, FileName: layer.Digest.String(), - Config: &cfg.RetryConfig, OnRetry: func(attempt uint, reason string, backoff time.Duration) { prompt := fmt.Sprintf("%s (retry %d, %s, waiting %s)", internalpb.NormalizePrompt("Copying blob"), attempt, reason, backoff.Truncate(time.Second)) @@ -140,7 +139,6 @@ func (b *backend) Push(ctx context.Context, target string, cfg *config.Push) err }, retrypolicy.DoOpts{ FileSize: manifest.Config.Size, FileName: "config", - Config: &cfg.RetryConfig, OnRetry: func(attempt uint, reason string, backoff time.Duration) { prompt := fmt.Sprintf("%s (retry %d, %s, waiting %s)", internalpb.NormalizePrompt("Copying config"), attempt, reason, backoff.Truncate(time.Second)) @@ -162,7 +160,6 @@ func (b *backend) Push(ctx context.Context, target string, cfg *config.Push) err }, retrypolicy.DoOpts{ FileSize: manifestDesc.Size, FileName: "manifest", - Config: &cfg.RetryConfig, OnRetry: func(attempt uint, reason string, backoff time.Duration) { prompt := fmt.Sprintf("%s (retry %d, %s, waiting %s)", internalpb.NormalizePrompt("Copying manifest"), attempt, reason, backoff.Truncate(time.Second)) diff --git a/pkg/config/build.go b/pkg/config/build.go index 861cc002..d3ec828c 100644 --- a/pkg/config/build.go +++ b/pkg/config/build.go @@ -18,8 +18,6 @@ package config import ( "fmt" - - "github.com/modelpack/modctl/pkg/retrypolicy" ) const ( @@ -40,7 +38,6 @@ type Build struct { Raw bool Reasoning bool NoCreationTime bool - RetryConfig retrypolicy.Config } func NewBuild() *Build { diff --git a/pkg/config/fetch.go b/pkg/config/fetch.go index 7fb8f899..472bb6a7 100644 --- a/pkg/config/fetch.go +++ b/pkg/config/fetch.go @@ -20,8 +20,6 @@ import ( "fmt" "io" "os" - - "github.com/modelpack/modctl/pkg/retrypolicy" ) const ( @@ -40,7 +38,6 @@ type Fetch struct { ProgressWriter io.Writer DisableProgress bool Hooks PullHooks - RetryConfig retrypolicy.Config } func NewFetch() *Fetch { diff --git a/pkg/config/pull.go b/pkg/config/pull.go index 3c290e2c..6d33715a 100644 --- a/pkg/config/pull.go +++ b/pkg/config/pull.go @@ -22,8 +22,6 @@ import ( "os" ocispec "github.com/opencontainers/image-spec/specs-go/v1" - - "github.com/modelpack/modctl/pkg/retrypolicy" ) const ( @@ -42,7 +40,6 @@ type Pull struct { ProgressWriter io.Writer DisableProgress bool DragonflyEndpoint string - RetryConfig retrypolicy.Config } func NewPull() *Pull { diff --git a/pkg/config/push.go b/pkg/config/push.go index 9ba596bd..dc75b199 100644 --- a/pkg/config/push.go +++ b/pkg/config/push.go @@ -18,8 +18,6 @@ package config import ( "fmt" - - "github.com/modelpack/modctl/pkg/retrypolicy" ) const ( @@ -32,7 +30,6 @@ type Push struct { PlainHTTP bool Insecure bool Nydusify bool - RetryConfig retrypolicy.Config } func NewPush() *Push {