Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions cmd/wfctl/download_progress.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package main

import (
"bytes"
"fmt"
"io"
"os"
"time"
)

const downloadProgressInterval = 2 * time.Second

func readDownloadBodyWithProgress(r io.Reader, total int64) ([]byte, error) {
var buf bytes.Buffer
tracker := newDownloadProgress(os.Stderr, total)
if _, err := io.Copy(io.MultiWriter(&buf, tracker), r); err != nil {
return nil, err
}
tracker.finish()
return buf.Bytes(), nil
}

type downloadProgress struct {
w io.Writer
total int64
read int64
last time.Time
}

func newDownloadProgress(w io.Writer, total int64) *downloadProgress {
p := &downloadProgress{w: w, total: total}
p.emit("Download progress")
return p
}

func (p *downloadProgress) Write(data []byte) (int, error) {
n := len(data)
p.read += int64(n)
now := time.Now()
if p.last.IsZero() || now.Sub(p.last) >= downloadProgressInterval || (p.total > 0 && p.read >= p.total) {
p.emit("Download progress")
}
return n, nil
}

func (p *downloadProgress) finish() {
p.emit("Download complete")
}

func (p *downloadProgress) emit(prefix string) {
if p.w == nil {
return
}
p.last = time.Now()
if p.total > 0 {
percent := float64(p.read) / float64(p.total) * 100
fmt.Fprintf(p.w, "%s: %s/%s (%.0f%%)\n", prefix, formatDownloadBytes(p.read), formatDownloadBytes(p.total), percent)
return
}
fmt.Fprintf(p.w, "%s: %s\n", prefix, formatDownloadBytes(p.read))
}

func formatDownloadBytes(n int64) string {
const unit = 1024
if n < unit {
return fmt.Sprintf("%d B", n)
}
div, exp := int64(unit), 0
for next := n / unit; next >= unit && exp < 4; next /= unit {
div *= unit
exp++
}
return fmt.Sprintf("%.1f %ciB", float64(n)/float64(div), "KMGTPE"[exp])
Comment on lines +68 to +73
}
20 changes: 15 additions & 5 deletions cmd/wfctl/plugin_install.go
Original file line number Diff line number Diff line change
Expand Up @@ -1054,6 +1054,8 @@ var gitHubAPIBaseURL = "https://api.github.com"
// independently. A generous timeout covers large binary asset downloads.
var gitHubAPIClient = &http.Client{Timeout: 10 * time.Minute}

const gitHubReleaseMetadataTimeout = 30 * time.Second

