Skip to content

Commit e2ee6fb

Browse files
committed
refactor download functions
1 parent 603ac7e commit e2ee6fb

File tree

7 files changed

+59
-111
lines changed

7 files changed

+59
-111
lines changed

internal/updater/download_image.go

Lines changed: 9 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,14 @@
1616
package updater
1717

1818
import (
19-
"bytes"
2019
"context"
21-
"crypto/sha256"
22-
"encoding/hex"
2320
"fmt"
24-
"io"
2521
"os"
2622

2723
"github.com/arduino/go-paths-helper"
2824
"github.com/codeclysm/extract/v4"
2925
"github.com/schollz/progressbar/v3"
26+
"go.bug.st/downloader/v2"
3027

3128
"github.com/arduino/arduino-flasher-cli/cmd/feedback"
3229
"github.com/arduino/arduino-flasher-cli/cmd/i18n"
@@ -65,58 +62,25 @@ func DownloadAndExtract(ctx context.Context, targetVersion string, temp *paths.P
6562
}
6663

6764
func DownloadImage(ctx context.Context, targetVersion string, downloadPath *paths.Path) (*paths.Path, string, error) {
68-
var err error
69-
7065
client := NewClient()
71-
manifest, err := client.GetInfoManifest(ctx)
72-
if err != nil {
73-
return nil, "", err
74-
}
75-
76-
var rel *Release
77-
if targetVersion == "latest" || targetVersion == manifest.Latest.Version {
78-
rel = &manifest.Latest
79-
} else {
80-
for _, r := range manifest.Releases {
81-
if targetVersion == r.Version {
82-
rel = &r
83-
break
84-
}
85-
}
86-
}
87-
88-
if rel == nil {
89-
return nil, "", fmt.Errorf("could not find Debian image %s", targetVersion)
90-
}
91-
92-
download, size, err := client.FetchZip(ctx, rel.Url)
66+
rel, err := client.GetReleaseByVersion(ctx, targetVersion)
9367
if err != nil {
94-
return nil, "", fmt.Errorf("could not fetch Debian image: %w", err)
68+
return nil, "", fmt.Errorf("could not get release info: %w", err)
9569
}
96-
defer download.Close()
9770

9871
tmpZip := downloadPath.Join("arduino-unoq-debian-image-" + rel.Version + ".tar.zst")
99-
tmpZipFile, err := tmpZip.Create()
100-
if err != nil {
101-
return nil, "", err
102-
}
103-
defer tmpZipFile.Close()
10472

105-
// Download and keep track of the progress
10673
bar := progressbar.DefaultBytes(
107-
size,
74+
0,
10875
i18n.Tr("Downloading Debian image version %s", rel.Version),
10976
)
110-
checksum := sha256.New()
111-
if _, err := io.Copy(io.MultiWriter(checksum, tmpZipFile, bar), download); err != nil {
112-
return nil, "", err
77+
callback := func(current, total int64) {
78+
bar.AddMax64(total)
79+
_ = bar.Set64(current)
11380
}
11481

115-
// Check the hash
116-
if sha256Byte, err := hex.DecodeString(rel.Sha256); err != nil {
117-
return nil, "", fmt.Errorf("could not convert sha256 from hex to bytes: %w", err)
118-
} else if s := checksum.Sum(nil); !bytes.Equal(s, sha256Byte) {
119-
return nil, "", fmt.Errorf("bad hash: %x (expected %x)", s, sha256Byte)
82+
if err := client.DownloadFile(ctx, tmpZip, rel, callback, downloader.Config{}); err != nil {
83+
return nil, "", fmt.Errorf("could not download Debian image: %w", err)
12084
}
12185

12286
return tmpZip, rel.Version, nil

internal/updater/http_client.go

Lines changed: 33 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,11 @@ import (
2828
"time"
2929

3030
"github.com/arduino/go-paths-helper"
31+
"github.com/shirou/gopsutil/v4/disk"
3132
"go.bug.st/downloader/v2"
3233
"go.bug.st/f"
3334

3435
"github.com/arduino/arduino-flasher-cli/cmd/i18n"
35-
"github.com/arduino/arduino-flasher-cli/rpc/cc/arduino/flasher/v1"
3636
)
3737

3838
var baseURL = f.Must(url.Parse("https://downloads.arduino.cc"))
@@ -41,15 +41,10 @@ const pathRelease = "debian-im/Stable"
4141

4242
// Client holds the base URL, command name, allows custom HTTP client, and optional headers.
4343
type Client struct {
44-
HTTPClient HTTPDoer
44+
HTTPClient *http.Client
4545
Headers map[string]string // Optional headers to add to each request
4646
}
4747

48-
// HTTPDoer is an interface for http.Client or mocks.
49-
type HTTPDoer interface {
50-
Do(req *http.Request) (*http.Response, error)
51-
}
52-
5348
// Option is a functional option for configuring Client.
5449
type Option func(*Client)
5550

@@ -61,7 +56,7 @@ func WithHeaders(headers map[string]string) Option {
6156
}
6257

6358
// WithHTTPClient sets a custom HTTP client for the Client.
64-
func WithHTTPClient(client HTTPDoer) Option {
59+
func WithHTTPClient(client *http.Client) Option {
6560
return func(c *Client) {
6661
c.HTTPClient = client
6762
}
@@ -116,45 +111,51 @@ func (c *Client) GetInfoManifest(ctx context.Context) (Manifest, error) {
116111
return res, nil
117112
}
118113

119-
// FetchZip fetches the Debian image archive.
120-
func (c *Client) FetchZip(ctx context.Context, zipURL string) (io.ReadCloser, int64, error) {
121-
req, err := http.NewRequestWithContext(ctx, "GET", zipURL, nil)
114+
func (c *Client) GetReleaseByVersion(ctx context.Context, version string) (Release, error) {
115+
manifest, err := c.GetInfoManifest(ctx)
122116
if err != nil {
123-
return nil, 0, fmt.Errorf("failed to create request: %w", err)
117+
return Release{}, err
124118
}
125-
c.addHeaders(req)
126-
// #nosec G107 -- zipURL is constructed from trusted config and parameters
127-
resp, err := c.HTTPClient.Do(req)
128-
if err != nil {
129-
return nil, 0, fmt.Errorf("failed to GET zip: %w", err)
130-
}
131-
if resp.StatusCode != http.StatusOK {
132-
resp.Body.Close()
133-
return nil, 0, fmt.Errorf("bad http status from %s: %v", zipURL, resp.Status)
119+
120+
if version == "latest" || version == manifest.Latest.Version {
121+
return manifest.Latest, nil
122+
} else {
123+
for _, r := range manifest.Releases {
124+
if version == r.Version {
125+
return r, nil
126+
}
127+
}
134128
}
135-
return resp.Body, resp.ContentLength, nil
129+
130+
return Release{}, fmt.Errorf("could not find Debian image %s", version)
136131
}
137132

133+
type downloadCallback func(current, total int64)
134+
138135
// DownloadFile downloads a file from a URL into the specified path. An optional config and options may be passed (or nil to use the defaults).
139136
// A DownloadProgressCB callback function must be passed to monitor download progress.
140137
// If a not empty queryParameter is passed, it is appended to the URL for analysis purposes.
141-
func DownloadFile(ctx context.Context, path *paths.Path, rel *Release, downloadCB flasher.DownloadProgressCB, config downloader.Config, options ...downloader.DownloadOptions) (returnedError error) {
142-
downloadCB.Start(rel.Url, rel.Version)
143-
defer func() {
144-
if returnedError == nil {
145-
downloadCB.End(true, "")
146-
} else {
147-
downloadCB.End(false, returnedError.Error())
148-
}
149-
}()
138+
func (c *Client) DownloadFile(ctx context.Context, path *paths.Path, rel Release, cb downloadCallback, config downloader.Config, options ...downloader.DownloadOptions) (returnedError error) {
139+
140+
// Check if there is enough free disk space before downloading and extracting an image
141+
dk, err := disk.Usage(path.String())
142+
if err != nil {
143+
return err
144+
}
145+
// TODO: improve disk space check with Content-Length header
146+
if dk.Free/GiB < DownloadDiskSpace {
147+
return fmt.Errorf("download and extraction requires up to %d GiB of free space", DownloadDiskSpace)
148+
}
150149

150+
config.HttpClient = *c.HTTPClient
151+
// TODO: add headers to downloader's http client
151152
d, err := downloader.DownloadWithConfigAndContext(ctx, path.String(), rel.Url, config, options...)
152153
if err != nil {
153154
return err
154155
}
155156

156157
err = d.RunAndPoll(func(downloaded int64) {
157-
downloadCB.Update(downloaded, d.Size())
158+
cb(downloaded, d.Size())
158159
}, 250*time.Millisecond)
159160
if err != nil {
160161
return err

rpc/cc/arduino/flasher/v1/commands.pb.go

Lines changed: 3 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

rpc/cc/arduino/flasher/v1/commands.proto

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ message FlashRequest {
4141
// The version of the Debian image to download.
4242
string version = 1;
4343
// The path in which the image will be downloaded and extracted.
44+
// If an image file is already present, it will not be downloaded again.
4445
string temp_path = 2;
4546
// Preserve user partition if possible.
4647
bool preserve_user = 3;

rpc/cc/arduino/flasher/v1/commands_grpc.pb.go

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

rpc/cc/arduino/flasher/v1/common.pb.go

Lines changed: 3 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

service/service_flash.go

Lines changed: 9 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ import (
2020

2121
"github.com/arduino/go-paths-helper"
2222
"github.com/codeclysm/extract/v4"
23-
"github.com/shirou/gopsutil/v4/disk"
2423
"go.bug.st/downloader/v2"
2524

2625
"github.com/arduino/arduino-flasher-cli/internal/updater"
@@ -37,7 +36,7 @@ func (s *flasherServerImpl) Flash(req *flasher.FlashRequest, stream flasher.Flas
3736
responseCallback = func(*flasher.FlashResponse) error { return nil }
3837
}
3938
ctx := stream.Context()
40-
downloadCB := func(msg *flasher.DownloadProgress) {
39+
var downloadCB flasher.DownloadProgressCB = func(msg *flasher.DownloadProgress) {
4140
_ = responseCallback(&flasher.FlashResponse{
4241
Message: &flasher.FlashResponse_DownloadProgress{
4342
DownloadProgress: msg,
@@ -59,43 +58,23 @@ func (s *flasherServerImpl) Flash(req *flasher.FlashRequest, stream flasher.Flas
5958
})
6059
}
6160

62-
// Check if there is enough free disk space before downloading and extracting an image
63-
d, err := disk.Usage(req.TempPath)
64-
if err != nil {
65-
return err
66-
}
67-
if d.Free/updater.GiB < updater.DownloadDiskSpace {
68-
return fmt.Errorf("download and extraction requires up to %d GiB of free space", updater.DownloadDiskSpace)
69-
}
70-
7161
client := updater.NewClient()
72-
manifest, err := client.GetInfoManifest(ctx)
73-
if err != nil {
74-
return err
75-
}
76-
77-
var rel *updater.Release
78-
if req.Version == "latest" || req.Version == manifest.Latest.Version {
79-
rel = &manifest.Latest
80-
} else {
81-
for _, r := range manifest.Releases {
82-
if req.Version == r.Version {
83-
rel = &r
84-
break
85-
}
86-
}
87-
}
8862

89-
if rel == nil {
90-
return fmt.Errorf("could not find Debian image %s", req.Version)
63+
rel, err := client.GetReleaseByVersion(ctx, req.GetVersion())
64+
if err != nil {
65+
return fmt.Errorf("could not get release info: %w", err)
9166
}
9267

9368
tmpZip := paths.New(req.GetTempPath(), "arduino-unoq-debian-image-"+rel.Version+".tar.zst")
9469
defer func() { _ = tmpZip.RemoveAll() }()
9570

96-
if err := updater.DownloadFile(ctx, tmpZip, rel, downloadCB, downloader.Config{}); err != nil {
71+
downloadCB.Start(rel.Url, rel.Version)
72+
if err := client.DownloadFile(ctx, tmpZip, rel, downloadCB.Update, downloader.Config{}); err != nil {
73+
// FIXME: Maybe this is redundant?
74+
downloadCB.End(false, err.Error())
9775
return err
9876
}
77+
downloadCB.End(true, "")
9978

10079
tmpZipFile, err := tmpZip.Open()
10180
if err != nil {

0 commit comments

Comments
 (0)