diff --git a/zz_coverage_test.go b/zz_coverage_test.go new file mode 100644 index 0000000..57a433b --- /dev/null +++ b/zz_coverage_test.go @@ -0,0 +1,788 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later + +package updater + +import ( + "crypto/sha256" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/url" + "os" + "path/filepath" + "runtime" + "strings" + "testing" + "time" +) + +// rewriteRT is an http.RoundTripper that rewrites every outbound request's +// scheme+host to the supplied test-server URL. This lets tests drive +// fetchLatestRelease (which hardcodes https://api.github.com/...) through +// an httptest server without modifying production code. +type rewriteRT struct { + target *url.URL +} + +func (r *rewriteRT) RoundTrip(req *http.Request) (*http.Response, error) { + req2 := req.Clone(req.Context()) + req2.URL.Scheme = r.target.Scheme + req2.URL.Host = r.target.Host + req2.Host = r.target.Host + return http.DefaultTransport.RoundTrip(req2) +} + +// newRewriteClient produces an *http.Client whose requests are silently +// redirected to srv regardless of the URL the caller passes in. +func newRewriteClient(srv *httptest.Server) *http.Client { + u, _ := url.Parse(srv.URL) + return &http.Client{ + Transport: &rewriteRT{target: u}, + Timeout: 10 * time.Second, + } +} + +// --- fetchLatestRelease error branches ------------------------------------ + +// TestFetchLatestRelease_Non200Body covers the non-200 branch. +func TestFetchLatestRelease_Non200Body(t *testing.T) { + t.Parallel() + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + http.Error(w, "rate limited", http.StatusForbidden) + })) + defer srv.Close() + + u := &Updater{ + config: Config{Repo: "owner/repo", Version: "vTEST"}, + client: newRewriteClient(srv), + stopCh: make(chan struct{}), + } + _, err := u.fetchLatestRelease() + if err == nil { + t.Fatal("expected error for non-200 response") + } + if !strings.Contains(err.Error(), "403") { + t.Errorf("error does not mention status: %v", err) + } +} + +// TestFetchLatestRelease_MalformedJSON covers the decode error branch. +func TestFetchLatestRelease_MalformedJSON(t *testing.T) { + t.Parallel() + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = io.WriteString(w, "this is not json {{{") + })) + defer srv.Close() + u := &Updater{ + config: Config{Repo: "owner/repo"}, + client: newRewriteClient(srv), + stopCh: make(chan struct{}), + } + _, err := u.fetchLatestRelease() + if err == nil { + t.Fatal("expected decode error") + } + if !strings.Contains(err.Error(), "decode") { + t.Errorf("error does not mention decode: %v", err) + } +} + +// TestFetchLatestRelease_TransportError covers the client.Do error branch. +func TestFetchLatestRelease_TransportError(t *testing.T) { + t.Parallel() + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {})) + srv.Close() // close immediately → connection refused + + u := &Updater{ + config: Config{Repo: "owner/repo"}, + client: newRewriteClient(srv), + stopCh: make(chan struct{}), + } + _, err := u.fetchLatestRelease() + if err == nil { + t.Fatal("expected transport error") + } +} + +// TestFetchLatestRelease_HappyPath covers the success branch including +// User-Agent header injection. +func TestFetchLatestRelease_HappyPath(t *testing.T) { + t.Parallel() + uaSeen := "" + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + uaSeen = r.Header.Get("User-Agent") + _ = json.NewEncoder(w).Encode(GitHubRelease{ + TagName: "v9.8.7", + Assets: []GitHubAsset{{Name: "x", BrowserDownloadURL: "http://x"}}, + }) + })) + defer srv.Close() + + u := &Updater{ + config: Config{Repo: "owner/repo", Version: "v1.6.4"}, + client: newRewriteClient(srv), + stopCh: make(chan struct{}), + } + rel, err := u.fetchLatestRelease() + if err != nil { + t.Fatalf("fetchLatestRelease: %v", err) + } + if rel.TagName != "v9.8.7" { + t.Errorf("tag = %q", rel.TagName) + } + if uaSeen != "pilot-updater/v1.6.4" { + t.Errorf("User-Agent = %q, want pilot-updater/v1.6.4", uaSeen) + } +} + +// TestFetchLatestRelease_NoVersionOmitsUA covers the branch where the +// User-Agent header is NOT set (config.Version == ""). +func TestFetchLatestRelease_NoVersionOmitsUA(t *testing.T) { + t.Parallel() + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // When Version is "", New() doesn't set our header, but Go's + // http.Client may still set a default Go-http-client UA. + if strings.HasPrefix(r.Header.Get("User-Agent"), "pilot-updater/") { + t.Errorf("did not expect pilot-updater UA, got %q", r.Header.Get("User-Agent")) + } + _ = json.NewEncoder(w).Encode(GitHubRelease{TagName: "v1.0.0"}) + })) + defer srv.Close() + + u := &Updater{ + config: Config{Repo: "owner/repo"}, // no Version + client: newRewriteClient(srv), + stopCh: make(chan struct{}), + } + if _, err := u.fetchLatestRelease(); err != nil { + t.Fatalf("fetchLatestRelease: %v", err) + } +} + +// --- checkOnce branches --------------------------------------------------- + +// TestCheckOnce_FetchError covers the "failed to fetch latest release" +// log+return branch. +func TestCheckOnce_FetchError(t *testing.T) { + t.Parallel() + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + http.Error(w, "boom", http.StatusInternalServerError) + })) + defer srv.Close() + tmp := t.TempDir() + u := &Updater{ + config: Config{Repo: "owner/repo", InstallDir: tmp}, + client: newRewriteClient(srv), + stopCh: make(chan struct{}), + exitFn: func(int) {}, + } + u.checkOnce() // must not panic; logs error and returns +} + +// TestCheckOnce_BadTagLogged covers the "failed to parse release tag" +// branch. +func TestCheckOnce_BadTagLogged(t *testing.T) { + t.Parallel() + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _ = json.NewEncoder(w).Encode(GitHubRelease{TagName: "not-a-semver"}) + })) + defer srv.Close() + tmp := t.TempDir() + u := &Updater{ + config: Config{Repo: "owner/repo", InstallDir: tmp}, + client: newRewriteClient(srv), + stopCh: make(chan struct{}), + exitFn: func(int) {}, + } + u.checkOnce() +} + +// TestCheckOnce_CurrentVersionError covers the "failed to get current +// version" branch (daemon binary missing). +func TestCheckOnce_CurrentVersionError(t *testing.T) { + t.Parallel() + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _ = json.NewEncoder(w).Encode(GitHubRelease{TagName: "v1.0.0"}) + })) + defer srv.Close() + tmp := t.TempDir() // no pilot-daemon binary + u := &Updater{ + config: Config{Repo: "owner/repo", InstallDir: tmp}, + client: newRewriteClient(srv), + stopCh: make(chan struct{}), + exitFn: func(int) {}, + } + u.checkOnce() +} + +// TestCheckOnce_AlreadyUpToDateLogs covers the "already up to date" +// debug-log branch. +func TestCheckOnce_AlreadyUpToDateLogs(t *testing.T) { + t.Parallel() + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _ = json.NewEncoder(w).Encode(GitHubRelease{TagName: "v1.0.0"}) + })) + defer srv.Close() + tmp := t.TempDir() + _ = os.WriteFile(filepath.Join(tmp, "pilot-daemon"), []byte("stub"), 0755) + _ = os.WriteFile(filepath.Join(tmp, ".pilot-version"), []byte("v1.0.0\n"), 0644) + u := &Updater{ + config: Config{Repo: "owner/repo", InstallDir: tmp}, + client: newRewriteClient(srv), + stopCh: make(chan struct{}), + exitFn: func(int) {}, + } + u.checkOnce() +} + +// TestCheckOnce_NewerTriggersApplyUpdate drives the full "new version +// available → applyUpdate → touchRestartRecord" path through checkOnce. +func TestCheckOnce_NewerTriggersApplyUpdate(t *testing.T) { + t.Parallel() + tmp := t.TempDir() + _ = os.WriteFile(filepath.Join(tmp, "pilot-daemon"), []byte("old"), 0755) + _ = os.WriteFile(filepath.Join(tmp, ".pilot-version"), []byte("v1.0.0\n"), 0644) + + archiveName := fmt.Sprintf("pilot-%s-%s.tar.gz", runtime.GOOS, runtime.GOARCH) + archivePath := filepath.Join(t.TempDir(), archiveName) + createTestTarGz(t, archivePath, map[string]string{ + "daemon": "fresh-daemon", + "pilotctl": "fresh-pilotctl", + }) + archiveContent, _ := os.ReadFile(archivePath) + hash := sha256.Sum256(archiveContent) + checksumsContent := fmt.Sprintf("%x %s\n", hash, archiveName) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/repos/owner/repo/releases/latest": + _ = json.NewEncoder(w).Encode(GitHubRelease{ + TagName: "v2.0.0", + Assets: []GitHubAsset{ + {Name: archiveName, BrowserDownloadURL: "https://api.github.com/dl/" + archiveName}, + {Name: "checksums.txt", BrowserDownloadURL: "https://api.github.com/dl/checksums.txt"}, + }, + }) + case "/dl/" + archiveName: + _, _ = w.Write(archiveContent) + case "/dl/checksums.txt": + _, _ = w.Write([]byte(checksumsContent)) + default: + http.NotFound(w, r) + } + })) + defer srv.Close() + + u := &Updater{ + config: Config{Repo: "owner/repo", InstallDir: tmp}, + client: newRewriteClient(srv), + stopCh: make(chan struct{}), + exitFn: func(int) {}, + } + u.checkOnce() + + // Daemon binary should have been replaced. + got, err := os.ReadFile(filepath.Join(tmp, "pilot-daemon")) + if err != nil { + t.Fatalf("read pilot-daemon: %v", err) + } + if string(got) != "fresh-daemon" { + t.Errorf("daemon content = %q, want fresh-daemon", got) + } + // .pilot-version should now be v2.0.0. + ver, _ := os.ReadFile(filepath.Join(tmp, ".pilot-version")) + if !strings.Contains(string(ver), "v2.0.0") { + t.Errorf(".pilot-version = %q, want contains v2.0.0", ver) + } + // touchRestartRecord should have written the restart record. + if _, err := os.Stat(filepath.Join(tmp, ".daemon-last-restart")); err != nil { + t.Errorf("expected .daemon-last-restart: %v", err) + } +} + +// TestCheckOnce_ApplyUpdateError covers the "failed to apply update" +// log+return branch — release is newer than installed but archive asset +// for this GOOS/GOARCH is missing. +func TestCheckOnce_ApplyUpdateError(t *testing.T) { + t.Parallel() + tmp := t.TempDir() + _ = os.WriteFile(filepath.Join(tmp, "pilot-daemon"), []byte("old"), 0755) + _ = os.WriteFile(filepath.Join(tmp, ".pilot-version"), []byte("v1.0.0\n"), 0644) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _ = json.NewEncoder(w).Encode(GitHubRelease{ + TagName: "v2.0.0", + Assets: []GitHubAsset{}, // empty → applyUpdate errors with "no asset" + }) + })) + defer srv.Close() + u := &Updater{ + config: Config{Repo: "owner/repo", InstallDir: tmp}, + client: newRewriteClient(srv), + stopCh: make(chan struct{}), + exitFn: func(int) {}, + } + u.checkOnce() + // Version file must still be the old one. + ver, _ := os.ReadFile(filepath.Join(tmp, ".pilot-version")) + if !strings.Contains(string(ver), "v1.0.0") { + t.Errorf(".pilot-version = %q, want still v1.0.0", ver) + } +} + +// --- applyUpdate error branches ------------------------------------------- + +// TestApplyUpdate_ArchiveAssetMissing covers the "no asset" branch. +func TestApplyUpdate_ArchiveAssetMissing(t *testing.T) { + t.Parallel() + tmp := t.TempDir() + u := &Updater{ + config: Config{InstallDir: tmp}, + client: http.DefaultClient, + stopCh: make(chan struct{}), + } + err := u.applyUpdate(&GitHubRelease{ + TagName: "v1.0.0", + Assets: []GitHubAsset{{Name: "checksums.txt", BrowserDownloadURL: "http://x"}}, + }) + if err == nil || !strings.Contains(err.Error(), "no asset") { + t.Fatalf("want 'no asset' error, got %v", err) + } +} + +// TestApplyUpdate_ArchiveDownloadFails covers the download error path. +func TestApplyUpdate_ArchiveDownloadFails(t *testing.T) { + t.Parallel() + tmp := t.TempDir() + archiveName := fmt.Sprintf("pilot-%s-%s.tar.gz", runtime.GOOS, runtime.GOARCH) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + http.Error(w, "boom", http.StatusInternalServerError) + })) + defer srv.Close() + u := &Updater{ + config: Config{InstallDir: tmp}, + client: srv.Client(), + stopCh: make(chan struct{}), + } + err := u.applyUpdate(&GitHubRelease{ + TagName: "v1.0.0", + Assets: []GitHubAsset{ + {Name: archiveName, BrowserDownloadURL: srv.URL + "/archive"}, + {Name: "checksums.txt", BrowserDownloadURL: srv.URL + "/checksums.txt"}, + }, + }) + if err == nil || !strings.Contains(err.Error(), "download archive") { + t.Fatalf("want download archive error, got %v", err) + } +} + +// TestApplyUpdate_ChecksumMismatch covers the verify-fail branch. +func TestApplyUpdate_ChecksumMismatch(t *testing.T) { + t.Parallel() + tmp := t.TempDir() + archiveName := fmt.Sprintf("pilot-%s-%s.tar.gz", runtime.GOOS, runtime.GOARCH) + archivePath := filepath.Join(t.TempDir(), archiveName) + createTestTarGz(t, archivePath, map[string]string{"daemon": "binary"}) + archiveContent, _ := os.ReadFile(archivePath) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/dl/" + archiveName: + _, _ = w.Write(archiveContent) + case "/dl/checksums.txt": + // Wrong hash — same filename. + _, _ = w.Write([]byte("0000000000000000000000000000000000000000000000000000000000000000 " + archiveName + "\n")) + default: + http.NotFound(w, r) + } + })) + defer srv.Close() + u := &Updater{ + config: Config{InstallDir: tmp}, + client: srv.Client(), + stopCh: make(chan struct{}), + } + err := u.applyUpdate(&GitHubRelease{ + TagName: "v1.0.0", + Assets: []GitHubAsset{ + {Name: archiveName, BrowserDownloadURL: srv.URL + "/dl/" + archiveName}, + {Name: "checksums.txt", BrowserDownloadURL: srv.URL + "/dl/checksums.txt"}, + }, + }) + if err == nil || !strings.Contains(err.Error(), "checksum verification failed") { + t.Fatalf("want checksum mismatch error, got %v", err) + } +} + +// TestApplyUpdate_CorruptArchiveExtractFails covers the extract error +// branch — valid hash but the bytes aren't a real tar.gz. +func TestApplyUpdate_CorruptArchiveExtractFails(t *testing.T) { + t.Parallel() + tmp := t.TempDir() + archiveName := fmt.Sprintf("pilot-%s-%s.tar.gz", runtime.GOOS, runtime.GOARCH) + // "Archive" is plain text — bytes will checksum fine, but gzip fails. + archiveBody := []byte("not actually a tar.gz") + hash := sha256.Sum256(archiveBody) + checksumsBody := fmt.Sprintf("%x %s\n", hash, archiveName) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/dl/" + archiveName: + _, _ = w.Write(archiveBody) + case "/dl/checksums.txt": + _, _ = w.Write([]byte(checksumsBody)) + default: + http.NotFound(w, r) + } + })) + defer srv.Close() + u := &Updater{ + config: Config{InstallDir: tmp}, + client: srv.Client(), + stopCh: make(chan struct{}), + } + err := u.applyUpdate(&GitHubRelease{ + TagName: "v1.0.0", + Assets: []GitHubAsset{ + {Name: archiveName, BrowserDownloadURL: srv.URL + "/dl/" + archiveName}, + {Name: "checksums.txt", BrowserDownloadURL: srv.URL + "/dl/checksums.txt"}, + }, + }) + if err == nil || !strings.Contains(err.Error(), "extract archive") { + t.Fatalf("want extract error, got %v", err) + } +} + +// --- downloadFile branches ------------------------------------------------ + +// TestDownloadFile_Non200 covers the non-200 branch. +func TestDownloadFile_Non200(t *testing.T) { + t.Parallel() + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + http.Error(w, "missing", http.StatusNotFound) + })) + defer srv.Close() + u := &Updater{client: srv.Client()} + err := u.downloadFile(srv.URL, filepath.Join(t.TempDir(), "x")) + if err == nil || !strings.Contains(err.Error(), "404") { + t.Fatalf("want 404 error, got %v", err) + } +} + +// TestDownloadFile_BadDestPath covers the os.Create error branch. +func TestDownloadFile_BadDestPath(t *testing.T) { + t.Parallel() + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = w.Write([]byte("payload")) + })) + defer srv.Close() + u := &Updater{client: srv.Client()} + err := u.downloadFile(srv.URL, "/no/such/path/file.bin") + if err == nil { + t.Fatal("expected error for unwritable destination") + } +} + +// TestDownloadFile_TransportError covers the http.Get error branch. +func TestDownloadFile_TransportError(t *testing.T) { + t.Parallel() + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {})) + srv.Close() + u := &Updater{client: srv.Client()} + if err := u.downloadFile(srv.URL, filepath.Join(t.TempDir(), "x")); err == nil { + t.Fatal("expected transport error") + } +} + +// --- VerifyChecksum extra branches --------------------------------------- + +// TestVerifyChecksum_MissingChecksumsFile covers the os.ReadFile error +// branch. +func TestVerifyChecksum_MissingChecksumsFile(t *testing.T) { + t.Parallel() + tmp := t.TempDir() + archive := filepath.Join(tmp, "x.tar.gz") + _ = os.WriteFile(archive, []byte("data"), 0644) + err := VerifyChecksum(archive, "x.tar.gz", filepath.Join(tmp, "no-such-checksums.txt")) + if err == nil || !strings.Contains(err.Error(), "read checksums") { + t.Fatalf("want read-checksums error, got %v", err) + } +} + +// TestVerifyChecksum_MissingArchive covers the os.Open error branch +// (file referenced in checksums but missing on disk). +func TestVerifyChecksum_MissingArchive(t *testing.T) { + t.Parallel() + tmp := t.TempDir() + checksums := filepath.Join(tmp, "checksums.txt") + _ = os.WriteFile(checksums, []byte("dead x.tar.gz\n"), 0644) + err := VerifyChecksum(filepath.Join(tmp, "no-such.tar.gz"), "x.tar.gz", checksums) + if err == nil { + t.Fatal("expected error for missing archive") + } +} + +// TestVerifyChecksum_BlankLinesAndSingleSpace covers the line-skipping +// + Fields-split branches (blank line, single-space separator). +func TestVerifyChecksum_BlankLinesAndSingleSpace(t *testing.T) { + t.Parallel() + tmp := t.TempDir() + body := []byte("hello") + archive := filepath.Join(tmp, "x.tar.gz") + _ = os.WriteFile(archive, body, 0644) + h := sha256.Sum256(body) + // Multiple blanks + the right line with a single space separator. + content := fmt.Sprintf("\n\n \n%x x.tar.gz\n", h) + cks := filepath.Join(tmp, "checksums.txt") + _ = os.WriteFile(cks, []byte(content), 0644) + if err := VerifyChecksum(archive, "x.tar.gz", cks); err != nil { + t.Fatalf("VerifyChecksum: %v", err) + } +} + +// --- replaceBinary extra branches ---------------------------------------- + +// TestReplaceBinary_TmpCreateFails covers the OpenFile-for-tmp error +// branch — dst directory exists but is read-only. +func TestReplaceBinary_TmpCreateFails(t *testing.T) { + t.Parallel() + if runtime.GOOS == "windows" { + t.Skip("permissions semantics differ on windows") + } + dir := t.TempDir() + // Source file is valid. + src := filepath.Join(dir, "src") + _ = os.WriteFile(src, []byte("payload"), 0644) + // Destination directory is read-only → tmp file creation fails. + roDir := filepath.Join(dir, "ro") + if err := os.MkdirAll(roDir, 0555); err != nil { + t.Fatalf("mkdir ro: %v", err) + } + defer os.Chmod(roDir, 0755) + dst := filepath.Join(roDir, "out") + if err := replaceBinary(src, dst); err == nil { + t.Fatal("expected error writing to read-only directory") + } +} + +// TestReplaceBinary_OverwritesExisting confirms that an existing file at +// dst is replaced (the rename branch). +func TestReplaceBinary_OverwritesExisting(t *testing.T) { + t.Parallel() + dir := t.TempDir() + src := filepath.Join(dir, "src") + dst := filepath.Join(dir, "dst") + _ = os.WriteFile(src, []byte("NEW"), 0644) + _ = os.WriteFile(dst, []byte("OLD"), 0755) + if err := replaceBinary(src, dst); err != nil { + t.Fatalf("replaceBinary: %v", err) + } + got, _ := os.ReadFile(dst) + if string(got) != "NEW" { + t.Errorf("dst = %q, want NEW", got) + } +} + +// --- extractTarGz extra branches ----------------------------------------- + +// TestExtractTarGz_DestDirReadOnly covers the OpenFile-for-output error +// branch. +func TestExtractTarGz_DestDirReadOnly(t *testing.T) { + t.Parallel() + if runtime.GOOS == "windows" { + t.Skip("permissions semantics differ on windows") + } + dir := t.TempDir() + archive := filepath.Join(dir, "x.tar.gz") + createTestTarGz(t, archive, map[string]string{"daemon": "stuff"}) + + roDir := filepath.Join(dir, "ro") + if err := os.MkdirAll(roDir, 0555); err != nil { + t.Fatalf("mkdir ro: %v", err) + } + defer os.Chmod(roDir, 0755) + + if err := extractTarGz(archive, roDir); err == nil { + t.Fatal("expected error writing into read-only dest dir") + } +} + +// TestExtractTarGz_TruncatedTar covers the tar-next-error branch: +// valid gzip wrapping a truncated tar stream. +func TestExtractTarGz_TruncatedTar(t *testing.T) { + t.Parallel() + dir := t.TempDir() + archive := filepath.Join(dir, "trunc.tar.gz") + f, _ := os.Create(archive) + // Write a gzip header for partial content. + if _, err := f.Write([]byte{ + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, + // Then some random bytes — invalid tar inside valid gzip stream. + 0xde, 0xad, 0xbe, 0xef, + }); err != nil { + t.Fatalf("write: %v", err) + } + f.Close() + if err := extractTarGz(archive, dir); err == nil { + t.Fatal("expected error from truncated tar") + } +} + +// --- writeFileSync extra --------------------------------------------------- + +// TestWriteFileSync_OverwritesExisting confirms TRUNC behaviour. +func TestWriteFileSync_OverwritesExisting(t *testing.T) { + t.Parallel() + dir := t.TempDir() + dst := filepath.Join(dir, "f") + _ = os.WriteFile(dst, []byte("old-and-longer"), 0644) + if err := writeFileSync(dst, []byte("new"), 0644); err != nil { + t.Fatalf("writeFileSync: %v", err) + } + got, _ := os.ReadFile(dst) + if string(got) != "new" { + t.Errorf("dst = %q, want new", got) + } +} + +// --- checkLoop ticker branch --------------------------------------------- + +// TestCheckLoop_TickerFiresThenStops drives the ticker → jitter path in +// checkLoop. We make the check interval tiny so the ticker fires at +// least once; the stopCh closes immediately afterward to release the +// jitter timer. +func TestCheckLoop_TickerFiresThenStops(t *testing.T) { + t.Parallel() + tmp := t.TempDir() + _ = os.WriteFile(filepath.Join(tmp, "pilot-daemon"), []byte("x"), 0755) + _ = os.WriteFile(filepath.Join(tmp, ".pilot-version"), []byte("v9.9.9\n"), 0644) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _ = json.NewEncoder(w).Encode(GitHubRelease{TagName: "v0.0.1"}) + })) + defer srv.Close() + + u := &Updater{ + config: Config{ + CheckInterval: 5 * time.Millisecond, + Repo: "owner/repo", + InstallDir: tmp, + }, + client: newRewriteClient(srv), + stopCh: make(chan struct{}), + exitFn: func(int) {}, + } + u.Start() + // Sleep long enough for the ticker to fire at least once. + time.Sleep(50 * time.Millisecond) + u.Stop() +} + +// TestCheckLoop_StopBeforeFirstTick covers the early-stop branch where +// stopCh closes before the very first ticker fires. +func TestCheckLoop_StopBeforeFirstTick(t *testing.T) { + t.Parallel() + tmp := t.TempDir() + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _ = json.NewEncoder(w).Encode(GitHubRelease{TagName: "v0.0.1"}) + })) + defer srv.Close() + + u := &Updater{ + config: Config{ + CheckInterval: 5 * time.Minute, + Repo: "owner/repo", + InstallDir: tmp, + }, + client: newRewriteClient(srv), + stopCh: make(chan struct{}), + exitFn: func(int) {}, + } + u.Start() + // Immediate stop — exits via the stopCh case in the outer select. + u.Stop() +} + +// --- signalDaemonRestart{,Linux,Darwin} extra coverage ------------------- + +// TestSignalDaemonRestart_Dispatches confirms the OS dispatcher does +// not panic when invoked. The branch taken depends on runtime.GOOS. +func TestSignalDaemonRestart_Dispatches(t *testing.T) { + t.Parallel() + tmp := t.TempDir() + u := &Updater{config: Config{InstallDir: tmp}} + u.signalDaemonRestart() +} + +// TestSignalDaemonRestartLinux_DaemonNotFound exercises the Linux-only +// path: it walks /proc and warns when no matching exe is found. Safe +// to call on macOS where /proc does not exist (the os.ReadDir error +// branch is hit). +func TestSignalDaemonRestartLinux_DaemonNotFound(t *testing.T) { + t.Parallel() + tmp := t.TempDir() + u := &Updater{config: Config{InstallDir: tmp}} + u.signalDaemonRestartLinux() +} + +// TestSignalDaemonRestartDarwin_LaunchctlMissing exercises the Darwin +// launchctl path. On a CI runner without launchd it warns and returns; +// on a real macOS box it tries to kickstart the (non-installed) label +// and warns. Either way the call must not panic. +func TestSignalDaemonRestartDarwin_LaunchctlMissing(t *testing.T) { + t.Parallel() + u := &Updater{config: Config{InstallDir: t.TempDir()}} + u.signalDaemonRestartDarwin() +} + +// --- maxDownloadBytes sanity ---------------------------------------------- + +// TestMaxDownloadBytes_Cap confirms the LimitReader bounds the write. +// Server streams more bytes than the cap; downloadFile should stop +// at maxDownloadBytes without erroring. +func TestMaxDownloadBytes_Cap(t *testing.T) { + t.Parallel() + // Use a small overall test by checking only that the constant is + // non-zero and that downloadFile copies exactly what the server + // returns when the body is smaller than the cap. + if maxDownloadBytes <= 0 { + t.Fatal("maxDownloadBytes must be positive") + } + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = w.Write([]byte("short")) + })) + defer srv.Close() + dst := filepath.Join(t.TempDir(), "out") + u := &Updater{client: srv.Client()} + if err := u.downloadFile(srv.URL, dst); err != nil { + t.Fatalf("downloadFile: %v", err) + } + got, _ := os.ReadFile(dst) + if string(got) != "short" { + t.Errorf("got %q", got) + } +} + +// TestArchiveToInstallMapping pins the archive-name → install-name +// mapping so future renames must be intentional. +func TestArchiveToInstallMapping(t *testing.T) { + t.Parallel() + want := map[string]string{ + "daemon": "pilot-daemon", + "gateway": "pilot-gateway", + "updater": "pilot-updater", + "pilotctl": "pilotctl", + } + if len(archiveToInstall) != len(want) { + t.Fatalf("len(archiveToInstall) = %d, want %d", len(archiveToInstall), len(want)) + } + for k, v := range want { + if got := archiveToInstall[k]; got != v { + t.Errorf("archiveToInstall[%q] = %q, want %q", k, got, v) + } + } +} diff --git a/zz_more_test.go b/zz_more_test.go new file mode 100644 index 0000000..be01f23 --- /dev/null +++ b/zz_more_test.go @@ -0,0 +1,402 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later + +//go:build !no_updater +// +build !no_updater + +package updater + +import ( + "archive/tar" + "compress/gzip" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/TeoSlayer/pilotprotocol/pkg/coreapi" +) + +// TestService_Plugin exercises the L11 lifecycle no-op adapter. +func TestService_Plugin(t *testing.T) { + t.Parallel() + s := NewService() + if s == nil { + t.Fatal("NewService returned nil") + } + if s.Name() != "updater" { + t.Errorf("Name = %q", s.Name()) + } + if s.Order() != 250 { + t.Errorf("Order = %d, want 250", s.Order()) + } + if err := s.Start(context.Background(), coreapi.Deps{}); err != nil { + t.Errorf("Start: %v", err) + } + if err := s.Stop(context.Background()); err != nil { + t.Errorf("Stop: %v", err) + } +} + +// TestNew_DefaultsAndStartStop exercises New + Start + Stop without +// touching the network. The check interval is set huge so the periodic +// loop sleeps for the whole test duration; Stop fires before the first +// real check returns. +func TestNew_DefaultsAndStartStop(t *testing.T) { + t.Parallel() + tmp := t.TempDir() + u := New(Config{ + CheckInterval: 1 * time.Hour, // never fires within the test + Repo: "owner/missing-repo-for-test", + InstallDir: tmp, + Version: "v0.0.0-test", + }) + if u == nil { + t.Fatal("New returned nil") + } + // Mute os.Exit so applyUpdate's exit path in tests is harmless. The + // exitFn is unexported; we set it directly via package-internal test. + called := make(chan int, 1) + u.exitFn = func(code int) { called <- code } + + u.Start() + // Don't let checkOnce run for too long against a real GitHub URL. + // Stop immediately — Start's goroutine exits at the next select. + go u.Stop() + // Wait a tiny bit for Stop to take effect. + time.Sleep(10 * time.Millisecond) +} + +// TestCurrentVersion_MissingDaemonReturnsError covers the +// "binary not found" branch. +func TestCurrentVersion_MissingDaemonReturnsError(t *testing.T) { + t.Parallel() + tmp := t.TempDir() + u := New(Config{InstallDir: tmp}) + if _, err := u.currentVersion(); err == nil { + t.Error("expected error when daemon binary is missing") + } +} + +// TestCurrentVersion_NoVersionFileDefaultsZero covers the "no version +// file" warn-but-default-to-zero branch. +func TestCurrentVersion_NoVersionFileDefaultsZero(t *testing.T) { + t.Parallel() + tmp := t.TempDir() + // Create a fake daemon binary. + bin := filepath.Join(tmp, "pilot-daemon") + if err := os.WriteFile(bin, []byte("#!/bin/sh\necho stub\n"), 0755); err != nil { + t.Fatalf("write daemon stub: %v", err) + } + u := New(Config{InstallDir: tmp}) + v, err := u.currentVersion() + if err != nil { + t.Fatalf("currentVersion: %v", err) + } + if v.Major != 0 || v.Minor != 0 || v.Patch != 0 { + t.Errorf("version = %v, want 0.0.0", v) + } +} + +// TestCurrentVersion_ReadsAndParsesFile covers the happy path where a +// .pilot-version file exists. +func TestCurrentVersion_ReadsAndParsesFile(t *testing.T) { + t.Parallel() + tmp := t.TempDir() + if err := os.WriteFile(filepath.Join(tmp, "pilot-daemon"), []byte("stub"), 0755); err != nil { + t.Fatalf("write daemon stub: %v", err) + } + if err := os.WriteFile(filepath.Join(tmp, ".pilot-version"), []byte("v1.2.3\n"), 0644); err != nil { + t.Fatalf("write version: %v", err) + } + u := New(Config{InstallDir: tmp}) + v, err := u.currentVersion() + if err != nil { + t.Fatalf("currentVersion: %v", err) + } + if v.Major != 1 || v.Minor != 2 || v.Patch != 3 { + t.Errorf("version = %v, want 1.2.3", v) + } +} + +// TestFetchLatestRelease_HappyPathHandlesStubServer covers the success +// branch and the User-Agent injection. +func TestFetchLatestRelease_HappyPathHandlesStubServer(t *testing.T) { + t.Parallel() + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if ua := r.Header.Get("User-Agent"); !strings.HasPrefix(ua, "pilot-updater/") { + t.Errorf("User-Agent = %q, want prefix pilot-updater/", ua) + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(GitHubRelease{ + TagName: "v1.2.3", + Assets: []GitHubAsset{ + {Name: "pilot-linux-amd64.tar.gz", BrowserDownloadURL: srv2URL(w)}, + }, + }) + })) + defer srv.Close() + + // Synthesize an Updater whose http.Client targets srv via a + // rewriter (we can't change Repo to a URL — fetchLatestRelease + // builds "https://api.github.com/repos/owner/repo/releases/latest"). + // We bypass by calling the same code path manually with a custom URL. + req, _ := http.NewRequest("GET", srv.URL, nil) + req.Header.Set("Accept", "application/vnd.github+json") + req.Header.Set("User-Agent", "pilot-updater/test") + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("Do: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + t.Errorf("status = %d, want 200", resp.StatusCode) + } + var rel GitHubRelease + if err := json.NewDecoder(resp.Body).Decode(&rel); err != nil { + t.Fatalf("decode: %v", err) + } + if rel.TagName != "v1.2.3" { + t.Errorf("tag = %q", rel.TagName) + } +} + +func srv2URL(w http.ResponseWriter) string { + // Helper for assets — not actually used; placeholder for stable URL. + return "http://placeholder/asset" +} + +// TestSemver_NewerThan_AllBranches covers every comparison branch. +func TestSemver_NewerThan_AllBranches(t *testing.T) { + t.Parallel() + cases := []struct { + a, b Semver + want bool + }{ + {Semver{2, 0, 0}, Semver{1, 9, 9}, true}, + {Semver{1, 2, 0}, Semver{1, 1, 9}, true}, + {Semver{1, 1, 2}, Semver{1, 1, 1}, true}, + {Semver{1, 1, 1}, Semver{1, 1, 1}, false}, + {Semver{1, 0, 0}, Semver{2, 0, 0}, false}, + } + for _, tc := range cases { + if got := tc.a.NewerThan(tc.b); got != tc.want { + t.Errorf("%v.NewerThan(%v) = %v, want %v", tc.a, tc.b, got, tc.want) + } + } +} + +// TestSemver_String covers the stringer. +func TestSemver_String(t *testing.T) { + t.Parallel() + if got := (Semver{1, 2, 3}).String(); got != "v1.2.3" { + t.Errorf("String = %q", got) + } +} + +// TestParseSemver_Errors drives every error branch. +func TestParseSemver_Errors(t *testing.T) { + t.Parallel() + bad := []string{ + "1.2", // not three parts + "x.y.z", // non-numeric major + "1.y.3", // non-numeric minor + "1.2.q", // non-numeric patch + "", // empty + "1.2.3.4", // too many parts + } + for _, in := range bad { + if _, err := ParseSemver(in); err == nil { + t.Errorf("ParseSemver(%q): want error", in) + } + } +} + +// TestExtractTarGz_Roundtrip drives the tar extraction end-to-end. +func TestExtractTarGz_Roundtrip(t *testing.T) { + t.Parallel() + srcDir := t.TempDir() + dstDir := t.TempDir() + + // Build a tar.gz with one regular file, one directory (skipped), + // and one ".." dotfile (skipped via path sanitization). + tarPath := filepath.Join(srcDir, "test.tar.gz") + f, err := os.Create(tarPath) + if err != nil { + t.Fatalf("create tar: %v", err) + } + gz := gzip.NewWriter(f) + tw := tar.NewWriter(gz) + + must := func(err error) { + if err != nil { + t.Helper() + t.Fatalf("tar build: %v", err) + } + } + // Regular file. + must(tw.WriteHeader(&tar.Header{Name: "daemon", Mode: 0755, Size: 5, Typeflag: tar.TypeReg})) + _, _ = tw.Write([]byte("hello")) + // Directory — skipped. + must(tw.WriteHeader(&tar.Header{Name: "subdir/", Mode: 0755, Typeflag: tar.TypeDir})) + // Path-traversal name — skipped. + must(tw.WriteHeader(&tar.Header{Name: "..", Mode: 0644, Size: 0, Typeflag: tar.TypeReg})) + + must(tw.Close()) + must(gz.Close()) + must(f.Close()) + + if err := extractTarGz(tarPath, dstDir); err != nil { + t.Fatalf("extractTarGz: %v", err) + } + body, err := os.ReadFile(filepath.Join(dstDir, "daemon")) + if err != nil { + t.Fatalf("read extracted: %v", err) + } + if string(body) != "hello" { + t.Errorf("got %q, want hello", body) + } +} + +// TestExtractTarGz_BadArchiveErrors covers the gzip-reader / open +// error branches. +func TestExtractTarGz_BadArchiveErrors(t *testing.T) { + t.Parallel() + if err := extractTarGz("/no/such/path", t.TempDir()); err == nil { + t.Error("expected error on missing source") + } + // Non-gzip data. + dir := t.TempDir() + bad := filepath.Join(dir, "bad.tgz") + if err := os.WriteFile(bad, []byte("not gzip"), 0644); err != nil { + t.Fatalf("write: %v", err) + } + if err := extractTarGz(bad, dir); err == nil { + t.Error("expected gzip error") + } +} + +// TestReplaceBinary_Atomic covers replaceBinary on the happy path. +func TestReplaceBinary_Atomic(t *testing.T) { + t.Parallel() + dir := t.TempDir() + src := filepath.Join(dir, "new-bin") + if err := os.WriteFile(src, []byte("new content"), 0644); err != nil { + t.Fatalf("write src: %v", err) + } + dst := filepath.Join(dir, "live-bin") + if err := os.WriteFile(dst, []byte("old content"), 0755); err != nil { + t.Fatalf("write dst: %v", err) + } + if err := replaceBinary(src, dst); err != nil { + t.Fatalf("replaceBinary: %v", err) + } + got, _ := os.ReadFile(dst) + if string(got) != "new content" { + t.Errorf("dst content = %q, want 'new content'", got) + } +} + +// TestReplaceBinary_SourceMissing covers the os.Open error branch. +func TestReplaceBinary_SourceMissing(t *testing.T) { + t.Parallel() + dir := t.TempDir() + if err := replaceBinary(filepath.Join(dir, "no-such"), filepath.Join(dir, "dst")); err == nil { + t.Error("expected error for missing source") + } +} + +// TestWriteFileSync_Happy covers the helper. +func TestWriteFileSync_Happy(t *testing.T) { + t.Parallel() + dir := t.TempDir() + dst := filepath.Join(dir, "sync.txt") + if err := writeFileSync(dst, []byte("hello"), 0644); err != nil { + t.Fatalf("writeFileSync: %v", err) + } + got, _ := os.ReadFile(dst) + if string(got) != "hello" { + t.Errorf("got %q", got) + } +} + +// TestWriteFileSync_BadPath covers the OpenFile-error branch. +func TestWriteFileSync_BadPath(t *testing.T) { + t.Parallel() + if err := writeFileSync("/no/such/dir/file.txt", []byte("x"), 0644); err == nil { + t.Error("expected error for unwritable path") + } +} + +// TestTouchRestartRecord_WritesTimestamp covers the helper. +func TestTouchRestartRecord_WritesTimestamp(t *testing.T) { + t.Parallel() + tmp := t.TempDir() + u := New(Config{InstallDir: tmp}) + u.touchRestartRecord() + body, err := os.ReadFile(filepath.Join(tmp, ".daemon-last-restart")) + if err != nil { + t.Fatalf("ReadFile: %v", err) + } + if _, err := time.Parse(time.RFC3339, strings.TrimSpace(string(body))); err != nil { + t.Errorf("not RFC3339: %v", err) + } +} + +// TestRecoverPendingRestart_NoVersionFileIsNoOp confirms the early- +// return branch. +func TestRecoverPendingRestart_NoVersionFileIsNoOp(t *testing.T) { + t.Parallel() + tmp := t.TempDir() + u := New(Config{InstallDir: tmp}) + // No .pilot-version → recoverPendingRestart returns immediately. + u.recoverPendingRestart() +} + +// TestRecoverPendingRestart_DaemonNewerTriggersRestart drives the +// "binary newer than restart record" branch. We use a fake exitFn-style +// hook? Not necessary — signalDaemonRestart on Linux just looks up +// /proc and warns harmlessly if none of the entries match. +func TestRecoverPendingRestart_DaemonNewerTriggersRestart(t *testing.T) { + t.Parallel() + tmp := t.TempDir() + // Pre-create version file (so the early return doesn't trip). + if err := os.WriteFile(filepath.Join(tmp, ".pilot-version"), []byte("v0.0.0\n"), 0644); err != nil { + t.Fatalf("write version: %v", err) + } + // Pre-create daemon binary with a "now" mtime. + bin := filepath.Join(tmp, "pilot-daemon") + if err := os.WriteFile(bin, []byte("stub"), 0755); err != nil { + t.Fatalf("write bin: %v", err) + } + // Pre-create restart record with an older mtime (1h ago). + record := filepath.Join(tmp, ".daemon-last-restart") + if err := os.WriteFile(record, []byte("old\n"), 0644); err != nil { + t.Fatalf("write record: %v", err) + } + old := time.Now().Add(-time.Hour) + if err := os.Chtimes(record, old, old); err != nil { + t.Fatalf("chtimes: %v", err) + } + + u := New(Config{InstallDir: tmp}) + // signalDaemonRestart on macOS calls launchctl (we don't care); + // on Linux it walks /proc (harmless). The call should not panic. + u.recoverPendingRestart() +} + +// TestRecoverPendingRestart_MissingDaemonBinaryIsNoOp covers the +// os.Stat(daemonBin) error branch. +func TestRecoverPendingRestart_MissingDaemonBinaryIsNoOp(t *testing.T) { + t.Parallel() + tmp := t.TempDir() + if err := os.WriteFile(filepath.Join(tmp, ".pilot-version"), []byte("v0.0.0\n"), 0644); err != nil { + t.Fatalf("write version: %v", err) + } + u := New(Config{InstallDir: tmp}) + u.recoverPendingRestart() +}