// gitHubToken returns the first non-empty GitHub token from the environment,
// checking RELEASES_TOKEN, GH_TOKEN, and GITHUB_TOKEN in order.
func gitHubToken() string {
Expand Down Expand Up @@ -1119,7 +1121,9 @@ func downloadGitHubReleaseAsset(owner, repo, tag, filename, token string) ([]byt
neturl.PathEscape(repo),
neturl.PathEscape(tag),
)
req, err := http.NewRequest(http.MethodGet, releaseURL, nil)
metadataCtx, metadataCancel := context.WithTimeout(context.Background(), gitHubReleaseMetadataTimeout)
defer metadataCancel()
req, err := http.NewRequestWithContext(metadataCtx, http.MethodGet, releaseURL, nil)
if err != nil {
return nil, err
}
Expand All @@ -1146,6 +1150,7 @@ func downloadGitHubReleaseAsset(owner, repo, tag, filename, token string) ([]byt
if err := json.NewDecoder(resp.Body).Decode(&release); err != nil {
return nil, fmt.Errorf("decode GitHub release response: %w", err)
}
metadataCancel()

var assetID int64
for _, a := range release.Assets {
Expand All @@ -1165,7 +1170,9 @@ func downloadGitHubReleaseAsset(owner, repo, tag, filename, token string) ([]byt
neturl.PathEscape(repo),
assetID,
)
req2, err := http.NewRequest(http.MethodGet, assetURL, nil)
assetCtx, assetCancel := context.WithTimeout(context.Background(), downloadTimeout)
defer assetCancel()
req2, err := http.NewRequestWithContext(assetCtx, http.MethodGet, assetURL, nil)
if err != nil {
return nil, err
}
Expand All @@ -1182,7 +1189,7 @@ func downloadGitHubReleaseAsset(owner, repo, tag, filename, token string) ([]byt
if resp2.StatusCode != http.StatusOK {
return nil, fmt.Errorf("GitHub asset download API: HTTP %d for asset %d", resp2.StatusCode, assetID)
}
return io.ReadAll(resp2.Body)
return readDownloadBodyWithProgress(resp2.Body, resp2.ContentLength)
}

// downloadURL fetches a URL and returns the body bytes.
Expand All @@ -1203,7 +1210,10 @@ func downloadURL(rawURL string) ([]byte, error) {
}

// Public repos and non-release GitHub URLs: direct GET with optional Bearer.
req, err := http.NewRequest(http.MethodGet, rawURL, nil) //nolint:gosec // G107: URL comes from registry manifest
ctx, cancel := context.WithTimeout(context.Background(), downloadTimeout)
defer cancel()

req, err := http.NewRequestWithContext(ctx, http.MethodGet, rawURL, nil) //nolint:gosec // G107: URL comes from registry manifest
if err != nil {
return nil, err
}
Expand All @@ -1220,7 +1230,7 @@ func downloadURL(rawURL string) ([]byte, error) {
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("HTTP %d from %s", resp.StatusCode, rawURL)
}
return io.ReadAll(resp.Body)
return readDownloadBodyWithProgress(resp.Body, resp.ContentLength)
}

// verifyChecksum checks that data matches the expected SHA256 hex string.
Expand Down
124 changes: 124 additions & 0 deletions cmd/wfctl/plugin_install_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"encoding/json"
"flag"
"fmt"
"io"
"net/http"
"net/http/httptest"
"net/url"
Expand All @@ -18,6 +19,7 @@ import (
"strings"
"sync"
"testing"
"time"
)

// captureTransport is a test http.RoundTripper that:
Expand Down Expand Up @@ -115,6 +117,59 @@ func installTestClient(t *testing.T, ct *captureTransport) {
t.Cleanup(func() { http.DefaultClient = orig })
}

func TestDownloadURL_DirectGetUsesBoundedRequestContext(t *testing.T) {
orig := http.DefaultClient
http.DefaultClient = &http.Client{
Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
if _, ok := req.Context().Deadline(); !ok {
return nil, fmt.Errorf("request has no deadline")
}
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader("ok")),
Header: make(http.Header),
Request: req,
}, nil
}),
}
t.Cleanup(func() { http.DefaultClient = orig })

got, err := downloadURL("https://example.com/plugin.tar.gz")
if err != nil {
t.Fatalf("downloadURL: %v", err)
}
if string(got) != "ok" {
t.Fatalf("downloadURL body = %q, want ok", got)
}
}
Comment on lines +137 to +144

func TestDownloadURL_LargeDirectDownloadEmitsProgress(t *testing.T) {
orig := http.DefaultClient
http.DefaultClient = &http.Client{
Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusOK,
ContentLength: int64(len("fake tarball bytes")),
Body: io.NopCloser(strings.NewReader("fake tarball bytes")),
Header: make(http.Header),
Request: req,
}, nil
}),
}
t.Cleanup(func() { http.DefaultClient = orig })

stderr, err := captureStderr(t, func() error {
_, err := downloadURL("https://example.com/plugin.tar.gz")
return err
})
if err != nil {
t.Fatalf("downloadURL: %v", err)
}
if !strings.Contains(stderr, "Download progress") || !strings.Contains(stderr, "Download complete") {
t.Fatalf("stderr = %q, want progress and completion indicators", stderr)
}
}

// TestDownloadURL_GitHubAuthHeader verifies that downloadURL injects a Bearer
// Authorization header for non-release github.com URLs (direct-download path)
// using the first non-empty token env var (RELEASES_TOKEN > GH_TOKEN >
Expand Down Expand Up @@ -390,6 +445,75 @@ func TestDownloadURL_PrivateReleaseAsset(t *testing.T) {
}
}

