diff --git a/updater.go b/updater.go index 4a7843a..beab11b 100644 --- a/updater.go +++ b/updater.go @@ -342,8 +342,19 @@ func (u *Updater) downloadFile(url, dst string) error { } defer f.Close() - _, err = io.Copy(f, io.LimitReader(resp.Body, maxDownloadBytes)) - return err + // Read one byte past the cap so we can distinguish "exactly at the limit" + // from "exceeded the limit". A plain io.LimitReader(maxDownloadBytes) would + // silently truncate oversize archives — the SHA256 check would then fail + // with a confusing "checksum mismatch" instead of telling the operator + // the archive is too large. + n, err := io.Copy(f, io.LimitReader(resp.Body, maxDownloadBytes+1)) + if err != nil { + return err + } + if n > maxDownloadBytes { + return fmt.Errorf("archive exceeds max download size %d bytes", maxDownloadBytes) + } + return nil } // VerifyChecksum checks the SHA256 of archivePath against the checksums file. @@ -440,6 +451,18 @@ func extractTarGz(archivePath, destDir string) error { } func replaceBinary(src, dst string) error { + // Refuse to swap in a zero-byte staged binary. A 0-byte rename over the + // live daemon binary would brick the daemon on next start; better to fail + // loudly here so the operator sees the update failed and the existing + // binary keeps running. + fi, err := os.Stat(src) + if err != nil { + return err + } + if fi.Size() == 0 { + return fmt.Errorf("refusing to replace binary with empty source: %s", src) + } + // Write to a temp file beside the destination, then atomically rename. // This avoids "text file busy" on Linux (rename unlinks the old inode // while the running process keeps its file descriptor open) and prevents diff --git a/zz_replace_binary_and_download_bug_test.go b/zz_replace_binary_and_download_bug_test.go new file mode 100644 index 0000000..b8b8c48 --- /dev/null +++ b/zz_replace_binary_and_download_bug_test.go @@ -0,0 +1,99 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later + +//go:build !no_updater +// +build !no_updater + +package updater + +import ( + "fmt" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "testing" +) + +// TestReplaceBinary_RejectsEmptySource ensures that a zero-byte staged +// binary never atomic-renames over the live destination. Previously the +// rename would succeed and brick the daemon on next start. +func TestReplaceBinary_RejectsEmptySource(t *testing.T) { + t.Parallel() + dir := t.TempDir() + + src := filepath.Join(dir, "empty-src") + if err := os.WriteFile(src, nil, 0755); err != nil { + t.Fatalf("write empty src: %v", err) + } + + dst := filepath.Join(dir, "live-bin") + original := []byte("ORIGINAL DAEMON BINARY") + if err := os.WriteFile(dst, original, 0755); err != nil { + t.Fatalf("write dst: %v", err) + } + + err := replaceBinary(src, dst) + if err == nil { + t.Fatal("replaceBinary returned nil for zero-byte source; expected error") + } + if !strings.Contains(err.Error(), "empty source") { + t.Errorf("error %q does not mention empty source", err) + } + + // Destination must be untouched — both contents and that no .new + // scratch file was left behind beside it. + got, readErr := os.ReadFile(dst) + if readErr != nil { + t.Fatalf("read dst after failed replace: %v", readErr) + } + if string(got) != string(original) { + t.Errorf("dst was modified: got %q, want %q", got, original) + } + if _, err := os.Stat(dst + ".new"); err == nil { + t.Errorf("scratch file %s.new was left behind", dst) + } +} + +// TestDownloadRejectsOversizedArchive ensures that downloadFile returns a +// clear "exceeds max download size" error instead of silently truncating +// the body (which previously surfaced as a confusing "checksum mismatch"). +func TestDownloadRejectsOversizedArchive(t *testing.T) { + t.Parallel() + + // Serve maxDownloadBytes + 1 KiB of zeros. We don't allocate the whole + // body in memory — stream it instead, so the test stays cheap. + const overshoot = 1024 + totalBytes := int64(maxDownloadBytes) + overshoot + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/octet-stream") + w.Header().Set("Content-Length", fmt.Sprintf("%d", totalBytes)) + w.WriteHeader(http.StatusOK) + buf := make([]byte, 64*1024) + var sent int64 + for sent < totalBytes { + n := int64(len(buf)) + if totalBytes-sent < n { + n = totalBytes - sent + } + if _, err := w.Write(buf[:n]); err != nil { + return + } + sent += n + } + })) + defer srv.Close() + + u := New(Config{}) + dst := filepath.Join(t.TempDir(), "oversize.tar.gz") + err := u.downloadFile(srv.URL, dst) + if err == nil { + t.Fatal("downloadFile returned nil for oversize archive; expected error") + } + if !strings.Contains(err.Error(), "exceeds max download size") { + t.Errorf("error %q does not mention exceeding max download size", err) + } + if strings.Contains(strings.ToLower(err.Error()), "checksum") { + t.Errorf("error %q surfaces as a checksum problem; should fail explicitly on size", err) + } +}