Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions updater.go
Original file line number Diff line number Diff line change
Expand Up @@ -342,8 +342,19 @@ func (u *Updater) downloadFile(url, dst string) error {
}
defer f.Close()

_, err = io.Copy(f, io.LimitReader(resp.Body, maxDownloadBytes))
return err
// Read one byte past the cap so we can distinguish "exactly at the limit"
// from "exceeded the limit". A plain io.LimitReader(maxDownloadBytes) would
// silently truncate oversize archives — the SHA256 check would then fail
// with a confusing "checksum mismatch" instead of telling the operator
// the archive is too large.
n, err := io.Copy(f, io.LimitReader(resp.Body, maxDownloadBytes+1))
if err != nil {
return err
}
if n > maxDownloadBytes {
return fmt.Errorf("archive exceeds max download size %d bytes", maxDownloadBytes)
}
return nil
}

// VerifyChecksum checks the SHA256 of archivePath against the checksums file.
Expand Down Expand Up @@ -440,6 +451,18 @@ func extractTarGz(archivePath, destDir string) error {
}

func replaceBinary(src, dst string) error {
// Refuse to swap in a zero-byte staged binary. A 0-byte rename over the
// live daemon binary would brick the daemon on next start; better to fail
// loudly here so the operator sees the update failed and the existing
// binary keeps running.
fi, err := os.Stat(src)
if err != nil {
return err
}
if fi.Size() == 0 {
return fmt.Errorf("refusing to replace binary with empty source: %s", src)
}

// Write to a temp file beside the destination, then atomically rename.
// This avoids "text file busy" on Linux (rename unlinks the old inode
// while the running process keeps its file descriptor open) and prevents
Expand Down
99 changes: 99 additions & 0 deletions zz_replace_binary_and_download_bug_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
// SPDX-License-Identifier: AGPL-3.0-or-later

//go:build !no_updater
// +build !no_updater

package updater

import (
"fmt"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
)

// TestReplaceBinary_RejectsEmptySource ensures that a zero-byte staged
// binary never atomic-renames over the live destination. Previously the
// rename would succeed and brick the daemon on next start.
func TestReplaceBinary_RejectsEmptySource(t *testing.T) {
t.Parallel()
dir := t.TempDir()

src := filepath.Join(dir, "empty-src")
if err := os.WriteFile(src, nil, 0755); err != nil {
t.Fatalf("write empty src: %v", err)
}

dst := filepath.Join(dir, "live-bin")
original := []byte("ORIGINAL DAEMON BINARY")
if err := os.WriteFile(dst, original, 0755); err != nil {
t.Fatalf("write dst: %v", err)
}

err := replaceBinary(src, dst)
if err == nil {
t.Fatal("replaceBinary returned nil for zero-byte source; expected error")
}
if !strings.Contains(err.Error(), "empty source") {
t.Errorf("error %q does not mention empty source", err)
}

// Destination must be untouched — both contents and that no .new
// scratch file was left behind beside it.
got, readErr := os.ReadFile(dst)
if readErr != nil {
t.Fatalf("read dst after failed replace: %v", readErr)
}
if string(got) != string(original) {
t.Errorf("dst was modified: got %q, want %q", got, original)
}
if _, err := os.Stat(dst + ".new"); err == nil {
t.Errorf("scratch file %s.new was left behind", dst)
}
}

// TestDownloadRejectsOversizedArchive ensures that downloadFile returns a
// clear "exceeds max download size" error instead of silently truncating
// the body (which previously surfaced as a confusing "checksum mismatch").
func TestDownloadRejectsOversizedArchive(t *testing.T) {
t.Parallel()

// Serve maxDownloadBytes + 1 KiB of zeros. We don't allocate the whole
// body in memory — stream it instead, so the test stays cheap.
const overshoot = 1024
totalBytes := int64(maxDownloadBytes) + overshoot
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/octet-stream")
w.Header().Set("Content-Length", fmt.Sprintf("%d", totalBytes))
w.WriteHeader(http.StatusOK)
buf := make([]byte, 64*1024)
var sent int64
for sent < totalBytes {
n := int64(len(buf))
if totalBytes-sent < n {
n = totalBytes - sent
}
if _, err := w.Write(buf[:n]); err != nil {
return
}
sent += n
}
}))
defer srv.Close()

u := New(Config{})
dst := filepath.Join(t.TempDir(), "oversize.tar.gz")
err := u.downloadFile(srv.URL, dst)
if err == nil {
t.Fatal("downloadFile returned nil for oversize archive; expected error")
}
if !strings.Contains(err.Error(), "exceeds max download size") {
t.Errorf("error %q does not mention exceeding max download size", err)
}
if strings.Contains(strings.ToLower(err.Error()), "checksum") {
t.Errorf("error %q surfaces as a checksum problem; should fail explicitly on size", err)
}
}
Loading