func TestDownloadURL_PrivateReleaseAssetUsesFreshAssetDownloadDeadline(t *testing.T) {
const (
wantAssetID = int64(99)
wantFilename = "plugin-linux-amd64.tar.gz"
wantTag = "v1.0.0"
wantOwner = "GoCodeAlone"
wantRepo = "test-plugin"
wantToken = "test-secret-token"
)

var metadataDeadline, assetDeadline time.Time
rt := roundTripFunc(func(req *http.Request) (*http.Response, error) {
deadline, ok := req.Context().Deadline()
if !ok {
return nil, fmt.Errorf("%s has no request deadline", req.URL.Path)
}

switch req.URL.Path {
case fmt.Sprintf("/repos/%s/%s/releases/tags/%s", wantOwner, wantRepo, wantTag):
metadataDeadline = deadline
time.Sleep(10 * time.Millisecond)
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(
fmt.Sprintf(`{"assets":[{"id":%d,"name":%q}]}`, wantAssetID, wantFilename),
)),
Header: make(http.Header),
Request: req,
}, nil
case fmt.Sprintf("/repos/%s/%s/releases/assets/%d", wantOwner, wantRepo, wantAssetID):
assetDeadline = deadline
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader("fake tarball bytes")),
Header: make(http.Header),
Request: req,
}, nil
default:
return nil, fmt.Errorf("unexpected path %s", req.URL.Path)
}
})

origAPIBase := gitHubAPIBaseURL
origAPIClient := gitHubAPIClient
gitHubAPIBaseURL = "https://api.github.test"
gitHubAPIClient = &http.Client{Transport: rt}
t.Cleanup(func() {
gitHubAPIBaseURL = origAPIBase
gitHubAPIClient = origAPIClient
})

t.Setenv("RELEASES_TOKEN", wantToken)
for _, k := range []string{"GH_TOKEN", "GITHUB_TOKEN"} {
t.Setenv(k, "")
}

rawURL := fmt.Sprintf("https://github.com/%s/%s/releases/download/%s/%s",
wantOwner, wantRepo, wantTag, wantFilename)
if _, err := downloadURL(rawURL); err != nil {
t.Fatalf("downloadURL: %v", err)
}
if metadataDeadline.IsZero() || assetDeadline.IsZero() {
t.Fatalf("missing recorded deadlines: metadata=%v asset=%v", metadataDeadline, assetDeadline)
}
if !assetDeadline.After(metadataDeadline) {
t.Fatalf("asset deadline = %v, want after metadata deadline %v", assetDeadline, metadataDeadline)
}
}

// TestDownloadURL_PublicReleaseNoToken verifies that when no token is set,
// downloadURL falls back to a plain GET for release download URLs (public repos).
func TestDownloadURL_PublicReleaseNoToken(t *testing.T) {
Expand Down
8 changes: 6 additions & 2 deletions cmd/wfctl/update.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package main

import (
"context"
"encoding/json"
"flag"
"fmt"
Expand Down Expand Up @@ -353,8 +354,11 @@ func verifyAssetChecksum(checksumAsset *githubAsset, assetName string, data []by

// downloadWithTimeout fetches a URL using an HTTP client with the given timeout.
func downloadWithTimeout(url string, timeout time.Duration) ([]byte, error) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()

client := &http.Client{Timeout: timeout}
req, err := http.NewRequest(http.MethodGet, url, nil) //nolint:noctx // timeout is set on the client
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, err
}
Expand All @@ -367,7 +371,7 @@ func downloadWithTimeout(url string, timeout time.Duration) ([]byte, error) {
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("HTTP %d from %s", resp.StatusCode, url)
}
return io.ReadAll(resp.Body)
return readDownloadBodyWithProgress(resp.Body, resp.ContentLength)
}

// replaceBinary writes newData to execPath atomically by writing to a temp file
Expand Down
Loading