From d79c0ed0add80eb27861a806ac1e2b5db3d4457e Mon Sep 17 00:00:00 2001 From: Damiano <71268257+tis24dev@users.noreply.github.com> Date: Fri, 27 Feb 2026 15:30:18 +0100 Subject: [PATCH 01/18] test(orchestrator): maximize coverage for network_apply - Add network_apply_additional_test.go with extensive tests for rollback (arm/disarm), NIC repair CLI (overrides/conflicts), snapshot IP parsing, command selection, and error paths. - Use FakeFS/FakeCommandRunner plus PATH stubs to deterministically exercise both success and failure branches. - Bring network_apply.go coverage to ~99% with no production code changes. --- .../network_apply_additional_test.go | 1791 +++++++++++++++++ 1 file changed, 1791 insertions(+) create mode 100644 internal/orchestrator/network_apply_additional_test.go diff --git a/internal/orchestrator/network_apply_additional_test.go b/internal/orchestrator/network_apply_additional_test.go new file mode 100644 index 0000000..b3e2780 --- /dev/null +++ b/internal/orchestrator/network_apply_additional_test.go @@ -0,0 +1,1791 @@ +package orchestrator + +import ( + "archive/tar" + "bufio" + "context" + "errors" + "fmt" + "io" + "io/fs" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/tis24dev/proxsave/internal/logging" + "github.com/tis24dev/proxsave/internal/types" +) + +type writeFailFS struct { + FS + failPath string + err error +} + +func (f writeFailFS) WriteFile(path string, data []byte, perm fs.FileMode) error { + if filepath.Clean(path) == filepath.Clean(f.failPath) { + return f.err + } + return f.FS.WriteFile(path, data, perm) +} + +func newDiscardLogger() *logging.Logger { + logger := logging.New(types.LogLevelDebug, false) + logger.SetOutput(io.Discard) + return logger +} + +func TestNetworkApplyNotCommittedError_ErrorAndUnwrap(t *testing.T) { + var err *NetworkApplyNotCommittedError + if got := err.Error(); got != ErrNetworkApplyNotCommitted.Error() { + t.Fatalf("Error()=%q want %q", got, ErrNetworkApplyNotCommitted.Error()) + } + if got := err.Unwrap(); !errors.Is(got, ErrNetworkApplyNotCommitted) { + t.Fatalf("Unwrap()=%v want %v", got, ErrNetworkApplyNotCommitted) + } + + err = &NetworkApplyNotCommittedError{RollbackLog: "/tmp/log"} + if got := err.Error(); got != ErrNetworkApplyNotCommitted.Error() { + t.Fatalf("Error()=%q want %q", got, ErrNetworkApplyNotCommitted.Error()) + } + if got := err.Unwrap(); got != ErrNetworkApplyNotCommitted { + t.Fatalf("Unwrap()=%v want %v", got, ErrNetworkApplyNotCommitted) + } +} + +func TestNetworkRollbackHandleRemaining(t *testing.T) { + var h *networkRollbackHandle + if got := h.remaining(time.Now()); got != 0 { + t.Fatalf("nil remaining=%s want 0", got) + } + + h = &networkRollbackHandle{ + armedAt: time.Date(2026, 2, 1, 1, 2, 3, 0, time.UTC), + timeout: 10 * time.Second, + } + if got := h.remaining(h.armedAt); got != 10*time.Second { + t.Fatalf("remaining=%s want %s", got, 10*time.Second) + } + if got := h.remaining(h.armedAt.Add(4 * time.Second)); got != 6*time.Second { + t.Fatalf("remaining=%s want %s", got, 6*time.Second) + } + if got := h.remaining(h.armedAt.Add(20 * time.Second)); got != 0 { + t.Fatalf("remaining=%s want 0", got) + } +} + +func TestShouldAttemptNetworkApply(t *testing.T) { + if shouldAttemptNetworkApply(nil) { + t.Fatalf("expected false for nil plan") + } + if shouldAttemptNetworkApply(&RestorePlan{NormalCategories: []Category{{ID: "storage_pve"}}}) { + t.Fatalf("expected false when network category not present") + } + if !shouldAttemptNetworkApply(&RestorePlan{NormalCategories: []Category{{ID: "network"}}}) { + t.Fatalf("expected true when network category present") + } +} + +func TestExtractIPFromSnapshot_EmptyArgsReturnUnknown(t *testing.T) { + if got := extractIPFromSnapshot("", "vmbr0"); got != "unknown" { + t.Fatalf("got %q want unknown", got) + } + if got := extractIPFromSnapshot("/snap.txt", ""); got != "unknown" { + t.Fatalf("got %q want unknown", got) + } +} + +func TestExtractIPFromSnapshot_IgnoresLinesOutsideAddrSection(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + snapshot := strings.Join([]string{ + "$ ip -br addr", + "lo UNKNOWN 127.0.0.1/8", + "$ ip route show", + "vmbr0 UP 192.0.2.10/24", + "", + }, "\n") + if err := fakeFS.WriteFile("/snap.txt", []byte(snapshot), 0o600); err != nil { + t.Fatalf("write snapshot: %v", err) + } + if got := extractIPFromSnapshot("/snap.txt", "vmbr0"); got != "unknown" { + t.Fatalf("got %q want unknown", got) + } +} + +func TestExtractIPFromSnapshot_SkipsInvalidTokensButReturnsFirstValid(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + snapshot := strings.Join([]string{ + "$ ip -br addr", + "vmbr0 UP not-an-ip 2a01:db8::2/64", + "", + }, "\n") + if err := fakeFS.WriteFile("/snap.txt", []byte(snapshot), 0o600); err != nil { + t.Fatalf("write snapshot: %v", err) + } + if got := extractIPFromSnapshot("/snap.txt", "vmbr0"); got != "2a01:db8::2/64" { + t.Fatalf("got %q want %q", got, "2a01:db8::2/64") + } +} + +func TestExtractIPFromSnapshot_SkipsErrorLinesAndParsesNext(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + snapshot := strings.Join([]string{ + "$ ip -br addr", + "ERROR: ip failed", + "vmbr0 UP 192.0.2.55/24", + "", + }, "\n") + if err := fakeFS.WriteFile("/snap.txt", []byte(snapshot), 0o600); err != nil { + t.Fatalf("write snapshot: %v", err) + } + if got := extractIPFromSnapshot("/snap.txt", "vmbr0"); got != "192.0.2.55/24" { + t.Fatalf("got %q want %q", got, "192.0.2.55/24") + } +} + +func TestExtractIPFromSnapshot_ReturnsUnknownWhenNoValidAddressTokens(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + snapshot := strings.Join([]string{ + "$ ip -br addr", + "vmbr0 UP not-an-ip also-bad", + "", + }, "\n") + if err := fakeFS.WriteFile("/snap.txt", []byte(snapshot), 0o600); err != nil { + t.Fatalf("write snapshot: %v", err) + } + if got := extractIPFromSnapshot("/snap.txt", "vmbr0"); got != "unknown" { + t.Fatalf("got %q want unknown", got) + } +} + +func TestExtractIPFromSnapshot_IgnoresCommandsBeforeAddrSection(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + snapshot := strings.Join([]string{ + "$ ip route show", + "default via 192.0.2.1 dev vmbr0", + "$ ip -br addr", + "vmbr0 UP 192.0.2.99/24", + "", + }, "\n") + if err := fakeFS.WriteFile("/snap.txt", []byte(snapshot), 0o600); err != nil { + t.Fatalf("write snapshot: %v", err) + } + if got := extractIPFromSnapshot("/snap.txt", "vmbr0"); got != "192.0.2.99/24" { + t.Fatalf("got %q want %q", got, "192.0.2.99/24") + } +} + +func TestBuildNetworkApplyNotCommittedError_HandleNilIfaceEmpty(t *testing.T) { + origFS := restoreFS + origCmd := restoreCmd + t.Cleanup(func() { + restoreFS = origFS + restoreCmd = origCmd + }) + + restoreFS = NewFakeFS() + restoreCmd = &FakeCommandRunner{} + + logger := newDiscardLogger() + + got := buildNetworkApplyNotCommittedError(context.Background(), logger, "", nil) + if got.RollbackArmed { + t.Fatalf("RollbackArmed=true want false") + } + if got.RollbackLog != "" || got.RollbackMarker != "" { + t.Fatalf("expected empty rollback paths, got log=%q marker=%q", got.RollbackLog, got.RollbackMarker) + } + if got.RestoredIP != "unknown" || got.OriginalIP != "unknown" { + t.Fatalf("restored=%q original=%q want unknown/unknown", got.RestoredIP, got.OriginalIP) + } + if !got.RollbackDeadline.IsZero() { + t.Fatalf("RollbackDeadline=%s want zero", got.RollbackDeadline) + } +} + +func TestBuildNetworkApplyNotCommittedError_ArmedWithMarkerAndSnapshots(t *testing.T) { + origFS := restoreFS + origCmd := restoreCmd + t.Cleanup(func() { + restoreFS = origFS + restoreCmd = origCmd + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + fakeCmd := &FakeCommandRunner{ + Outputs: map[string][]byte{ + "ip -o addr show dev vmbr0 scope global": []byte(strings.Join([]string{ + "2: vmbr0 inet 192.0.2.10/24 brd 192.0.2.255 scope global vmbr0", + "2: vmbr0 inet6 2001:db8::1/64 scope global", + }, "\n")), + "ip route show default": []byte("default via 192.0.2.1 dev vmbr0\n"), + }, + } + restoreCmd = fakeCmd + + logger := newDiscardLogger() + + handle := &networkRollbackHandle{ + workDir: "/work", + markerPath: "/work/marker", + logPath: "/work/log", + armedAt: time.Date(2026, 2, 1, 12, 0, 0, 0, time.UTC), + timeout: 3 * time.Minute, + } + + if err := fakeFS.WriteFile(handle.markerPath, []byte("pending\n"), 0o600); err != nil { + t.Fatalf("write marker: %v", err) + } + beforeSnapshot := strings.Join([]string{ + "$ ip -br addr", + "vmbr0 UP 10.0.0.2/24", + "", + }, "\n") + if err := fakeFS.WriteFile("/work/before.txt", []byte(beforeSnapshot), 0o600); err != nil { + t.Fatalf("write before snapshot: %v", err) + } + + got := buildNetworkApplyNotCommittedError(context.Background(), logger, "vmbr0", handle) + if !got.RollbackArmed { + t.Fatalf("RollbackArmed=false want true") + } + if got.RollbackLog != "/work/log" || got.RollbackMarker != "/work/marker" { + t.Fatalf("paths log=%q marker=%q want /work/log /work/marker", got.RollbackLog, got.RollbackMarker) + } + if got.RestoredIP != "192.0.2.10/24, 2001:db8::1/64" { + t.Fatalf("RestoredIP=%q", got.RestoredIP) + } + if got.OriginalIP != "10.0.0.2/24" { + t.Fatalf("OriginalIP=%q want %q", got.OriginalIP, "10.0.0.2/24") + } + wantDeadline := handle.armedAt.Add(handle.timeout) + if !got.RollbackDeadline.Equal(wantDeadline) { + t.Fatalf("RollbackDeadline=%s want %s", got.RollbackDeadline, wantDeadline) + } +} + +func TestBuildNetworkApplyNotCommittedError_MarkerMissingAndIPQueryFails(t *testing.T) { + origFS := restoreFS + origCmd := restoreCmd + t.Cleanup(func() { + restoreFS = origFS + restoreCmd = origCmd + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + restoreCmd = &FakeCommandRunner{ + Errors: map[string]error{ + "ip -o addr show dev vmbr0 scope global": errors.New("boom"), + }, + } + + logger := newDiscardLogger() + handle := &networkRollbackHandle{ + workDir: "/work", + markerPath: "/work/missing", + armedAt: time.Date(2026, 2, 1, 12, 0, 0, 0, time.UTC), + timeout: 1 * time.Minute, + } + + got := buildNetworkApplyNotCommittedError(context.Background(), logger, "vmbr0", handle) + if got.RollbackArmed { + t.Fatalf("RollbackArmed=true want false when marker missing") + } + if got.RestoredIP != "unknown" { + t.Fatalf("RestoredIP=%q want unknown", got.RestoredIP) + } + if got.OriginalIP != "unknown" { + t.Fatalf("OriginalIP=%q want unknown", got.OriginalIP) + } +} + +func TestRollbackAlreadyRunning(t *testing.T) { + origCmd := restoreCmd + t.Cleanup(func() { restoreCmd = origCmd }) + + logger := newDiscardLogger() + pathDir := t.TempDir() + + t.Run("skip when handle nil", func(t *testing.T) { + t.Setenv("PATH", pathDir) + restoreCmd = &FakeCommandRunner{} + if rollbackAlreadyRunning(context.Background(), logger, nil) { + t.Fatalf("expected false for nil handle") + } + }) + + t.Run("skip when unit name empty", func(t *testing.T) { + t.Setenv("PATH", pathDir) + restoreCmd = &FakeCommandRunner{} + if rollbackAlreadyRunning(context.Background(), logger, &networkRollbackHandle{unitName: ""}) { + t.Fatalf("expected false when unitName empty") + } + }) + + t.Run("skip when systemctl missing", func(t *testing.T) { + t.Setenv("PATH", pathDir) + restoreCmd = &FakeCommandRunner{} + if rollbackAlreadyRunning(context.Background(), logger, &networkRollbackHandle{unitName: "x"}) { + t.Fatalf("expected false when systemctl not available") + } + }) + + t.Run("running when active or activating", func(t *testing.T) { + writeExecutable(t, pathDir, "systemctl") + t.Setenv("PATH", pathDir) + + for _, tc := range []struct { + state string + want bool + }{ + {state: "active\n", want: true}, + {state: "activating\n", want: true}, + {state: "inactive\n", want: false}, + } { + restoreCmd = &FakeCommandRunner{ + Outputs: map[string][]byte{ + "systemctl is-active unit.service": []byte(tc.state), + }, + } + if got := rollbackAlreadyRunning(context.Background(), logger, &networkRollbackHandle{unitName: "unit"}); got != tc.want { + t.Fatalf("state=%q got=%v want=%v", tc.state, got, tc.want) + } + } + }) + + t.Run("not running when systemctl errors", func(t *testing.T) { + writeExecutable(t, pathDir, "systemctl") + t.Setenv("PATH", pathDir) + restoreCmd = &FakeCommandRunner{ + Errors: map[string]error{ + "systemctl is-active unit.service": errors.New("boom"), + }, + } + if rollbackAlreadyRunning(context.Background(), logger, &networkRollbackHandle{unitName: "unit"}) { + t.Fatalf("expected false when systemctl is-active errors") + } + }) +} + +func TestArmNetworkRollback_ValidationErrors(t *testing.T) { + if _, err := armNetworkRollback(context.Background(), newDiscardLogger(), "", 10*time.Second, ""); err == nil { + t.Fatalf("expected error for empty backup path") + } + if _, err := armNetworkRollback(context.Background(), newDiscardLogger(), "/backup.tar", 0, ""); err == nil { + t.Fatalf("expected error for invalid timeout") + } +} + +func TestArmNetworkRollback_CreateRollbackDirFailureReturnsError(t *testing.T) { + origFS := restoreFS + origTime := restoreTime + t.Cleanup(func() { + restoreFS = origFS + restoreTime = origTime + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + fakeFS.MkdirAllErr = errors.New("boom") + restoreFS = fakeFS + restoreTime = &FakeTime{Current: time.Date(2026, 2, 1, 12, 34, 56, 0, time.UTC)} + + if _, err := armNetworkRollback(context.Background(), newDiscardLogger(), "/backup.tar", 30*time.Second, "/secure"); err == nil || !strings.Contains(err.Error(), "create rollback directory") { + t.Fatalf("err=%v want create rollback directory error", err) + } +} + +func TestArmNetworkRollback_MarkerWriteFailureReturnsError(t *testing.T) { + origFS := restoreFS + origTime := restoreTime + t.Cleanup(func() { + restoreFS = origFS + restoreTime = origTime + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + fakeFS.WriteErr = errors.New("boom") + restoreFS = fakeFS + restoreTime = &FakeTime{Current: time.Date(2026, 2, 1, 12, 34, 56, 0, time.UTC)} + + if _, err := armNetworkRollback(context.Background(), newDiscardLogger(), "/backup.tar", 30*time.Second, "/secure"); err == nil || !strings.Contains(err.Error(), "write rollback marker") { + t.Fatalf("err=%v want write rollback marker error", err) + } +} + +func TestArmNetworkRollback_ScriptWriteFailureReturnsError(t *testing.T) { + origFS := restoreFS + origTime := restoreTime + t.Cleanup(func() { + restoreFS = origFS + restoreTime = origTime + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreTime = &FakeTime{Current: time.Date(2026, 2, 1, 12, 34, 56, 0, time.UTC)} + + scriptPath := "/secure/network_rollback_20260201_123456.sh" + restoreFS = writeFailFS{FS: fakeFS, failPath: scriptPath, err: errors.New("boom")} + + if _, err := armNetworkRollback(context.Background(), newDiscardLogger(), "/backup.tar", 30*time.Second, "/secure"); err == nil || !strings.Contains(err.Error(), "write rollback script") { + t.Fatalf("err=%v want write rollback script error", err) + } +} + +func TestArmNetworkRollback_UsesSystemdRunWhenAvailable(t *testing.T) { + origFS := restoreFS + origCmd := restoreCmd + origTime := restoreTime + t.Cleanup(func() { + restoreFS = origFS + restoreCmd = origCmd + restoreTime = origTime + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + restoreTime = &FakeTime{Current: time.Date(2026, 2, 1, 12, 34, 56, 0, time.UTC)} + + pathDir := t.TempDir() + writeExecutable(t, pathDir, "systemd-run") + t.Setenv("PATH", pathDir) + + fakeCmd := &FakeCommandRunner{ + Outputs: map[string][]byte{}, + } + restoreCmd = fakeCmd + + logger := newDiscardLogger() + handle, err := armNetworkRollback(context.Background(), logger, "/backup.tar", 30*time.Second, "") + if err != nil { + t.Fatalf("armNetworkRollback error: %v", err) + } + if handle == nil || handle.unitName == "" { + t.Fatalf("expected handle with unitName set") + } + if !strings.Contains(handle.unitName, "20260201_123456") { + t.Fatalf("unitName=%q want timestamp", handle.unitName) + } + + data, err := fakeFS.ReadFile(handle.scriptPath) + if err != nil { + t.Fatalf("read script: %v", err) + } + if !strings.Contains(string(data), "Restart networking after rollback") { + t.Fatalf("expected restartNetworking block in script") + } + + foundSystemdRun := false + for _, call := range fakeCmd.CallsList() { + if strings.HasPrefix(call, "systemd-run --unit=") { + foundSystemdRun = true + } + if strings.HasPrefix(call, "sh -c ") { + t.Fatalf("unexpected fallback sh -c call: %s", call) + } + } + if !foundSystemdRun { + t.Fatalf("expected systemd-run to be called; calls=%v", fakeCmd.CallsList()) + } +} + +func TestArmNetworkRollback_CustomWorkDirAndSystemdRunOutput(t *testing.T) { + origFS := restoreFS + origCmd := restoreCmd + origTime := restoreTime + t.Cleanup(func() { + restoreFS = origFS + restoreCmd = origCmd + restoreTime = origTime + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + restoreTime = &FakeTime{Current: time.Date(2026, 2, 1, 12, 34, 56, 0, time.UTC)} + + pathDir := t.TempDir() + writeExecutable(t, pathDir, "systemd-run") + t.Setenv("PATH", pathDir) + + expectedSystemdRun := "systemd-run --unit=proxsave-network-rollback-20260201_123456 --on-active=2s /bin/sh /secure/network_rollback_20260201_123456.sh" + restoreCmd = &FakeCommandRunner{ + Outputs: map[string][]byte{ + expectedSystemdRun: []byte("unit started\n"), + }, + } + + handle, err := armNetworkRollback(context.Background(), newDiscardLogger(), "/backup.tar", 2*time.Second, "/secure") + if err != nil { + t.Fatalf("armNetworkRollback error: %v", err) + } + if handle == nil || handle.workDir != "/secure" { + t.Fatalf("handle=%#v want workDir=/secure", handle) + } +} + +func TestArmNetworkRollback_SystemdRunFailureFallsBackToNohup(t *testing.T) { + origFS := restoreFS + origCmd := restoreCmd + origTime := restoreTime + t.Cleanup(func() { + restoreFS = origFS + restoreCmd = origCmd + restoreTime = origTime + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + restoreTime = &FakeTime{Current: time.Date(2026, 2, 1, 12, 34, 56, 0, time.UTC)} + + pathDir := t.TempDir() + writeExecutable(t, pathDir, "systemd-run") + t.Setenv("PATH", pathDir) + + fakeCmd := &FakeCommandRunner{ + Errors: map[string]error{}, + } + restoreCmd = fakeCmd + + logger := newDiscardLogger() + expectedSystemdRun := "systemd-run --unit=proxsave-network-rollback-20260201_123456 --on-active=30s /bin/sh /tmp/proxsave/network_rollback_20260201_123456.sh" + fakeCmd.Errors[expectedSystemdRun] = errors.New("boom") + + handle, err := armNetworkRollback(context.Background(), logger, "/backup.tar", 30*time.Second, "") + if err != nil { + t.Fatalf("armNetworkRollback error: %v", err) + } + if handle == nil { + t.Fatalf("expected handle") + } + if handle.unitName != "" { + t.Fatalf("unitName=%q want empty after systemd-run failure", handle.unitName) + } + + foundSystemdRun := false + foundFallback := false + for _, call := range fakeCmd.CallsList() { + if strings.HasPrefix(call, "systemd-run ") { + foundSystemdRun = true + } + if strings.HasPrefix(call, "sh -c nohup sh -c 'sleep ") { + foundFallback = true + } + } + if !foundSystemdRun || !foundFallback { + t.Fatalf("expected both systemd-run and fallback; calls=%v", fakeCmd.CallsList()) + } +} + +func TestArmNetworkRollback_WithoutSystemdRunUsesNohup(t *testing.T) { + origFS := restoreFS + origCmd := restoreCmd + origTime := restoreTime + t.Cleanup(func() { + restoreFS = origFS + restoreCmd = origCmd + restoreTime = origTime + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + restoreTime = &FakeTime{Current: time.Date(2026, 2, 1, 12, 34, 56, 0, time.UTC)} + + pathDir := t.TempDir() + t.Setenv("PATH", pathDir) + + fakeCmd := &FakeCommandRunner{} + restoreCmd = fakeCmd + + logger := newDiscardLogger() + handle, err := armNetworkRollback(context.Background(), logger, "/backup.tar", 1*time.Second, "") + if err != nil { + t.Fatalf("armNetworkRollback error: %v", err) + } + if handle == nil { + t.Fatalf("expected handle") + } + if handle.unitName != "" { + t.Fatalf("unitName=%q want empty in nohup mode", handle.unitName) + } + + foundFallback := false + for _, call := range fakeCmd.CallsList() { + if strings.HasPrefix(call, "sh -c nohup sh -c 'sleep ") { + foundFallback = true + } + } + if !foundFallback { + t.Fatalf("expected fallback sh -c call; calls=%v", fakeCmd.CallsList()) + } +} + +func TestArmNetworkRollback_SubSecondTimeoutArmsAtLeastOneSecond(t *testing.T) { + origFS := restoreFS + origCmd := restoreCmd + origTime := restoreTime + t.Cleanup(func() { + restoreFS = origFS + restoreCmd = origCmd + restoreTime = origTime + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + restoreTime = &FakeTime{Current: time.Date(2026, 2, 1, 12, 34, 56, 0, time.UTC)} + + pathDir := t.TempDir() + t.Setenv("PATH", pathDir) + + fakeCmd := &FakeCommandRunner{} + restoreCmd = fakeCmd + + handle, err := armNetworkRollback(context.Background(), newDiscardLogger(), "/backup.tar", 500*time.Millisecond, "") + if err != nil { + t.Fatalf("armNetworkRollback error: %v", err) + } + if handle == nil { + t.Fatalf("expected handle") + } + + foundSleep1 := false + for _, call := range fakeCmd.CallsList() { + if strings.Contains(call, "sleep 1;") { + foundSleep1 = true + } + } + if !foundSleep1 { + t.Fatalf("expected sleep 1 in nohup command; calls=%v", fakeCmd.CallsList()) + } +} + +func TestArmNetworkRollback_FallbackCommandFailureReturnsError(t *testing.T) { + origFS := restoreFS + origCmd := restoreCmd + origTime := restoreTime + t.Cleanup(func() { + restoreFS = origFS + restoreCmd = origCmd + restoreTime = origTime + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + restoreTime = &FakeTime{Current: time.Date(2026, 2, 1, 12, 34, 56, 0, time.UTC)} + + pathDir := t.TempDir() + t.Setenv("PATH", pathDir) + + restoreCmd = &FakeCommandRunner{ + Errors: map[string]error{ + "sh -c nohup sh -c 'sleep 1; /bin/sh /tmp/proxsave/network_rollback_20260201_123456.sh' >/dev/null 2>&1 &": errors.New("boom"), + }, + } + + _, err := armNetworkRollback(context.Background(), newDiscardLogger(), "/backup.tar", 1*time.Second, "") + if err == nil || !strings.Contains(err.Error(), "failed to arm rollback timer") { + t.Fatalf("err=%v want failed to arm rollback timer", err) + } +} + +func TestDisarmNetworkRollback_RemovesMarkerAndStopsTimer(t *testing.T) { + origFS := restoreFS + origCmd := restoreCmd + t.Cleanup(func() { + restoreFS = origFS + restoreCmd = origCmd + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + pathDir := t.TempDir() + writeExecutable(t, pathDir, "systemctl") + t.Setenv("PATH", pathDir) + + fakeCmd := &FakeCommandRunner{} + restoreCmd = fakeCmd + + handle := &networkRollbackHandle{ + markerPath: "/tmp/marker", + unitName: "proxsave-network-rollback-test", + } + if err := fakeFS.WriteFile(handle.markerPath, []byte("pending\n"), 0o600); err != nil { + t.Fatalf("write marker: %v", err) + } + + disarmNetworkRollback(context.Background(), newDiscardLogger(), handle) + if _, err := fakeFS.Stat(handle.markerPath); !errors.Is(err, os.ErrNotExist) { + t.Fatalf("expected marker to be removed, stat err=%v", err) + } + + calls := strings.Join(fakeCmd.CallsList(), "\n") + if !strings.Contains(calls, "systemctl stop "+handle.unitName+".timer") { + t.Fatalf("expected systemctl stop call; calls=%v", fakeCmd.CallsList()) + } + if !strings.Contains(calls, "systemctl reset-failed "+handle.unitName+".service "+handle.unitName+".timer") { + t.Fatalf("expected systemctl reset-failed call; calls=%v", fakeCmd.CallsList()) + } +} + +func TestDisarmNetworkRollback_NilHandleNoop(t *testing.T) { + disarmNetworkRollback(context.Background(), newDiscardLogger(), nil) +} + +func TestDisarmNetworkRollback_MarkerRemoveFailureAndStopError(t *testing.T) { + origFS := restoreFS + origCmd := restoreCmd + t.Cleanup(func() { + restoreFS = origFS + restoreCmd = origCmd + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + + pathDir := t.TempDir() + writeExecutable(t, pathDir, "systemctl") + t.Setenv("PATH", pathDir) + + fakeCmd := &FakeCommandRunner{ + Errors: map[string]error{ + "systemctl stop unit.timer": errors.New("boom"), + }, + } + restoreCmd = fakeCmd + + handle := &networkRollbackHandle{ + markerPath: "/tmp/marker", + unitName: "unit", + } + if err := fakeFS.WriteFile(handle.markerPath, []byte("pending\n"), 0o600); err != nil { + t.Fatalf("write marker: %v", err) + } + + restoreFS = removeFailFS{FS: fakeFS, failPath: handle.markerPath, err: errors.New("remove boom")} + disarmNetworkRollback(context.Background(), newDiscardLogger(), handle) + + calls := strings.Join(fakeCmd.CallsList(), "\n") + if !strings.Contains(calls, "systemctl stop unit.timer") { + t.Fatalf("expected systemctl stop call; calls=%v", fakeCmd.CallsList()) + } + if !strings.Contains(calls, "systemctl reset-failed unit.service unit.timer") { + t.Fatalf("expected systemctl reset-failed call; calls=%v", fakeCmd.CallsList()) + } +} + +func TestMaybeRepairNICNamesCLI_SkippedWhenArchiveMissing(t *testing.T) { + origTime := restoreTime + t.Cleanup(func() { restoreTime = origTime }) + restoreTime = &FakeTime{Current: time.Date(2026, 2, 2, 3, 4, 5, 0, time.UTC)} + + reader := bufio.NewReader(strings.NewReader("")) + result := maybeRepairNICNamesCLI(context.Background(), reader, newDiscardLogger(), "") + if result == nil { + t.Fatalf("expected result") + } + if !strings.Contains(result.SkippedReason, "backup archive not available") { + t.Fatalf("SkippedReason=%q", result.SkippedReason) + } + if !result.AppliedAt.Equal(nowRestore()) { + t.Fatalf("AppliedAt=%s want %s", result.AppliedAt, nowRestore()) + } +} + +func TestMaybeRepairNICNamesCLI_ReturnsNilOnPlanError(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + if err := fakeFS.WriteFile("/backup.zip", []byte("not a tar"), 0o600); err != nil { + t.Fatalf("write archive: %v", err) + } + + reader := bufio.NewReader(strings.NewReader("")) + if got := maybeRepairNICNamesCLI(context.Background(), reader, newDiscardLogger(), "/backup.zip"); got != nil { + t.Fatalf("expected nil on plan error, got %#v", got) + } +} + +func TestMaybeRepairNICNamesCLI_AppliesMappingWithoutConflicts(t *testing.T) { + origFS := restoreFS + origTime := restoreTime + origSysNet := sysClassNetPath + t.Cleanup(func() { + restoreFS = origFS + restoreTime = origTime + sysClassNetPath = origSysNet + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + restoreTime = &FakeTime{Current: time.Date(2026, 2, 2, 3, 4, 5, 0, time.UTC)} + + pathDir := t.TempDir() + t.Setenv("PATH", pathDir) + + sysDir := t.TempDir() + sysClassNetPath = sysDir + if err := os.MkdirAll(filepath.Join(sysDir, "enp3s0"), 0o755); err != nil { + t.Fatalf("mkdir enp3s0: %v", err) + } + if err := os.WriteFile(filepath.Join(sysDir, "enp3s0", "address"), []byte("aa:bb:cc:dd:ee:ff\n"), 0o644); err != nil { + t.Fatalf("write address: %v", err) + } + + invJSON := []byte(`{"interfaces":[{"name":"eno1","mac":"aa:bb:cc:dd:ee:ff"}]}`) + writeTarToFakeFS(t, fakeFS, "/backup.tar", []tarEntry{ + { + Name: "./var/lib/proxsave-info/commands/system/network_inventory.json", + Typeflag: tar.TypeReg, + Mode: 0o644, + Data: invJSON, + }, + }) + + if err := fakeFS.WriteFile("/etc/network/interfaces", []byte("auto eno1\niface eno1 inet manual\n"), 0o644); err != nil { + t.Fatalf("write interfaces: %v", err) + } + + reader := bufio.NewReader(strings.NewReader("")) + result := maybeRepairNICNamesCLI(context.Background(), reader, newDiscardLogger(), "/backup.tar") + if result == nil || !result.Applied() { + t.Fatalf("expected applied result, got %#v", result) + } + + updated, err := fakeFS.ReadFile("/etc/network/interfaces") + if err != nil { + t.Fatalf("read updated interfaces: %v", err) + } + if string(updated) != "auto enp3s0\niface enp3s0 inet manual\n" { + t.Fatalf("updated=%q", string(updated)) + } +} + +func TestMaybeRepairNICNamesCLI_SkipsWhenNamingOverridesDetectedAndUserConfirms(t *testing.T) { + origFS := restoreFS + origTime := restoreTime + origSysNet := sysClassNetPath + t.Cleanup(func() { + restoreFS = origFS + restoreTime = origTime + sysClassNetPath = origSysNet + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + restoreTime = &FakeTime{Current: time.Date(2026, 2, 2, 3, 4, 5, 0, time.UTC)} + + pathDir := t.TempDir() + t.Setenv("PATH", pathDir) + + sysDir := t.TempDir() + sysClassNetPath = sysDir + if err := os.MkdirAll(filepath.Join(sysDir, "enp3s0"), 0o755); err != nil { + t.Fatalf("mkdir enp3s0: %v", err) + } + if err := os.WriteFile(filepath.Join(sysDir, "enp3s0", "address"), []byte("aa:bb:cc:dd:ee:ff\n"), 0o644); err != nil { + t.Fatalf("write address: %v", err) + } + + invJSON := []byte(`{"interfaces":[{"name":"eno1","mac":"aa:bb:cc:dd:ee:ff"}]}`) + writeTarToFakeFS(t, fakeFS, "/backup.tar", []tarEntry{ + { + Name: "./var/lib/proxsave-info/commands/system/network_inventory.json", + Typeflag: tar.TypeReg, + Mode: 0o644, + Data: invJSON, + }, + }) + + if err := fakeFS.MkdirAll("/etc/udev/rules.d", 0o755); err != nil { + t.Fatalf("mkdir udev: %v", err) + } + rule := `SUBSYSTEM=="net", ATTR{address}=="aa:bb:cc:dd:ee:ff", NAME="eth0"` + if err := fakeFS.WriteFile("/etc/udev/rules.d/70-persistent-net.rules", []byte(rule+"\n"), 0o644); err != nil { + t.Fatalf("write rule: %v", err) + } + + reader := bufio.NewReader(strings.NewReader("y\n")) + result := maybeRepairNICNamesCLI(context.Background(), reader, newDiscardLogger(), "/backup.tar") + if result == nil { + t.Fatalf("expected result") + } + if !strings.Contains(result.SkippedReason, "persistent NIC naming rules") { + t.Fatalf("SkippedReason=%q", result.SkippedReason) + } +} + +func TestMaybeRepairNICNamesCLI_OverridesDetectedUserChoosesProceed(t *testing.T) { + origFS := restoreFS + origTime := restoreTime + origSysNet := sysClassNetPath + t.Cleanup(func() { + restoreFS = origFS + restoreTime = origTime + sysClassNetPath = origSysNet + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + restoreTime = &FakeTime{Current: time.Date(2026, 2, 2, 3, 4, 5, 0, time.UTC)} + + pathDir := t.TempDir() + t.Setenv("PATH", pathDir) + + sysDir := t.TempDir() + sysClassNetPath = sysDir + if err := os.MkdirAll(filepath.Join(sysDir, "enp3s0"), 0o755); err != nil { + t.Fatalf("mkdir enp3s0: %v", err) + } + if err := os.WriteFile(filepath.Join(sysDir, "enp3s0", "address"), []byte("aa:bb:cc:dd:ee:ff\n"), 0o644); err != nil { + t.Fatalf("write address: %v", err) + } + + invJSON := []byte(`{"interfaces":[{"name":"eno1","mac":"aa:bb:cc:dd:ee:ff"}]}`) + writeTarToFakeFS(t, fakeFS, "/backup.tar", []tarEntry{ + { + Name: "./var/lib/proxsave-info/commands/system/network_inventory.json", + Typeflag: tar.TypeReg, + Mode: 0o644, + Data: invJSON, + }, + }) + + if err := fakeFS.MkdirAll("/etc/udev/rules.d", 0o755); err != nil { + t.Fatalf("mkdir udev: %v", err) + } + rule := `SUBSYSTEM=="net", ATTR{address}=="aa:bb:cc:dd:ee:ff", NAME="eth0"` + if err := fakeFS.WriteFile("/etc/udev/rules.d/70-persistent-net.rules", []byte(rule+"\n"), 0o644); err != nil { + t.Fatalf("write rule: %v", err) + } + + if err := fakeFS.WriteFile("/etc/network/interfaces", []byte("auto eno1\niface eno1 inet manual\n"), 0o644); err != nil { + t.Fatalf("write interfaces: %v", err) + } + + reader := bufio.NewReader(strings.NewReader("n\n")) + result := maybeRepairNICNamesCLI(context.Background(), reader, newDiscardLogger(), "/backup.tar") + if result == nil || !result.Applied() { + t.Fatalf("expected applied result, got %#v", result) + } +} + +func TestMaybeRepairNICNamesCLI_OverridesDetectionErrorStillApplies(t *testing.T) { + origFS := restoreFS + origTime := restoreTime + origSysNet := sysClassNetPath + t.Cleanup(func() { + restoreFS = origFS + restoreTime = origTime + sysClassNetPath = origSysNet + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + restoreTime = &FakeTime{Current: time.Date(2026, 2, 2, 3, 4, 5, 0, time.UTC)} + + pathDir := t.TempDir() + t.Setenv("PATH", pathDir) + + sysDir := t.TempDir() + sysClassNetPath = sysDir + if err := os.MkdirAll(filepath.Join(sysDir, "enp3s0"), 0o755); err != nil { + t.Fatalf("mkdir enp3s0: %v", err) + } + if err := os.WriteFile(filepath.Join(sysDir, "enp3s0", "address"), []byte("aa:bb:cc:dd:ee:ff\n"), 0o644); err != nil { + t.Fatalf("write address: %v", err) + } + + invJSON := []byte(`{"interfaces":[{"name":"eno1","mac":"aa:bb:cc:dd:ee:ff"}]}`) + writeTarToFakeFS(t, fakeFS, "/backup.tar", []tarEntry{ + { + Name: "./var/lib/proxsave-info/commands/system/network_inventory.json", + Typeflag: tar.TypeReg, + Mode: 0o644, + Data: invJSON, + }, + }) + + if err := fakeFS.WriteFile("/etc/network/interfaces", []byte("auto eno1\niface eno1 inet manual\n"), 0o644); err != nil { + t.Fatalf("write interfaces: %v", err) + } + + restoreFS = readDirFailFS{FS: fakeFS, failPath: "/etc/udev/rules.d", err: errors.New("boom")} + + reader := bufio.NewReader(strings.NewReader("")) + result := maybeRepairNICNamesCLI(context.Background(), reader, newDiscardLogger(), "/backup.tar") + if result == nil || !result.Applied() { + t.Fatalf("expected applied result, got %#v", result) + } +} + +func TestMaybeRepairNICNamesCLI_ConflictsPromptAndSkip(t *testing.T) { + origFS := restoreFS + origTime := restoreTime + origSysNet := sysClassNetPath + t.Cleanup(func() { + restoreFS = origFS + restoreTime = origTime + sysClassNetPath = origSysNet + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + restoreTime = &FakeTime{Current: time.Date(2026, 2, 2, 3, 4, 5, 0, time.UTC)} + + pathDir := t.TempDir() + t.Setenv("PATH", pathDir) + + sysDir := t.TempDir() + sysClassNetPath = sysDir + if err := os.MkdirAll(filepath.Join(sysDir, "enp3s0"), 0o755); err != nil { + t.Fatalf("mkdir enp3s0: %v", err) + } + if err := os.MkdirAll(filepath.Join(sysDir, "eno1"), 0o755); err != nil { + t.Fatalf("mkdir eno1: %v", err) + } + if err := os.WriteFile(filepath.Join(sysDir, "enp3s0", "address"), []byte("aa:bb:cc:dd:ee:ff\n"), 0o644); err != nil { + t.Fatalf("write enp3s0 address: %v", err) + } + if err := os.WriteFile(filepath.Join(sysDir, "eno1", "address"), []byte("11:22:33:44:55:66\n"), 0o644); err != nil { + t.Fatalf("write eno1 address: %v", err) + } + + invJSON := []byte(`{"interfaces":[{"name":"eno1","mac":"aa:bb:cc:dd:ee:ff"}]}`) + writeTarToFakeFS(t, fakeFS, "/backup.tar", []tarEntry{ + { + Name: "./var/lib/proxsave-info/commands/system/network_inventory.json", + Typeflag: tar.TypeReg, + Mode: 0o644, + Data: invJSON, + }, + }) + + reader := bufio.NewReader(strings.NewReader("n\n")) + result := maybeRepairNICNamesCLI(context.Background(), reader, newDiscardLogger(), "/backup.tar") + if result == nil { + t.Fatalf("expected result") + } + if result.Applied() { + t.Fatalf("expected no changes when conflicts skipped, got %#v", result) + } + if !strings.Contains(result.SkippedReason, "conflicting NIC mappings") { + t.Fatalf("SkippedReason=%q", result.SkippedReason) + } +} + +func TestMaybeRepairNICNamesCLI_ConflictsPromptAndApply(t *testing.T) { + origFS := restoreFS + origTime := restoreTime + origSysNet := sysClassNetPath + t.Cleanup(func() { + restoreFS = origFS + restoreTime = origTime + sysClassNetPath = origSysNet + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + restoreTime = &FakeTime{Current: time.Date(2026, 2, 2, 3, 4, 5, 0, time.UTC)} + + pathDir := t.TempDir() + t.Setenv("PATH", pathDir) + + sysDir := t.TempDir() + sysClassNetPath = sysDir + if err := os.MkdirAll(filepath.Join(sysDir, "enp3s0"), 0o755); err != nil { + t.Fatalf("mkdir enp3s0: %v", err) + } + if err := os.MkdirAll(filepath.Join(sysDir, "eno1"), 0o755); err != nil { + t.Fatalf("mkdir eno1: %v", err) + } + if err := os.WriteFile(filepath.Join(sysDir, "enp3s0", "address"), []byte("aa:bb:cc:dd:ee:ff\n"), 0o644); err != nil { + t.Fatalf("write enp3s0 address: %v", err) + } + if err := os.WriteFile(filepath.Join(sysDir, "eno1", "address"), []byte("11:22:33:44:55:66\n"), 0o644); err != nil { + t.Fatalf("write eno1 address: %v", err) + } + + invJSON := []byte(`{"interfaces":[{"name":"eno1","mac":"aa:bb:cc:dd:ee:ff"}]}`) + writeTarToFakeFS(t, fakeFS, "/backup.tar", []tarEntry{ + { + Name: "./var/lib/proxsave-info/commands/system/network_inventory.json", + Typeflag: tar.TypeReg, + Mode: 0o644, + Data: invJSON, + }, + }) + + if err := fakeFS.WriteFile("/etc/network/interfaces", []byte("auto eno1\niface eno1 inet manual\n"), 0o644); err != nil { + t.Fatalf("write interfaces: %v", err) + } + + reader := bufio.NewReader(strings.NewReader("y\n")) + result := maybeRepairNICNamesCLI(context.Background(), reader, newDiscardLogger(), "/backup.tar") + if result == nil || !result.Applied() { + t.Fatalf("expected applied result, got %#v", result) + } +} + +func TestMaybeRepairNICNamesCLI_OverridesPromptErrorContinues(t *testing.T) { + origFS := restoreFS + origTime := restoreTime + origSysNet := sysClassNetPath + t.Cleanup(func() { + restoreFS = origFS + restoreTime = origTime + sysClassNetPath = origSysNet + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + restoreTime = &FakeTime{Current: time.Date(2026, 2, 2, 3, 4, 5, 0, time.UTC)} + + pathDir := t.TempDir() + t.Setenv("PATH", pathDir) + + sysDir := t.TempDir() + sysClassNetPath = sysDir + if err := os.MkdirAll(filepath.Join(sysDir, "enp3s0"), 0o755); err != nil { + t.Fatalf("mkdir enp3s0: %v", err) + } + if err := os.WriteFile(filepath.Join(sysDir, "enp3s0", "address"), []byte("aa:bb:cc:dd:ee:ff\n"), 0o644); err != nil { + t.Fatalf("write address: %v", err) + } + + invJSON := []byte(`{"interfaces":[{"name":"eno1","mac":"aa:bb:cc:dd:ee:ff"}]}`) + writeTarToFakeFS(t, fakeFS, "/backup.tar", []tarEntry{ + { + Name: "./var/lib/proxsave-info/commands/system/network_inventory.json", + Typeflag: tar.TypeReg, + Mode: 0o644, + Data: invJSON, + }, + }) + + if err := fakeFS.MkdirAll("/etc/udev/rules.d", 0o755); err != nil { + t.Fatalf("mkdir udev: %v", err) + } + rule := `SUBSYSTEM=="net", ATTR{address}=="aa:bb:cc:dd:ee:ff", NAME="eth0"` + if err := fakeFS.WriteFile("/etc/udev/rules.d/70-persistent-net.rules", []byte(rule+"\n"), 0o644); err != nil { + t.Fatalf("write rule: %v", err) + } + + if err := fakeFS.WriteFile("/etc/network/interfaces", []byte("auto eno1\niface eno1 inet manual\n"), 0o644); err != nil { + t.Fatalf("write interfaces: %v", err) + } + + reader := bufio.NewReader(strings.NewReader("")) + result := maybeRepairNICNamesCLI(context.Background(), reader, newDiscardLogger(), "/backup.tar") + if result == nil || !result.Applied() { + t.Fatalf("expected applied result despite prompt error, got %#v", result) + } +} + +func TestMaybeRepairNICNamesCLI_ConflictPromptErrorLeavesConflictsExcluded(t *testing.T) { + origFS := restoreFS + origTime := restoreTime + origSysNet := sysClassNetPath + t.Cleanup(func() { + restoreFS = origFS + restoreTime = origTime + sysClassNetPath = origSysNet + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + restoreTime = &FakeTime{Current: time.Date(2026, 2, 2, 3, 4, 5, 0, time.UTC)} + + pathDir := t.TempDir() + t.Setenv("PATH", pathDir) + + sysDir := t.TempDir() + sysClassNetPath = sysDir + if err := os.MkdirAll(filepath.Join(sysDir, "enp3s0"), 0o755); err != nil { + t.Fatalf("mkdir enp3s0: %v", err) + } + if err := os.MkdirAll(filepath.Join(sysDir, "eno1"), 0o755); err != nil { + t.Fatalf("mkdir eno1: %v", err) + } + if err := os.WriteFile(filepath.Join(sysDir, "enp3s0", "address"), []byte("aa:bb:cc:dd:ee:ff\n"), 0o644); err != nil { + t.Fatalf("write enp3s0 address: %v", err) + } + if err := os.WriteFile(filepath.Join(sysDir, "eno1", "address"), []byte("11:22:33:44:55:66\n"), 0o644); err != nil { + t.Fatalf("write eno1 address: %v", err) + } + + invJSON := []byte(`{"interfaces":[{"name":"eno1","mac":"aa:bb:cc:dd:ee:ff"}]}`) + writeTarToFakeFS(t, fakeFS, "/backup.tar", []tarEntry{ + { + Name: "./var/lib/proxsave-info/commands/system/network_inventory.json", + Typeflag: tar.TypeReg, + Mode: 0o644, + Data: invJSON, + }, + }) + + reader := bufio.NewReader(strings.NewReader("")) + result := maybeRepairNICNamesCLI(context.Background(), reader, newDiscardLogger(), "/backup.tar") + if result == nil { + t.Fatalf("expected result") + } + if result.Applied() { + t.Fatalf("expected conflicts excluded due to prompt error, got %#v", result) + } + if !strings.Contains(result.SkippedReason, "conflicting NIC mappings") { + t.Fatalf("SkippedReason=%q", result.SkippedReason) + } +} + +func TestMaybeRepairNICNamesCLI_MoreThan32ConflictsTriggersTruncationBranch(t *testing.T) { + origFS := restoreFS + origTime := restoreTime + origSysNet := sysClassNetPath + t.Cleanup(func() { + restoreFS = origFS + restoreTime = origTime + sysClassNetPath = origSysNet + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + restoreTime = &FakeTime{Current: time.Date(2026, 2, 2, 3, 4, 5, 0, time.UTC)} + + pathDir := t.TempDir() + t.Setenv("PATH", pathDir) + + sysDir := t.TempDir() + sysClassNetPath = sysDir + + var ifaceJSON []string + for i := 0; i < 33; i++ { + oldName := fmt.Sprintf("eno%d", i) + newName := fmt.Sprintf("enp%d", i) + mac := fmt.Sprintf("aa:bb:cc:dd:ee:%02x", i) + ifaceJSON = append(ifaceJSON, fmt.Sprintf(`{"name":"%s","mac":"%s"}`, oldName, mac)) + + if err := os.MkdirAll(filepath.Join(sysDir, newName), 0o755); err != nil { + t.Fatalf("mkdir %s: %v", newName, err) + } + if err := os.WriteFile(filepath.Join(sysDir, newName, "address"), []byte(mac+"\n"), 0o644); err != nil { + t.Fatalf("write %s address: %v", newName, err) + } + + if err := os.MkdirAll(filepath.Join(sysDir, oldName), 0o755); err != nil { + t.Fatalf("mkdir %s: %v", oldName, err) + } + conflictMAC := fmt.Sprintf("11:22:33:44:55:%02x", i) + if err := os.WriteFile(filepath.Join(sysDir, oldName, "address"), []byte(conflictMAC+"\n"), 0o644); err != nil { + t.Fatalf("write %s address: %v", oldName, err) + } + } + + invJSON := []byte(`{"interfaces":[` + strings.Join(ifaceJSON, ",") + `]}`) + writeTarToFakeFS(t, fakeFS, "/backup.tar", []tarEntry{ + { + Name: "./var/lib/proxsave-info/commands/system/network_inventory.json", + Typeflag: tar.TypeReg, + Mode: 0o644, + Data: invJSON, + }, + }) + + reader := bufio.NewReader(strings.NewReader("n\n")) + result := maybeRepairNICNamesCLI(context.Background(), reader, newDiscardLogger(), "/backup.tar") + if result == nil { + t.Fatalf("expected result") + } + if result.Applied() { + t.Fatalf("expected no changes when conflicts skipped, got %#v", result) + } +} + +func TestMaybeRepairNICNamesCLI_ReturnsNilOnApplyError(t *testing.T) { + origFS := restoreFS + origTime := restoreTime + origSysNet := sysClassNetPath + t.Cleanup(func() { + restoreFS = origFS + restoreTime = origTime + sysClassNetPath = origSysNet + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + restoreTime = &FakeTime{Current: time.Date(2026, 2, 2, 3, 4, 5, 0, time.UTC)} + + pathDir := t.TempDir() + t.Setenv("PATH", pathDir) + + sysDir := t.TempDir() + sysClassNetPath = sysDir + if err := os.MkdirAll(filepath.Join(sysDir, "enp3s0"), 0o755); err != nil { + t.Fatalf("mkdir enp3s0: %v", err) + } + if err := os.WriteFile(filepath.Join(sysDir, "enp3s0", "address"), []byte("aa:bb:cc:dd:ee:ff\n"), 0o644); err != nil { + t.Fatalf("write address: %v", err) + } + + invJSON := []byte(`{"interfaces":[{"name":"eno1","mac":"aa:bb:cc:dd:ee:ff"}]}`) + writeTarToFakeFS(t, fakeFS, "/backup.tar", []tarEntry{ + { + Name: "./var/lib/proxsave-info/commands/system/network_inventory.json", + Typeflag: tar.TypeReg, + Mode: 0o644, + Data: invJSON, + }, + }) + + if err := fakeFS.WriteFile("/etc/network/interfaces", []byte("auto eno1\niface eno1 inet manual\n"), 0o644); err != nil { + t.Fatalf("write interfaces: %v", err) + } + + restoreFS = mkdirAllFailFS{FS: fakeFS, failPath: "/tmp/proxsave", err: errors.New("boom")} + + reader := bufio.NewReader(strings.NewReader("")) + if got := maybeRepairNICNamesCLI(context.Background(), reader, newDiscardLogger(), "/backup.tar"); got != nil { + t.Fatalf("expected nil on apply error, got %#v", got) + } +} + +func TestApplyNetworkConfig_SelectsAvailableCommand(t *testing.T) { + origCmd := restoreCmd + t.Cleanup(func() { restoreCmd = origCmd }) + + logger := newDiscardLogger() + + t.Run("ifreload", func(t *testing.T) { + pathDir := t.TempDir() + writeExecutable(t, pathDir, "ifreload") + writeExecutable(t, pathDir, "systemctl") + writeExecutable(t, pathDir, "ifup") + t.Setenv("PATH", pathDir) + + fakeCmd := &FakeCommandRunner{} + restoreCmd = fakeCmd + + if err := applyNetworkConfig(context.Background(), logger); err != nil { + t.Fatalf("applyNetworkConfig error: %v", err) + } + if calls := fakeCmd.CallsList(); len(calls) != 1 || calls[0] != "ifreload -a" { + t.Fatalf("calls=%v want [ifreload -a]", calls) + } + }) + + t.Run("systemctl", func(t *testing.T) { + pathDir := t.TempDir() + writeExecutable(t, pathDir, "systemctl") + t.Setenv("PATH", pathDir) + + fakeCmd := &FakeCommandRunner{} + restoreCmd = fakeCmd + + if err := applyNetworkConfig(context.Background(), logger); err != nil { + t.Fatalf("applyNetworkConfig error: %v", err) + } + if calls := fakeCmd.CallsList(); len(calls) != 1 || calls[0] != "systemctl restart networking" { + t.Fatalf("calls=%v want [systemctl restart networking]", calls) + } + }) + + t.Run("ifup", func(t *testing.T) { + pathDir := t.TempDir() + writeExecutable(t, pathDir, "ifup") + t.Setenv("PATH", pathDir) + + fakeCmd := &FakeCommandRunner{} + restoreCmd = fakeCmd + + if err := applyNetworkConfig(context.Background(), logger); err != nil { + t.Fatalf("applyNetworkConfig error: %v", err) + } + if calls := fakeCmd.CallsList(); len(calls) != 1 || calls[0] != "ifup -a" { + t.Fatalf("calls=%v want [ifup -a]", calls) + } + }) + + t.Run("none", func(t *testing.T) { + pathDir := t.TempDir() + t.Setenv("PATH", pathDir) + + restoreCmd = &FakeCommandRunner{} + if err := applyNetworkConfig(context.Background(), logger); err == nil || !strings.Contains(err.Error(), "no supported network reload command") { + t.Fatalf("err=%v want supported network reload command error", err) + } + }) +} + +func TestParseSSHClientIP(t *testing.T) { + t.Run("SSH_CONNECTION", func(t *testing.T) { + t.Setenv("SSH_CONNECTION", "203.0.113.9 1234 10.0.0.1 22") + t.Setenv("SSH_CLIENT", "") + if got := parseSSHClientIP(); got != "203.0.113.9" { + t.Fatalf("got %q want %q", got, "203.0.113.9") + } + }) + + t.Run("SSH_CLIENT", func(t *testing.T) { + t.Setenv("SSH_CONNECTION", "") + t.Setenv("SSH_CLIENT", "203.0.113.8 2222 22") + if got := parseSSHClientIP(); got != "203.0.113.8" { + t.Fatalf("got %q want %q", got, "203.0.113.8") + } + }) + + t.Run("none", func(t *testing.T) { + t.Setenv("SSH_CONNECTION", "") + t.Setenv("SSH_CLIENT", "") + if got := parseSSHClientIP(); got != "" { + t.Fatalf("got %q want empty", got) + } + }) +} + +func TestDetectManagementInterface(t *testing.T) { + origCmd := restoreCmd + t.Cleanup(func() { restoreCmd = origCmd }) + + logger := newDiscardLogger() + + t.Run("ssh route", func(t *testing.T) { + t.Setenv("SSH_CONNECTION", "203.0.113.9 1234 10.0.0.1 22") + t.Setenv("SSH_CLIENT", "") + restoreCmd = &FakeCommandRunner{ + Outputs: map[string][]byte{ + "ip route get 203.0.113.9": []byte("203.0.113.9 dev vmbr0 src 192.0.2.2\n"), + }, + } + iface, src := detectManagementInterface(context.Background(), logger) + if iface != "vmbr0" || src != "ssh" { + t.Fatalf("iface=%q src=%q want vmbr0 ssh", iface, src) + } + }) + + t.Run("default route fallback", func(t *testing.T) { + t.Setenv("SSH_CONNECTION", "203.0.113.9 1234 10.0.0.1 22") + t.Setenv("SSH_CLIENT", "") + restoreCmd = &FakeCommandRunner{ + Errors: map[string]error{ + "ip route get 203.0.113.9": errors.New("boom"), + }, + Outputs: map[string][]byte{ + "ip route show default": []byte("default via 192.0.2.1 dev nic1\n"), + }, + } + iface, src := detectManagementInterface(context.Background(), logger) + if iface != "nic1" || src != "default-route" { + t.Fatalf("iface=%q src=%q want nic1 default-route", iface, src) + } + }) + + t.Run("none", func(t *testing.T) { + t.Setenv("SSH_CONNECTION", "") + t.Setenv("SSH_CLIENT", "") + restoreCmd = &FakeCommandRunner{ + Errors: map[string]error{ + "ip route show default": errors.New("boom"), + }, + } + iface, src := detectManagementInterface(context.Background(), logger) + if iface != "" || src != "" { + t.Fatalf("iface=%q src=%q want empty", iface, src) + } + }) +} + +func TestRouteInterfaceForIPAndDefaultRouteInterface_ErrorCases(t *testing.T) { + origCmd := restoreCmd + t.Cleanup(func() { restoreCmd = origCmd }) + + restoreCmd = &FakeCommandRunner{ + Errors: map[string]error{ + "ip route get 203.0.113.9": errors.New("boom"), + "ip route show default": errors.New("boom"), + "ip route show default -x": errors.New("boom"), + "ip route show default --y": errors.New("boom"), + }, + } + if got := routeInterfaceForIP(context.Background(), "203.0.113.9"); got != "" { + t.Fatalf("routeInterfaceForIP=%q want empty on error", got) + } + if got := defaultRouteInterface(context.Background()); got != "" { + t.Fatalf("defaultRouteInterface=%q want empty on error", got) + } +} + +func TestDefaultRouteInterface_EmptyOutputReturnsEmpty(t *testing.T) { + origCmd := restoreCmd + t.Cleanup(func() { restoreCmd = origCmd }) + + restoreCmd = &FakeCommandRunner{ + Outputs: map[string][]byte{ + "ip route show default": []byte(""), + }, + } + if got := defaultRouteInterface(context.Background()); got != "" { + t.Fatalf("defaultRouteInterface=%q want empty", got) + } +} + +func TestDefaultNetworkPortChecks(t *testing.T) { + if got := defaultNetworkPortChecks(SystemTypePVE); len(got) != 1 || got[0].Port != 8006 { + t.Fatalf("PVE checks=%v", got) + } + if got := defaultNetworkPortChecks(SystemTypePBS); len(got) != 1 || got[0].Port != 8007 { + t.Fatalf("PBS checks=%v", got) + } + if got := defaultNetworkPortChecks(SystemTypeUnknown); got != nil { + t.Fatalf("unknown checks=%v want nil", got) + } +} + +func TestPromptNetworkCommitWithCountdown_InputAborted(t *testing.T) { + reader := bufio.NewReader(strings.NewReader("")) + logger := newDiscardLogger() + + committed, err := promptNetworkCommitWithCountdown(context.Background(), reader, logger, 2*time.Second) + if committed { + t.Fatalf("expected committed=false") + } + if err == nil { + t.Fatalf("expected error") + } +} + +func TestRollbackNetworkFilesNow_ErrorCasesAndScriptFailure(t *testing.T) { + origFS := restoreFS + origCmd := restoreCmd + origTime := restoreTime + t.Cleanup(func() { + restoreFS = origFS + restoreCmd = origCmd + restoreTime = origTime + }) + + logger := newDiscardLogger() + + if _, err := rollbackNetworkFilesNow(context.Background(), logger, "", ""); err == nil { + t.Fatalf("expected error for empty backup path") + } + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreTime = &FakeTime{Current: time.Date(2026, 2, 1, 12, 34, 56, 0, time.UTC)} + + t.Run("mkdir error", func(t *testing.T) { + restoreFS = &FakeFS{Root: fakeFS.Root, MkdirAllErr: errors.New("boom"), StatErr: make(map[string]error), StatErrors: make(map[string]error), OpenFileErr: make(map[string]error)} + restoreCmd = &FakeCommandRunner{} + if _, err := rollbackNetworkFilesNow(context.Background(), logger, "/backup.tar", "/work"); err == nil { + t.Fatalf("expected error for mkdir failure") + } + }) + + t.Run("marker write error", func(t *testing.T) { + failFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(failFS.Root) }) + failFS.WriteErr = errors.New("boom") + restoreFS = failFS + restoreCmd = &FakeCommandRunner{} + + if _, err := rollbackNetworkFilesNow(context.Background(), logger, "/backup.tar", "/work"); err == nil || !strings.Contains(err.Error(), "write rollback marker") { + t.Fatalf("err=%v want write rollback marker error", err) + } + }) + + t.Run("script write error removes marker", func(t *testing.T) { + restoreFS = fakeFS + restoreCmd = &FakeCommandRunner{} + + scriptPath := "/work/network_rollback_now_20260201_123456.sh" + restoreFS = writeFailFS{FS: fakeFS, failPath: scriptPath, err: errors.New("boom")} + + _, err := rollbackNetworkFilesNow(context.Background(), logger, "/backup.tar", "/work") + if err == nil || !strings.Contains(err.Error(), "write rollback script") { + t.Fatalf("err=%v want write rollback script error", err) + } + + markerPath := "/work/network_rollback_now_pending_20260201_123456" + if _, statErr := fakeFS.Stat(markerPath); !errors.Is(statErr, os.ErrNotExist) { + t.Fatalf("expected marker to be removed on script write failure, stat err=%v", statErr) + } + }) + + t.Run("marker remove error is non-fatal", func(t *testing.T) { + restoreCmd = &FakeCommandRunner{ + Outputs: map[string][]byte{ + "sh /work/network_rollback_now_20260201_123456.sh": []byte("ok\n"), + }, + } + + markerPath := "/work/network_rollback_now_pending_20260201_123456" + restoreFS = removeFailFS{FS: fakeFS, failPath: markerPath, err: errors.New("remove boom")} + + logPath, err := rollbackNetworkFilesNow(context.Background(), logger, "/backup.tar", "/work") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if logPath != "/work/network_rollback_now_20260201_123456.log" { + t.Fatalf("logPath=%q", logPath) + } + }) + + t.Run("script run error returns log path", func(t *testing.T) { + restoreFS = fakeFS + + restoreCmd = &FakeCommandRunner{ + Errors: map[string]error{ + "sh /work/network_rollback_now_20260201_123456.sh": errors.New("boom"), + }, + } + + logPath, err := rollbackNetworkFilesNow(context.Background(), logger, "/backup.tar", "/work") + if err == nil || !strings.Contains(err.Error(), "rollback script failed") { + t.Fatalf("err=%v want rollback script failed", err) + } + if logPath != "/work/network_rollback_now_20260201_123456.log" { + t.Fatalf("logPath=%q", logPath) + } + }) +} + +func TestRollbackNetworkFilesNow_DefaultWorkDirUsesTmpProxsave(t *testing.T) { + origFS := restoreFS + origCmd := restoreCmd + origTime := restoreTime + t.Cleanup(func() { + restoreFS = origFS + restoreCmd = origCmd + restoreTime = origTime + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + restoreTime = &FakeTime{Current: time.Date(2026, 2, 1, 12, 34, 56, 0, time.UTC)} + + restoreCmd = &FakeCommandRunner{ + Outputs: map[string][]byte{ + "sh /tmp/proxsave/network_rollback_now_20260201_123456.sh": []byte("ok\n"), + }, + } + + logPath, err := rollbackNetworkFilesNow(context.Background(), newDiscardLogger(), "/backup.tar", "") + if err != nil { + t.Fatalf("rollbackNetworkFilesNow error: %v", err) + } + if logPath != "/tmp/proxsave/network_rollback_now_20260201_123456.log" { + t.Fatalf("logPath=%q", logPath) + } +} + +func TestBuildRollbackScript_IncludesRestartBlockWhenEnabled(t *testing.T) { + script := buildRollbackScript("/marker", "/backup with spaces.tar", "/tmp/log file.log", true) + if !strings.Contains(script, "Restart networking after rollback") { + t.Fatalf("expected restartNetworking block") + } + if !strings.Contains(script, "LOG='/tmp/log file.log'") { + t.Fatalf("expected quoted LOG, got:\n%s", script) + } + if !strings.Contains(script, "BACKUP='/backup with spaces.tar'") { + t.Fatalf("expected quoted BACKUP, got:\n%s", script) + } +} + +func TestShellQuote(t *testing.T) { + if got := shellQuote(""); got != "''" { + t.Fatalf("shellQuote empty=%q", got) + } + if got := shellQuote("simple"); got != "simple" { + t.Fatalf("shellQuote simple=%q", got) + } + want := `'a b '\''c'\'''` + if got := shellQuote("a b 'c'"); got != want { + t.Fatalf("shellQuote=%q want %q", got, want) + } +} + +func TestRunCommandLogged_SuccessAndFailure(t *testing.T) { + origCmd := restoreCmd + t.Cleanup(func() { restoreCmd = origCmd }) + + logger := newDiscardLogger() + + restoreCmd = &FakeCommandRunner{ + Outputs: map[string][]byte{ + "echo hi": []byte("hi\n"), + }, + } + if err := runCommandLogged(context.Background(), logger, "echo", "hi"); err != nil { + t.Fatalf("runCommandLogged error: %v", err) + } + + restoreCmd = &FakeCommandRunner{ + Errors: map[string]error{ + "false x": errors.New("boom"), + }, + } + err := runCommandLogged(context.Background(), logger, "false", "x") + if err == nil || !strings.Contains(err.Error(), "false") || !strings.Contains(err.Error(), "failed") { + t.Fatalf("err=%v want wrapped failure", err) + } +} From 6e378623afc265fb8fa3e32aa4921502aec78e25 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 4 Mar 2026 12:05:26 +0100 Subject: [PATCH 02/18] ci: bump the actions-updates group across 1 directory with 2 updates (#164) Bumps the actions-updates group with 2 updates in the / directory: [goreleaser/goreleaser-action](https://github.com/goreleaser/goreleaser-action) and [actions/attest-build-provenance](https://github.com/actions/attest-build-provenance). Updates `goreleaser/goreleaser-action` from 6 to 7 - [Release notes](https://github.com/goreleaser/goreleaser-action/releases) - [Commits](https://github.com/goreleaser/goreleaser-action/compare/v6...v7) Updates `actions/attest-build-provenance` from 3 to 4 - [Release notes](https://github.com/actions/attest-build-provenance/releases) - [Changelog](https://github.com/actions/attest-build-provenance/blob/main/RELEASE.md) - [Commits](https://github.com/actions/attest-build-provenance/compare/v3...v4) --- updated-dependencies: - dependency-name: goreleaser/goreleaser-action dependency-version: '7' dependency-type: direct:production update-type: version-update:semver-major dependency-group: actions-updates - dependency-name: actions/attest-build-provenance dependency-version: '4' dependency-type: direct:production update-type: version-update:semver-major dependency-group: actions-updates ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/release.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 73f72f2..d4c1577 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -70,7 +70,7 @@ jobs: # GORELEASER ######################################## - name: Run GoReleaser - uses: goreleaser/goreleaser-action@v6 + uses: goreleaser/goreleaser-action@v7 with: version: latest workdir: ${{ github.workspace }} @@ -82,6 +82,6 @@ jobs: # ATTESTAZIONE PROVENIENZA BUILD ######################################## - name: Attest Build Provenance - uses: actions/attest-build-provenance@v3 + uses: actions/attest-build-provenance@v4 with: subject-path: build/proxsave_* From e2ec610138db3b1f9b85f3b8150dd126763fdc75 Mon Sep 17 00:00:00 2001 From: tis24dev Date: Thu, 5 Mar 2026 20:38:39 +0100 Subject: [PATCH 03/18] backup/pve: clarify datastore skip logs Make PVE datastore skip messages user-friendly: log disabled storages as SKIP, offline storages as actionable WARNING, and always emit DEBUG skip details with active/enabled/status and an explicit reason. Fix detected datastores report header to match output column order. Extend unit tests to cover runtime flag parsing and the new skip logging, and make Makefile coverage targets honor the go.mod toolchain via GOTOOLCHAIN. --- Makefile | 14 +-- internal/backup/collector_pve.go | 77 +++++++++++++++-- internal/backup/collector_pve_parse_test.go | 36 ++++++++ internal/backup/collector_pve_test.go | 95 ++++++++++++++++++++- 4 files changed, 207 insertions(+), 15 deletions(-) diff --git a/Makefile b/Makefile index e3041d4..6d17a9d 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,8 @@ .PHONY: build test clean run build-release test-coverage lint fmt deps help coverage coverage-check COVERAGE_THRESHOLD ?= 50.0 +TOOLCHAIN_FROM_MOD := $(shell awk '/^toolchain /{print $$2}' go.mod 2>/dev/null) +COVER_GOTOOLCHAIN := $(if $(TOOLCHAIN_FROM_MOD),$(TOOLCHAIN_FROM_MOD)+auto,auto) # Build del progetto build: @@ -60,20 +62,20 @@ test: # Test con coverage test-coverage: - go test -coverprofile=coverage.out ./... - go tool cover -html=coverage.out + GOTOOLCHAIN=$(COVER_GOTOOLCHAIN) go test -coverprofile=coverage.out ./... + GOTOOLCHAIN=$(COVER_GOTOOLCHAIN) go tool cover -html=coverage.out # Full coverage report (all packages) coverage: @echo "Running coverage across all packages..." - @go test -coverpkg=./... -coverprofile=coverage.out ./... - @go tool cover -func=coverage.out | tail -n 1 + @GOTOOLCHAIN=$(COVER_GOTOOLCHAIN) go test -coverpkg=./... -coverprofile=coverage.out ./... + @GOTOOLCHAIN=$(COVER_GOTOOLCHAIN) go tool cover -func=coverage.out | tail -n 1 # Enforce minimum coverage threshold coverage-check: @echo "Running coverage check (threshold $(COVERAGE_THRESHOLD)% )..." - @go test -coverpkg=./... -coverprofile=coverage.out ./... - @total=$$(go tool cover -func=coverage.out | grep total: | awk '{print $$3}' | sed 's/%//'); \ + @GOTOOLCHAIN=$(COVER_GOTOOLCHAIN) go test -coverpkg=./... -coverprofile=coverage.out ./... + @total=$$(GOTOOLCHAIN=$(COVER_GOTOOLCHAIN) go tool cover -func=coverage.out | grep total: | awk '{print $$3}' | sed 's/%//'); \ echo "Total coverage: $$total%"; \ if awk -v total="$$total" -v threshold="$(COVERAGE_THRESHOLD)" 'BEGIN { exit !(total+0 >= threshold+0) }'; then \ echo "Coverage check passed."; \ diff --git a/internal/backup/collector_pve.go b/internal/backup/collector_pve.go index ed0a8a6..b8f9c7a 100644 --- a/internal/backup/collector_pve.go +++ b/internal/backup/collector_pve.go @@ -1029,7 +1029,7 @@ func (c *Collector) collectPVEStorageMetadata(ctx context.Context, storages []pv var summary strings.Builder summary.WriteString("# PVE datastores detected on ") summary.WriteString(time.Now().Format(time.RFC3339)) - summary.WriteString("\n# Format: NAME|PATH|TYPE|CONTENT\n\n") + summary.WriteString("\n# Format: TYPE|NAME|PATH|CONTENT\n\n") ioTimeout := time.Duration(0) if c.config != nil && c.config.FsIoTimeoutSeconds > 0 { @@ -1073,6 +1073,45 @@ func (c *Collector) collectPVEStorageMetadata(ctx context.Context, storages []pv return "" } + describeOptionalBool := func(v *bool) string { + if v == nil { + return "unknown" + } + if *v { + return "true" + } + return "false" + } + + logSkipDetails := func(storage pveStorageEntry, reason string, err error) { + if err != nil { + c.logger.Debug( + "PVE datastore skip details: name=%s path=%s type=%s content=%s active=%s enabled=%s status=%s reason=%s err=%v", + storage.Name, + storage.Path, + storage.Type, + storage.Content, + describeOptionalBool(storage.Active), + describeOptionalBool(storage.Enabled), + strings.TrimSpace(storage.Status), + reason, + err, + ) + return + } + c.logger.Debug( + "PVE datastore skip details: name=%s path=%s type=%s content=%s active=%s enabled=%s status=%s reason=%s", + storage.Name, + storage.Path, + storage.Type, + storage.Content, + describeOptionalBool(storage.Active), + describeOptionalBool(storage.Enabled), + strings.TrimSpace(storage.Status), + reason, + ) + } + processed := 0 for _, storage := range storages { if storage.Path == "" { @@ -1083,7 +1122,32 @@ func (c *Collector) collectPVEStorageMetadata(ctx context.Context, storages []pv } if reason := unavailableReason(storage); reason != "" { - c.logger.Warning("Skipping datastore %s (path=%s)%s: not available (%s)", storage.Name, storage.Path, formatRuntime(storage), reason) + status := strings.TrimPrefix(reason, "status=") + switch { + case reason == "enabled=false" || status == "disabled": + c.logger.Skip( + "PVE datastore %s ignored: disabled in Proxmox. Not scanning %s for vzdump backup files (PVE config backup continues).", + storage.Name, storage.Path, + ) + case reason == "active=false" || status == "inactive" || status == "unavailable": + c.logger.Warning( + "PVE datastore %s skipped: storage is offline. Not scanning %s for vzdump backup files. Mount/activate it, or disable it in Proxmox if unused.", + storage.Name, storage.Path, + ) + default: + if status != reason { + c.logger.Warning( + "PVE datastore %s skipped: Proxmox reports storage status %q. Not scanning %s for vzdump backup files. Fix storage status, or disable it in Proxmox if unused.", + storage.Name, status, storage.Path, + ) + } else { + c.logger.Warning( + "PVE datastore %s skipped: Proxmox reports the storage is not available. Not scanning %s for vzdump backup files. Fix storage status, or disable it in Proxmox if unused.", + storage.Name, storage.Path, + ) + } + } + logSkipDetails(storage, reason, nil) continue } @@ -1091,14 +1155,17 @@ func (c *Collector) collectPVEStorageMetadata(ctx context.Context, storages []pv stat, err := safefs.Stat(ctx, storage.Path, ioTimeout) if err != nil { if errors.Is(err, safefs.ErrTimeout) { - c.logger.Warning("Skipping datastore %s (path=%s)%s: filesystem probe timed out (%v)", storage.Name, storage.Path, formatRuntime(storage), err) + c.logger.Warning("PVE datastore %s skipped: filesystem probe timed out for %s (%v). Not scanning for vzdump backup files.", storage.Name, storage.Path, err) + logSkipDetails(storage, "filesystem_probe_timeout", err) } else { - c.logger.Debug("Skipping datastore %s (path not accessible: %s): %v", storage.Name, storage.Path, err) + c.logger.Warning("PVE datastore %s skipped: path %s not accessible (%v). Not scanning for vzdump backup files.", storage.Name, storage.Path, err) + logSkipDetails(storage, "path_not_accessible", err) } continue } if !stat.IsDir() { - c.logger.Debug("Skipping datastore %s (path not a directory: %s)", storage.Name, storage.Path) + c.logger.Warning("PVE datastore %s skipped: path %s is not a directory. Not scanning for vzdump backup files.", storage.Name, storage.Path) + logSkipDetails(storage, "path_not_directory", nil) continue } diff --git a/internal/backup/collector_pve_parse_test.go b/internal/backup/collector_pve_parse_test.go index aaea755..5f957e3 100644 --- a/internal/backup/collector_pve_parse_test.go +++ b/internal/backup/collector_pve_parse_test.go @@ -8,6 +8,11 @@ import ( // TestParseNodeStorageList tests parsing PVE storage entries from JSON func TestParseNodeStorageList(t *testing.T) { + boolPtr := func(v bool) *bool { + b := v + return &b + } + tests := []struct { name string input string @@ -26,6 +31,24 @@ func TestParseNodeStorageList(t *testing.T) { {Name: "local-lvm", Path: "", Type: "lvmthin", Content: "images,rootdir"}, }, }, + { + name: "storage with runtime flags", + input: `[ + {"storage": "HDD1", "path": "/mnt/pve/HDD1", "type": "dir", "content": "backup", "active": 0, "enabled": 0, "status": "disabled"}, + {"storage": "HDD2", "path": "/mnt/pve/HDD2", "type": "nfs", "content": "backup", "active": 0, "enabled": 1, "status": "unavailable"}, + {"storage": "HDD3", "path": "/mnt/pve/HDD3", "type": "dir", "content": "backup", "active": true, "enabled": true, "status": "ok"}, + {"storage": "HDD4", "path": "/mnt/pve/HDD4", "type": "dir", "content": "backup", "active": "0", "enabled": "1", "status": "unknown"}, + {"storage": "HDD5", "path": "/mnt/pve/HDD5", "type": "dir", "content": "backup", "active": null, "enabled": null, "status": ""} + ]`, + expectError: false, + expected: []pveStorageEntry{ + {Name: "HDD1", Path: "/mnt/pve/HDD1", Type: "dir", Content: "backup", Active: boolPtr(false), Enabled: boolPtr(false), Status: "disabled"}, + {Name: "HDD2", Path: "/mnt/pve/HDD2", Type: "nfs", Content: "backup", Active: boolPtr(false), Enabled: boolPtr(true), Status: "unavailable"}, + {Name: "HDD3", Path: "/mnt/pve/HDD3", Type: "dir", Content: "backup", Active: boolPtr(true), Enabled: boolPtr(true), Status: "ok"}, + {Name: "HDD4", Path: "/mnt/pve/HDD4", Type: "dir", Content: "backup", Active: boolPtr(false), Enabled: boolPtr(true), Status: "unknown"}, + {Name: "HDD5", Path: "/mnt/pve/HDD5", Type: "dir", Content: "backup", Active: nil, Enabled: nil, Status: ""}, + }, + }, { name: "storage with name field instead of storage", input: `[ @@ -145,6 +168,19 @@ func TestParseNodeStorageList(t *testing.T) { if entry.Content != tt.expected[i].Content { t.Errorf("entry[%d].Content = %q, want %q", i, entry.Content, tt.expected[i].Content) } + if (entry.Active == nil) != (tt.expected[i].Active == nil) { + t.Errorf("entry[%d].Active nilness = %v, want %v", i, entry.Active == nil, tt.expected[i].Active == nil) + } else if entry.Active != nil && tt.expected[i].Active != nil && *entry.Active != *tt.expected[i].Active { + t.Errorf("entry[%d].Active = %v, want %v", i, *entry.Active, *tt.expected[i].Active) + } + if (entry.Enabled == nil) != (tt.expected[i].Enabled == nil) { + t.Errorf("entry[%d].Enabled nilness = %v, want %v", i, entry.Enabled == nil, tt.expected[i].Enabled == nil) + } else if entry.Enabled != nil && tt.expected[i].Enabled != nil && *entry.Enabled != *tt.expected[i].Enabled { + t.Errorf("entry[%d].Enabled = %v, want %v", i, *entry.Enabled, *tt.expected[i].Enabled) + } + if entry.Status != tt.expected[i].Status { + t.Errorf("entry[%d].Status = %q, want %q", i, entry.Status, tt.expected[i].Status) + } } }) } diff --git a/internal/backup/collector_pve_test.go b/internal/backup/collector_pve_test.go index 6d4a808..2171b92 100644 --- a/internal/backup/collector_pve_test.go +++ b/internal/backup/collector_pve_test.go @@ -1,11 +1,13 @@ package backup import ( + "bytes" "context" "errors" "fmt" "os" "path/filepath" + "strings" "testing" "github.com/tis24dev/proxsave/internal/logging" @@ -814,15 +816,100 @@ func TestCollectPVEStorageMetadata(t *testing.T) { collector := newPVECollector(t) ctx := context.Background() + storageDir := t.TempDir() storages := []pveStorageEntry{ - {Name: "local", Path: "/var/lib/vz", Type: "dir"}, + {Name: "local", Path: storageDir, Type: "dir"}, } - err := collector.collectPVEStorageMetadata(ctx, storages) - if err != nil { - t.Logf("collectPVEStorageMetadata returned error: %v", err) + if err := collector.collectPVEStorageMetadata(ctx, storages); err != nil { + t.Fatalf("collectPVEStorageMetadata returned error: %v", err) + } + + metaPath := filepath.Join(collector.tempDir, "var/lib/pve-cluster/info/datastores", "local", "metadata.json") + if _, err := os.Stat(metaPath); err != nil { + t.Fatalf("expected %s created, got %v", metaPath, err) } } +func TestCollectPVEStorageMetadata_SkipReasonsAreUserFriendly(t *testing.T) { + boolPtr := func(v bool) *bool { + b := v + return &b + } + ctx := context.Background() + + t.Run("disabled storage uses SKIP with debug details", func(t *testing.T) { + collector := newPVECollector(t) + var out bytes.Buffer + collector.logger.SetOutput(&out) + + storages := []pveStorageEntry{ + {Name: "HDD1", Path: "/mnt/pve/HDD1", Active: boolPtr(false), Enabled: boolPtr(false)}, + } + if err := collector.collectPVEStorageMetadata(ctx, storages); err != nil { + t.Fatalf("collectPVEStorageMetadata returned error: %v", err) + } + + logText := out.String() + if collector.logger.WarningCount() != 0 { + t.Fatalf("expected 0 warnings, got %d", collector.logger.WarningCount()) + } + if !strings.Contains(logText, "SKIP") || !strings.Contains(logText, "disabled in Proxmox") { + t.Fatalf("expected user-friendly SKIP message for disabled storage, got logs:\n%s", logText) + } + if !strings.Contains(logText, "PVE datastore skip details:") { + t.Fatalf("expected debug details for skipped storage, got logs:\n%s", logText) + } + }) + + t.Run("offline storage uses WARNING with debug details", func(t *testing.T) { + collector := newPVECollector(t) + var out bytes.Buffer + collector.logger.SetOutput(&out) + + storages := []pveStorageEntry{ + {Name: "HDD2", Path: "/mnt/pve/HDD2", Active: boolPtr(false), Enabled: boolPtr(true)}, + } + if err := collector.collectPVEStorageMetadata(ctx, storages); err != nil { + t.Fatalf("collectPVEStorageMetadata returned error: %v", err) + } + + logText := out.String() + if collector.logger.WarningCount() != 1 { + t.Fatalf("expected 1 warning, got %d", collector.logger.WarningCount()) + } + if !strings.Contains(logText, "WARNING") || !strings.Contains(logText, "storage is offline") { + t.Fatalf("expected warning message for offline storage, got logs:\n%s", logText) + } + if !strings.Contains(logText, "PVE datastore skip details:") { + t.Fatalf("expected debug details for skipped storage, got logs:\n%s", logText) + } + }) + + t.Run("non-existent path uses WARNING with debug details", func(t *testing.T) { + collector := newPVECollector(t) + var out bytes.Buffer + collector.logger.SetOutput(&out) + + storages := []pveStorageEntry{ + {Name: "MISSING", Path: filepath.Join(t.TempDir(), "does-not-exist"), Type: "dir"}, + } + if err := collector.collectPVEStorageMetadata(ctx, storages); err != nil { + t.Fatalf("collectPVEStorageMetadata returned error: %v", err) + } + + logText := out.String() + if collector.logger.WarningCount() != 1 { + t.Fatalf("expected 1 warning, got %d", collector.logger.WarningCount()) + } + if !strings.Contains(logText, "not accessible") { + t.Fatalf("expected warning for missing path, got logs:\n%s", logText) + } + if !strings.Contains(logText, "reason=path_not_accessible") { + t.Fatalf("expected debug details including reason, got logs:\n%s", logText) + } + }) +} + // Test collectPVECephInfo function func TestCollectPVECephInfo(t *testing.T) { t.Run("no ceph configured", func(t *testing.T) { From fc0aed400d11384f38c014b50621520ea5669a3b Mon Sep 17 00:00:00 2001 From: tis24dev Date: Thu, 5 Mar 2026 20:40:35 +0100 Subject: [PATCH 04/18] feat(webhook): optional Discord content fallback for embed-only messages MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds per-endpoint opt-in Discord content text (WEBHOOK__DISCORD_CONTENT_ENABLED/..._DISCORD_CONTENT) to avoid “empty” messages when embeds aren’t rendered. Default remains unchanged (embed-only). Includes 2000-char truncation, docs/template updates, and tests. --- docs/CONFIGURATION.md | 4 + docs/EXAMPLES.md | 3 + internal/config/config.go | 19 ++++ internal/config/config_test.go | 15 +++ internal/config/templates/backup.env | 5 + internal/notify/webhook.go | 22 ++++ internal/notify/webhook_payloads.go | 50 +++++++++ internal/notify/webhook_test.go | 148 +++++++++++++++++++++++++++ 8 files changed, 266 insertions(+) diff --git a/docs/CONFIGURATION.md b/docs/CONFIGURATION.md index a3ef784..1e21da1 100644 --- a/docs/CONFIGURATION.md +++ b/docs/CONFIGURATION.md @@ -912,6 +912,10 @@ WEBHOOK_DISCORD_ALERTS_AUTH_TOKEN= # Bearer token WEBHOOK_DISCORD_ALERTS_AUTH_USER= # Basic auth username WEBHOOK_DISCORD_ALERTS_AUTH_PASS= # Basic auth password WEBHOOK_DISCORD_ALERTS_AUTH_SECRET= # HMAC secret key + +# Discord-only: optional text fallback (helps if embeds are hidden on some clients) +WEBHOOK_DISCORD_ALERTS_DISCORD_CONTENT_ENABLED=false # true | false (default false) +WEBHOOK_DISCORD_ALERTS_DISCORD_CONTENT= # Optional; empty = auto-summary (max 2000 chars) ``` **Supported formats**: diff --git a/docs/EXAMPLES.md b/docs/EXAMPLES.md index 3f8a6eb..886e97b 100644 --- a/docs/EXAMPLES.md +++ b/docs/EXAMPLES.md @@ -535,6 +535,9 @@ WEBHOOK_DISCORD_ALERTS_URL=https://discord.com/api/webhooks/XXXX/YYYY WEBHOOK_DISCORD_ALERTS_FORMAT=discord WEBHOOK_DISCORD_ALERTS_METHOD=POST +# Optional (Discord): uncomment if embeds look empty on some clients +# WEBHOOK_DISCORD_ALERTS_DISCORD_CONTENT_ENABLED=true + # Run backup ./build/proxsave # Result: Notifications sent to Telegram, Email, and Discord diff --git a/internal/config/config.go b/internal/config/config.go index bb35388..16e9da2 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1262,6 +1262,11 @@ func (c *Config) BuildWebhookConfig() *WebhookConfig { } } + discordOpts := WebhookDiscordOptions{ + ContentEnabled: c.getBool(prefix+"DISCORD_CONTENT_ENABLED", false), + Content: c.getString(prefix+"DISCORD_CONTENT", ""), + } + endpoints = append(endpoints, WebhookEndpoint{ Name: name, URL: url, @@ -1269,6 +1274,7 @@ func (c *Config) BuildWebhookConfig() *WebhookConfig { Method: method, Headers: headers, Auth: auth, + Discord: discordOpts, }) } @@ -1363,9 +1369,22 @@ type WebhookEndpoint struct { Method string Headers map[string]string Auth WebhookAuth + Discord WebhookDiscordOptions CustomFields map[string]interface{} } +// WebhookDiscordOptions holds Discord-specific webhook options. +// These options are ignored for non-Discord formats. +type WebhookDiscordOptions struct { + // ContentEnabled adds a short "content" text alongside embeds. + // This helps clients that don't render embeds reliably. + ContentEnabled bool + + // Content overrides the fallback text when ContentEnabled is true. + // If empty, ProxSave generates an automatic summary. + Content string +} + // WebhookAuth holds authentication configuration for a webhook type WebhookAuth struct { Type string diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 3fb7ab0..f1363e9 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -790,6 +790,8 @@ func TestBuildWebhookConfigParsesConfiguredEndpoints(t *testing.T) { "WEBHOOK_ALERT_URL": "https://example.com/alert", "WEBHOOK_ALERT_FORMAT": "lines", "WEBHOOK_ALERT_METHOD": "PUT", + "WEBHOOK_ALERT_DISCORD_CONTENT_ENABLED": "true", + "WEBHOOK_ALERT_DISCORD_CONTENT": "Hello from ProxSave", "WEBHOOK_ALERT_AUTH_TYPE": "bearer", "WEBHOOK_ALERT_AUTH_TOKEN": "tok-123", "WEBHOOK_ALERT_AUTH_USER": "admin", @@ -837,6 +839,13 @@ func TestBuildWebhookConfigParsesConfiguredEndpoints(t *testing.T) { if alert.Auth.Type != "bearer" || alert.Auth.Token != "tok-123" || alert.Auth.User != "admin" || alert.Auth.Pass != "pass-123" || alert.Auth.Secret != "secret-xyz" { t.Fatalf("alert auth = %+v; want bearer token-user-pass-secret", alert.Auth) } + if !alert.Discord.ContentEnabled { + t.Fatalf("expected alert Discord.ContentEnabled to be true") + } + if alert.Discord.Content != "Hello from ProxSave" { + t.Fatalf("alert Discord.Content = %q; want %q", alert.Discord.Content, "Hello from ProxSave") + } + if alert.Headers["X-Trace"] != "12345" || alert.Headers["Authorization"] != "Bearer token123" { t.Fatalf("alert headers = %+v; want parsed header values", alert.Headers) } @@ -854,6 +863,12 @@ func TestBuildWebhookConfigParsesConfiguredEndpoints(t *testing.T) { if backup.Auth.Type != "none" { t.Fatalf("backup auth type = %q; want none", backup.Auth.Type) } + if backup.Discord.ContentEnabled { + t.Fatalf("expected backup Discord.ContentEnabled to be false by default") + } + if backup.Discord.Content != "" { + t.Fatalf("expected backup Discord.Content to be empty by default, got %q", backup.Discord.Content) + } if len(backup.Headers) != 0 { t.Fatalf("expected backup headers to be empty, got %+v", backup.Headers) } diff --git a/internal/config/templates/backup.env b/internal/config/templates/backup.env index 2dbf810..835e3af 100644 --- a/internal/config/templates/backup.env +++ b/internal/config/templates/backup.env @@ -255,6 +255,11 @@ WEBHOOK_RETRY_DELAY=2 # seconds # WEBHOOK_DISCORD_ALERTS_AUTH_USER= # WEBHOOK_DISCORD_ALERTS_AUTH_PASS= # WEBHOOK_DISCORD_ALERTS_AUTH_SECRET= +# +# Optional Discord-only: send a short `content` text fallback alongside the embed. +# Useful when embeds are hidden on some clients (e.g., Discord mobile settings). +# WEBHOOK_DISCORD_ALERTS_DISCORD_CONTENT_ENABLED=false +# WEBHOOK_DISCORD_ALERTS_DISCORD_CONTENT= # Optional; empty = auto-summary (max 2000 chars) # ---------------------------------------------------------------------- # Metriche / Prometheus diff --git a/internal/notify/webhook.go b/internal/notify/webhook.go index d317a0f..e5e5d30 100644 --- a/internal/notify/webhook.go +++ b/internal/notify/webhook.go @@ -184,6 +184,28 @@ func (w *WebhookNotifier) sendToEndpoint(ctx context.Context, endpoint config.We return fmt.Errorf("failed to build payload: %w", err) } + if strings.EqualFold(format, "discord") && endpoint.Discord.ContentEnabled { + payloadMap, ok := payload.(map[string]interface{}) + if !ok { + w.logger.Warning("Discord content enabled for '%s' but payload type is %T; skipping content", endpoint.Name, payload) + } else { + content := endpoint.Discord.Content + if strings.TrimSpace(content) == "" { + content = buildDiscordContentSummary(data) + } + + content, truncated := truncateDiscordContent(content) + if truncated { + w.logger.Warning("Discord content for '%s' exceeded 2000 characters; truncated", endpoint.Name) + } + + if strings.TrimSpace(content) != "" { + payloadMap["content"] = content + payload = payloadMap + } + } + } + payloadDuration := time.Since(payloadStart) w.logger.Debug("Payload built successfully in %dms", payloadDuration.Milliseconds()) diff --git a/internal/notify/webhook_payloads.go b/internal/notify/webhook_payloads.go index 1c4530f..b0c4ce1 100644 --- a/internal/notify/webhook_payloads.go +++ b/internal/notify/webhook_payloads.go @@ -3,6 +3,7 @@ package notify import ( "fmt" "strings" + "unicode/utf8" "github.com/tis24dev/proxsave/internal/logging" ) @@ -168,6 +169,55 @@ func buildDiscordPayload(data *NotificationData, logger *logging.Logger) (map[st return payload, nil } +const discordMaxContentLength = 2000 + +func buildDiscordContentSummary(data *NotificationData) string { + statusEmoji := GetStatusEmoji(data.Status) + proxmoxType := strings.ToUpper(data.ProxmoxType.String()) + status := strings.ToUpper(data.Status.String()) + + headline := fmt.Sprintf("%s %s backup %s on %s", statusEmoji, proxmoxType, status, data.Hostname) + + details := []string{} + if strings.TrimSpace(data.BackupSizeHR) != "" { + details = append(details, fmt.Sprintf("size %s", strings.TrimSpace(data.BackupSizeHR))) + } + if data.BackupDuration > 0 { + details = append(details, fmt.Sprintf("duration %s", FormatDuration(data.BackupDuration))) + } + if data.ErrorCount > 0 || data.WarningCount > 0 { + details = append(details, fmt.Sprintf("errors %d, warnings %d", data.ErrorCount, data.WarningCount)) + } + + if len(details) == 0 { + return headline + } + return headline + " • " + strings.Join(details, " • ") +} + +func truncateDiscordContent(content string) (string, bool) { + if utf8.RuneCountInString(content) <= discordMaxContentLength { + return content, false + } + + if discordMaxContentLength <= 1 { + return "…", true + } + + var b strings.Builder + b.Grow(len(content)) + count := 0 + for _, r := range content { + if count >= discordMaxContentLength-1 { + break + } + b.WriteRune(r) + count++ + } + b.WriteRune('…') + return b.String(), true +} + // buildSlackPayload builds a Slack-formatted webhook payload with blocks func buildSlackPayload(data *NotificationData, logger *logging.Logger) (map[string]interface{}, error) { logger.Debug("buildSlackPayload() starting...") diff --git a/internal/notify/webhook_test.go b/internal/notify/webhook_test.go index 78926cb..82b2bf7 100644 --- a/internal/notify/webhook_test.go +++ b/internal/notify/webhook_test.go @@ -10,6 +10,7 @@ import ( "strings" "testing" "time" + "unicode/utf8" "github.com/tis24dev/proxsave/internal/config" "github.com/tis24dev/proxsave/internal/logging" @@ -811,6 +812,153 @@ func TestWebhookNotifier_sendToEndpoint_CoversErrorBranches(t *testing.T) { }) } +func TestWebhookNotifier_sendToEndpoint_DiscordContent(t *testing.T) { + logger := logging.New(types.LogLevelDebug, false) + data := createTestNotificationData() + + notifier, err := NewWebhookNotifier(&config.WebhookConfig{ + Enabled: true, + DefaultFormat: "generic", + MaxRetries: 0, + Endpoints: []config.WebhookEndpoint{ + {Name: "x", URL: "https://example.com"}, + }, + }, logger) + if err != nil { + t.Fatalf("NewWebhookNotifier() error = %v", err) + } + + t.Run("custom content", func(t *testing.T) { + notifier.client = &http.Client{ + Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + body, readErr := io.ReadAll(req.Body) + if readErr != nil { + t.Fatalf("failed to read request body: %v", readErr) + } + var payload map[string]interface{} + if err := json.Unmarshal(body, &payload); err != nil { + t.Fatalf("failed to decode JSON: %v", err) + } + content, ok := payload["content"].(string) + if !ok || content != "Hello" { + t.Fatalf("content = %q; want %q", payload["content"], "Hello") + } + if _, ok := payload["embeds"]; !ok { + t.Fatalf("expected embeds in payload") + } + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader("ok")), + Header: make(http.Header), + Request: req, + }, nil + }), + } + + endpoint := config.WebhookEndpoint{ + Name: "discord", + URL: "https://example.com", + Method: "POST", + Format: "discord", + Discord: config.WebhookDiscordOptions{ + ContentEnabled: true, + Content: "Hello", + }, + } + if err := notifier.sendToEndpoint(context.Background(), endpoint, data); err != nil { + t.Fatalf("expected success, got %v", err) + } + }) + + t.Run("auto summary when content empty", func(t *testing.T) { + notifier.client = &http.Client{ + Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + body, readErr := io.ReadAll(req.Body) + if readErr != nil { + t.Fatalf("failed to read request body: %v", readErr) + } + var payload map[string]interface{} + if err := json.Unmarshal(body, &payload); err != nil { + t.Fatalf("failed to decode JSON: %v", err) + } + content, ok := payload["content"].(string) + if !ok || strings.TrimSpace(content) == "" { + t.Fatalf("expected non-empty content summary, got %v", payload["content"]) + } + if !strings.Contains(content, data.Hostname) { + t.Fatalf("expected content to contain hostname %q, got %q", data.Hostname, content) + } + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader("ok")), + Header: make(http.Header), + Request: req, + }, nil + }), + } + + endpoint := config.WebhookEndpoint{ + Name: "discord", + URL: "https://example.com", + Method: "POST", + Format: "discord", + Discord: config.WebhookDiscordOptions{ + ContentEnabled: true, + Content: "", + }, + } + if err := notifier.sendToEndpoint(context.Background(), endpoint, data); err != nil { + t.Fatalf("expected success, got %v", err) + } + }) + + t.Run("truncates content over 2000 chars", func(t *testing.T) { + long := strings.Repeat("a", 2100) + notifier.client = &http.Client{ + Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + body, readErr := io.ReadAll(req.Body) + if readErr != nil { + t.Fatalf("failed to read request body: %v", readErr) + } + var payload map[string]interface{} + if err := json.Unmarshal(body, &payload); err != nil { + t.Fatalf("failed to decode JSON: %v", err) + } + content, ok := payload["content"].(string) + if !ok { + t.Fatalf("expected content to be a string, got %T", payload["content"]) + } + if utf8.RuneCountInString(content) != 2000 { + t.Fatalf("content rune length = %d; want 2000", utf8.RuneCountInString(content)) + } + if !strings.HasSuffix(content, "…") { + t.Fatalf("expected content to end with ellipsis, got %q", content[len(content)-10:]) + } + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader("ok")), + Header: make(http.Header), + Request: req, + }, nil + }), + } + + endpoint := config.WebhookEndpoint{ + Name: "discord", + URL: "https://example.com", + Method: "POST", + Format: "discord", + Discord: config.WebhookDiscordOptions{ + ContentEnabled: true, + Content: long, + }, + } + if err := notifier.sendToEndpoint(context.Background(), endpoint, data); err != nil { + t.Fatalf("expected success, got %v", err) + } + }) +} + func TestBuildDiscordPayload(t *testing.T) { logger := logging.New(types.LogLevelDebug, false) data := createTestNotificationData() From 55cad476b618eff004098b40a79b34be8057b38e Mon Sep 17 00:00:00 2001 From: tis24dev Date: Fri, 6 Mar 2026 13:51:31 +0100 Subject: [PATCH 05/18] Revert "feat(webhook): optional Discord content fallback for embed-only messages" This reverts commit fc0aed400d11384f38c014b50621520ea5669a3b. --- docs/CONFIGURATION.md | 4 - docs/EXAMPLES.md | 3 - internal/config/config.go | 19 ---- internal/config/config_test.go | 15 --- internal/config/templates/backup.env | 5 - internal/notify/webhook.go | 22 ---- internal/notify/webhook_payloads.go | 50 --------- internal/notify/webhook_test.go | 148 --------------------------- 8 files changed, 266 deletions(-) diff --git a/docs/CONFIGURATION.md b/docs/CONFIGURATION.md index 1e21da1..a3ef784 100644 --- a/docs/CONFIGURATION.md +++ b/docs/CONFIGURATION.md @@ -912,10 +912,6 @@ WEBHOOK_DISCORD_ALERTS_AUTH_TOKEN= # Bearer token WEBHOOK_DISCORD_ALERTS_AUTH_USER= # Basic auth username WEBHOOK_DISCORD_ALERTS_AUTH_PASS= # Basic auth password WEBHOOK_DISCORD_ALERTS_AUTH_SECRET= # HMAC secret key - -# Discord-only: optional text fallback (helps if embeds are hidden on some clients) -WEBHOOK_DISCORD_ALERTS_DISCORD_CONTENT_ENABLED=false # true | false (default false) -WEBHOOK_DISCORD_ALERTS_DISCORD_CONTENT= # Optional; empty = auto-summary (max 2000 chars) ``` **Supported formats**: diff --git a/docs/EXAMPLES.md b/docs/EXAMPLES.md index 886e97b..3f8a6eb 100644 --- a/docs/EXAMPLES.md +++ b/docs/EXAMPLES.md @@ -535,9 +535,6 @@ WEBHOOK_DISCORD_ALERTS_URL=https://discord.com/api/webhooks/XXXX/YYYY WEBHOOK_DISCORD_ALERTS_FORMAT=discord WEBHOOK_DISCORD_ALERTS_METHOD=POST -# Optional (Discord): uncomment if embeds look empty on some clients -# WEBHOOK_DISCORD_ALERTS_DISCORD_CONTENT_ENABLED=true - # Run backup ./build/proxsave # Result: Notifications sent to Telegram, Email, and Discord diff --git a/internal/config/config.go b/internal/config/config.go index 16e9da2..bb35388 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1262,11 +1262,6 @@ func (c *Config) BuildWebhookConfig() *WebhookConfig { } } - discordOpts := WebhookDiscordOptions{ - ContentEnabled: c.getBool(prefix+"DISCORD_CONTENT_ENABLED", false), - Content: c.getString(prefix+"DISCORD_CONTENT", ""), - } - endpoints = append(endpoints, WebhookEndpoint{ Name: name, URL: url, @@ -1274,7 +1269,6 @@ func (c *Config) BuildWebhookConfig() *WebhookConfig { Method: method, Headers: headers, Auth: auth, - Discord: discordOpts, }) } @@ -1369,22 +1363,9 @@ type WebhookEndpoint struct { Method string Headers map[string]string Auth WebhookAuth - Discord WebhookDiscordOptions CustomFields map[string]interface{} } -// WebhookDiscordOptions holds Discord-specific webhook options. -// These options are ignored for non-Discord formats. -type WebhookDiscordOptions struct { - // ContentEnabled adds a short "content" text alongside embeds. - // This helps clients that don't render embeds reliably. - ContentEnabled bool - - // Content overrides the fallback text when ContentEnabled is true. - // If empty, ProxSave generates an automatic summary. - Content string -} - // WebhookAuth holds authentication configuration for a webhook type WebhookAuth struct { Type string diff --git a/internal/config/config_test.go b/internal/config/config_test.go index f1363e9..3fb7ab0 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -790,8 +790,6 @@ func TestBuildWebhookConfigParsesConfiguredEndpoints(t *testing.T) { "WEBHOOK_ALERT_URL": "https://example.com/alert", "WEBHOOK_ALERT_FORMAT": "lines", "WEBHOOK_ALERT_METHOD": "PUT", - "WEBHOOK_ALERT_DISCORD_CONTENT_ENABLED": "true", - "WEBHOOK_ALERT_DISCORD_CONTENT": "Hello from ProxSave", "WEBHOOK_ALERT_AUTH_TYPE": "bearer", "WEBHOOK_ALERT_AUTH_TOKEN": "tok-123", "WEBHOOK_ALERT_AUTH_USER": "admin", @@ -839,13 +837,6 @@ func TestBuildWebhookConfigParsesConfiguredEndpoints(t *testing.T) { if alert.Auth.Type != "bearer" || alert.Auth.Token != "tok-123" || alert.Auth.User != "admin" || alert.Auth.Pass != "pass-123" || alert.Auth.Secret != "secret-xyz" { t.Fatalf("alert auth = %+v; want bearer token-user-pass-secret", alert.Auth) } - if !alert.Discord.ContentEnabled { - t.Fatalf("expected alert Discord.ContentEnabled to be true") - } - if alert.Discord.Content != "Hello from ProxSave" { - t.Fatalf("alert Discord.Content = %q; want %q", alert.Discord.Content, "Hello from ProxSave") - } - if alert.Headers["X-Trace"] != "12345" || alert.Headers["Authorization"] != "Bearer token123" { t.Fatalf("alert headers = %+v; want parsed header values", alert.Headers) } @@ -863,12 +854,6 @@ func TestBuildWebhookConfigParsesConfiguredEndpoints(t *testing.T) { if backup.Auth.Type != "none" { t.Fatalf("backup auth type = %q; want none", backup.Auth.Type) } - if backup.Discord.ContentEnabled { - t.Fatalf("expected backup Discord.ContentEnabled to be false by default") - } - if backup.Discord.Content != "" { - t.Fatalf("expected backup Discord.Content to be empty by default, got %q", backup.Discord.Content) - } if len(backup.Headers) != 0 { t.Fatalf("expected backup headers to be empty, got %+v", backup.Headers) } diff --git a/internal/config/templates/backup.env b/internal/config/templates/backup.env index 835e3af..2dbf810 100644 --- a/internal/config/templates/backup.env +++ b/internal/config/templates/backup.env @@ -255,11 +255,6 @@ WEBHOOK_RETRY_DELAY=2 # seconds # WEBHOOK_DISCORD_ALERTS_AUTH_USER= # WEBHOOK_DISCORD_ALERTS_AUTH_PASS= # WEBHOOK_DISCORD_ALERTS_AUTH_SECRET= -# -# Optional Discord-only: send a short `content` text fallback alongside the embed. -# Useful when embeds are hidden on some clients (e.g., Discord mobile settings). -# WEBHOOK_DISCORD_ALERTS_DISCORD_CONTENT_ENABLED=false -# WEBHOOK_DISCORD_ALERTS_DISCORD_CONTENT= # Optional; empty = auto-summary (max 2000 chars) # ---------------------------------------------------------------------- # Metriche / Prometheus diff --git a/internal/notify/webhook.go b/internal/notify/webhook.go index e5e5d30..d317a0f 100644 --- a/internal/notify/webhook.go +++ b/internal/notify/webhook.go @@ -184,28 +184,6 @@ func (w *WebhookNotifier) sendToEndpoint(ctx context.Context, endpoint config.We return fmt.Errorf("failed to build payload: %w", err) } - if strings.EqualFold(format, "discord") && endpoint.Discord.ContentEnabled { - payloadMap, ok := payload.(map[string]interface{}) - if !ok { - w.logger.Warning("Discord content enabled for '%s' but payload type is %T; skipping content", endpoint.Name, payload) - } else { - content := endpoint.Discord.Content - if strings.TrimSpace(content) == "" { - content = buildDiscordContentSummary(data) - } - - content, truncated := truncateDiscordContent(content) - if truncated { - w.logger.Warning("Discord content for '%s' exceeded 2000 characters; truncated", endpoint.Name) - } - - if strings.TrimSpace(content) != "" { - payloadMap["content"] = content - payload = payloadMap - } - } - } - payloadDuration := time.Since(payloadStart) w.logger.Debug("Payload built successfully in %dms", payloadDuration.Milliseconds()) diff --git a/internal/notify/webhook_payloads.go b/internal/notify/webhook_payloads.go index b0c4ce1..1c4530f 100644 --- a/internal/notify/webhook_payloads.go +++ b/internal/notify/webhook_payloads.go @@ -3,7 +3,6 @@ package notify import ( "fmt" "strings" - "unicode/utf8" "github.com/tis24dev/proxsave/internal/logging" ) @@ -169,55 +168,6 @@ func buildDiscordPayload(data *NotificationData, logger *logging.Logger) (map[st return payload, nil } -const discordMaxContentLength = 2000 - -func buildDiscordContentSummary(data *NotificationData) string { - statusEmoji := GetStatusEmoji(data.Status) - proxmoxType := strings.ToUpper(data.ProxmoxType.String()) - status := strings.ToUpper(data.Status.String()) - - headline := fmt.Sprintf("%s %s backup %s on %s", statusEmoji, proxmoxType, status, data.Hostname) - - details := []string{} - if strings.TrimSpace(data.BackupSizeHR) != "" { - details = append(details, fmt.Sprintf("size %s", strings.TrimSpace(data.BackupSizeHR))) - } - if data.BackupDuration > 0 { - details = append(details, fmt.Sprintf("duration %s", FormatDuration(data.BackupDuration))) - } - if data.ErrorCount > 0 || data.WarningCount > 0 { - details = append(details, fmt.Sprintf("errors %d, warnings %d", data.ErrorCount, data.WarningCount)) - } - - if len(details) == 0 { - return headline - } - return headline + " • " + strings.Join(details, " • ") -} - -func truncateDiscordContent(content string) (string, bool) { - if utf8.RuneCountInString(content) <= discordMaxContentLength { - return content, false - } - - if discordMaxContentLength <= 1 { - return "…", true - } - - var b strings.Builder - b.Grow(len(content)) - count := 0 - for _, r := range content { - if count >= discordMaxContentLength-1 { - break - } - b.WriteRune(r) - count++ - } - b.WriteRune('…') - return b.String(), true -} - // buildSlackPayload builds a Slack-formatted webhook payload with blocks func buildSlackPayload(data *NotificationData, logger *logging.Logger) (map[string]interface{}, error) { logger.Debug("buildSlackPayload() starting...") diff --git a/internal/notify/webhook_test.go b/internal/notify/webhook_test.go index 82b2bf7..78926cb 100644 --- a/internal/notify/webhook_test.go +++ b/internal/notify/webhook_test.go @@ -10,7 +10,6 @@ import ( "strings" "testing" "time" - "unicode/utf8" "github.com/tis24dev/proxsave/internal/config" "github.com/tis24dev/proxsave/internal/logging" @@ -812,153 +811,6 @@ func TestWebhookNotifier_sendToEndpoint_CoversErrorBranches(t *testing.T) { }) } -func TestWebhookNotifier_sendToEndpoint_DiscordContent(t *testing.T) { - logger := logging.New(types.LogLevelDebug, false) - data := createTestNotificationData() - - notifier, err := NewWebhookNotifier(&config.WebhookConfig{ - Enabled: true, - DefaultFormat: "generic", - MaxRetries: 0, - Endpoints: []config.WebhookEndpoint{ - {Name: "x", URL: "https://example.com"}, - }, - }, logger) - if err != nil { - t.Fatalf("NewWebhookNotifier() error = %v", err) - } - - t.Run("custom content", func(t *testing.T) { - notifier.client = &http.Client{ - Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { - body, readErr := io.ReadAll(req.Body) - if readErr != nil { - t.Fatalf("failed to read request body: %v", readErr) - } - var payload map[string]interface{} - if err := json.Unmarshal(body, &payload); err != nil { - t.Fatalf("failed to decode JSON: %v", err) - } - content, ok := payload["content"].(string) - if !ok || content != "Hello" { - t.Fatalf("content = %q; want %q", payload["content"], "Hello") - } - if _, ok := payload["embeds"]; !ok { - t.Fatalf("expected embeds in payload") - } - return &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(strings.NewReader("ok")), - Header: make(http.Header), - Request: req, - }, nil - }), - } - - endpoint := config.WebhookEndpoint{ - Name: "discord", - URL: "https://example.com", - Method: "POST", - Format: "discord", - Discord: config.WebhookDiscordOptions{ - ContentEnabled: true, - Content: "Hello", - }, - } - if err := notifier.sendToEndpoint(context.Background(), endpoint, data); err != nil { - t.Fatalf("expected success, got %v", err) - } - }) - - t.Run("auto summary when content empty", func(t *testing.T) { - notifier.client = &http.Client{ - Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { - body, readErr := io.ReadAll(req.Body) - if readErr != nil { - t.Fatalf("failed to read request body: %v", readErr) - } - var payload map[string]interface{} - if err := json.Unmarshal(body, &payload); err != nil { - t.Fatalf("failed to decode JSON: %v", err) - } - content, ok := payload["content"].(string) - if !ok || strings.TrimSpace(content) == "" { - t.Fatalf("expected non-empty content summary, got %v", payload["content"]) - } - if !strings.Contains(content, data.Hostname) { - t.Fatalf("expected content to contain hostname %q, got %q", data.Hostname, content) - } - return &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(strings.NewReader("ok")), - Header: make(http.Header), - Request: req, - }, nil - }), - } - - endpoint := config.WebhookEndpoint{ - Name: "discord", - URL: "https://example.com", - Method: "POST", - Format: "discord", - Discord: config.WebhookDiscordOptions{ - ContentEnabled: true, - Content: "", - }, - } - if err := notifier.sendToEndpoint(context.Background(), endpoint, data); err != nil { - t.Fatalf("expected success, got %v", err) - } - }) - - t.Run("truncates content over 2000 chars", func(t *testing.T) { - long := strings.Repeat("a", 2100) - notifier.client = &http.Client{ - Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { - body, readErr := io.ReadAll(req.Body) - if readErr != nil { - t.Fatalf("failed to read request body: %v", readErr) - } - var payload map[string]interface{} - if err := json.Unmarshal(body, &payload); err != nil { - t.Fatalf("failed to decode JSON: %v", err) - } - content, ok := payload["content"].(string) - if !ok { - t.Fatalf("expected content to be a string, got %T", payload["content"]) - } - if utf8.RuneCountInString(content) != 2000 { - t.Fatalf("content rune length = %d; want 2000", utf8.RuneCountInString(content)) - } - if !strings.HasSuffix(content, "…") { - t.Fatalf("expected content to end with ellipsis, got %q", content[len(content)-10:]) - } - return &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(strings.NewReader("ok")), - Header: make(http.Header), - Request: req, - }, nil - }), - } - - endpoint := config.WebhookEndpoint{ - Name: "discord", - URL: "https://example.com", - Method: "POST", - Format: "discord", - Discord: config.WebhookDiscordOptions{ - ContentEnabled: true, - Content: long, - }, - } - if err := notifier.sendToEndpoint(context.Background(), endpoint, data); err != nil { - t.Fatalf("expected success, got %v", err) - } - }) -} - func TestBuildDiscordPayload(t *testing.T) { logger := logging.New(types.LogLevelDebug, false) data := createTestNotificationData() From d46db6b783ee9e9592c988b6d66c78e5e0662671 Mon Sep 17 00:00:00 2001 From: tis24dev Date: Mon, 9 Mar 2026 12:36:16 +0100 Subject: [PATCH 06/18] Update go.mod --- go.mod | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go.mod b/go.mod index 4a36bb8..f99518e 100644 --- a/go.mod +++ b/go.mod @@ -2,7 +2,7 @@ module github.com/tis24dev/proxsave go 1.25 -toolchain go1.25.7 +toolchain go1.25.8 require ( filippo.io/age v1.3.1 From c6bfbb1069273a0065e5954f1f3842f1a4b955cf Mon Sep 17 00:00:00 2001 From: tis24dev Date: Mon, 9 Mar 2026 14:47:00 +0100 Subject: [PATCH 07/18] Harden restore decisions against untrusted backup metadata Derive restore compatibility, cluster gating, and hostname warnings from archive-backed facts instead of candidate manifest metadata. Add restore archive analysis for internal backup metadata and category-based fallback, update restore planning and compatibility checks to consume trusted decision inputs, and add regression tests covering manifest spoofing for backup type and cluster mode. --- internal/orchestrator/compatibility.go | 26 +- internal/orchestrator/compatibility_test.go | 5 +- internal/orchestrator/restore_decision.go | 273 ++++++++++++++++++ .../orchestrator/restore_decision_test.go | 75 +++++ internal/orchestrator/restore_plan.go | 10 +- internal/orchestrator/restore_plan_test.go | 9 +- .../restore_workflow_decision_test.go | 237 +++++++++++++++ .../restore_workflow_more_test.go | 3 +- internal/orchestrator/restore_workflow_ui.go | 31 +- .../restore_workflow_ui_helpers_test.go | 7 + internal/orchestrator/selective.go | 36 +-- 11 files changed, 633 insertions(+), 79 deletions(-) create mode 100644 internal/orchestrator/restore_decision.go create mode 100644 internal/orchestrator/restore_decision_test.go create mode 100644 internal/orchestrator/restore_workflow_decision_test.go diff --git a/internal/orchestrator/compatibility.go b/internal/orchestrator/compatibility.go index cc2eb3f..5aba6e6 100644 --- a/internal/orchestrator/compatibility.go +++ b/internal/orchestrator/compatibility.go @@ -41,12 +41,8 @@ func DetectBackupType(manifest *backup.Manifest) SystemType { // Check ProxmoxType field if present if manifest.ProxmoxType != "" { - proxmoxType := strings.ToLower(manifest.ProxmoxType) - if strings.Contains(proxmoxType, "pve") || strings.Contains(proxmoxType, "proxmox-ve") { - return SystemTypePVE - } - if strings.Contains(proxmoxType, "pbs") || strings.Contains(proxmoxType, "proxmox-backup") { - return SystemTypePBS + if backupType := parseSystemTypeString(manifest.ProxmoxType); backupType != SystemTypeUnknown { + return backupType } } @@ -65,12 +61,20 @@ func DetectBackupType(manifest *backup.Manifest) SystemType { return SystemTypeUnknown } -// ValidateCompatibility checks if a backup is compatible with the current system -func ValidateCompatibility(manifest *backup.Manifest) error { - currentSystem := DetectCurrentSystem() - backupType := DetectBackupType(manifest) +func parseSystemTypeString(value string) SystemType { + normalized := strings.ToLower(strings.TrimSpace(value)) + switch { + case strings.Contains(normalized, "pve"), strings.Contains(normalized, "proxmox-ve"): + return SystemTypePVE + case strings.Contains(normalized, "pbs"), strings.Contains(normalized, "proxmox-backup"): + return SystemTypePBS + default: + return SystemTypeUnknown + } +} - // If we can't detect either, issue a warning but allow +// ValidateCompatibility checks if a backup is compatible with the current system. +func ValidateCompatibility(currentSystem, backupType SystemType) error { if currentSystem == SystemTypeUnknown { return fmt.Errorf("warning: cannot detect current system type - restoration may fail") } diff --git a/internal/orchestrator/compatibility_test.go b/internal/orchestrator/compatibility_test.go index 5c7ef98..dcdcae5 100644 --- a/internal/orchestrator/compatibility_test.go +++ b/internal/orchestrator/compatibility_test.go @@ -3,8 +3,6 @@ package orchestrator import ( "os" "testing" - - "github.com/tis24dev/proxsave/internal/backup" ) func TestValidateCompatibility_Mismatch(t *testing.T) { @@ -18,8 +16,7 @@ func TestValidateCompatibility_Mismatch(t *testing.T) { t.Fatalf("mkdir: %v", err) } - manifest := &backup.Manifest{ProxmoxType: "pbs"} - if err := ValidateCompatibility(manifest); err == nil { + if err := ValidateCompatibility(SystemTypePVE, SystemTypePBS); err == nil { t.Fatalf("expected incompatibility error") } } diff --git a/internal/orchestrator/restore_decision.go b/internal/orchestrator/restore_decision.go new file mode 100644 index 0000000..6e9d0d8 --- /dev/null +++ b/internal/orchestrator/restore_decision.go @@ -0,0 +1,273 @@ +package orchestrator + +import ( + "archive/tar" + "bufio" + "context" + "fmt" + "io" + "path" + "strings" + + "github.com/tis24dev/proxsave/internal/logging" +) + +// RestoreDecisionSource describes where trusted restore facts were derived from. +type RestoreDecisionSource string + +const ( + RestoreDecisionSourceUnknown RestoreDecisionSource = "unknown" + RestoreDecisionSourceInternalMetadata RestoreDecisionSource = "internal_metadata" + RestoreDecisionSourceCategories RestoreDecisionSource = "categories" + RestoreDecisionSourceAmbiguous RestoreDecisionSource = "ambiguous" +) + +// RestoreDecisionInfo contains the archive-derived facts used for restore decisions. +type RestoreDecisionInfo struct { + BackupType SystemType + ClusterPayload bool + BackupHostname string + Source RestoreDecisionSource +} + +type restoreDecisionMetadata struct { + BackupType SystemType + ClusterMode string + Hostname string +} + +type restoreArchiveInspection struct { + AvailableCategories []Category + Decision *RestoreDecisionInfo +} + +const restoreDecisionMetadataPath = "var/lib/proxsave-info/backup_metadata.txt" + +// AnalyzeRestoreArchive inspects the archive once and derives trusted restore facts +// from archive contents plus internal backup metadata when present. +func AnalyzeRestoreArchive(archivePath string, logger *logging.Logger) (categories []Category, decision *RestoreDecisionInfo, err error) { + if logger == nil { + logger = logging.GetDefaultLogger() + } + + done := logging.DebugStart(logger, "analyze restore archive", "archive=%s", archivePath) + defer func() { done(err) }() + logger.Info("Analyzing backup contents...") + + inspection, err := inspectRestoreArchiveContents(archivePath, logger) + if err != nil { + return nil, nil, err + } + + for _, cat := range inspection.AvailableCategories { + logger.Debug("Category available: %s (%s)", cat.ID, cat.Name) + } + logger.Info("Detected %d available categories", len(inspection.AvailableCategories)) + if inspection.Decision != nil { + logger.Debug( + "Restore decision facts: backup_type=%s cluster_payload=%v hostname=%q source=%s", + inspection.Decision.BackupType, + inspection.Decision.ClusterPayload, + inspection.Decision.BackupHostname, + inspection.Decision.Source, + ) + } + + return inspection.AvailableCategories, inspection.Decision, nil +} + +func inspectRestoreArchiveContents(archivePath string, logger *logging.Logger) (*restoreArchiveInspection, error) { + file, err := restoreFS.Open(archivePath) + if err != nil { + return nil, fmt.Errorf("open archive: %w", err) + } + defer file.Close() + + reader, err := createDecompressionReader(context.Background(), file, archivePath) + if err != nil { + return nil, err + } + defer func() { + if closer, ok := reader.(interface{ Close() error }); ok { + closer.Close() + } + }() + + tarReader := tar.NewReader(reader) + archivePaths, metadata, metadataErr, err := collectRestoreArchiveFacts(tarReader) + if err != nil { + return nil, fmt.Errorf("inspect archive: %w", err) + } + if metadataErr != nil { + logger.Warning("Could not parse internal backup metadata: %v", metadataErr) + } + + logger.Debug("Found %d entries in archive", len(archivePaths)) + availableCategories := AnalyzeArchivePaths(archivePaths, GetAllCategories()) + + decision := buildRestoreDecisionInfo(metadata, availableCategories, logger) + return &restoreArchiveInspection{ + AvailableCategories: availableCategories, + Decision: decision, + }, nil +} + +func collectRestoreArchiveFacts(tarReader *tar.Reader) ([]string, *restoreDecisionMetadata, error, error) { + var ( + archivePaths []string + metadata *restoreDecisionMetadata + metadataErr error + ) + + for { + header, err := tarReader.Next() + if err == io.EOF { + break + } + if err != nil { + return nil, nil, nil, err + } + + archivePaths = append(archivePaths, header.Name) + if metadata != nil || header.FileInfo().IsDir() { + continue + } + if !isRestoreDecisionMetadataEntry(header.Name) { + continue + } + + data, err := io.ReadAll(tarReader) + if err != nil { + return nil, nil, nil, err + } + parsed, parseErr := parseRestoreDecisionMetadata(data) + if parseErr != nil { + metadataErr = parseErr + continue + } + metadata = parsed + } + + return archivePaths, metadata, metadataErr, nil +} + +func isRestoreDecisionMetadataEntry(entryName string) bool { + return normalizeRestoreEntryPath(entryName) == restoreDecisionMetadataPath +} + +func normalizeRestoreEntryPath(entryName string) string { + clean := strings.TrimSpace(strings.ReplaceAll(entryName, "\\", "/")) + if clean == "" { + return "" + } + + clean = path.Clean(clean) + clean = strings.TrimPrefix(clean, "./") + clean = strings.TrimPrefix(clean, "/") + if clean == "." { + return "" + } + return clean +} + +func parseRestoreDecisionMetadata(data []byte) (*restoreDecisionMetadata, error) { + meta := &restoreDecisionMetadata{} + scanner := bufio.NewScanner(strings.NewReader(string(data))) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" || strings.HasPrefix(line, "#") { + continue + } + + parts := strings.SplitN(line, "=", 2) + if len(parts) != 2 { + continue + } + + key := strings.TrimSpace(parts[0]) + value := strings.TrimSpace(parts[1]) + switch key { + case "BACKUP_TYPE": + meta.BackupType = parseSystemTypeString(value) + case "PVE_CLUSTER_MODE", "CLUSTER_MODE": + meta.ClusterMode = value + case "HOSTNAME": + meta.Hostname = value + } + } + if err := scanner.Err(); err != nil { + return nil, err + } + return meta, nil +} + +func buildRestoreDecisionInfo(metadata *restoreDecisionMetadata, categories []Category, logger *logging.Logger) *RestoreDecisionInfo { + if logger == nil { + logger = logging.GetDefaultLogger() + } + + info := &RestoreDecisionInfo{ + BackupType: SystemTypeUnknown, + ClusterPayload: hasCategoryID(categories, "pve_cluster"), + Source: RestoreDecisionSourceUnknown, + } + + if metadata != nil { + info.BackupHostname = strings.TrimSpace(metadata.Hostname) + } + + categoryType, ambiguousType := detectBackupTypeFromCategories(categories) + switch { + case ambiguousType: + logger.Warning("Archive contains both PVE and PBS-specific payloads; treating backup type as unknown for compatibility checks") + info.Source = RestoreDecisionSourceAmbiguous + case metadata != nil && metadata.BackupType != SystemTypeUnknown && categoryType == SystemTypeUnknown: + info.BackupType = metadata.BackupType + info.Source = RestoreDecisionSourceInternalMetadata + case metadata != nil && metadata.BackupType != SystemTypeUnknown && categoryType != SystemTypeUnknown && metadata.BackupType != categoryType: + logger.Warning("Internal backup metadata and archive payload disagree on backup type; using archive-derived type %s", strings.ToUpper(string(categoryType))) + info.BackupType = categoryType + info.Source = RestoreDecisionSourceCategories + case categoryType != SystemTypeUnknown: + info.BackupType = categoryType + info.Source = RestoreDecisionSourceCategories + case metadata != nil && metadata.BackupType != SystemTypeUnknown: + info.BackupType = metadata.BackupType + info.Source = RestoreDecisionSourceInternalMetadata + } + + if metadata != nil { + metadataCluster := strings.EqualFold(strings.TrimSpace(metadata.ClusterMode), "cluster") + switch { + case metadataCluster && !info.ClusterPayload: + logger.Warning("Internal backup metadata reports cluster mode, but no pve_cluster payload was found; guarded cluster restore remains disabled") + case !metadataCluster && info.ClusterPayload: + logger.Warning("Cluster payload detected in archive despite metadata reporting non-cluster backup; guarded cluster restore remains enabled") + } + } + + return info +} + +func detectBackupTypeFromCategories(categories []Category) (SystemType, bool) { + var hasPVE, hasPBS bool + for _, cat := range categories { + switch cat.Type { + case CategoryTypePVE: + hasPVE = true + case CategoryTypePBS: + hasPBS = true + } + } + + switch { + case hasPVE && hasPBS: + return SystemTypeUnknown, true + case hasPVE: + return SystemTypePVE, false + case hasPBS: + return SystemTypePBS, false + default: + return SystemTypeUnknown, false + } +} diff --git a/internal/orchestrator/restore_decision_test.go b/internal/orchestrator/restore_decision_test.go new file mode 100644 index 0000000..c91db5d --- /dev/null +++ b/internal/orchestrator/restore_decision_test.go @@ -0,0 +1,75 @@ +package orchestrator + +import ( + "path/filepath" + "testing" + + "github.com/tis24dev/proxsave/internal/logging" +) + +func TestAnalyzeRestoreArchive_UsesInternalMetadataWhenCategoriesAreCommonOnly(t *testing.T) { + origRestoreFS := restoreFS + t.Cleanup(func() { restoreFS = origRestoreFS }) + restoreFS = osFS{} + + archivePath := filepath.Join(t.TempDir(), "backup.tar") + if err := writeTarFile(archivePath, map[string]string{ + "etc/hosts": "127.0.0.1 localhost\n", + "var/lib/proxsave-info/backup_metadata.txt": "# ProxSave Metadata\nBACKUP_TYPE=pbs\nHOSTNAME=pbs-node\nPVE_CLUSTER_MODE=cluster\n", + }); err != nil { + t.Fatalf("writeTarFile: %v", err) + } + + logger := logging.New(logging.GetDefaultLogger().GetLevel(), false) + categories, decision, err := AnalyzeRestoreArchive(archivePath, logger) + if err != nil { + t.Fatalf("AnalyzeRestoreArchive() error: %v", err) + } + if backupType, ambiguous := detectBackupTypeFromCategories(categories); backupType != SystemTypeUnknown || ambiguous { + t.Fatalf("detectBackupTypeFromCategories() = (%s, %v); want (%s, false)", backupType, ambiguous, SystemTypeUnknown) + } + if decision == nil { + t.Fatalf("decision info is nil") + } + if decision.BackupType != SystemTypePBS { + t.Fatalf("BackupType=%s; want %s", decision.BackupType, SystemTypePBS) + } + if decision.Source != RestoreDecisionSourceInternalMetadata { + t.Fatalf("Source=%s; want %s", decision.Source, RestoreDecisionSourceInternalMetadata) + } + if decision.BackupHostname != "pbs-node" { + t.Fatalf("BackupHostname=%q; want %q", decision.BackupHostname, "pbs-node") + } + if decision.ClusterPayload { + t.Fatalf("ClusterPayload should stay false without pve_cluster payload") + } +} + +func TestAnalyzeRestoreArchive_ClusterPayloadUsesArchiveContents(t *testing.T) { + origRestoreFS := restoreFS + t.Cleanup(func() { restoreFS = origRestoreFS }) + restoreFS = osFS{} + + archivePath := filepath.Join(t.TempDir(), "backup.tar") + if err := writeTarFile(archivePath, map[string]string{ + "var/lib/pve-cluster/config.db": "db\n", + "var/lib/proxsave-info/backup_metadata.txt": "BACKUP_TYPE=pve\nPVE_CLUSTER_MODE=standalone\nHOSTNAME=node1\n", + }); err != nil { + t.Fatalf("writeTarFile: %v", err) + } + + logger := logging.New(logging.GetDefaultLogger().GetLevel(), false) + _, decision, err := AnalyzeRestoreArchive(archivePath, logger) + if err != nil { + t.Fatalf("AnalyzeRestoreArchive() error: %v", err) + } + if decision == nil { + t.Fatalf("decision info is nil") + } + if !decision.ClusterPayload { + t.Fatalf("ClusterPayload should be true when pve_cluster payload exists") + } + if decision.BackupType != SystemTypePVE { + t.Fatalf("BackupType=%s; want %s", decision.BackupType, SystemTypePVE) + } +} diff --git a/internal/orchestrator/restore_plan.go b/internal/orchestrator/restore_plan.go index 6d54716..6eba614 100644 --- a/internal/orchestrator/restore_plan.go +++ b/internal/orchestrator/restore_plan.go @@ -1,11 +1,5 @@ package orchestrator -import ( - "strings" - - "github.com/tis24dev/proxsave/internal/backup" -) - // RestorePlan contains a pure, side-effect-free description of a restore run. type RestorePlan struct { Mode RestoreMode @@ -22,7 +16,7 @@ type RestorePlan struct { // PlanRestore computes the restore plan without performing any I/O or prompts. func PlanRestore( - manifest *backup.Manifest, + clusterBackup bool, selectedCategories []Category, systemType SystemType, mode RestoreMode, @@ -35,7 +29,7 @@ func PlanRestore( NormalCategories: normal, StagedCategories: staged, ExportCategories: export, - ClusterBackup: manifest != nil && strings.EqualFold(strings.TrimSpace(manifest.ClusterMode), "cluster"), + ClusterBackup: clusterBackup, } plan.NeedsClusterRestore = systemType == SystemTypePVE && hasCategoryID(normal, "pve_cluster") diff --git a/internal/orchestrator/restore_plan_test.go b/internal/orchestrator/restore_plan_test.go index c38b562..f27ba99 100644 --- a/internal/orchestrator/restore_plan_test.go +++ b/internal/orchestrator/restore_plan_test.go @@ -2,16 +2,13 @@ package orchestrator import ( "testing" - - "github.com/tis24dev/proxsave/internal/backup" ) func TestPlanRestoreClusterSafeToggle(t *testing.T) { clusterCat := Category{ID: "pve_cluster", Type: CategoryTypePVE} storageCat := Category{ID: "storage_pve", Type: CategoryTypePVE} - manifest := &backup.Manifest{ClusterMode: "cluster"} - plan := PlanRestore(manifest, []Category{clusterCat, storageCat}, SystemTypePVE, RestoreModeCustom) + plan := PlanRestore(true, []Category{clusterCat, storageCat}, SystemTypePVE, RestoreModeCustom) if !plan.NeedsClusterRestore { t.Fatalf("expected NeedsClusterRestore true") @@ -50,7 +47,7 @@ func TestPlanRestorePBSCategories(t *testing.T) { pbsCat := Category{ID: "pbs_config", Type: CategoryTypePBS, ExportOnly: true} normalCat := Category{ID: "network", Type: CategoryTypeCommon} - plan := PlanRestore(nil, []Category{pbsCat, normalCat}, SystemTypePBS, RestoreModeCustom) + plan := PlanRestore(false, []Category{pbsCat, normalCat}, SystemTypePBS, RestoreModeCustom) if len(plan.ExportCategories) != 1 || !hasCategoryID(plan.ExportCategories, "pbs_config") { t.Fatalf("expected pbs_config to be exported, got %+v", plan.ExportCategories) } @@ -66,7 +63,7 @@ func TestPlanRestoreKeepsExportCategoriesFromFullSelection(t *testing.T) { exportCat := Category{ID: "pve_config_export", ExportOnly: true} normalCat := Category{ID: "network"} - plan := PlanRestore(nil, []Category{normalCat, exportCat}, SystemTypePVE, RestoreModeFull) + plan := PlanRestore(false, []Category{normalCat, exportCat}, SystemTypePVE, RestoreModeFull) if len(plan.StagedCategories) != 1 || plan.StagedCategories[0].ID != "network" { t.Fatalf("expected staged categories to keep network, got %+v", plan.StagedCategories) } diff --git a/internal/orchestrator/restore_workflow_decision_test.go b/internal/orchestrator/restore_workflow_decision_test.go new file mode 100644 index 0000000..a710372 --- /dev/null +++ b/internal/orchestrator/restore_workflow_decision_test.go @@ -0,0 +1,237 @@ +package orchestrator + +import ( + "context" + "os" + "path/filepath" + "testing" + "time" + + "github.com/tis24dev/proxsave/internal/backup" + "github.com/tis24dev/proxsave/internal/config" + "github.com/tis24dev/proxsave/internal/logging" + "github.com/tis24dev/proxsave/internal/types" +) + +func stubPreparedRestoreBundle(archivePath string, manifest *backup.Manifest) func(context.Context, *config.Config, *logging.Logger, string, RestoreWorkflowUI) (*decryptCandidate, *preparedBundle, error) { + return func(ctx context.Context, cfg *config.Config, logger *logging.Logger, version string, ui RestoreWorkflowUI) (*decryptCandidate, *preparedBundle, error) { + return &decryptCandidate{ + DisplayBase: "test", + Manifest: manifest, + }, &preparedBundle{ + ArchivePath: archivePath, + Manifest: backup.Manifest{ArchivePath: archivePath}, + cleanup: func() {}, + }, nil + } +} + +func TestRunRestoreWorkflow_ClusterPromptUsesArchivePayloadNotManifest(t *testing.T) { + origRestoreFS := restoreFS + origRestoreCmd := restoreCmd + origRestoreSystem := restoreSystem + origCompatFS := compatFS + origPrepare := prepareRestoreBundleFunc + origSafetyFS := safetyFS + t.Cleanup(func() { + restoreFS = origRestoreFS + restoreCmd = origRestoreCmd + restoreSystem = origRestoreSystem + compatFS = origCompatFS + prepareRestoreBundleFunc = origPrepare + safetyFS = origSafetyFS + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + compatFS = fakeFS + safetyFS = fakeFS + restoreCmd = runOnlyRunner{} + restoreSystem = fakeSystemDetector{systemType: SystemTypePVE} + + if err := fakeFS.AddFile("/usr/bin/qm", []byte("x")); err != nil { + t.Fatalf("fakeFS.AddFile: %v", err) + } + + tmpTar := filepath.Join(t.TempDir(), "bundle.tar") + if err := writeTarFile(tmpTar, map[string]string{ + "etc/hosts": "127.0.0.1 localhost\n", + "var/lib/pve-cluster/config.db": "db\n", + }); err != nil { + t.Fatalf("writeTarFile: %v", err) + } + tarBytes, err := os.ReadFile(tmpTar) + if err != nil { + t.Fatalf("ReadFile tar: %v", err) + } + if err := fakeFS.WriteFile("/bundle.tar", tarBytes, 0o640); err != nil { + t.Fatalf("fakeFS.WriteFile: %v", err) + } + + prepareRestoreBundleFunc = stubPreparedRestoreBundle("/bundle.tar", &backup.Manifest{ + CreatedAt: time.Unix(1700000000, 0), + ClusterMode: "standalone", + ProxmoxType: "pve", + ScriptVersion: "vtest", + }) + + logger := logging.New(types.LogLevelError, false) + cfg := &config.Config{BaseDir: "/base"} + ui := &fakeRestoreWorkflowUI{ + mode: RestoreModeCustom, + categories: []Category{ + mustCategoryByID(t, "pve_cluster"), + }, + confirmRestore: true, + clusterMode: ClusterRestoreSafe, + } + + if err := runRestoreWorkflowWithUI(context.Background(), cfg, logger, "vtest", ui); err != nil { + t.Fatalf("runRestoreWorkflowWithUI error: %v", err) + } + if ui.clusterRestoreModeCalls != 1 { + t.Fatalf("clusterRestoreModeCalls=%d; want 1", ui.clusterRestoreModeCalls) + } +} + +func TestRunRestoreWorkflow_CompatibilityUsesArchivePayloadNotManifest(t *testing.T) { + origRestoreFS := restoreFS + origRestoreCmd := restoreCmd + origRestoreSystem := restoreSystem + origCompatFS := compatFS + origPrepare := prepareRestoreBundleFunc + origSafetyFS := safetyFS + t.Cleanup(func() { + restoreFS = origRestoreFS + restoreCmd = origRestoreCmd + restoreSystem = origRestoreSystem + compatFS = origCompatFS + prepareRestoreBundleFunc = origPrepare + safetyFS = origSafetyFS + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + compatFS = fakeFS + safetyFS = fakeFS + restoreCmd = runOnlyRunner{} + restoreSystem = fakeSystemDetector{systemType: SystemTypePVE} + + if err := fakeFS.AddFile("/usr/bin/qm", []byte("x")); err != nil { + t.Fatalf("fakeFS.AddFile: %v", err) + } + + tmpTar := filepath.Join(t.TempDir(), "bundle.tar") + if err := writeTarFile(tmpTar, map[string]string{ + "etc/ssh/sshd_config": "Port 22\n", + "etc/pve/jobs.cfg": "jobs\n", + }); err != nil { + t.Fatalf("writeTarFile: %v", err) + } + tarBytes, err := os.ReadFile(tmpTar) + if err != nil { + t.Fatalf("ReadFile tar: %v", err) + } + if err := fakeFS.WriteFile("/bundle.tar", tarBytes, 0o640); err != nil { + t.Fatalf("fakeFS.WriteFile: %v", err) + } + + prepareRestoreBundleFunc = stubPreparedRestoreBundle("/bundle.tar", &backup.Manifest{ + CreatedAt: time.Unix(1700000000, 0), + ClusterMode: "standalone", + ProxmoxType: "pbs", + ScriptVersion: "vtest", + }) + + logger := logging.New(types.LogLevelError, false) + cfg := &config.Config{BaseDir: "/base"} + ui := &fakeRestoreWorkflowUI{ + mode: RestoreModeCustom, + categories: []Category{ + mustCategoryByID(t, "ssh"), + }, + confirmRestore: true, + confirmCompatible: false, + } + + if err := runRestoreWorkflowWithUI(context.Background(), cfg, logger, "vtest", ui); err != nil { + t.Fatalf("runRestoreWorkflowWithUI error: %v", err) + } + if ui.confirmCompatibilityCalls != 0 { + t.Fatalf("confirmCompatibilityCalls=%d; want 0", ui.confirmCompatibilityCalls) + } +} + +func TestRunRestoreWorkflow_CompatibilityWarnsOnArchiveMismatchDespiteManifest(t *testing.T) { + origRestoreFS := restoreFS + origRestoreCmd := restoreCmd + origRestoreSystem := restoreSystem + origCompatFS := compatFS + origPrepare := prepareRestoreBundleFunc + origSafetyFS := safetyFS + t.Cleanup(func() { + restoreFS = origRestoreFS + restoreCmd = origRestoreCmd + restoreSystem = origRestoreSystem + compatFS = origCompatFS + prepareRestoreBundleFunc = origPrepare + safetyFS = origSafetyFS + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + compatFS = fakeFS + safetyFS = fakeFS + restoreCmd = runOnlyRunner{} + restoreSystem = fakeSystemDetector{systemType: SystemTypePVE} + + if err := fakeFS.AddFile("/usr/bin/qm", []byte("x")); err != nil { + t.Fatalf("fakeFS.AddFile: %v", err) + } + + tmpTar := filepath.Join(t.TempDir(), "bundle.tar") + if err := writeTarFile(tmpTar, map[string]string{ + "etc/ssh/sshd_config": "Port 22\n", + "etc/proxmox-backup/sync.cfg": "sync\n", + }); err != nil { + t.Fatalf("writeTarFile: %v", err) + } + tarBytes, err := os.ReadFile(tmpTar) + if err != nil { + t.Fatalf("ReadFile tar: %v", err) + } + if err := fakeFS.WriteFile("/bundle.tar", tarBytes, 0o640); err != nil { + t.Fatalf("fakeFS.WriteFile: %v", err) + } + + prepareRestoreBundleFunc = stubPreparedRestoreBundle("/bundle.tar", &backup.Manifest{ + CreatedAt: time.Unix(1700000000, 0), + ClusterMode: "standalone", + ProxmoxType: "pve", + ScriptVersion: "vtest", + }) + + logger := logging.New(types.LogLevelError, false) + cfg := &config.Config{BaseDir: "/base"} + ui := &fakeRestoreWorkflowUI{ + mode: RestoreModeCustom, + categories: []Category{ + mustCategoryByID(t, "ssh"), + }, + confirmRestore: true, + confirmCompatible: true, + } + + if err := runRestoreWorkflowWithUI(context.Background(), cfg, logger, "vtest", ui); err != nil { + t.Fatalf("runRestoreWorkflowWithUI error: %v", err) + } + if ui.confirmCompatibilityCalls != 1 { + t.Fatalf("confirmCompatibilityCalls=%d; want 1", ui.confirmCompatibilityCalls) + } + if ui.lastCompatibilityWarning == nil { + t.Fatalf("expected compatibility warning to be passed to UI") + } +} diff --git a/internal/orchestrator/restore_workflow_more_test.go b/internal/orchestrator/restore_workflow_more_test.go index 41fa569..e60eb56 100644 --- a/internal/orchestrator/restore_workflow_more_test.go +++ b/internal/orchestrator/restore_workflow_more_test.go @@ -321,7 +321,8 @@ func TestRunRestoreWorkflow_IncompatibilityAndSafetyBackupFailureCanContinue(t * tmpTar := filepath.Join(t.TempDir(), "bundle.tar") if err := writeTarFile(tmpTar, map[string]string{ - "etc/hosts": "127.0.0.1 localhost\n", + "etc/hosts": "127.0.0.1 localhost\n", + "etc/proxmox-backup/sync.cfg": "sync\n", }); err != nil { t.Fatalf("writeTarFile: %v", err) } diff --git a/internal/orchestrator/restore_workflow_ui.go b/internal/orchestrator/restore_workflow_ui.go index 9ce2057..45e9099 100644 --- a/internal/orchestrator/restore_workflow_ui.go +++ b/internal/orchestrator/restore_workflow_ui.go @@ -76,7 +76,17 @@ func runRestoreWorkflowWithUI(ctx context.Context, cfg *config.Config, logger *l systemType := restoreSystem.DetectCurrentSystem() logger.Info("Detected system type: %s", GetSystemTypeString(systemType)) - if warn := ValidateCompatibility(candidate.Manifest); warn != nil { + availableCategories, decisionInfo, err := AnalyzeRestoreArchive(prepared.ArchivePath, logger) + if err != nil { + logger.Warning("Could not analyze categories: %v", err) + logger.Info("Falling back to full restore mode") + return runFullRestoreWithUI(ctx, ui, candidate, prepared, destRoot, logger, cfg.DryRun) + } + if decisionInfo == nil { + decisionInfo = &RestoreDecisionInfo{} + } + + if warn := ValidateCompatibility(systemType, decisionInfo.BackupType); warn != nil { logger.Warning("Compatibility check: %v", warn) proceed, perr := ui.ConfirmCompatibility(ctx, warn) if perr != nil { @@ -87,14 +97,6 @@ func runRestoreWorkflowWithUI(ctx context.Context, cfg *config.Config, logger *l } } - logger.Info("Analyzing backup contents...") - availableCategories, err := AnalyzeBackupCategories(prepared.ArchivePath, logger) - if err != nil { - logger.Warning("Could not analyze categories: %v", err) - logger.Info("Falling back to full restore mode") - return runFullRestoreWithUI(ctx, ui, candidate, prepared, destRoot, logger, cfg.DryRun) - } - var ( mode RestoreMode selectedCategories []Category @@ -127,7 +129,7 @@ func runRestoreWorkflowWithUI(ctx context.Context, cfg *config.Config, logger *l } } - plan := PlanRestore(candidate.Manifest, selectedCategories, systemType, mode) + plan := PlanRestore(decisionInfo.ClusterPayload, selectedCategories, systemType, mode) if plan.SystemType == SystemTypePBS && (plan.HasCategoryID("pbs_host") || @@ -145,9 +147,8 @@ func runRestoreWorkflowWithUI(ctx context.Context, cfg *config.Config, logger *l logger.Info("PBS restore behavior: %s", behavior.DisplayName()) } - clusterBackup := strings.EqualFold(strings.TrimSpace(candidate.Manifest.ClusterMode), "cluster") - if plan.NeedsClusterRestore && clusterBackup { - logger.Info("Backup marked as cluster node; enabling guarded restore options for pve_cluster") + if plan.NeedsClusterRestore && plan.ClusterBackup { + logger.Info("Cluster payload detected in backup; enabling guarded restore options for pve_cluster") choice, promptErr := ui.SelectClusterRestoreMode(ctx) if promptErr != nil { return promptErr @@ -168,8 +169,8 @@ func runRestoreWorkflowWithUI(ctx context.Context, cfg *config.Config, logger *l if plan.HasCategoryID("pve_access_control") || plan.HasCategoryID("pbs_access_control") { currentHost, hostErr := os.Hostname() - if hostErr == nil && strings.TrimSpace(candidate.Manifest.Hostname) != "" && strings.TrimSpace(currentHost) != "" { - backupHost := strings.TrimSpace(candidate.Manifest.Hostname) + if hostErr == nil && strings.TrimSpace(decisionInfo.BackupHostname) != "" && strings.TrimSpace(currentHost) != "" { + backupHost := strings.TrimSpace(decisionInfo.BackupHostname) if !strings.EqualFold(strings.TrimSpace(currentHost), backupHost) { logger.Warning("Access control/TFA: backup hostname=%s current hostname=%s; WebAuthn users may require re-enrollment if the UI origin (FQDN/port) changes", backupHost, currentHost) } diff --git a/internal/orchestrator/restore_workflow_ui_helpers_test.go b/internal/orchestrator/restore_workflow_ui_helpers_test.go index f2a03724..ca1a8fe 100644 --- a/internal/orchestrator/restore_workflow_ui_helpers_test.go +++ b/internal/orchestrator/restore_workflow_ui_helpers_test.go @@ -35,6 +35,10 @@ type fakeRestoreWorkflowUI struct { confirmActionErr error repairNICNamesErr error networkCommitErr error + + confirmCompatibilityCalls int + clusterRestoreModeCalls int + lastCompatibilityWarning error } func (f *fakeRestoreWorkflowUI) RunTask(ctx context.Context, title, initialMessage string, run func(ctx context.Context, report ProgressReporter) error) error { @@ -84,10 +88,13 @@ func (f *fakeRestoreWorkflowUI) ConfirmRestore(ctx context.Context) (bool, error } func (f *fakeRestoreWorkflowUI) ConfirmCompatibility(ctx context.Context, warning error) (bool, error) { + f.confirmCompatibilityCalls++ + f.lastCompatibilityWarning = warning return f.confirmCompatible, f.confirmCompatibleErr } func (f *fakeRestoreWorkflowUI) SelectClusterRestoreMode(ctx context.Context) (ClusterRestoreMode, error) { + f.clusterRestoreModeCalls++ return f.clusterMode, f.clusterModeErr } diff --git a/internal/orchestrator/selective.go b/internal/orchestrator/selective.go index 4b05ea2..5b0b8af 100644 --- a/internal/orchestrator/selective.go +++ b/internal/orchestrator/selective.go @@ -27,40 +27,8 @@ type SelectiveRestoreConfig struct { // AnalyzeBackupCategories detects which categories are available in the backup func AnalyzeBackupCategories(archivePath string, logger *logging.Logger) (categories []Category, err error) { - done := logging.DebugStart(logger, "analyze backup categories", "archive=%s", archivePath) - defer func() { done(err) }() - logger.Info("Analyzing backup categories...") - - // Open the archive and read all entry names - file, err := restoreFS.Open(archivePath) - if err != nil { - return nil, fmt.Errorf("open archive: %w", err) - } - defer file.Close() - - // Create appropriate reader based on compression - reader, err := createDecompressionReader(context.Background(), file, archivePath) - if err != nil { - return nil, err - } - defer func() { - if closer, ok := reader.(interface{ Close() error }); ok { - closer.Close() - } - }() - - tarReader := tar.NewReader(reader) - - archivePaths := collectArchivePaths(tarReader) - logger.Debug("Found %d entries in archive", len(archivePaths)) - - availableCategories := AnalyzeArchivePaths(archivePaths, GetAllCategories()) - for _, cat := range availableCategories { - logger.Debug("Category available: %s (%s)", cat.ID, cat.Name) - } - - logger.Info("Detected %d available categories", len(availableCategories)) - return availableCategories, nil + availableCategories, _, err := AnalyzeRestoreArchive(archivePath, logger) + return availableCategories, err } // AnalyzeArchivePaths determines available categories from the provided archive entries. From 1f45f1b41c5a6f95a977acae886df9c946eaf664 Mon Sep 17 00:00:00 2001 From: tis24dev Date: Mon, 9 Mar 2026 15:26:44 +0100 Subject: [PATCH 08/18] Enforce staged backup checksum verification in restore/decrypt flows Require checksum verification for staged backup archives before any restore or decrypt operation proceeds. Add strict SHA256 parsing and normalization, enforce checksum source agreement between sidecar and manifest metadata, reject raw backup candidates with no verifiable integrity source, centralize shared prepare-path verification for bundle and raw inputs, preserve source artifact checksum separately from the prepared plain archive checksum, and update regression tests to cover missing, mismatched, and legacy checksum scenarios. --- internal/backup/checksum.go | 38 ++++- internal/backup/checksum_legacy_test.go | 7 +- internal/backup/checksum_test.go | 5 +- internal/orchestrator/backup_sources.go | 18 ++- internal/orchestrator/backup_sources_test.go | 15 +- internal/orchestrator/decrypt.go | 9 +- internal/orchestrator/decrypt_integrity.go | 78 ++++++++++ .../orchestrator/decrypt_integrity_test.go | 131 +++++++++++++++++ .../decrypt_integrity_test_helpers_test.go | 15 ++ .../orchestrator/decrypt_prepare_common.go | 133 ++++++++++++++++++ internal/orchestrator/decrypt_test.go | 30 ++-- internal/orchestrator/decrypt_tui.go | 106 +------------- internal/orchestrator/decrypt_tui_test.go | 2 +- .../orchestrator/decrypt_workflow_test.go | 15 +- internal/orchestrator/decrypt_workflow_ui.go | 110 +-------------- 15 files changed, 454 insertions(+), 258 deletions(-) create mode 100644 internal/orchestrator/decrypt_integrity.go create mode 100644 internal/orchestrator/decrypt_integrity_test.go create mode 100644 internal/orchestrator/decrypt_integrity_test_helpers_test.go create mode 100644 internal/orchestrator/decrypt_prepare_common.go diff --git a/internal/backup/checksum.go b/internal/backup/checksum.go index fc917ad..0e08b39 100644 --- a/internal/backup/checksum.go +++ b/internal/backup/checksum.go @@ -35,6 +35,30 @@ type Manifest struct { ClusterMode string `json:"cluster_mode,omitempty"` } +// NormalizeChecksum validates and normalizes a SHA256 checksum string. +func NormalizeChecksum(value string) (string, error) { + checksum := strings.ToLower(strings.TrimSpace(value)) + if checksum == "" { + return "", fmt.Errorf("checksum is empty") + } + if len(checksum) != sha256.Size*2 { + return "", fmt.Errorf("checksum must be %d hex characters, got %d", sha256.Size*2, len(checksum)) + } + if _, err := hex.DecodeString(checksum); err != nil { + return "", fmt.Errorf("checksum is not valid hex: %w", err) + } + return checksum, nil +} + +// ParseChecksumData extracts a SHA256 checksum from checksum file contents. +func ParseChecksumData(data []byte) (string, error) { + fields := strings.Fields(string(data)) + if len(fields) == 0 { + return "", fmt.Errorf("checksum file is empty") + } + return NormalizeChecksum(fields[0]) +} + // GenerateChecksum calculates SHA256 checksum of a file func GenerateChecksum(ctx context.Context, logger *logging.Logger, filePath string) (string, error) { logger.Debug("Generating SHA256 checksum for: %s", filePath) @@ -105,16 +129,21 @@ func CreateManifest(ctx context.Context, logger *logging.Logger, manifest *Manif func VerifyChecksum(ctx context.Context, logger *logging.Logger, filePath, expectedChecksum string) (bool, error) { logger.Debug("Verifying checksum for: %s", filePath) + normalizedExpected, err := NormalizeChecksum(expectedChecksum) + if err != nil { + return false, fmt.Errorf("invalid expected checksum: %w", err) + } + actualChecksum, err := GenerateChecksum(ctx, logger, filePath) if err != nil { return false, fmt.Errorf("failed to generate checksum: %w", err) } - matches := actualChecksum == expectedChecksum + matches := actualChecksum == normalizedExpected if matches { logger.Debug("Checksum verification passed") } else { - logger.Warning("Checksum mismatch! Expected: %s, Got: %s", expectedChecksum, actualChecksum) + logger.Warning("Checksum mismatch! Expected: %s, Got: %s", normalizedExpected, actualChecksum) } return matches, nil @@ -205,9 +234,8 @@ func parseLegacyMetadata(scanner *bufio.Scanner, legacy *Manifest) { func loadLegacyChecksum(archivePath string, legacy *Manifest) { // Attempt to load checksum from legacy .sha256 file if shaData, err := os.ReadFile(archivePath + ".sha256"); err == nil { - fields := strings.Fields(string(shaData)) - if len(fields) > 0 { - legacy.SHA256 = fields[0] + if checksum, parseErr := ParseChecksumData(shaData); parseErr == nil { + legacy.SHA256 = checksum } } } diff --git a/internal/backup/checksum_legacy_test.go b/internal/backup/checksum_legacy_test.go index 1013ae8..c6738cf 100644 --- a/internal/backup/checksum_legacy_test.go +++ b/internal/backup/checksum_legacy_test.go @@ -28,7 +28,8 @@ func TestLoadLegacyManifestWithShaAndFallbackEncryption(t *testing.T) { t.Fatalf("write metadata: %v", err) } - shaLine := "deadbeef " + filepath.Base(archive) + "\n" + expectedSHA := strings.Repeat("a", 64) + shaLine := expectedSHA + " " + filepath.Base(archive) + "\n" if err := os.WriteFile(archive+".sha256", []byte(shaLine), 0o640); err != nil { t.Fatalf("write sha256: %v", err) } @@ -50,8 +51,8 @@ func TestLoadLegacyManifestWithShaAndFallbackEncryption(t *testing.T) { if m.EncryptionMode != "plain" { t.Fatalf("expected fallback encryption mode plain, got %s", m.EncryptionMode) } - if m.SHA256 != "deadbeef" { - t.Fatalf("expected sha256 deadbeef, got %s", m.SHA256) + if m.SHA256 != expectedSHA { + t.Fatalf("expected sha256 %s, got %s", expectedSHA, m.SHA256) } if time.Since(m.CreatedAt) > time.Minute { t.Fatalf("unexpected CreatedAt too old: %v", m.CreatedAt) diff --git a/internal/backup/checksum_test.go b/internal/backup/checksum_test.go index bc12bd0..676506f 100644 --- a/internal/backup/checksum_test.go +++ b/internal/backup/checksum_test.go @@ -145,8 +145,9 @@ ENCRYPTION_MODE=age if err := os.WriteFile(metadataPath, []byte(metadata), 0644); err != nil { t.Fatalf("failed to write metadata: %v", err) } + expectedSHA := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" shaPath := archivePath + ".sha256" - if err := os.WriteFile(shaPath, []byte("deadbeef "+filepath.Base(archivePath)), 0644); err != nil { + if err := os.WriteFile(shaPath, []byte(expectedSHA+" "+filepath.Base(archivePath)), 0644); err != nil { t.Fatalf("failed to write sha file: %v", err) } @@ -163,7 +164,7 @@ ENCRYPTION_MODE=age if manifest.Hostname != "legacy-host" || manifest.ScriptVersion != "legacy-1.0" { t.Fatalf("legacy metadata not parsed correctly: %+v", manifest) } - if manifest.SHA256 != "deadbeef" { + if manifest.SHA256 != expectedSHA { t.Fatalf("expected SHA256 from sidecar, got %q", manifest.SHA256) } if manifest.EncryptionMode != "age" { diff --git a/internal/orchestrator/backup_sources.go b/internal/orchestrator/backup_sources.go index aac895e..5ba7ac9 100644 --- a/internal/orchestrator/backup_sources.go +++ b/internal/orchestrator/backup_sources.go @@ -130,6 +130,7 @@ func discoverRcloneBackups(ctx context.Context, cfg *config.Config, remotePath s emptyEntries := 0 nonCandidateEntries := 0 manifestErrors := 0 + integrityMissing := 0 logDebug(logger, "Cloud (rclone): scanned %d entries from rclone lsf output", totalEntries) snapshot := make(map[string]struct{}, len(lines)) @@ -259,6 +260,11 @@ func discoverRcloneBackups(ctx context.Context, cfg *config.Config, remotePath s if strings.TrimSpace(displayBase) == "" { displayBase = filepath.Base(baseNameFromRemoteRef(item.remoteArchive)) } + if item.remoteChecksum == "" && strings.TrimSpace(manifest.SHA256) == "" { + integrityMissing++ + logWarning(logger, "Skipping rclone backup %s: no checksum verification available", item.filename) + continue + } candidates = append(candidates, &decryptCandidate{ Manifest: manifest, Source: sourceRaw, @@ -292,11 +298,12 @@ func discoverRcloneBackups(ctx context.Context, cfg *config.Config, remotePath s logging.DebugStep( logger, "discover rclone backups", - "summary entries=%d empty=%d non_candidate=%d manifest_errors=%d accepted=%d elapsed=%s", + "summary entries=%d empty=%d non_candidate=%d manifest_errors=%d integrity_missing=%d accepted=%d elapsed=%s", totalEntries, emptyEntries, nonCandidateEntries, manifestErrors, + integrityMissing, len(candidates), time.Since(start), ) @@ -336,6 +343,7 @@ func discoverBackupCandidates(logger *logging.Logger, root string) (candidates [ metadataMissingArchive := 0 metadataManifestErrors := 0 checksumMissing := 0 + integrityUnavailable := 0 for _, entry := range entries { if entry.IsDir() { @@ -394,9 +402,10 @@ func discoverBackupCandidates(logger *logging.Logger, root string) (candidates [ } logging.DebugStep(logger, "discover backup candidates", "raw candidate accepted: %s created_at=%s", name, manifest.CreatedAt.Format(time.RFC3339)) - // If checksum is missing from both file and manifest, warn user if !hasChecksum && manifest.SHA256 == "" { - logWarning(logger, "Backup %s has no checksum verification available", baseName) + integrityUnavailable++ + logWarning(logger, "Skipping backup %s: no checksum verification available", baseName) + continue } rawBases[baseName] = struct{}{} @@ -418,7 +427,7 @@ func discoverBackupCandidates(logger *logging.Logger, root string) (candidates [ logging.DebugStep( logger, "discover backup candidates", - "summary entries=%d files=%d dirs=%d bundles=%d bundle_manifest_errors=%d metadata=%d metadata_duplicate=%d metadata_missing_archive=%d metadata_manifest_errors=%d checksum_missing=%d candidates=%d", + "summary entries=%d files=%d dirs=%d bundles=%d bundle_manifest_errors=%d metadata=%d metadata_duplicate=%d metadata_missing_archive=%d metadata_manifest_errors=%d checksum_missing=%d integrity_unavailable=%d candidates=%d", len(entries), filesSeen, dirsSkipped, @@ -429,6 +438,7 @@ func discoverBackupCandidates(logger *logging.Logger, root string) (candidates [ metadataMissingArchive, metadataManifestErrors, checksumMissing, + integrityUnavailable, len(candidates), ) return candidates, nil diff --git a/internal/orchestrator/backup_sources_test.go b/internal/orchestrator/backup_sources_test.go index 30b9f0d..f943d10 100644 --- a/internal/orchestrator/backup_sources_test.go +++ b/internal/orchestrator/backup_sources_test.go @@ -235,6 +235,7 @@ func TestDiscoverRcloneBackups_IncludesRawMetadata(t *testing.T) { ProxmoxVersion: "8.1", CreatedAt: time.Date(2025, 12, 5, 12, 0, 0, 0, time.UTC), EncryptionMode: "none", + SHA256: checksumHexForBytes([]byte("node-backup-20251205")), } metaBytes, err := json.Marshal(&manifest) if err != nil { @@ -317,6 +318,7 @@ func TestDiscoverRcloneBackups_MixedCandidatesSortedByCreatedAt(t *testing.T) { CreatedAt: time.Date(2025, 1, 3, 0, 0, 0, 0, time.UTC), EncryptionMode: "none", ProxmoxType: "pve", + SHA256: checksumHexForBytes([]byte("x")), } rawNewestData, _ := json.Marshal(&rawNewest) if err := os.WriteFile(rawNewestMeta, rawNewestData, 0o600); err != nil { @@ -354,6 +356,7 @@ func TestDiscoverRcloneBackups_MixedCandidatesSortedByCreatedAt(t *testing.T) { CreatedAt: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC), EncryptionMode: "none", ProxmoxType: "pve", + SHA256: checksumHexForBytes([]byte("x")), } rawOldData, _ := json.Marshal(&rawOld) if err := os.WriteFile(rawOldMeta, rawOldData, 0o600); err != nil { @@ -629,7 +632,7 @@ esac return manifest, cleanup } -func TestDiscoverBackupCandidates_NoLoggerStillCollectsRawArtifacts(t *testing.T) { +func TestDiscoverBackupCandidates_NoLoggerSkipsRawArtifactsWithoutChecksumVerification(t *testing.T) { tmpDir := t.TempDir() archivePath := filepath.Join(tmpDir, "config.tar.xz") if err := os.WriteFile(archivePath, []byte("dummy"), 0o600); err != nil { @@ -656,13 +659,7 @@ func TestDiscoverBackupCandidates_NoLoggerStillCollectsRawArtifacts(t *testing.T if err != nil { t.Fatalf("discoverBackupCandidates() error = %v", err) } - if len(candidates) != 1 { - t.Fatalf("discoverBackupCandidates() returned %d candidates; want 1", len(candidates)) - } - if candidates[0].RawArchivePath != archivePath { - t.Fatalf("RawArchivePath = %q; want %q", candidates[0].RawArchivePath, archivePath) - } - if candidates[0].RawChecksumPath != "" { - t.Fatalf("RawChecksumPath should be empty when checksum missing; got %q", candidates[0].RawChecksumPath) + if len(candidates) != 0 { + t.Fatalf("discoverBackupCandidates() returned %d candidates; want 0", len(candidates)) } } diff --git a/internal/orchestrator/decrypt.go b/internal/orchestrator/decrypt.go index c995530..2d09922 100644 --- a/internal/orchestrator/decrypt.go +++ b/internal/orchestrator/decrypt.go @@ -54,10 +54,11 @@ type stagedFiles struct { } type preparedBundle struct { - ArchivePath string - Manifest backup.Manifest - Checksum string - cleanup func() + ArchivePath string + Manifest backup.Manifest + Checksum string + SourceChecksum string + cleanup func() } func (p *preparedBundle) Cleanup() { diff --git a/internal/orchestrator/decrypt_integrity.go b/internal/orchestrator/decrypt_integrity.go new file mode 100644 index 0000000..87869c8 --- /dev/null +++ b/internal/orchestrator/decrypt_integrity.go @@ -0,0 +1,78 @@ +package orchestrator + +import ( + "context" + "fmt" + "strings" + + "github.com/tis24dev/proxsave/internal/backup" + "github.com/tis24dev/proxsave/internal/logging" +) + +type stagedIntegrityExpectation struct { + Checksum string + Source string +} + +func resolveStagedIntegrityExpectation(staged stagedFiles, manifest *backup.Manifest) (*stagedIntegrityExpectation, error) { + var ( + checksumFromFile string + checksumFromManifest string + ) + + if strings.TrimSpace(staged.ChecksumPath) != "" { + data, err := restoreFS.ReadFile(staged.ChecksumPath) + if err != nil { + return nil, fmt.Errorf("read checksum file: %w", err) + } + checksumFromFile, err = backup.ParseChecksumData(data) + if err != nil { + return nil, fmt.Errorf("parse checksum file %s: %w", staged.ChecksumPath, err) + } + } + + if manifest != nil && strings.TrimSpace(manifest.SHA256) != "" { + normalized, err := backup.NormalizeChecksum(manifest.SHA256) + if err != nil { + return nil, fmt.Errorf("parse manifest checksum: %w", err) + } + checksumFromManifest = normalized + } + + switch { + case checksumFromFile != "" && checksumFromManifest != "" && checksumFromFile != checksumFromManifest: + return nil, fmt.Errorf("checksum mismatch between checksum file and manifest") + case checksumFromFile != "" && checksumFromManifest != "": + return &stagedIntegrityExpectation{Checksum: checksumFromFile, Source: "checksum file and manifest"}, nil + case checksumFromFile != "": + return &stagedIntegrityExpectation{Checksum: checksumFromFile, Source: "checksum file"}, nil + case checksumFromManifest != "": + return &stagedIntegrityExpectation{Checksum: checksumFromManifest, Source: "manifest"}, nil + default: + return nil, fmt.Errorf("backup has no checksum verification available") + } +} + +func verifyStagedArchiveIntegrity(ctx context.Context, logger *logging.Logger, staged stagedFiles, manifest *backup.Manifest) (string, error) { + if staged.ArchivePath == "" { + return "", fmt.Errorf("staged archive path is empty") + } + if logger == nil { + logger = logging.GetDefaultLogger() + } + + expectation, err := resolveStagedIntegrityExpectation(staged, manifest) + if err != nil { + return "", err + } + + logger.Info("Verifying staged archive integrity using %s", expectation.Source) + ok, err := backup.VerifyChecksum(ctx, logger, staged.ArchivePath, expectation.Checksum) + if err != nil { + return "", fmt.Errorf("verify staged archive: %w", err) + } + if !ok { + return "", fmt.Errorf("staged archive checksum mismatch") + } + return expectation.Checksum, nil +} diff --git a/internal/orchestrator/decrypt_integrity_test.go b/internal/orchestrator/decrypt_integrity_test.go new file mode 100644 index 0000000..1fcfa23 --- /dev/null +++ b/internal/orchestrator/decrypt_integrity_test.go @@ -0,0 +1,131 @@ +package orchestrator + +import ( + "bufio" + "context" + "encoding/json" + "os" + "strings" + "testing" + "time" + + "github.com/tis24dev/proxsave/internal/backup" + "github.com/tis24dev/proxsave/internal/logging" + "github.com/tis24dev/proxsave/internal/types" +) + +func TestResolveStagedIntegrityExpectation_RejectsConflictingSources(t *testing.T) { + origFS := restoreFS + restoreFS = osFS{} + t.Cleanup(func() { restoreFS = origFS }) + + dir := t.TempDir() + checksumPath := dir + "/backup.tar.sha256" + if err := os.WriteFile(checksumPath, checksumLineForBytes("backup.tar", []byte("archive")), 0o640); err != nil { + t.Fatalf("write checksum: %v", err) + } + + _, err := resolveStagedIntegrityExpectation(stagedFiles{ChecksumPath: checksumPath}, &backup.Manifest{ + SHA256: checksumHexForBytes([]byte("different")), + }) + if err == nil { + t.Fatalf("expected conflict error") + } + if !strings.Contains(err.Error(), "checksum mismatch") { + t.Fatalf("expected checksum mismatch error, got %v", err) + } +} + +func TestPreparePlainBundle_RejectsMissingChecksumVerification(t *testing.T) { + origFS := restoreFS + restoreFS = osFS{} + t.Cleanup(func() { restoreFS = origFS }) + + dir := t.TempDir() + archivePath := dir + "/backup.tar" + if err := os.WriteFile(archivePath, []byte("archive"), 0o640); err != nil { + t.Fatalf("write archive: %v", err) + } + metaPath := archivePath + ".metadata" + manifest := &backup.Manifest{ + ArchivePath: archivePath, + CreatedAt: time.Now(), + EncryptionMode: "none", + } + data, err := json.Marshal(manifest) + if err != nil { + t.Fatalf("marshal metadata: %v", err) + } + if err := os.WriteFile(metaPath, data, 0o640); err != nil { + t.Fatalf("write metadata: %v", err) + } + + cand := &decryptCandidate{ + Manifest: manifest, + Source: sourceRaw, + RawArchivePath: archivePath, + RawMetadataPath: metaPath, + DisplayBase: "backup.tar", + } + + reader := bufio.NewReader(strings.NewReader("")) + logger := logging.New(types.LogLevelError, false) + + _, err = preparePlainBundle(context.Background(), reader, cand, "", logger) + if err == nil { + t.Fatalf("expected missing checksum verification error") + } + if !strings.Contains(err.Error(), "no checksum verification available") { + t.Fatalf("expected checksum availability error, got %v", err) + } +} + +func TestPreparePlainBundle_RejectsChecksumMismatch(t *testing.T) { + origFS := restoreFS + restoreFS = osFS{} + t.Cleanup(func() { restoreFS = origFS }) + + dir := t.TempDir() + archiveData := []byte("archive") + archivePath := dir + "/backup.tar" + if err := os.WriteFile(archivePath, archiveData, 0o640); err != nil { + t.Fatalf("write archive: %v", err) + } + metaPath := archivePath + ".metadata" + manifest := &backup.Manifest{ + ArchivePath: archivePath, + CreatedAt: time.Now(), + EncryptionMode: "none", + } + data, err := json.Marshal(manifest) + if err != nil { + t.Fatalf("marshal metadata: %v", err) + } + if err := os.WriteFile(metaPath, data, 0o640); err != nil { + t.Fatalf("write metadata: %v", err) + } + checksumPath := archivePath + ".sha256" + if err := os.WriteFile(checksumPath, checksumLineForBytes("backup.tar", []byte("tampered")), 0o640); err != nil { + t.Fatalf("write checksum: %v", err) + } + + cand := &decryptCandidate{ + Manifest: manifest, + Source: sourceRaw, + RawArchivePath: archivePath, + RawMetadataPath: metaPath, + RawChecksumPath: checksumPath, + DisplayBase: "backup.tar", + } + + reader := bufio.NewReader(strings.NewReader("")) + logger := logging.New(types.LogLevelError, false) + + _, err = preparePlainBundle(context.Background(), reader, cand, "", logger) + if err == nil { + t.Fatalf("expected checksum mismatch error") + } + if !strings.Contains(err.Error(), "checksum mismatch") { + t.Fatalf("expected checksum mismatch error, got %v", err) + } +} diff --git a/internal/orchestrator/decrypt_integrity_test_helpers_test.go b/internal/orchestrator/decrypt_integrity_test_helpers_test.go new file mode 100644 index 0000000..f8f022a --- /dev/null +++ b/internal/orchestrator/decrypt_integrity_test_helpers_test.go @@ -0,0 +1,15 @@ +package orchestrator + +import ( + "crypto/sha256" + "fmt" +) + +func checksumHexForBytes(data []byte) string { + sum := sha256.Sum256(data) + return fmt.Sprintf("%x", sum) +} + +func checksumLineForBytes(filename string, data []byte) []byte { + return []byte(fmt.Sprintf("%s %s", checksumHexForBytes(data), filename)) +} diff --git a/internal/orchestrator/decrypt_prepare_common.go b/internal/orchestrator/decrypt_prepare_common.go new file mode 100644 index 0000000..0b21436 --- /dev/null +++ b/internal/orchestrator/decrypt_prepare_common.go @@ -0,0 +1,133 @@ +package orchestrator + +import ( + "context" + "fmt" + "path/filepath" + "strings" + + "github.com/tis24dev/proxsave/internal/backup" + "github.com/tis24dev/proxsave/internal/logging" +) + +type archiveDecryptFunc func(ctx context.Context, encryptedPath, outputPath, displayName string) error + +func preparePlainBundleCommon(ctx context.Context, cand *decryptCandidate, version string, logger *logging.Logger, decryptArchive archiveDecryptFunc) (bundle *preparedBundle, err error) { + if cand == nil || cand.Manifest == nil { + return nil, fmt.Errorf("invalid backup candidate") + } + if logger == nil { + logger = logging.GetDefaultLogger() + } + + var rcloneCleanup func() + if cand.IsRclone && cand.Source == sourceBundle { + logger.Debug("Detected rclone backup, downloading...") + localPath, cleanup, err := downloadRcloneBackup(ctx, cand.BundlePath, logger) + if err != nil { + return nil, fmt.Errorf("failed to download rclone backup: %w", err) + } + rcloneCleanup = cleanup + cand.BundlePath = localPath + } + + tempRoot := filepath.Join("/tmp", "proxsave") + if err := restoreFS.MkdirAll(tempRoot, 0o755); err != nil { + if rcloneCleanup != nil { + rcloneCleanup() + } + return nil, fmt.Errorf("create temp root: %w", err) + } + + workDir, err := restoreFS.MkdirTemp(tempRoot, "proxmox-decrypt-*") + if err != nil { + if rcloneCleanup != nil { + rcloneCleanup() + } + return nil, fmt.Errorf("create temp dir: %w", err) + } + + cleanup := func() { + _ = restoreFS.RemoveAll(workDir) + if rcloneCleanup != nil { + rcloneCleanup() + } + } + + var staged stagedFiles + switch cand.Source { + case sourceBundle: + logger.Info("Extracting bundle %s", filepath.Base(cand.BundlePath)) + staged, err = extractBundleToWorkdirWithLogger(cand.BundlePath, workDir, logger) + case sourceRaw: + logger.Info("Staging raw artifacts for %s", filepath.Base(cand.RawArchivePath)) + staged, err = copyRawArtifactsToWorkdirWithLogger(ctx, cand, workDir, logger) + default: + err = fmt.Errorf("unsupported candidate source") + } + if err != nil { + cleanup() + return nil, err + } + + sourceChecksum, err := verifyStagedArchiveIntegrity(ctx, logger, staged, cand.Manifest) + if err != nil { + cleanup() + return nil, err + } + + manifestCopy := *cand.Manifest + currentEncryption := strings.ToLower(manifestCopy.EncryptionMode) + logger.Info("Preparing archive %s for decryption (mode: %s)", manifestCopy.ArchivePath, statusFromManifest(&manifestCopy)) + + plainArchiveName := strings.TrimSuffix(filepath.Base(staged.ArchivePath), ".age") + plainArchivePath := filepath.Join(workDir, plainArchiveName) + + if currentEncryption == "age" { + if decryptArchive == nil { + cleanup() + return nil, fmt.Errorf("decrypt function not available") + } + displayName := cand.DisplayBase + if strings.TrimSpace(displayName) == "" { + displayName = filepath.Base(manifestCopy.ArchivePath) + } + if err := decryptArchive(ctx, staged.ArchivePath, plainArchivePath, displayName); err != nil { + cleanup() + return nil, err + } + } else if staged.ArchivePath != plainArchivePath { + if err := copyFile(restoreFS, staged.ArchivePath, plainArchivePath); err != nil { + cleanup() + return nil, fmt.Errorf("copy archive: %w", err) + } + } + + archiveInfo, err := restoreFS.Stat(plainArchivePath) + if err != nil { + cleanup() + return nil, fmt.Errorf("stat decrypted archive: %w", err) + } + + plainChecksum, err := backup.GenerateChecksum(ctx, logger, plainArchivePath) + if err != nil { + cleanup() + return nil, fmt.Errorf("generate checksum: %w", err) + } + + manifestCopy.ArchivePath = plainArchivePath + manifestCopy.ArchiveSize = archiveInfo.Size() + manifestCopy.SHA256 = plainChecksum + manifestCopy.EncryptionMode = "none" + if version != "" { + manifestCopy.ScriptVersion = version + } + + return &preparedBundle{ + ArchivePath: plainArchivePath, + Manifest: manifestCopy, + Checksum: plainChecksum, + SourceChecksum: sourceChecksum, + cleanup: cleanup, + }, nil +} diff --git a/internal/orchestrator/decrypt_test.go b/internal/orchestrator/decrypt_test.go index 84fcee1..dabad3e 100644 --- a/internal/orchestrator/decrypt_test.go +++ b/internal/orchestrator/decrypt_test.go @@ -1575,14 +1575,15 @@ func TestPreparePlainBundle_SourceBundleSuccess(t *testing.T) { dir := t.TempDir() // Create bundle with required files + archiveData := []byte("archive data") manifestData, _ := json.Marshal(&backup.Manifest{ ArchivePath: filepath.Join(dir, "archive.tar.xz"), EncryptionMode: "none", }) bundlePath := createTestBundle(t, []bundleEntry{ - {name: "archive.tar.xz", data: []byte("archive data")}, + {name: "archive.tar.xz", data: archiveData}, {name: "backup.metadata", data: manifestData}, - {name: "backup.sha256", data: []byte("abc123 archive.tar.xz")}, + {name: "backup.sha256", data: checksumLineForBytes("archive.tar.xz", archiveData)}, }) cand := &decryptCandidate{ @@ -2819,7 +2820,7 @@ func TestPreparePlainBundle_CopyFileSamePath(t *testing.T) { t.Fatalf("write metadata: %v", err) } checksumPath := archivePath + ".sha256" - if err := os.WriteFile(checksumPath, []byte("abc123 backup.tar.xz"), 0o644); err != nil { + if err := os.WriteFile(checksumPath, checksumLineForBytes("backup.tar.xz", []byte("archive content")), 0o644); err != nil { t.Fatalf("write checksum: %v", err) } @@ -2890,7 +2891,7 @@ func TestPreparePlainBundle_AgeDecryptionWithRclone(t *testing.T) { tw.Write(manifestData) // Add checksum - checksumData := []byte("abc123 backup.tar.xz.age") + checksumData := checksumLineForBytes("backup.tar.xz.age", archiveContent) tw.WriteHeader(&tar.Header{Name: "backup.sha256", Size: int64(len(checksumData)), Mode: 0o600}) tw.Write(checksumData) @@ -3017,7 +3018,7 @@ func TestPreparePlainBundle_SourceBundleAdditional(t *testing.T) { tw.WriteHeader(&tar.Header{Name: "backup.metadata", Size: int64(len(manifestData)), Mode: 0o600}) tw.Write(manifestData) - checksumData := []byte("abc123 backup.tar.xz") + checksumData := checksumLineForBytes("backup.tar.xz", archiveData) tw.WriteHeader(&tar.Header{Name: "backup.sha256", Size: int64(len(checksumData)), Mode: 0o600}) tw.Write(checksumData) @@ -3393,7 +3394,7 @@ func TestExtractBundleToWorkdir_OpenFileErrorOnExtract(t *testing.T) { } // Add checksum - checksum := []byte("checksum backup.tar.xz\n") + checksum := checksumLineForBytes("backup.tar.xz", archiveData) if err := tw.WriteHeader(&tar.Header{Name: "backup.sha256", Size: int64(len(checksum)), Mode: 0o640}); err != nil { t.Fatalf("write checksum header: %v", err) } @@ -3608,7 +3609,7 @@ func TestSelectDecryptCandidate_RequireEncryptedAllPlain(t *testing.T) { tw.Write(metaJSON) // Add checksum - checksum := []byte("abc123 backup.tar.xz\n") + checksum := checksumLineForBytes("backup.tar.xz", archiveData) tw.WriteHeader(&tar.Header{Name: "backup.sha256", Size: int64(len(checksum)), Mode: 0o640}) tw.Write(checksum) tw.Close() @@ -3720,7 +3721,7 @@ exit 1 tw.WriteHeader(&tar.Header{Name: "backup.metadata", Size: int64(len(metaJSON)), Mode: 0o640}) tw.Write(metaJSON) - checksum := []byte("abc123 backup.tar.xz\n") + checksum := checksumLineForBytes("backup.tar.xz", archiveData) tw.WriteHeader(&tar.Header{Name: "backup.sha256", Size: int64(len(checksum)), Mode: 0o640}) tw.Write(checksum) tw.Close() @@ -3771,7 +3772,7 @@ func TestPreparePlainBundle_StatErrorAfterExtract(t *testing.T) { tw.WriteHeader(&tar.Header{Name: "backup.metadata", Size: int64(len(metaJSON)), Mode: 0o640}) tw.Write(metaJSON) - checksum := []byte("abc123 backup.tar.xz\n") + checksum := checksumLineForBytes("backup.tar.xz", archiveData) tw.WriteHeader(&tar.Header{Name: "backup.sha256", Size: int64(len(checksum)), Mode: 0o640}) tw.Write(checksum) tw.Close() @@ -3869,8 +3870,9 @@ func TestPreparePlainBundle_MkdirTempErrorWithRcloneCleanup(t *testing.T) { metaJSON, _ := json.Marshal(backup.Manifest{EncryptionMode: "none", ArchivePath: "backup.tar.xz"}) tw.WriteHeader(&tar.Header{Name: "backup.metadata", Size: int64(len(metaJSON)), Mode: 0o640}) tw.Write(metaJSON) - tw.WriteHeader(&tar.Header{Name: "backup.sha256", Size: 5, Mode: 0o640}) - tw.Write([]byte("hash\n")) + checksum := checksumLineForBytes("backup.tar.xz", archiveData) + tw.WriteHeader(&tar.Header{Name: "backup.sha256", Size: int64(len(checksum)), Mode: 0o640}) + tw.Write(checksum) tw.Close() bundleFile.Close() @@ -4066,7 +4068,7 @@ func TestPreparePlainBundle_CopyFileError(t *testing.T) { tw.WriteHeader(&tar.Header{Name: "backup.metadata", Size: int64(len(metaJSON)), Mode: 0o640}) tw.Write(metaJSON) - checksum := []byte("abc123 backup.tar.xz\n") + checksum := checksumLineForBytes("backup.tar.xz", archiveData) tw.WriteHeader(&tar.Header{Name: "backup.sha256", Size: int64(len(checksum)), Mode: 0o640}) tw.Write(checksum) tw.Close() @@ -4175,7 +4177,7 @@ func TestPreparePlainBundle_StatErrorOnPlainArchive(t *testing.T) { tw.WriteHeader(&tar.Header{Name: "backup.metadata", Size: int64(len(metaJSON)), Mode: 0o640}) tw.Write(metaJSON) - checksum := []byte("abc123 backup.tar.xz\n") + checksum := checksumLineForBytes("backup.tar.xz", archiveData) tw.WriteHeader(&tar.Header{Name: "backup.sha256", Size: int64(len(checksum)), Mode: 0o640}) tw.Write(checksum) tw.Close() @@ -4311,7 +4313,7 @@ func TestPreparePlainBundle_GenerateChecksumErrorPath(t *testing.T) { tw.WriteHeader(&tar.Header{Name: "backup.metadata", Size: int64(len(metaJSON)), Mode: 0o640}) tw.Write(metaJSON) - checksum := []byte("abc123 backup.tar.xz\n") + checksum := checksumLineForBytes("backup.tar.xz", archiveData) tw.WriteHeader(&tar.Header{Name: "backup.sha256", Size: int64(len(checksum)), Mode: 0o640}) tw.Write(checksum) tw.Close() diff --git a/internal/orchestrator/decrypt_tui.go b/internal/orchestrator/decrypt_tui.go index 655f7f5..7602d8a 100644 --- a/internal/orchestrator/decrypt_tui.go +++ b/internal/orchestrator/decrypt_tui.go @@ -233,109 +233,9 @@ func promptNewPathInput(defaultPath, configPath, buildSig string) (string, error } func preparePlainBundleTUI(ctx context.Context, cand *decryptCandidate, version string, logger *logging.Logger, configPath, buildSig string) (*preparedBundle, error) { - if cand == nil || cand.Manifest == nil { - return nil, fmt.Errorf("invalid backup candidate") - } - - // If this is an rclone-backed bundle, download it first into the local temp area. - var rcloneCleanup func() - if cand.IsRclone && cand.Source == sourceBundle { - logger.Debug("Detected rclone backup, downloading for TUI workflow...") - localPath, cleanupFn, err := downloadRcloneBackup(ctx, cand.BundlePath, logger) - if err != nil { - return nil, fmt.Errorf("failed to download rclone backup: %w", err) - } - rcloneCleanup = cleanupFn - cand.BundlePath = localPath - } - - tempRoot := filepath.Join("/tmp", "proxsave") - if err := restoreFS.MkdirAll(tempRoot, 0o755); err != nil { - if rcloneCleanup != nil { - rcloneCleanup() - } - return nil, fmt.Errorf("create temp root: %w", err) - } - workDir, err := restoreFS.MkdirTemp(tempRoot, "proxmox-decrypt-*") - if err != nil { - if rcloneCleanup != nil { - rcloneCleanup() - } - return nil, fmt.Errorf("create temp dir: %w", err) - } - cleanup := func() { - _ = restoreFS.RemoveAll(workDir) - if rcloneCleanup != nil { - rcloneCleanup() - } - } - - var staged stagedFiles - switch cand.Source { - case sourceBundle: - logger.Debug("Extracting bundle %s", filepath.Base(cand.BundlePath)) - staged, err = extractBundleToWorkdirWithLogger(cand.BundlePath, workDir, logger) - case sourceRaw: - logger.Debug("Staging raw artifacts for %s", filepath.Base(cand.RawArchivePath)) - staged, err = copyRawArtifactsToWorkdirWithLogger(ctx, cand, workDir, logger) - default: - err = fmt.Errorf("unsupported candidate source") - } - if err != nil { - cleanup() - return nil, err - } - - manifestCopy := *cand.Manifest - currentEncryption := strings.ToLower(manifestCopy.EncryptionMode) - - logger.Debug("Preparing archive %s for decryption (mode: %s)", filepath.Base(manifestCopy.ArchivePath), statusFromManifest(&manifestCopy)) - - plainArchiveName := strings.TrimSuffix(filepath.Base(staged.ArchivePath), ".age") - plainArchivePath := filepath.Join(workDir, plainArchiveName) - - if currentEncryption == "age" { - displayName := cand.DisplayBase - if displayName == "" { - displayName = filepath.Base(manifestCopy.ArchivePath) - } - if err := decryptArchiveWithTUIPrompts(ctx, staged.ArchivePath, plainArchivePath, displayName, configPath, buildSig, logger); err != nil { - cleanup() - return nil, err - } - } else if staged.ArchivePath != plainArchivePath { - if err := copyFile(restoreFS, staged.ArchivePath, plainArchivePath); err != nil { - cleanup() - return nil, fmt.Errorf("copy archive: %w", err) - } - } - - archiveInfo, err := restoreFS.Stat(plainArchivePath) - if err != nil { - cleanup() - return nil, fmt.Errorf("stat decrypted archive: %w", err) - } - - checksum, err := backup.GenerateChecksum(ctx, logger, plainArchivePath) - if err != nil { - cleanup() - return nil, fmt.Errorf("generate checksum: %w", err) - } - - manifestCopy.ArchivePath = plainArchivePath - manifestCopy.ArchiveSize = archiveInfo.Size() - manifestCopy.SHA256 = checksum - manifestCopy.EncryptionMode = "none" - if version != "" { - manifestCopy.ScriptVersion = version - } - - return &preparedBundle{ - ArchivePath: plainArchivePath, - Manifest: manifestCopy, - Checksum: checksum, - cleanup: cleanup, - }, nil + return preparePlainBundleCommon(ctx, cand, version, logger, func(ctx context.Context, encryptedPath, outputPath, displayName string) error { + return decryptArchiveWithTUIPrompts(ctx, encryptedPath, outputPath, displayName, configPath, buildSig, logger) + }) } func decryptArchiveWithTUIPrompts(ctx context.Context, encryptedPath, outputPath, displayName, configPath, buildSig string, logger *logging.Logger) error { diff --git a/internal/orchestrator/decrypt_tui_test.go b/internal/orchestrator/decrypt_tui_test.go index 6404b76..f08f9a0 100644 --- a/internal/orchestrator/decrypt_tui_test.go +++ b/internal/orchestrator/decrypt_tui_test.go @@ -204,7 +204,7 @@ func TestPreparePlainBundleTUICopiesRawArtifacts(t *testing.T) { if err := os.WriteFile(rawMetadata, []byte(`{"manifest":true}`), 0o640); err != nil { t.Fatalf("write metadata: %v", err) } - if err := os.WriteFile(rawChecksum, []byte("checksum backup.tar\n"), 0o640); err != nil { + if err := os.WriteFile(rawChecksum, checksumLineForBytes("backup.tar", []byte("payload-data")), 0o640); err != nil { t.Fatalf("write checksum: %v", err) } diff --git a/internal/orchestrator/decrypt_workflow_test.go b/internal/orchestrator/decrypt_workflow_test.go index 3f68fa4..e145b7d 100644 --- a/internal/orchestrator/decrypt_workflow_test.go +++ b/internal/orchestrator/decrypt_workflow_test.go @@ -18,7 +18,8 @@ import ( func writeRawBackup(t *testing.T, dir, name string) *backup.Manifest { t.Helper() archive := filepath.Join(dir, name) - if err := os.WriteFile(archive, []byte("data"), 0o640); err != nil { + archiveData := []byte("data") + if err := os.WriteFile(archive, archiveData, 0o640); err != nil { t.Fatalf("write archive: %v", err) } manifest := &backup.Manifest{ @@ -31,7 +32,7 @@ func writeRawBackup(t *testing.T, dir, name string) *backup.Manifest { if err := os.WriteFile(manifestPath, data, 0o640); err != nil { t.Fatalf("write metadata: %v", err) } - if err := os.WriteFile(archive+".sha256", []byte("checksum file"), 0o640); err != nil { + if err := os.WriteFile(archive+".sha256", checksumLineForBytes(filepath.Base(archive), archiveData), 0o640); err != nil { t.Fatalf("write checksum: %v", err) } return manifest @@ -70,21 +71,23 @@ func TestRunDecryptWorkflow_BundleNotFound(t *testing.T) { func TestPreparePlainBundle_AllowsMissingRawChecksumSidecar(t *testing.T) { dir := t.TempDir() archive := filepath.Join(dir, "bad.bundle.tar") - if err := os.WriteFile(archive, []byte("data"), 0o640); err != nil { + archiveData := []byte("data") + if err := os.WriteFile(archive, archiveData, 0o640); err != nil { t.Fatalf("write archive: %v", err) } manifest := &backup.Manifest{ ArchivePath: archive, CreatedAt: time.Now(), Hostname: "host", + SHA256: checksumHexForBytes(archiveData), } metaPath := archive + ".metadata" data, _ := json.Marshal(manifest) if err := os.WriteFile(metaPath, data, 0o640); err != nil { t.Fatalf("write metadata: %v", err) } - // No checksum file: ProxSave should still allow restore/decrypt to proceed - // (it re-computes checksums on the staged/plain archive anyway). + // No checksum sidecar: restore/decrypt should still proceed when the manifest + // already carries the expected archive checksum. cand := &decryptCandidate{ Manifest: manifest, @@ -100,7 +103,7 @@ func TestPreparePlainBundle_AllowsMissingRawChecksumSidecar(t *testing.T) { t.Cleanup(func() { restoreFS = osFS{} }) if _, err := preparePlainBundle(context.Background(), reader, cand, "", logging.New(types.LogLevelInfo, false)); err != nil { - t.Fatalf("expected missing checksum to be tolerated, got error: %v", err) + t.Fatalf("expected manifest checksum to cover missing sidecar, got error: %v", err) } } diff --git a/internal/orchestrator/decrypt_workflow_ui.go b/internal/orchestrator/decrypt_workflow_ui.go index 2ae37d7..7de3f69 100644 --- a/internal/orchestrator/decrypt_workflow_ui.go +++ b/internal/orchestrator/decrypt_workflow_ui.go @@ -179,113 +179,9 @@ func preparePlainBundleWithUI(ctx context.Context, cand *decryptCandidate, versi }) (bundle *preparedBundle, err error) { done := logging.DebugStart(logger, "prepare plain bundle (ui)", "source=%v rclone=%v", cand.Source, cand.IsRclone) defer func() { done(err) }() - - if cand == nil || cand.Manifest == nil { - return nil, fmt.Errorf("invalid backup candidate") - } - - var rcloneCleanup func() - if cand.IsRclone && cand.Source == sourceBundle { - logger.Debug("Detected rclone backup, downloading...") - localPath, cleanup, err := downloadRcloneBackup(ctx, cand.BundlePath, logger) - if err != nil { - return nil, fmt.Errorf("failed to download rclone backup: %w", err) - } - rcloneCleanup = cleanup - cand.BundlePath = localPath - } - - tempRoot := filepath.Join("/tmp", "proxsave") - if err := restoreFS.MkdirAll(tempRoot, 0o755); err != nil { - if rcloneCleanup != nil { - rcloneCleanup() - } - return nil, fmt.Errorf("create temp root: %w", err) - } - - workDir, err := restoreFS.MkdirTemp(tempRoot, "proxmox-decrypt-*") - if err != nil { - if rcloneCleanup != nil { - rcloneCleanup() - } - return nil, fmt.Errorf("create temp dir: %w", err) - } - - cleanup := func() { - _ = restoreFS.RemoveAll(workDir) - if rcloneCleanup != nil { - rcloneCleanup() - } - } - - var staged stagedFiles - switch cand.Source { - case sourceBundle: - logger.Info("Extracting bundle %s", filepath.Base(cand.BundlePath)) - staged, err = extractBundleToWorkdirWithLogger(cand.BundlePath, workDir, logger) - case sourceRaw: - logger.Info("Staging raw artifacts for %s", filepath.Base(cand.RawArchivePath)) - staged, err = copyRawArtifactsToWorkdirWithLogger(ctx, cand, workDir, logger) - default: - err = fmt.Errorf("unsupported candidate source") - } - if err != nil { - cleanup() - return nil, err - } - - manifestCopy := *cand.Manifest - currentEncryption := strings.ToLower(manifestCopy.EncryptionMode) - logger.Info("Preparing archive %s for decryption (mode: %s)", manifestCopy.ArchivePath, statusFromManifest(&manifestCopy)) - - plainArchiveName := strings.TrimSuffix(filepath.Base(staged.ArchivePath), ".age") - plainArchivePath := filepath.Join(workDir, plainArchiveName) - - if currentEncryption == "age" { - displayName := cand.DisplayBase - if strings.TrimSpace(displayName) == "" { - displayName = filepath.Base(manifestCopy.ArchivePath) - } - if err := decryptArchiveWithSecretPrompt(ctx, staged.ArchivePath, plainArchivePath, displayName, ui.PromptDecryptSecret); err != nil { - cleanup() - return nil, err - } - } else { - if staged.ArchivePath != plainArchivePath { - if err := copyFile(restoreFS, staged.ArchivePath, plainArchivePath); err != nil { - cleanup() - return nil, fmt.Errorf("copy archive: %w", err) - } - } - } - - archiveInfo, err := restoreFS.Stat(plainArchivePath) - if err != nil { - cleanup() - return nil, fmt.Errorf("stat decrypted archive: %w", err) - } - - checksum, err := backup.GenerateChecksum(ctx, logger, plainArchivePath) - if err != nil { - cleanup() - return nil, fmt.Errorf("generate checksum: %w", err) - } - - manifestCopy.ArchivePath = plainArchivePath - manifestCopy.ArchiveSize = archiveInfo.Size() - manifestCopy.SHA256 = checksum - manifestCopy.EncryptionMode = "none" - if version != "" { - manifestCopy.ScriptVersion = version - } - - bundle = &preparedBundle{ - ArchivePath: plainArchivePath, - Manifest: manifestCopy, - Checksum: checksum, - cleanup: cleanup, - } - return bundle, nil + return preparePlainBundleCommon(ctx, cand, version, logger, func(ctx context.Context, encryptedPath, outputPath, displayName string) error { + return decryptArchiveWithSecretPrompt(ctx, encryptedPath, outputPath, displayName, ui.PromptDecryptSecret) + }) } func runDecryptWorkflowWithUI(ctx context.Context, cfg *config.Config, logger *logging.Logger, version string, ui DecryptWorkflowUI) (err error) { From e5992cc8d6be2c16794e25247d58aa52b62cf10e Mon Sep 17 00:00:00 2001 From: tis24dev Date: Mon, 9 Mar 2026 17:57:13 +0100 Subject: [PATCH 09/18] fix(orchestrator): preserve staged network symlinks Use Lstat in network staged apply so symlink entries are handled as symlinks instead of being dereferenced and copied as regular file contents. Reject symlinked staged network directories, extend the FS abstraction with Lstat, and add regression tests covering preserved symlinks and invalid symlinked stage roots. --- internal/orchestrator/deps.go | 8 +- internal/orchestrator/deps_test.go | 10 ++ internal/orchestrator/network_staged_apply.go | 75 ++++++++++++++- .../orchestrator/network_staged_apply_test.go | 92 +++++++++++++++++++ internal/orchestrator/restore_errors_test.go | 9 +- 5 files changed, 184 insertions(+), 10 deletions(-) create mode 100644 internal/orchestrator/network_staged_apply_test.go diff --git a/internal/orchestrator/deps.go b/internal/orchestrator/deps.go index 648e20b..6530c30 100644 --- a/internal/orchestrator/deps.go +++ b/internal/orchestrator/deps.go @@ -16,6 +16,7 @@ import ( // FS abstracts filesystem operations to simplify testing. type FS interface { Stat(path string) (os.FileInfo, error) + Lstat(path string) (os.FileInfo, error) ReadFile(path string) ([]byte, error) Open(path string) (*os.File, error) OpenFile(path string, flag int, perm fs.FileMode) (*os.File, error) @@ -69,9 +70,10 @@ type Deps struct { type osFS struct{} -func (osFS) Stat(path string) (os.FileInfo, error) { return os.Stat(path) } -func (osFS) ReadFile(path string) ([]byte, error) { return os.ReadFile(path) } -func (osFS) Open(path string) (*os.File, error) { return os.Open(path) } +func (osFS) Stat(path string) (os.FileInfo, error) { return os.Stat(path) } +func (osFS) Lstat(path string) (os.FileInfo, error) { return os.Lstat(path) } +func (osFS) ReadFile(path string) ([]byte, error) { return os.ReadFile(path) } +func (osFS) Open(path string) (*os.File, error) { return os.Open(path) } func (osFS) OpenFile(path string, flag int, perm fs.FileMode) (*os.File, error) { return os.OpenFile(path, flag, perm) } diff --git a/internal/orchestrator/deps_test.go b/internal/orchestrator/deps_test.go index ba4d404..986051f 100644 --- a/internal/orchestrator/deps_test.go +++ b/internal/orchestrator/deps_test.go @@ -60,6 +60,16 @@ func (f *FakeFS) Stat(path string) (os.FileInfo, error) { return os.Stat(f.onDisk(path)) } +func (f *FakeFS) Lstat(path string) (os.FileInfo, error) { + if err, ok := f.StatErr[filepath.Clean(path)]; ok { + return nil, err + } + if err, ok := f.StatErrors[filepath.Clean(path)]; ok { + return nil, err + } + return os.Lstat(f.onDisk(path)) +} + func (f *FakeFS) ReadFile(path string) ([]byte, error) { return os.ReadFile(f.onDisk(path)) } diff --git a/internal/orchestrator/network_staged_apply.go b/internal/orchestrator/network_staged_apply.go index 3412d15..99f7f37 100644 --- a/internal/orchestrator/network_staged_apply.go +++ b/internal/orchestrator/network_staged_apply.go @@ -59,13 +59,16 @@ func applyNetworkFilesFromStage(logger *logging.Logger, stageRoot string) (appli } func copyDirOverlay(srcDir, destDir string) ([]string, error) { - info, err := restoreFS.Stat(srcDir) + info, err := restoreFS.Lstat(srcDir) if err != nil { if os.IsNotExist(err) { return nil, nil } return nil, fmt.Errorf("stat %s: %w", srcDir, err) } + if info.Mode()&os.ModeSymlink != 0 { + return nil, fmt.Errorf("staged directory must not be a symlink: %s", srcDir) + } if !info.IsDir() { return nil, nil } @@ -91,7 +94,26 @@ func copyDirOverlay(srcDir, destDir string) ([]string, error) { src := filepath.Join(srcDir, name) dest := filepath.Join(destDir, name) - if entry.IsDir() { + info, err := restoreFS.Lstat(src) + if err != nil { + if os.IsNotExist(err) { + continue + } + return applied, fmt.Errorf("stat %s: %w", src, err) + } + + if info.Mode()&os.ModeSymlink != 0 { + ok, err := copySymlinkOverlay(src, dest) + if err != nil { + return applied, err + } + if ok { + applied = append(applied, dest) + } + continue + } + + if info.IsDir() { paths, err := copyDirOverlay(src, dest) if err != nil { return applied, err @@ -113,16 +135,22 @@ func copyDirOverlay(srcDir, destDir string) ([]string, error) { } func copyFileOverlay(src, dest string) (bool, error) { - info, err := restoreFS.Stat(src) + info, err := restoreFS.Lstat(src) if err != nil { if os.IsNotExist(err) { return false, nil } return false, fmt.Errorf("stat %s: %w", src, err) } + if info.Mode()&os.ModeSymlink != 0 { + return copySymlinkOverlay(src, dest) + } if info.IsDir() { return false, nil } + if !info.Mode().IsRegular() { + return false, fmt.Errorf("unsupported staged file type %s (mode=%s)", src, info.Mode()) + } data, err := restoreFS.ReadFile(src) if err != nil { @@ -141,3 +169,44 @@ func copyFileOverlay(src, dest string) (bool, error) { } return true, nil } + +func copySymlinkOverlay(src, dest string) (bool, error) { + info, err := restoreFS.Lstat(src) + if err != nil { + if os.IsNotExist(err) { + return false, nil + } + return false, fmt.Errorf("stat %s: %w", src, err) + } + if info.Mode()&os.ModeSymlink == 0 { + return false, fmt.Errorf("source is not a symlink: %s", src) + } + + target, err := restoreFS.Readlink(src) + if err != nil { + if os.IsNotExist(err) { + return false, nil + } + return false, fmt.Errorf("readlink %s: %w", src, err) + } + + if err := ensureDirExistsWithInheritedMeta(filepath.Dir(dest)); err != nil { + return false, fmt.Errorf("ensure %s: %w", filepath.Dir(dest), err) + } + + if existing, err := restoreFS.Lstat(dest); err == nil { + if existing.IsDir() { + return false, fmt.Errorf("destination exists as directory: %s", dest) + } + if err := restoreFS.Remove(dest); err != nil && !os.IsNotExist(err) { + return false, fmt.Errorf("remove %s: %w", dest, err) + } + } else if !os.IsNotExist(err) { + return false, fmt.Errorf("stat %s: %w", dest, err) + } + + if err := restoreFS.Symlink(target, dest); err != nil { + return false, fmt.Errorf("symlink %s -> %s: %w", dest, target, err) + } + return true, nil +} diff --git a/internal/orchestrator/network_staged_apply_test.go b/internal/orchestrator/network_staged_apply_test.go new file mode 100644 index 0000000..99122f9 --- /dev/null +++ b/internal/orchestrator/network_staged_apply_test.go @@ -0,0 +1,92 @@ +package orchestrator + +import ( + "os" + "path/filepath" + "strings" + "testing" +) + +type preservingSymlinkFS struct { + *FakeFS +} + +func newPreservingSymlinkFS() *preservingSymlinkFS { + return &preservingSymlinkFS{FakeFS: NewFakeFS()} +} + +func (f *preservingSymlinkFS) Symlink(oldname, newname string) error { + if err := os.MkdirAll(filepath.Dir(f.onDisk(newname)), 0o755); err != nil { + return err + } + return os.Symlink(oldname, f.onDisk(newname)) +} + +func TestApplyNetworkFilesFromStage_PreservesSymlinkEntries(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := newPreservingSymlinkFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + if err := fakeFS.MkdirAll("/stage/etc/network", 0o755); err != nil { + t.Fatalf("create stage dir: %v", err) + } + if err := fakeFS.Symlink("/stage/etc/network/interfaces.real", "/stage/etc/network/interfaces"); err != nil { + t.Fatalf("create staged symlink: %v", err) + } + + applied, err := applyNetworkFilesFromStage(newTestLogger(), "/stage") + if err != nil { + t.Fatalf("applyNetworkFilesFromStage: %v", err) + } + + info, err := fakeFS.Lstat("/etc/network/interfaces") + if err != nil { + t.Fatalf("lstat dest symlink: %v", err) + } + if info.Mode()&os.ModeSymlink == 0 { + t.Fatalf("expected symlink mode, got %s", info.Mode()) + } + + gotTarget, err := fakeFS.Readlink("/etc/network/interfaces") + if err != nil { + t.Fatalf("readlink dest symlink: %v", err) + } + if gotTarget != "/stage/etc/network/interfaces.real" { + t.Fatalf("symlink target=%q, want %q", gotTarget, "/stage/etc/network/interfaces.real") + } + + found := false + for _, path := range applied { + if path == "/etc/network/interfaces" { + found = true + break + } + } + if !found { + t.Fatalf("expected applied paths to include /etc/network/interfaces, got %#v", applied) + } +} + +func TestApplyNetworkFilesFromStage_RejectsSymlinkStageDirectory(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := newPreservingSymlinkFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + if err := fakeFS.MkdirAll("/stage/etc", 0o755); err != nil { + t.Fatalf("create stage parent dir: %v", err) + } + if err := fakeFS.Symlink("/outside/network", "/stage/etc/network"); err != nil { + t.Fatalf("create staged dir symlink: %v", err) + } + + _, err := applyNetworkFilesFromStage(newTestLogger(), "/stage") + if err == nil || !strings.Contains(err.Error(), "must not be a symlink") { + t.Fatalf("expected staged directory symlink error, got %v", err) + } +} diff --git a/internal/orchestrator/restore_errors_test.go b/internal/orchestrator/restore_errors_test.go index cad33a8..0faf024 100644 --- a/internal/orchestrator/restore_errors_test.go +++ b/internal/orchestrator/restore_errors_test.go @@ -798,10 +798,11 @@ type ErrorInjectingFS struct { linkErr error } -func (f *ErrorInjectingFS) Stat(path string) (os.FileInfo, error) { return f.base.Stat(path) } -func (f *ErrorInjectingFS) ReadFile(path string) ([]byte, error) { return f.base.ReadFile(path) } -func (f *ErrorInjectingFS) Open(path string) (*os.File, error) { return f.base.Open(path) } -func (f *ErrorInjectingFS) Create(name string) (*os.File, error) { return f.base.Create(name) } +func (f *ErrorInjectingFS) Stat(path string) (os.FileInfo, error) { return f.base.Stat(path) } +func (f *ErrorInjectingFS) Lstat(path string) (os.FileInfo, error) { return f.base.Lstat(path) } +func (f *ErrorInjectingFS) ReadFile(path string) ([]byte, error) { return f.base.ReadFile(path) } +func (f *ErrorInjectingFS) Open(path string) (*os.File, error) { return f.base.Open(path) } +func (f *ErrorInjectingFS) Create(name string) (*os.File, error) { return f.base.Create(name) } func (f *ErrorInjectingFS) WriteFile(path string, data []byte, perm os.FileMode) error { return f.base.WriteFile(path, data, perm) } From dbf61c75d2d4bbc998f834603fac8c8651f96d3d Mon Sep 17 00:00:00 2001 From: tis24dev Date: Mon, 9 Mar 2026 18:54:32 +0100 Subject: [PATCH 10/18] fix(orchestrator): harden restore path validation against broken symlink escapes Replace the restore path safety checks with a shared FS-aware resolver that walks paths component by component using Lstat and Readlink, correctly handling missing tails and rejecting intermediate symlink escapes. Apply the new validation to regular restore, safety restore, symlink and hardlink handling, and add regression tests for broken escaping symlinks, symlink loops, and symlinked restore roots. --- internal/orchestrator/backup_safety.go | 73 +------ internal/orchestrator/backup_safety_test.go | 57 ++++++ internal/orchestrator/path_security.go | 214 ++++++++++++++++++++ internal/orchestrator/path_security_test.go | 78 +++++++ internal/orchestrator/restore.go | 74 +------ internal/orchestrator/restore_test.go | 58 ++++++ 6 files changed, 424 insertions(+), 130 deletions(-) create mode 100644 internal/orchestrator/path_security.go create mode 100644 internal/orchestrator/path_security_test.go diff --git a/internal/orchestrator/backup_safety.go b/internal/orchestrator/backup_safety.go index 5402897..b992052 100644 --- a/internal/orchestrator/backup_safety.go +++ b/internal/orchestrator/backup_safety.go @@ -23,38 +23,8 @@ type safetyBackupSpec struct { WriteLocationFile bool } -// resolveAndCheckPath cleans and resolves symlinks for candidate extraction paths -// and verifies the resolved path is still within destRoot. func resolveAndCheckPath(destRoot, candidate string) (string, error) { - joined := candidate - if !filepath.IsAbs(candidate) { - joined = filepath.Join(destRoot, candidate) - } - - resolved, err := filepath.EvalSymlinks(joined) - if err != nil { - // If the path doesn't exist yet, EvalSymlinks will fail; fallback to the cleaned path. - resolved = filepath.Clean(joined) - } - - absDestRoot, err := filepath.Abs(destRoot) - if err != nil { - return "", fmt.Errorf("cannot resolve destination root: %w", err) - } - absResolved, err := filepath.Abs(resolved) - if err != nil { - return "", fmt.Errorf("cannot resolve candidate path: %w", err) - } - - rel, err := filepath.Rel(absDestRoot, absResolved) - if err != nil { - return "", fmt.Errorf("cannot compute relative path: %w", err) - } - if strings.HasPrefix(rel, ".."+string(os.PathSeparator)) || rel == ".." || filepath.IsAbs(rel) { - return "", fmt.Errorf("resolved path escapes destination: %s", absResolved) - } - - return absResolved, nil + return resolvePathWithinRootFS(safetyFS, destRoot, candidate) } // SafetyBackupResult contains information about the safety backup @@ -391,7 +361,7 @@ func RestoreSafetyBackup(logger *logging.Logger, backupPath string, destRoot str return fmt.Errorf("read tar entry: %w", err) } - target, _, err := sanitizeRestoreEntryTarget(absDestRoot, header.Name) + target, _, err := sanitizeRestoreEntryTargetWithFS(safetyFS, absDestRoot, header.Name) if err != nil { logger.Warning("Skipping archive entry %s: %v", header.Name, err) continue @@ -425,14 +395,7 @@ func RestoreSafetyBackup(logger *logging.Logger, backupPath string, destRoot str if header.Typeflag == tar.TypeSymlink { linkTarget := header.Linkname - // Resolve intended target relative to the sanitized symlink directory inside the archive - sanitizedDir := filepath.Dir(relTarget) - resolvedLinkPath := linkTarget - if !filepath.IsAbs(linkTarget) { - resolvedLinkPath = filepath.Join(sanitizedDir, linkTarget) - } - - if _, pathErr := resolveAndCheckPath(destRoot, resolvedLinkPath); pathErr != nil { + if _, pathErr := resolvePathRelativeToBaseWithinRootFS(safetyFS, absDestRoot, filepath.Dir(target), linkTarget); pathErr != nil { logger.Warning("Skipping symlink %s -> %s: target escapes root: %v", target, linkTarget, pathErr) continue } @@ -454,33 +417,9 @@ func RestoreSafetyBackup(logger *logging.Logger, backupPath string, destRoot str continue } - // Resolve the symlink target relative to the symlink's directory - symlinkTargetDir := filepath.Dir(target) - resolvedTarget := actualTarget - if !filepath.IsAbs(actualTarget) { - resolvedTarget = filepath.Join(symlinkTargetDir, actualTarget) - } - - // Validate the resolved target stays within destRoot - absDestRoot, err := filepath.Abs(destRoot) - if err != nil { - logger.Warning("Cannot resolve destination root: %v", err) - safetyFS.Remove(target) - continue - } - - absResolvedTarget, err := filepath.Abs(resolvedTarget) - if err != nil { - logger.Warning("Cannot resolve symlink target: %v", err) - safetyFS.Remove(target) - continue - } - - // Check if resolved target is within destRoot - rel, err := filepath.Rel(absDestRoot, absResolvedTarget) - if err != nil || strings.HasPrefix(rel, ".."+string(os.PathSeparator)) || rel == ".." { - logger.Warning("Removing symlink %s -> %s: target escapes root after creation (resolves to %s)", - target, actualTarget, absResolvedTarget) + if _, err := resolvePathRelativeToBaseWithinRootFS(safetyFS, absDestRoot, filepath.Dir(target), actualTarget); err != nil { + logger.Warning("Removing symlink %s -> %s: target escapes root after creation: %v", + target, actualTarget, err) safetyFS.Remove(target) continue } diff --git a/internal/orchestrator/backup_safety_test.go b/internal/orchestrator/backup_safety_test.go index 2529e2e..68c1eba 100644 --- a/internal/orchestrator/backup_safety_test.go +++ b/internal/orchestrator/backup_safety_test.go @@ -400,6 +400,51 @@ func TestRestoreSafetyBackup_DoesNotFollowExistingSymlinkTargetPath(t *testing.T } } +func TestRestoreSafetyBackup_RejectsBrokenIntermediateSymlinkEscape(t *testing.T) { + logger := logging.New(types.LogLevelInfo, false) + var logBuf bytes.Buffer + logger.SetOutput(&logBuf) + + tmpDir := t.TempDir() + backupPath := filepath.Join(tmpDir, "broken-escape.tar.gz") + restoreDir := filepath.Join(tmpDir, "restore") + outside := t.TempDir() + + if err := os.MkdirAll(restoreDir, 0o755); err != nil { + t.Fatalf("mkdir restore dir: %v", err) + } + if err := os.Symlink(outside, filepath.Join(restoreDir, "escape-link")); err != nil { + t.Fatalf("create escape symlink: %v", err) + } + + var buf bytes.Buffer + gzw := gzip.NewWriter(&buf) + tw := tar.NewWriter(gzw) + if err := tw.WriteHeader(&tar.Header{Name: "escape-link/missing/", Mode: 0o755, Typeflag: tar.TypeDir}); err != nil { + t.Fatalf("write dir header failed: %v", err) + } + if err := tw.Close(); err != nil { + t.Fatalf("tar close failed: %v", err) + } + if err := gzw.Close(); err != nil { + t.Fatalf("gzip close failed: %v", err) + } + if err := os.WriteFile(backupPath, buf.Bytes(), 0o644); err != nil { + t.Fatalf("write archive failed: %v", err) + } + + if err := RestoreSafetyBackup(logger, backupPath, restoreDir); err != nil { + t.Fatalf("RestoreSafetyBackup failed: %v", err) + } + + if _, err := os.Stat(filepath.Join(outside, "missing")); !os.IsNotExist(err) { + t.Fatalf("outside path should not be created, got err=%v", err) + } + if !strings.Contains(logBuf.String(), "Skipping archive entry") { + t.Fatalf("expected archive entry skip warning, logs=%q", logBuf.String()) + } +} + func TestCleanupOldSafetyBackups(t *testing.T) { logger := logging.New(types.LogLevelInfo, false) @@ -667,6 +712,18 @@ func TestResolveAndCheckPathRejectsSymlinkEscape(t *testing.T) { } } +func TestResolveAndCheckPathRejectsBrokenIntermediateSymlinkEscape(t *testing.T) { + root := t.TempDir() + outside := t.TempDir() + if err := os.Symlink(outside, filepath.Join(root, "escape-link")); err != nil { + t.Fatalf("create symlink: %v", err) + } + + if _, err := resolveAndCheckPath(root, filepath.Join("escape-link", "missing", "data.txt")); err == nil { + t.Fatalf("expected broken symlink escape to be rejected") + } +} + // ===================================== // walkFS / walkFSRecursive tests // ===================================== diff --git a/internal/orchestrator/path_security.go b/internal/orchestrator/path_security.go new file mode 100644 index 0000000..938bc50 --- /dev/null +++ b/internal/orchestrator/path_security.go @@ -0,0 +1,214 @@ +package orchestrator + +import ( + "fmt" + "os" + "path/filepath" + "strings" +) + +const maxPathSecuritySymlinkHops = 40 + +func resolvePathWithinRootFS(fsys FS, destRoot, candidate string) (string, error) { + lexicalRoot, canonicalRoot, err := prepareRootPathsFS(fsys, destRoot) + if err != nil { + return "", err + } + + candidateAbs, err := normalizeCandidateWithinRoot(lexicalRoot, canonicalRoot, candidate) + if err != nil { + return "", err + } + + return resolvePathWithinPreparedRootFS(fsys, lexicalRoot, canonicalRoot, candidateAbs, true, maxPathSecuritySymlinkHops) +} + +func resolvePathRelativeToBaseWithinRootFS(fsys FS, destRoot, baseDir, candidate string) (string, error) { + lexicalRoot, canonicalRoot, err := prepareRootPathsFS(fsys, destRoot) + if err != nil { + return "", err + } + + baseAbs, err := filepath.Abs(filepath.Clean(baseDir)) + if err != nil { + return "", fmt.Errorf("resolve base directory: %w", err) + } + canonicalBaseDir, err := normalizeAbsolutePathWithinRoot(lexicalRoot, canonicalRoot, baseAbs) + if err != nil { + return "", err + } + + var candidateAbs string + if filepath.IsAbs(candidate) { + candidateAbs, err = normalizeAbsolutePathWithinRoot(lexicalRoot, canonicalRoot, filepath.Clean(candidate)) + if err != nil { + return "", err + } + } else { + candidateAbs = filepath.Clean(filepath.Join(canonicalBaseDir, candidate)) + } + + return resolvePathWithinPreparedRootFS(fsys, lexicalRoot, canonicalRoot, candidateAbs, true, maxPathSecuritySymlinkHops) +} + +func prepareRootPathsFS(fsys FS, destRoot string) (string, string, error) { + fsys = pathSecurityFS(fsys) + + cleanRoot := filepath.Clean(destRoot) + absRoot, err := filepath.Abs(cleanRoot) + if err != nil { + return "", "", fmt.Errorf("resolve destination root: %w", err) + } + + canonicalRoot, err := resolvePathFromFilesystemRootFS(fsys, absRoot, true, maxPathSecuritySymlinkHops) + if err != nil { + return "", "", fmt.Errorf("resolve destination root: %w", err) + } + + return absRoot, canonicalRoot, nil +} + +func normalizeCandidateWithinRoot(lexicalRoot, canonicalRoot, candidate string) (string, error) { + if filepath.IsAbs(candidate) { + return normalizeAbsolutePathWithinRoot(lexicalRoot, canonicalRoot, filepath.Clean(candidate)) + } + + return filepath.Clean(filepath.Join(canonicalRoot, candidate)), nil +} + +func normalizeAbsolutePathWithinRoot(lexicalRoot, canonicalRoot, candidateAbs string) (string, error) { + rel, ok, err := relativePathWithinRoot(lexicalRoot, candidateAbs) + if err != nil { + return "", fmt.Errorf("cannot compute relative path: %w", err) + } + if ok { + return filepath.Clean(filepath.Join(canonicalRoot, rel)), nil + } + + rel, ok, err = relativePathWithinRoot(canonicalRoot, candidateAbs) + if err != nil { + return "", fmt.Errorf("cannot compute relative path: %w", err) + } + if ok { + return filepath.Clean(filepath.Join(canonicalRoot, rel)), nil + } + + return "", fmt.Errorf("resolved path escapes destination: %s", candidateAbs) +} + +func resolvePathFromFilesystemRootFS(fsys FS, candidateAbs string, allowMissingTail bool, hopsRemaining int) (string, error) { + root := string(os.PathSeparator) + return resolvePathWithinPreparedRootFS(fsys, root, root, candidateAbs, allowMissingTail, hopsRemaining) +} + +func resolvePathWithinPreparedRootFS(fsys FS, lexicalRoot, canonicalRoot, candidateAbs string, allowMissingTail bool, hopsRemaining int) (string, error) { + fsys = pathSecurityFS(fsys) + + lexicalRoot = filepath.Clean(lexicalRoot) + canonicalRoot = filepath.Clean(canonicalRoot) + candidateAbs = filepath.Clean(candidateAbs) + + if !filepath.IsAbs(lexicalRoot) { + return "", fmt.Errorf("destination root must be absolute: %s", lexicalRoot) + } + if !filepath.IsAbs(canonicalRoot) { + return "", fmt.Errorf("destination root must be absolute: %s", canonicalRoot) + } + if !filepath.IsAbs(candidateAbs) { + return "", fmt.Errorf("candidate path must be absolute: %s", candidateAbs) + } + + rel, ok, err := relativePathWithinRoot(canonicalRoot, candidateAbs) + if err != nil { + return "", fmt.Errorf("cannot compute relative path: %w", err) + } + if !ok { + return "", fmt.Errorf("resolved path escapes destination: %s", candidateAbs) + } + if rel == "." { + return canonicalRoot, nil + } + + current := canonicalRoot + parts := splitRelativePath(rel) + for idx, part := range parts { + next := filepath.Join(current, part) + info, err := fsys.Lstat(next) + if err != nil { + if allowMissingTail && os.IsNotExist(err) { + return filepath.Clean(filepath.Join(current, filepath.Join(parts[idx:]...))), nil + } + return "", fmt.Errorf("lstat %s: %w", next, err) + } + + if info.Mode()&os.ModeSymlink != 0 { + if hopsRemaining <= 0 { + return "", fmt.Errorf("too many symlink resolutions for %s", candidateAbs) + } + target, err := fsys.Readlink(next) + if err != nil { + return "", fmt.Errorf("readlink %s: %w", next, err) + } + + var resolvedLink string + if filepath.IsAbs(target) { + resolvedLink, err = normalizeAbsolutePathWithinRoot(lexicalRoot, canonicalRoot, filepath.Clean(target)) + if err != nil { + return "", err + } + } else { + resolvedLink = filepath.Join(current, target) + } + resolvedLink = filepath.Clean(resolvedLink) + + if _, ok, err := relativePathWithinRoot(canonicalRoot, resolvedLink); err != nil { + return "", fmt.Errorf("cannot compute relative path: %w", err) + } else if !ok { + return "", fmt.Errorf("resolved path escapes destination: %s", resolvedLink) + } + + remainder := filepath.Join(parts[idx+1:]...) + if remainder != "" { + resolvedLink = filepath.Join(resolvedLink, remainder) + } + + return resolvePathWithinPreparedRootFS(fsys, lexicalRoot, canonicalRoot, resolvedLink, allowMissingTail, hopsRemaining-1) + } + + if !info.IsDir() && idx < len(parts)-1 { + return "", fmt.Errorf("path component is not a directory: %s", next) + } + + current = next + } + + return current, nil +} + +func relativePathWithinRoot(root, candidate string) (string, bool, error) { + rel, err := filepath.Rel(root, candidate) + if err != nil { + return "", false, err + } + if rel == "." { + return rel, true, nil + } + if rel == ".." || strings.HasPrefix(rel, ".."+string(os.PathSeparator)) || filepath.IsAbs(rel) { + return rel, false, nil + } + return rel, true, nil +} + +func splitRelativePath(rel string) []string { + if rel == "" || rel == "." { + return nil + } + return strings.Split(rel, string(os.PathSeparator)) +} + +func pathSecurityFS(fsys FS) FS { + if fsys == nil { + return osFS{} + } + return fsys +} diff --git a/internal/orchestrator/path_security_test.go b/internal/orchestrator/path_security_test.go new file mode 100644 index 0000000..c211000 --- /dev/null +++ b/internal/orchestrator/path_security_test.go @@ -0,0 +1,78 @@ +package orchestrator + +import ( + "os" + "path/filepath" + "strings" + "testing" +) + +func TestResolvePathWithinRootFS_AllowsMissingTailAfterSafeSymlink(t *testing.T) { + root := t.TempDir() + if err := os.MkdirAll(filepath.Join(root, "safe-target"), 0o755); err != nil { + t.Fatalf("mkdir safe-target: %v", err) + } + if err := os.Symlink("safe-target", filepath.Join(root, "safe-link")); err != nil { + t.Fatalf("create safe symlink: %v", err) + } + + resolved, err := resolvePathWithinRootFS(osFS{}, root, filepath.Join("safe-link", "missing", "file.txt")) + if err != nil { + t.Fatalf("resolvePathWithinRootFS returned error: %v", err) + } + + want := filepath.Join(root, "safe-target", "missing", "file.txt") + if resolved != want { + t.Fatalf("resolved path = %q, want %q", resolved, want) + } +} + +func TestResolvePathWithinRootFS_RejectsBrokenIntermediateSymlinkEscape(t *testing.T) { + root := t.TempDir() + outside := t.TempDir() + if err := os.Symlink(outside, filepath.Join(root, "escape-link")); err != nil { + t.Fatalf("create escape symlink: %v", err) + } + + _, err := resolvePathWithinRootFS(osFS{}, root, filepath.Join("escape-link", "missing", "file.txt")) + if err == nil || !strings.Contains(err.Error(), "escapes destination") { + t.Fatalf("expected escape rejection, got %v", err) + } +} + +func TestResolvePathWithinRootFS_RejectsSymlinkLoop(t *testing.T) { + root := t.TempDir() + if err := os.Symlink("loop", filepath.Join(root, "loop")); err != nil { + t.Fatalf("create loop symlink: %v", err) + } + + _, err := resolvePathWithinRootFS(osFS{}, root, filepath.Join("loop", "file.txt")) + if err == nil || !strings.Contains(err.Error(), "too many symlink resolutions") { + t.Fatalf("expected symlink loop rejection, got %v", err) + } +} + +func TestResolvePathWithinRootFS_AllowsAbsoluteSymlinkTargetViaLexicalRoot(t *testing.T) { + parent := t.TempDir() + realRoot := filepath.Join(parent, "real-root") + linkRoot := filepath.Join(parent, "link-root") + if err := os.MkdirAll(filepath.Join(realRoot, "etc-real"), 0o755); err != nil { + t.Fatalf("mkdir real root: %v", err) + } + if err := os.Symlink(realRoot, linkRoot); err != nil { + t.Fatalf("create root symlink: %v", err) + } + if err := os.Symlink(filepath.Join(linkRoot, "etc-real"), filepath.Join(realRoot, "etc")); err != nil { + t.Fatalf("create nested symlink: %v", err) + } + + resolved, err := resolvePathWithinRootFS(osFS{}, linkRoot, filepath.Join("etc", "missing", "file.txt")) + if err != nil { + t.Fatalf("resolvePathWithinRootFS returned error: %v", err) + } + + want := filepath.Join(realRoot, "etc-real", "missing", "file.txt") + if resolved != want { + t.Fatalf("resolved path = %q, want %q", resolved, want) + } +} diff --git a/internal/orchestrator/restore.go b/internal/orchestrator/restore.go index 2f14d20..00a7a14 100644 --- a/internal/orchestrator/restore.go +++ b/internal/orchestrator/restore.go @@ -1448,6 +1448,10 @@ func runRestoreCommandStream(ctx context.Context, name string, stdin io.Reader, } func sanitizeRestoreEntryTarget(destRoot, entryName string) (string, string, error) { + return sanitizeRestoreEntryTargetWithFS(restoreFS, destRoot, entryName) +} + +func sanitizeRestoreEntryTargetWithFS(fsys FS, destRoot, entryName string) (string, string, error) { cleanDestRoot := filepath.Clean(destRoot) if cleanDestRoot == "" { cleanDestRoot = string(os.PathSeparator) @@ -1490,24 +1494,9 @@ func sanitizeRestoreEntryTarget(destRoot, entryName string) (string, string, err return "", "", fmt.Errorf("illegal path: %s", entryName) } - parentDir := filepath.Dir(absTarget) - resolvedParentDir, err := filepath.EvalSymlinks(parentDir) - if err != nil { - // If the path doesn't exist yet, EvalSymlinks will fail; fallback to the cleaned path. - resolvedParentDir = filepath.Clean(parentDir) - } - absResolvedParentDir, err := filepath.Abs(resolvedParentDir) - if err != nil { - return "", "", fmt.Errorf("resolve extraction parent directory: %w", err) - } - - relParent, err := filepath.Rel(absDestRoot, absResolvedParentDir) - if err != nil { + if _, err := resolvePathWithinRootFS(fsys, absDestRoot, absTarget); err != nil { return "", "", fmt.Errorf("illegal path: %s: %w", entryName, err) } - if strings.HasPrefix(relParent, ".."+string(os.PathSeparator)) || relParent == ".." || filepath.IsAbs(relParent) { - return "", "", fmt.Errorf("illegal path: %s", entryName) - } return absTarget, absDestRoot, nil } @@ -1537,7 +1526,7 @@ func shouldSkipProxmoxSystemRestore(relTarget string) (bool, string) { // extractTarEntry extracts a single TAR entry, preserving all attributes including atime/ctime func extractTarEntry(tarReader *tar.Reader, header *tar.Header, destRoot string, logger *logging.Logger) error { - target, cleanDestRoot, err := sanitizeRestoreEntryTarget(destRoot, header.Name) + target, cleanDestRoot, err := sanitizeRestoreEntryTargetWithFS(restoreFS, destRoot, header.Name) if err != nil { return err } @@ -1647,26 +1636,8 @@ func extractRegularFile(tarReader *tar.Reader, target string, header *tar.Header func extractSymlink(target string, header *tar.Header, destRoot string, logger *logging.Logger) error { linkTarget := header.Linkname - // Pre-validation: ensure the resolved target would stay within destRoot before creating the symlink - relativeTarget, err := filepath.Rel(destRoot, target) - if err != nil { - return fmt.Errorf("determine relative path for symlink %s: %w", target, err) - } - if strings.HasPrefix(relativeTarget, ".."+string(os.PathSeparator)) || relativeTarget == ".." { - return fmt.Errorf("sanitized symlink path escapes root: %s", target) - } - - symlinkArchivePath := path.Clean(filepath.ToSlash(relativeTarget)) - symlinkArchiveDir := path.Dir(symlinkArchivePath) - if symlinkArchiveDir == "." { - symlinkArchiveDir = "" - } - potentialTarget := linkTarget - if !filepath.IsAbs(linkTarget) { - potentialTarget = filepath.FromSlash(path.Join(symlinkArchiveDir, linkTarget)) - } - - if _, err := resolveAndCheckPath(destRoot, potentialTarget); err != nil { + // Pre-validation: ensure the symlink target resolves within destRoot before creation. + if _, err := resolvePathRelativeToBaseWithinRootFS(restoreFS, destRoot, filepath.Dir(target), linkTarget); err != nil { return fmt.Errorf("symlink target escapes root before creation: %s -> %s: %w", header.Name, linkTarget, err) } @@ -1685,32 +1656,9 @@ func extractSymlink(target string, header *tar.Header, destRoot string, logger * return fmt.Errorf("read created symlink %s: %w", target, err) } - // Resolve the symlink target relative to the symlink's directory - symlinkDir := filepath.Dir(target) - resolvedTarget := actualTarget - if !filepath.IsAbs(actualTarget) { - resolvedTarget = filepath.Join(symlinkDir, actualTarget) - } - - // Validate the resolved target stays within destRoot using absolute paths - absDestRoot, err := filepath.Abs(destRoot) - if err != nil { - restoreFS.Remove(target) - return fmt.Errorf("resolve destination root: %w", err) - } - - absResolvedTarget, err := filepath.Abs(resolvedTarget) - if err != nil { - restoreFS.Remove(target) - return fmt.Errorf("resolve symlink target: %w", err) - } - - // Check if resolved target is within destRoot - rel, err := filepath.Rel(absDestRoot, absResolvedTarget) - if err != nil || strings.HasPrefix(rel, ".."+string(os.PathSeparator)) || rel == ".." { + if _, err := resolvePathRelativeToBaseWithinRootFS(restoreFS, destRoot, filepath.Dir(target), actualTarget); err != nil { restoreFS.Remove(target) - return fmt.Errorf("symlink target escapes root after creation: %s -> %s (resolves to %s)", - header.Name, linkTarget, absResolvedTarget) + return fmt.Errorf("symlink target escapes root after creation: %s -> %s: %w", header.Name, actualTarget, err) } // Set ownership (on the symlink itself, not the target) @@ -1733,7 +1681,7 @@ func extractHardlink(target string, header *tar.Header, destRoot string) error { } // Validate the hard link target stays within extraction root - if _, err := resolveAndCheckPath(destRoot, linkName); err != nil { + if _, err := resolvePathWithinRootFS(restoreFS, destRoot, linkName); err != nil { return fmt.Errorf("hardlink target escapes root: %s -> %s: %w", header.Name, linkName, err) } diff --git a/internal/orchestrator/restore_test.go b/internal/orchestrator/restore_test.go index 80a1f10..9e80594 100644 --- a/internal/orchestrator/restore_test.go +++ b/internal/orchestrator/restore_test.go @@ -334,6 +334,64 @@ func TestExtractSymlink_SecurityValidation(t *testing.T) { } } +func TestExtractTarEntry_RejectsBrokenIntermediateSymlinkEscape(t *testing.T) { + logger := logging.New(types.LogLevelDebug, false) + orig := restoreFS + restoreFS = osFS{} + t.Cleanup(func() { restoreFS = orig }) + + destRoot := t.TempDir() + outside := t.TempDir() + if err := os.Symlink(outside, filepath.Join(destRoot, "escape-link")); err != nil { + t.Fatalf("create escape symlink: %v", err) + } + + header := &tar.Header{ + Name: "escape-link/missing", + Typeflag: tar.TypeDir, + Mode: 0o755, + } + + var tr *tar.Reader + err := extractTarEntry(tr, header, destRoot, logger) + if err == nil || !strings.Contains(err.Error(), "illegal path") { + t.Fatalf("expected illegal path error, got %v", err) + } + + if _, err := os.Stat(filepath.Join(outside, "missing")); !os.IsNotExist(err) { + t.Fatalf("outside path should not be created, got err=%v", err) + } +} + +func TestExtractSymlink_RejectsBrokenIntermediateSymlinkEscape(t *testing.T) { + logger := logging.New(types.LogLevelDebug, false) + orig := restoreFS + restoreFS = osFS{} + t.Cleanup(func() { restoreFS = orig }) + + destRoot := t.TempDir() + outside := t.TempDir() + if err := os.Symlink(outside, filepath.Join(destRoot, "escape-link")); err != nil { + t.Fatalf("create escape symlink: %v", err) + } + + header := &tar.Header{ + Name: "link", + Typeflag: tar.TypeSymlink, + Linkname: "escape-link/missing/file.txt", + } + target := filepath.Join(destRoot, header.Name) + + err := extractSymlink(target, header, destRoot, logger) + if err == nil || !strings.Contains(err.Error(), "escapes root") { + t.Fatalf("expected escapes root error, got %v", err) + } + + if _, err := os.Lstat(target); !os.IsNotExist(err) { + t.Fatalf("symlink should not be created, got err=%v", err) + } +} + func TestExtractTarEntry_DoesNotFollowExistingSymlinkTargetPath(t *testing.T) { logger := logging.New(types.LogLevelDebug, false) orig := restoreFS From e612754f57d85238bf679e5c81570f1131cebbc8 Mon Sep 17 00:00:00 2001 From: tis24dev Date: Mon, 9 Mar 2026 19:11:06 +0100 Subject: [PATCH 11/18] fix(backup): sanitize datastore and storage path segments Introduce a shared path-safe key for PBS datastore names and PVE storage names when building collector output paths, preventing path traversal and filename collisions from raw config values. Keep raw names unchanged in metadata and command invocations, update all affected PBS/PVE collector call sites, and add regression tests covering unsafe names while preserving the existing layout for already-safe names. --- internal/backup/collector.go | 17 ++ .../backup/collector_helpers_extra_test.go | 27 ++- internal/backup/collector_pbs.go | 3 +- internal/backup/collector_pbs_datastore.go | 27 +-- internal/backup/collector_pbs_test.go | 161 ++++++++++++++++++ internal/backup/collector_pve.go | 16 +- internal/backup/collector_pve_test.go | 48 ++++++ 7 files changed, 282 insertions(+), 17 deletions(-) diff --git a/internal/backup/collector.go b/internal/backup/collector.go index 466fed2..e3de49f 100644 --- a/internal/backup/collector.go +++ b/internal/backup/collector.go @@ -3,6 +3,8 @@ package backup import ( "bytes" "context" + "crypto/sha256" + "encoding/hex" "errors" "fmt" "io" @@ -1180,6 +1182,21 @@ func sanitizeFilename(name string) string { return clean } +func collectorPathKey(name string) string { + trimmed := strings.TrimSpace(name) + if trimmed == "" { + return "entry" + } + + safe := sanitizeFilename(trimmed) + if safe == trimmed { + return safe + } + + sum := sha256.Sum256([]byte(trimmed)) + return fmt.Sprintf("%s_%s", safe, hex.EncodeToString(sum[:4])) +} + // GetStats returns current collection statistics func (c *Collector) GetStats() *CollectionStats { c.statsMu.Lock() diff --git a/internal/backup/collector_helpers_extra_test.go b/internal/backup/collector_helpers_extra_test.go index ae58f2a..a0aa38b 100644 --- a/internal/backup/collector_helpers_extra_test.go +++ b/internal/backup/collector_helpers_extra_test.go @@ -1,6 +1,9 @@ package backup -import "testing" +import ( + "strings" + "testing" +) func TestSummarizeCommandOutputText(t *testing.T) { if got := summarizeCommandOutputText(""); got != "(no stdout/stderr)" { @@ -38,3 +41,25 @@ func TestSanitizeFilenameExtra(t *testing.T) { } } } + +func TestCollectorPathKey(t *testing.T) { + if got := collectorPathKey("store1"); got != "store1" { + t.Fatalf("collectorPathKey(store1)=%q want %q", got, "store1") + } + + unsafe := "../evil" + got := collectorPathKey(unsafe) + if got == unsafe { + t.Fatalf("collectorPathKey(%q) should not keep unsafe value", unsafe) + } + if got == sanitizeFilename(unsafe) { + t.Fatalf("collectorPathKey(%q) should add a disambiguating suffix", unsafe) + } + if !strings.HasPrefix(got, "__evil") { + t.Fatalf("collectorPathKey(%q)=%q should start with sanitized prefix", unsafe, got) + } + + if a, b := collectorPathKey("a/b"), collectorPathKey("a_b"); a == b { + t.Fatalf("collectorPathKey should avoid collisions: %q == %q", a, b) + } +} diff --git a/internal/backup/collector_pbs.go b/internal/backup/collector_pbs.go index 98212e9..dc053b3 100644 --- a/internal/backup/collector_pbs.go +++ b/internal/backup/collector_pbs.go @@ -383,9 +383,10 @@ func (c *Collector) collectPBSCommands(ctx context.Context, datastores []pbsData // Datastore usage details if c.config.BackupDatastoreConfigs && len(datastores) > 0 { for _, ds := range datastores { + dsKey := ds.pathKey() c.safeCmdOutput(ctx, fmt.Sprintf("proxmox-backup-manager datastore show %s --output-format=json", ds.Name), - filepath.Join(commandsDir, fmt.Sprintf("datastore_%s_status.json", ds.Name)), + filepath.Join(commandsDir, fmt.Sprintf("datastore_%s_status.json", dsKey)), fmt.Sprintf("Datastore %s status", ds.Name), false) } diff --git a/internal/backup/collector_pbs_datastore.go b/internal/backup/collector_pbs_datastore.go index 0b5c2ba..9ba87b1 100644 --- a/internal/backup/collector_pbs_datastore.go +++ b/internal/backup/collector_pbs_datastore.go @@ -22,6 +22,10 @@ type pbsDatastore struct { Comment string } +func (ds pbsDatastore) pathKey() string { + return collectorPathKey(ds.Name) +} + var listNamespacesFunc = pbs.ListNamespaces // collectDatastoreConfigs collects detailed datastore configurations @@ -38,10 +42,12 @@ func (c *Collector) collectDatastoreConfigs(ctx context.Context, datastores []pb } for _, ds := range datastores { + dsKey := ds.pathKey() + // Get datastore configuration details c.safeCmdOutput(ctx, fmt.Sprintf("proxmox-backup-manager datastore show %s --output-format=json", ds.Name), - filepath.Join(datastoreDir, fmt.Sprintf("%s_config.json", ds.Name)), + filepath.Join(datastoreDir, fmt.Sprintf("%s_config.json", dsKey)), fmt.Sprintf("Datastore %s configuration", ds.Name), false) @@ -60,7 +66,7 @@ func (c *Collector) collectDatastoreConfigs(ctx context.Context, datastores []pb func (c *Collector) collectDatastoreNamespaces(ctx context.Context, ds pbsDatastore, datastoreDir string) error { c.logger.Debug("Collecting namespaces for datastore %s (path: %s)", ds.Name, ds.Path) // Write location is deterministic; if excluded, skip the whole operation. - outputPath := filepath.Join(datastoreDir, fmt.Sprintf("%s_namespaces.json", ds.Name)) + outputPath := filepath.Join(datastoreDir, fmt.Sprintf("%s_namespaces.json", ds.pathKey())) if c.shouldExclude(outputPath) { c.incFilesSkipped() return nil @@ -214,16 +220,17 @@ func (c *Collector) processPxarDatastore(ctx context.Context, ds pbsDatastore, m start := time.Now() c.logger.Debug("PXAR: scanning datastore %s at %s", ds.Name, ds.Path) - dsDir := filepath.Join(metaRoot, ds.Name) + dsKey := ds.pathKey() + dsDir := filepath.Join(metaRoot, dsKey) if err := c.ensureDir(dsDir); err != nil { return fmt.Errorf("failed to create PXAR metadata directory for %s: %w", ds.Name, err) } for _, base := range []string{ - filepath.Join(selectedRoot, ds.Name, "vm"), - filepath.Join(selectedRoot, ds.Name, "ct"), - filepath.Join(smallRoot, ds.Name, "vm"), - filepath.Join(smallRoot, ds.Name, "ct"), + filepath.Join(selectedRoot, dsKey, "vm"), + filepath.Join(selectedRoot, dsKey, "ct"), + filepath.Join(smallRoot, dsKey, "vm"), + filepath.Join(smallRoot, dsKey, "ct"), } { if err := c.ensureDir(base); err != nil { c.logger.Debug("Failed to prepare PXAR directory %s: %v", base, err) @@ -278,7 +285,7 @@ func (c *Collector) processPxarDatastore(ctx context.Context, ds pbsDatastore, m return err } - if err := c.writePxarSubdirReport(ctx, filepath.Join(dsDir, fmt.Sprintf("%s_subdirs.txt", ds.Name)), ds, ioTimeout); err != nil { + if err := c.writePxarSubdirReport(ctx, filepath.Join(dsDir, fmt.Sprintf("%s_subdirs.txt", dsKey)), ds, ioTimeout); err != nil { if errors.Is(err, safefs.ErrTimeout) { c.logger.Warning("Skipping PXAR metadata for datastore %s (path=%s): subdir report timed out (%v)", ds.Name, ds.Path, err) return nil @@ -286,7 +293,7 @@ func (c *Collector) processPxarDatastore(ctx context.Context, ds pbsDatastore, m return err } - if err := c.writePxarListReport(ctx, filepath.Join(dsDir, fmt.Sprintf("%s_vm_pxar_list.txt", ds.Name)), ds, "vm", ioTimeout); err != nil { + if err := c.writePxarListReport(ctx, filepath.Join(dsDir, fmt.Sprintf("%s_vm_pxar_list.txt", dsKey)), ds, "vm", ioTimeout); err != nil { if errors.Is(err, safefs.ErrTimeout) { c.logger.Warning("Skipping PXAR metadata for datastore %s (path=%s): VM list report timed out (%v)", ds.Name, ds.Path, err) return nil @@ -294,7 +301,7 @@ func (c *Collector) processPxarDatastore(ctx context.Context, ds pbsDatastore, m return err } - if err := c.writePxarListReport(ctx, filepath.Join(dsDir, fmt.Sprintf("%s_ct_pxar_list.txt", ds.Name)), ds, "ct", ioTimeout); err != nil { + if err := c.writePxarListReport(ctx, filepath.Join(dsDir, fmt.Sprintf("%s_ct_pxar_list.txt", dsKey)), ds, "ct", ioTimeout); err != nil { if errors.Is(err, safefs.ErrTimeout) { c.logger.Warning("Skipping PXAR metadata for datastore %s (path=%s): CT list report timed out (%v)", ds.Name, ds.Path, err) return nil diff --git a/internal/backup/collector_pbs_test.go b/internal/backup/collector_pbs_test.go index e44fddf..3a96797 100644 --- a/internal/backup/collector_pbs_test.go +++ b/internal/backup/collector_pbs_test.go @@ -12,6 +12,7 @@ import ( "time" "github.com/tis24dev/proxsave/internal/pbs" + "github.com/tis24dev/proxsave/internal/types" ) func TestGetDatastoreListSuccessWithOverrides(t *testing.T) { @@ -333,6 +334,166 @@ func TestCollectDatastoreConfigsDryRun(t *testing.T) { } } +func TestCollectDatastoreConfigs_UsesPathSafeKeyForUnsafeDatastoreName(t *testing.T) { + unsafeName := "../escape" + expectedKey := collectorPathKey(unsafeName) + + stubListNamespaces(t, func(_ context.Context, name, path string, _ time.Duration) ([]pbs.Namespace, bool, error) { + if name != unsafeName || path != "/fake" { + t.Fatalf("unexpected datastore args name=%q path=%q", name, path) + } + return []pbs.Namespace{{Ns: ""}}, false, nil + }) + + var seenArgs []string + collector := newTestCollectorWithDeps(t, CollectorDeps{ + LookPath: func(cmd string) (string, error) { + return "/usr/bin/" + cmd, nil + }, + RunCommand: func(_ context.Context, name string, args ...string) ([]byte, error) { + if name != "proxmox-backup-manager" { + return nil, fmt.Errorf("unexpected command %s", name) + } + seenArgs = append([]string(nil), args...) + return []byte(`{"ok":true}`), nil + }, + }) + + ds := pbsDatastore{Name: unsafeName, Path: "/fake"} + if err := collector.collectDatastoreConfigs(context.Background(), []pbsDatastore{ds}); err != nil { + t.Fatalf("collectDatastoreConfigs failed: %v", err) + } + + if len(seenArgs) < 3 || seenArgs[0] != "datastore" || seenArgs[1] != "show" || seenArgs[2] != unsafeName { + t.Fatalf("expected raw datastore name in command args, got %v", seenArgs) + } + + datastoreDir := filepath.Join(collector.tempDir, "var", "lib", "proxsave-info", "pbs", "datastores") + safeConfig := filepath.Join(datastoreDir, fmt.Sprintf("%s_config.json", expectedKey)) + safeNamespaces := filepath.Join(datastoreDir, fmt.Sprintf("%s_namespaces.json", expectedKey)) + for _, path := range []string{safeConfig, safeNamespaces} { + if _, err := os.Stat(path); err != nil { + t.Fatalf("expected safe output %s: %v", path, err) + } + } + + rawConfig := filepath.Join(datastoreDir, fmt.Sprintf("%s_config.json", unsafeName)) + rawNamespaces := filepath.Join(datastoreDir, fmt.Sprintf("%s_namespaces.json", unsafeName)) + for _, path := range []string{rawConfig, rawNamespaces} { + if path == safeConfig || path == safeNamespaces { + continue + } + if _, err := os.Stat(path); !os.IsNotExist(err) { + t.Fatalf("raw output path should not exist (%s), got err=%v", path, err) + } + } +} + +func TestCollectPBSPxarMetadata_UsesPathSafeKeyForUnsafeDatastoreName(t *testing.T) { + tmp := t.TempDir() + cfg := GetDefaultCollectorConfig() + collector := NewCollector(newTestLogger(), cfg, tmp, types.ProxmoxBS, false) + + dsPath := filepath.Join(tmp, "datastore") + for _, sub := range []string{"vm", "ct"} { + if err := os.MkdirAll(filepath.Join(dsPath, sub), 0o755); err != nil { + t.Fatalf("mkdir %s: %v", sub, err) + } + } + if err := os.WriteFile(filepath.Join(dsPath, "vm", "backup1.pxar"), []byte("data"), 0o640); err != nil { + t.Fatalf("write vm pxar: %v", err) + } + if err := os.WriteFile(filepath.Join(dsPath, "ct", "backup2.pxar"), []byte("data"), 0o640); err != nil { + t.Fatalf("write ct pxar: %v", err) + } + + ds := pbsDatastore{Name: "../escape", Path: dsPath, Comment: "unsafe"} + if err := collector.collectPBSPxarMetadata(context.Background(), []pbsDatastore{ds}); err != nil { + t.Fatalf("collectPBSPxarMetadata failed: %v", err) + } + + dsKey := collectorPathKey(ds.Name) + base := filepath.Join(tmp, "var/lib/proxsave-info", "pbs", "pxar", "metadata", dsKey) + for _, path := range []string{ + filepath.Join(base, "metadata.json"), + filepath.Join(base, fmt.Sprintf("%s_subdirs.txt", dsKey)), + filepath.Join(base, fmt.Sprintf("%s_vm_pxar_list.txt", dsKey)), + filepath.Join(base, fmt.Sprintf("%s_ct_pxar_list.txt", dsKey)), + } { + if _, err := os.Stat(path); err != nil { + t.Fatalf("expected safe PXAR output %s: %v", path, err) + } + } + + metaBytes, err := os.ReadFile(filepath.Join(base, "metadata.json")) + if err != nil { + t.Fatalf("read metadata.json: %v", err) + } + if !strings.Contains(string(metaBytes), ds.Name) { + t.Fatalf("metadata should keep raw datastore name, got %s", string(metaBytes)) + } + + selectedVM := filepath.Join(tmp, "var/lib/proxsave-info", "pbs", "pxar", "selected", dsKey, "vm") + smallVM := filepath.Join(tmp, "var/lib/proxsave-info", "pbs", "pxar", "small", dsKey, "vm") + for _, path := range []string{selectedVM, smallVM} { + if info, err := os.Stat(path); err != nil || !info.IsDir() { + t.Fatalf("expected safe PXAR directory %s, got err=%v", path, err) + } + } + + rawBase := filepath.Join(tmp, "var/lib/proxsave-info", "pbs", "pxar", "metadata", ds.Name) + if rawBase != base { + if _, err := os.Stat(rawBase); !os.IsNotExist(err) { + t.Fatalf("raw PXAR directory should not exist (%s), got err=%v", rawBase, err) + } + } +} + +func TestCollectPBSCommands_UsesPathSafeKeyForUnsafeDatastoreName(t *testing.T) { + pbsRoot := t.TempDir() + if err := os.WriteFile(filepath.Join(pbsRoot, "tape.cfg"), []byte("ok"), 0o640); err != nil { + t.Fatalf("write tape.cfg: %v", err) + } + + cfg := GetDefaultCollectorConfig() + cfg.PBSConfigPath = pbsRoot + + collector := NewCollectorWithDeps(newTestLogger(), cfg, t.TempDir(), types.ProxmoxBS, false, CollectorDeps{ + LookPath: func(name string) (string, error) { + return "/bin/" + name, nil + }, + RunCommand: func(_ context.Context, name string, args ...string) ([]byte, error) { + return []byte(fmt.Sprintf("%s %s", name, strings.Join(args, " "))), nil + }, + }) + + ds := pbsDatastore{Name: "../escape", Path: "/data/escape"} + if err := collector.collectPBSCommands(context.Background(), []pbsDatastore{ds}); err != nil { + t.Fatalf("collectPBSCommands error: %v", err) + } + + key := collectorPathKey(ds.Name) + commandsDir := filepath.Join(collector.tempDir, "var/lib/proxsave-info", "commands", "pbs") + safePath := filepath.Join(commandsDir, fmt.Sprintf("datastore_%s_status.json", key)) + if _, err := os.Stat(safePath); err != nil { + t.Fatalf("expected safe datastore status file: %v", err) + } + data, err := os.ReadFile(safePath) + if err != nil { + t.Fatalf("read datastore status file: %v", err) + } + if !strings.Contains(string(data), ds.Name) { + t.Fatalf("status file should reflect raw datastore name in command output, got %s", string(data)) + } + + rawPath := filepath.Join(commandsDir, fmt.Sprintf("datastore_%s_status.json", ds.Name)) + if rawPath != safePath { + if _, err := os.Stat(rawPath); !os.IsNotExist(err) { + t.Fatalf("raw datastore status path should not exist (%s), got err=%v", rawPath, err) + } + } +} + func TestCollectUserConfigsWithTokens(t *testing.T) { collector := newTestCollectorWithDeps(t, CollectorDeps{ LookPath: func(cmd string) (string, error) { diff --git a/internal/backup/collector_pve.go b/internal/backup/collector_pve.go index b8f9c7a..8fe16eb 100644 --- a/internal/backup/collector_pve.go +++ b/internal/backup/collector_pve.go @@ -30,6 +30,10 @@ type pveStorageEntry struct { Status string } +func (s pveStorageEntry) pathKey() string { + return collectorPathKey(s.Name) +} + type pveRuntimeInfo struct { Nodes []string Storages []pveStorageEntry @@ -1176,7 +1180,7 @@ func (c *Collector) collectPVEStorageMetadata(ctx context.Context, storages []pv storage.Path, storage.Content)) - metaDir := filepath.Join(baseDir, storage.Name) + metaDir := filepath.Join(baseDir, storage.pathKey()) if err := c.ensureDir(metaDir); err != nil { c.logger.Warning("Failed to create metadata directory for %s: %v", storage.Name, err) continue @@ -1318,9 +1322,11 @@ func (c *Collector) collectDetailedPVEBackups(ctx context.Context, storage pveSt var totalFiles int64 var totalSize int64 + storageKey := storage.pathKey() + var smallDir string if c.config.BackupSmallPVEBackups && c.config.MaxPVEBackupSizeBytes > 0 { - smallDir = filepath.Join(c.tempDir, "var/lib/pve-cluster/small_backups", storage.Name) + smallDir = filepath.Join(c.tempDir, "var/lib/pve-cluster/small_backups", storageKey) if err := c.ensureDir(smallDir); err != nil { c.logger.Warning("Cannot create small backups directory %s: %v", smallDir, err) smallDir = "" @@ -1330,7 +1336,7 @@ func (c *Collector) collectDetailedPVEBackups(ctx context.Context, storage pveSt includePattern := strings.TrimSpace(c.config.PVEBackupIncludePattern) var includeDir string if includePattern != "" { - includeDir = filepath.Join(c.tempDir, "var/lib/pve-cluster/selected_backups", storage.Name) + includeDir = filepath.Join(c.tempDir, "var/lib/pve-cluster/selected_backups", storageKey) if err := c.ensureDir(includeDir); err != nil { c.logger.Warning("Cannot create selected backups directory %s: %v", includeDir, err) includeDir = "" @@ -1467,7 +1473,7 @@ type patternWriter struct { func newPatternWriter(storageName, storagePath, analysisDir, pattern string, dryRun bool) (*patternWriter, error) { clean := cleanPatternName(pattern) - filename := fmt.Sprintf("%s_%s_list.txt", storageName, clean) + filename := fmt.Sprintf("%s_%s_list.txt", collectorPathKey(storageName), clean) filePath := filepath.Join(analysisDir, filename) // In dry-run mode, create a writer without an actual file @@ -1593,7 +1599,7 @@ func (c *Collector) writePatternSummary(storage pveStorageEntry, analysisDir str return nil } - summaryPath := filepath.Join(analysisDir, fmt.Sprintf("%s_backup_summary.txt", storage.Name)) + summaryPath := filepath.Join(analysisDir, fmt.Sprintf("%s_backup_summary.txt", storage.pathKey())) file, err := os.OpenFile(summaryPath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0640) if err != nil { return err diff --git a/internal/backup/collector_pve_test.go b/internal/backup/collector_pve_test.go index 2171b92..1a5e127 100644 --- a/internal/backup/collector_pve_test.go +++ b/internal/backup/collector_pve_test.go @@ -830,6 +830,54 @@ func TestCollectPVEStorageMetadata(t *testing.T) { } } +func TestCollectPVEStorageMetadata_UsesPathSafeKeyForUnsafeStorageName(t *testing.T) { + collector := newPVECollector(t) + collector.config.BackupSmallPVEBackups = true + collector.config.MaxPVEBackupSizeBytes = 1024 * 1024 + collector.config.PVEBackupIncludePattern = "vm-100" + + ctx := context.Background() + storageDir := t.TempDir() + if err := os.WriteFile(filepath.Join(storageDir, "vm-100-backup.vma"), []byte("content"), 0o644); err != nil { + t.Fatalf("write backup file: %v", err) + } + + storage := pveStorageEntry{Name: "../escape", Path: storageDir, Type: "dir"} + if err := collector.collectPVEStorageMetadata(ctx, []pveStorageEntry{storage}); err != nil { + t.Fatalf("collectPVEStorageMetadata returned error: %v", err) + } + + key := collectorPathKey(storage.Name) + baseDir := filepath.Join(collector.tempDir, "var/lib/pve-cluster/info/datastores", key) + for _, path := range []string{ + filepath.Join(baseDir, "metadata.json"), + filepath.Join(baseDir, "backup_analysis", fmt.Sprintf("%s_backup_summary.txt", key)), + filepath.Join(baseDir, "backup_analysis", fmt.Sprintf("%s__vma_list.txt", key)), + filepath.Join(collector.tempDir, "var/lib/pve-cluster/small_backups", key, "vm-100-backup.vma"), + filepath.Join(collector.tempDir, "var/lib/pve-cluster/selected_backups", key, "vm-100-backup.vma"), + } { + if _, err := os.Stat(path); err != nil { + t.Fatalf("expected safe output %s: %v", path, err) + } + } + + metaPath := filepath.Join(baseDir, "metadata.json") + metaBytes, err := os.ReadFile(metaPath) + if err != nil { + t.Fatalf("read metadata.json: %v", err) + } + if !strings.Contains(string(metaBytes), storage.Name) { + t.Fatalf("metadata should keep raw storage name, got %s", string(metaBytes)) + } + + rawMetaPath := filepath.Join(collector.tempDir, "var/lib/pve-cluster/info/datastores", storage.Name, "metadata.json") + if rawMetaPath != metaPath { + if _, err := os.Stat(rawMetaPath); !os.IsNotExist(err) { + t.Fatalf("raw storage metadata path should not exist (%s), got err=%v", rawMetaPath, err) + } + } +} + func TestCollectPVEStorageMetadata_SkipReasonsAreUserFriendly(t *testing.T) { boolPtr := func(v bool) *bool { b := v From 036e573e1a355c97fc0c0840017dae21d6e96e2c Mon Sep 17 00:00:00 2001 From: tis24dev Date: Mon, 9 Mar 2026 23:16:35 +0100 Subject: [PATCH 12/18] fix(io): bound timed fs and input goroutine buildup Prevent unbounded goroutine accumulation in timeout/cancel wrappers by limiting in-flight safefs operations and reusing a single in-flight read per reader/file descriptor in the input package. Keep public APIs unchanged, preserve current timeout semantics, and add regression coverage for limiter saturation, repeated timeouts, and race-safe cleanup of timed operations. --- internal/input/input.go | 137 ++++++++++++++++++++++++++++----- internal/input/input_test.go | 104 +++++++++++++++++++++++++ internal/safefs/safefs.go | 131 +++++++++++++++---------------- internal/safefs/safefs_test.go | 89 +++++++++++++++++++-- 4 files changed, 372 insertions(+), 89 deletions(-) diff --git a/internal/input/input.go b/internal/input/input.go index 15c87b1..6de043e 100644 --- a/internal/input/input.go +++ b/internal/input/input.go @@ -7,6 +7,7 @@ import ( "io" "os" "strings" + "sync" ) // ErrInputAborted signals that interactive input was interrupted (typically via Ctrl+C @@ -15,6 +16,39 @@ import ( // Callers should translate this into the appropriate workflow-level abort error. var ErrInputAborted = errors.New("input aborted") +type lineResult struct { + line string + err error +} + +type lineState struct { + mu sync.Mutex + inflight *lineInflight +} + +type lineInflight struct { + done chan lineResult +} + +type passwordResult struct { + b []byte + err error +} + +type passwordState struct { + mu sync.Mutex + inflight *passwordInflight +} + +type passwordInflight struct { + done chan passwordResult +} + +var ( + lineStates sync.Map + passwordStates sync.Map +) + // IsAborted reports whether an operation was aborted by the user (typically via Ctrl+C), // by checking for ErrInputAborted and context cancellation. func IsAborted(err error) bool { @@ -41,28 +75,77 @@ func MapInputError(err error) error { return err } +func mapContextInputError(ctx context.Context) error { + if ctx == nil || ctx.Err() == nil { + return nil + } + if errors.Is(ctx.Err(), context.DeadlineExceeded) { + return context.DeadlineExceeded + } + return ErrInputAborted +} + +func getLineState(reader *bufio.Reader) *lineState { + if state, ok := lineStates.Load(reader); ok { + return state.(*lineState) + } + state := &lineState{} + actual, _ := lineStates.LoadOrStore(reader, state) + return actual.(*lineState) +} + +func getPasswordState(fd int) *passwordState { + if state, ok := passwordStates.Load(fd); ok { + return state.(*passwordState) + } + state := &passwordState{} + actual, _ := passwordStates.LoadOrStore(fd, state) + return actual.(*passwordState) +} + // ReadLineWithContext reads a single line and supports cancellation. On ctx cancellation // or stdin closure it returns ErrInputAborted. On ctx deadline it returns context.DeadlineExceeded. +// Cancellation stops waiting but does not interrupt an already-started reader.ReadString call; +// at most one in-flight read is kept per reader to avoid goroutine buildup across retries. func ReadLineWithContext(ctx context.Context, reader *bufio.Reader) (string, error) { if ctx == nil { ctx = context.Background() } - type result struct { - line string - err error + if err := mapContextInputError(ctx); err != nil { + return "", err + } + if reader == nil { + return "", errors.New("reader is nil") } - ch := make(chan result, 1) - go func() { - line, err := reader.ReadString('\n') - ch <- result{line: line, err: MapInputError(err)} - }() + state := getLineState(reader) + + state.mu.Lock() + defer state.mu.Unlock() + if state.inflight == nil { + inflight := &lineInflight{done: make(chan lineResult, 1)} + state.inflight = inflight + go func() { + line, err := reader.ReadString('\n') + inflight.done <- lineResult{line: line, err: MapInputError(err)} + state.mu.Lock() + if state.inflight == inflight { + state.inflight = nil + } + state.mu.Unlock() + }() + } + inflight := state.inflight + select { case <-ctx.Done(): if errors.Is(ctx.Err(), context.DeadlineExceeded) { return "", context.DeadlineExceeded } return "", ErrInputAborted - case res := <-ch: + case res := <-inflight.done: + if state.inflight == inflight { + state.inflight = nil + } return res.line, res.err } } @@ -70,29 +153,47 @@ func ReadLineWithContext(ctx context.Context, reader *bufio.Reader) (string, err // ReadPasswordWithContext reads a password (no echo) and supports cancellation. On ctx // cancellation or stdin closure it returns ErrInputAborted. On ctx deadline it returns // context.DeadlineExceeded. +// Cancellation stops waiting but does not interrupt an already-started password read; +// at most one in-flight password read is kept per file descriptor to avoid goroutine buildup. func ReadPasswordWithContext(ctx context.Context, readPassword func(int) ([]byte, error), fd int) ([]byte, error) { if ctx == nil { ctx = context.Background() } + if err := mapContextInputError(ctx); err != nil { + return nil, err + } if readPassword == nil { return nil, errors.New("readPassword function is nil") } - type result struct { - b []byte - err error + state := getPasswordState(fd) + + state.mu.Lock() + defer state.mu.Unlock() + if state.inflight == nil { + inflight := &passwordInflight{done: make(chan passwordResult, 1)} + state.inflight = inflight + go func() { + b, err := readPassword(fd) + inflight.done <- passwordResult{b: b, err: MapInputError(err)} + state.mu.Lock() + if state.inflight == inflight { + state.inflight = nil + } + state.mu.Unlock() + }() } - ch := make(chan result, 1) - go func() { - b, err := readPassword(fd) - ch <- result{b: b, err: MapInputError(err)} - }() + inflight := state.inflight + select { case <-ctx.Done(): if errors.Is(ctx.Err(), context.DeadlineExceeded) { return nil, context.DeadlineExceeded } return nil, ErrInputAborted - case res := <-ch: + case res := <-inflight.done: + if state.inflight == inflight { + state.inflight = nil + } return res.b, res.err } } diff --git a/internal/input/input_test.go b/internal/input/input_test.go index 4113024..16d4feb 100644 --- a/internal/input/input_test.go +++ b/internal/input/input_test.go @@ -7,10 +7,28 @@ import ( "io" "os" "strings" + "sync/atomic" "testing" "time" ) +type blockingLineReader struct { + release chan struct{} + payload string + calls atomic.Int32 +} + +func (r *blockingLineReader) Read(p []byte) (int, error) { + r.calls.Add(1) + <-r.release + if r.payload == "" { + return 0, io.EOF + } + n := copy(p, r.payload) + r.payload = r.payload[n:] + return n, nil +} + func TestMapInputError(t *testing.T) { if MapInputError(nil) != nil { t.Fatalf("expected nil") @@ -208,3 +226,89 @@ func TestReadPasswordWithContext_DeadlineReturnsDeadlineExceeded(t *testing.T) { t.Fatalf("err=%v; want %v", err, context.DeadlineExceeded) } } + +func TestReadLineWithContext_ReusesInflightReadAfterTimeout(t *testing.T) { + src := &blockingLineReader{ + release: make(chan struct{}), + payload: "hello\n", + } + reader := bufio.NewReader(src) + + ctx1, cancel1 := context.WithTimeout(context.Background(), 25*time.Millisecond) + defer cancel1() + _, err := ReadLineWithContext(ctx1, reader) + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("first err=%v; want %v", err, context.DeadlineExceeded) + } + + ctx2, cancel2 := context.WithTimeout(context.Background(), 25*time.Millisecond) + defer cancel2() + _, err = ReadLineWithContext(ctx2, reader) + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("second err=%v; want %v", err, context.DeadlineExceeded) + } + + if got := src.calls.Load(); got != 1 { + t.Fatalf("underlying Read calls=%d; want 1", got) + } + + close(src.release) + + line, err := ReadLineWithContext(context.Background(), reader) + if err != nil { + t.Fatalf("third ReadLineWithContext error: %v", err) + } + if line != "hello\n" { + t.Fatalf("line=%q; want %q", line, "hello\n") + } + if got := src.calls.Load(); got != 1 { + t.Fatalf("underlying Read calls after release=%d; want 1", got) + } +} + +func TestReadPasswordWithContext_ReusesInflightReadAfterTimeout(t *testing.T) { + unblock := make(chan struct{}) + var calls atomic.Int32 + readPassword := func(fd int) ([]byte, error) { + calls.Add(1) + <-unblock + return []byte("secret"), nil + } + + ctx1, cancel1 := context.WithTimeout(context.Background(), 25*time.Millisecond) + defer cancel1() + got, err := ReadPasswordWithContext(ctx1, readPassword, 42) + if got != nil { + t.Fatalf("expected nil bytes on first deadline") + } + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("first err=%v; want %v", err, context.DeadlineExceeded) + } + + ctx2, cancel2 := context.WithTimeout(context.Background(), 25*time.Millisecond) + defer cancel2() + got, err = ReadPasswordWithContext(ctx2, readPassword, 42) + if got != nil { + t.Fatalf("expected nil bytes on second deadline") + } + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("second err=%v; want %v", err, context.DeadlineExceeded) + } + + if gotCalls := calls.Load(); gotCalls != 1 { + t.Fatalf("readPassword calls=%d; want 1", gotCalls) + } + + close(unblock) + + got, err = ReadPasswordWithContext(context.Background(), readPassword, 42) + if err != nil { + t.Fatalf("third ReadPasswordWithContext error: %v", err) + } + if string(got) != "secret" { + t.Fatalf("got=%q; want %q", string(got), "secret") + } + if gotCalls := calls.Load(); gotCalls != 1 { + t.Fatalf("readPassword calls after release=%d; want 1", gotCalls) + } +} diff --git a/internal/safefs/safefs.go b/internal/safefs/safefs.go index 36b001a..d64cfa0 100644 --- a/internal/safefs/safefs.go +++ b/internal/safefs/safefs.go @@ -14,6 +14,7 @@ var ( osStat = os.Stat osReadDir = os.ReadDir syscallStatfs = syscall.Statfs + fsOpLimiter = newOperationLimiter(32) ) // ErrTimeout is a sentinel error used to classify filesystem operations that did not @@ -56,100 +57,100 @@ func effectiveTimeout(ctx context.Context, timeout time.Duration) time.Duration return timeout } -func Stat(ctx context.Context, path string, timeout time.Duration) (fs.FileInfo, error) { - if err := ctx.Err(); err != nil { - return nil, err +// operationLimiter bounds the number of in-flight filesystem goroutines whose +// callers may already have returned due to timeout/cancellation. +type operationLimiter struct { + slots chan struct{} +} + +func newOperationLimiter(capacity int) *operationLimiter { + if capacity < 1 { + capacity = 1 } - timeout = effectiveTimeout(ctx, timeout) - if timeout <= 0 { - return osStat(path) + return &operationLimiter{ + slots: make(chan struct{}, capacity), } +} - type result struct { - info fs.FileInfo - err error +func (l *operationLimiter) acquire(ctx context.Context, timer <-chan time.Time) error { + select { + case l.slots <- struct{}{}: + return nil + case <-ctx.Done(): + return ctx.Err() + case <-timer: + return ErrTimeout } - ch := make(chan result, 1) - go func() { - info, err := osStat(path) - ch <- result{info: info, err: err} - }() - - timer := time.NewTimer(timeout) - defer timer.Stop() +} +func (l *operationLimiter) release() { select { - case r := <-ch: - return r.info, r.err - case <-ctx.Done(): - return nil, ctx.Err() - case <-timer.C: - return nil, &TimeoutError{Op: "stat", Path: path, Timeout: timeout} + case <-l.slots: + default: } } -func ReadDir(ctx context.Context, path string, timeout time.Duration) ([]os.DirEntry, error) { +func (l *operationLimiter) inflight() int { + return len(l.slots) +} + +func runLimited[T any](ctx context.Context, timeout time.Duration, timeoutErr *TimeoutError, run func() (T, error)) (T, error) { + var zero T if err := ctx.Err(); err != nil { - return nil, err + return zero, err } timeout = effectiveTimeout(ctx, timeout) if timeout <= 0 { - return osReadDir(path) - } - - type result struct { - entries []os.DirEntry - err error + return run() } - ch := make(chan result, 1) - go func() { - entries, err := osReadDir(path) - ch <- result{entries: entries, err: err} - }() timer := time.NewTimer(timeout) defer timer.Stop() - select { - case r := <-ch: - return r.entries, r.err - case <-ctx.Done(): - return nil, ctx.Err() - case <-timer.C: - return nil, &TimeoutError{Op: "readdir", Path: path, Timeout: timeout} - } -} - -func Statfs(ctx context.Context, path string, timeout time.Duration) (syscall.Statfs_t, error) { - if err := ctx.Err(); err != nil { - return syscall.Statfs_t{}, err - } - timeout = effectiveTimeout(ctx, timeout) - if timeout <= 0 { - var stat syscall.Statfs_t - return stat, syscallStatfs(path, &stat) + if err := fsOpLimiter.acquire(ctx, timer.C); err != nil { + if errors.Is(err, ErrTimeout) { + return zero, timeoutErr + } + return zero, err } type result struct { - stat syscall.Statfs_t - err error + value T + err error } ch := make(chan result, 1) go func() { - var stat syscall.Statfs_t - err := syscallStatfs(path, &stat) - ch <- result{stat: stat, err: err} + defer fsOpLimiter.release() + value, err := run() + ch <- result{value: value, err: err} }() - timer := time.NewTimer(timeout) - defer timer.Stop() - select { case r := <-ch: - return r.stat, r.err + return r.value, r.err case <-ctx.Done(): - return syscall.Statfs_t{}, ctx.Err() + return zero, ctx.Err() case <-timer.C: - return syscall.Statfs_t{}, &TimeoutError{Op: "statfs", Path: path, Timeout: timeout} + return zero, timeoutErr } } + +func Stat(ctx context.Context, path string, timeout time.Duration) (fs.FileInfo, error) { + return runLimited(ctx, timeout, &TimeoutError{Op: "stat", Path: path, Timeout: effectiveTimeout(ctx, timeout)}, func() (fs.FileInfo, error) { + return osStat(path) + }) +} + +func ReadDir(ctx context.Context, path string, timeout time.Duration) ([]os.DirEntry, error) { + return runLimited(ctx, timeout, &TimeoutError{Op: "readdir", Path: path, Timeout: effectiveTimeout(ctx, timeout)}, func() ([]os.DirEntry, error) { + return osReadDir(path) + }) +} + +func Statfs(ctx context.Context, path string, timeout time.Duration) (syscall.Statfs_t, error) { + return runLimited(ctx, timeout, &TimeoutError{Op: "statfs", Path: path, Timeout: effectiveTimeout(ctx, timeout)}, func() (syscall.Statfs_t, error) { + var stat syscall.Statfs_t + err := syscallStatfs(path, &stat) + return stat, err + }) +} diff --git a/internal/safefs/safefs_test.go b/internal/safefs/safefs_test.go index 30646ae..19ce1e7 100644 --- a/internal/safefs/safefs_test.go +++ b/internal/safefs/safefs_test.go @@ -4,17 +4,35 @@ import ( "context" "errors" "os" + "sync/atomic" "syscall" "testing" "time" ) +func waitForSignal(t *testing.T, ch <-chan struct{}, name string) { + t.Helper() + select { + case <-ch: + case <-time.After(500 * time.Millisecond): + t.Fatalf("timeout waiting for %s", name) + } +} + func TestStat_ReturnsTimeoutError(t *testing.T) { prev := osStat - defer func() { osStat = prev }() + unblock := make(chan struct{}) + finished := make(chan struct{}) + defer func() { + close(unblock) + waitForSignal(t, finished, "stat completion") + osStat = prev + }() osStat = func(string) (os.FileInfo, error) { - select {} + <-unblock + close(finished) + return nil, os.ErrNotExist } start := time.Now() @@ -29,10 +47,18 @@ func TestStat_ReturnsTimeoutError(t *testing.T) { func TestReadDir_ReturnsTimeoutError(t *testing.T) { prev := osReadDir - defer func() { osReadDir = prev }() + unblock := make(chan struct{}) + finished := make(chan struct{}) + defer func() { + close(unblock) + waitForSignal(t, finished, "readdir completion") + osReadDir = prev + }() osReadDir = func(string) ([]os.DirEntry, error) { - select {} + <-unblock + close(finished) + return nil, nil } start := time.Now() @@ -47,10 +73,18 @@ func TestReadDir_ReturnsTimeoutError(t *testing.T) { func TestStatfs_ReturnsTimeoutError(t *testing.T) { prev := syscallStatfs - defer func() { syscallStatfs = prev }() + unblock := make(chan struct{}) + finished := make(chan struct{}) + defer func() { + close(unblock) + waitForSignal(t, finished, "statfs completion") + syscallStatfs = prev + }() syscallStatfs = func(string, *syscall.Statfs_t) error { - select {} + <-unblock + close(finished) + return nil } start := time.Now() @@ -72,3 +106,46 @@ func TestStat_PropagatesContextCancellation(t *testing.T) { t.Fatalf("Stat err = %v; want context.Canceled", err) } } + +func TestStat_DoesNotSpawnPastLimiterCapacity(t *testing.T) { + prevStat := osStat + prevLimiter := fsOpLimiter + defer func() { + osStat = prevStat + fsOpLimiter = prevLimiter + }() + + fsOpLimiter = newOperationLimiter(1) + + unblock := make(chan struct{}) + finished := make(chan struct{}) + var calls atomic.Int32 + osStat = func(string) (os.FileInfo, error) { + calls.Add(1) + <-unblock + close(finished) + return nil, os.ErrNotExist + } + + _, err := Stat(context.Background(), "/first", 25*time.Millisecond) + if err == nil || !errors.Is(err, ErrTimeout) { + t.Fatalf("first Stat err = %v; want timeout", err) + } + if got := calls.Load(); got != 1 { + t.Fatalf("calls after first timeout = %d; want 1", got) + } + if got := fsOpLimiter.inflight(); got != 1 { + t.Fatalf("inflight after first timeout = %d; want 1", got) + } + + _, err = Stat(context.Background(), "/second", 25*time.Millisecond) + if err == nil || !errors.Is(err, ErrTimeout) { + t.Fatalf("second Stat err = %v; want timeout", err) + } + if got := calls.Load(); got != 1 { + t.Fatalf("calls after limiter saturation = %d; want 1", got) + } + + close(unblock) + waitForSignal(t, finished, "limited stat completion") +} From 74b51239adb340bdc2c4ffb730c376ced5519232 Mon Sep 17 00:00:00 2001 From: tis24dev Date: Mon, 9 Mar 2026 23:39:37 +0100 Subject: [PATCH 13/18] fix(orchestrator): preserve operational restore errors during path validation Teach the restore path resolver to distinguish security violations from local operational failures, so only real root-escape conditions are reported as illegal paths. Keep broken-symlink and traversal hardening intact, allow in-root permission and ENOTDIR failures to surface through the normal restore flow, and add regression tests covering the new security vs operational error classification. --- internal/orchestrator/path_security.go | 71 +++++++++++++++++++-- internal/orchestrator/path_security_test.go | 42 ++++++++++++ internal/orchestrator/restore.go | 7 +- internal/orchestrator/restore_test.go | 25 ++++++++ 4 files changed, 137 insertions(+), 8 deletions(-) diff --git a/internal/orchestrator/path_security.go b/internal/orchestrator/path_security.go index 938bc50..53bd020 100644 --- a/internal/orchestrator/path_security.go +++ b/internal/orchestrator/path_security.go @@ -1,6 +1,7 @@ package orchestrator import ( + "errors" "fmt" "os" "path/filepath" @@ -9,6 +10,62 @@ import ( const maxPathSecuritySymlinkHops = 40 +type pathResolutionErrorKind string + +const ( + pathResolutionErrorSecurity pathResolutionErrorKind = "security" + pathResolutionErrorOperational pathResolutionErrorKind = "operational" +) + +type pathResolutionError struct { + kind pathResolutionErrorKind + msg string + cause error +} + +func (e *pathResolutionError) Error() string { + if e == nil { + return "" + } + if e.cause != nil { + return fmt.Sprintf("%s: %v", e.msg, e.cause) + } + return e.msg +} + +func (e *pathResolutionError) Unwrap() error { + if e == nil { + return nil + } + return e.cause +} + +func newPathResolutionError(kind pathResolutionErrorKind, cause error, format string, args ...any) error { + return &pathResolutionError{ + kind: kind, + msg: fmt.Sprintf(format, args...), + cause: cause, + } +} + +func newPathSecurityError(format string, args ...any) error { + return newPathResolutionError(pathResolutionErrorSecurity, nil, format, args...) +} + +func wrapPathOperationalError(cause error, format string, args ...any) error { + return newPathResolutionError(pathResolutionErrorOperational, cause, format, args...) +} + +func isPathSecurityError(err error) bool { + var target *pathResolutionError + return errors.As(err, &target) && target.kind == pathResolutionErrorSecurity +} + +func isPathOperationalError(err error) bool { + var target *pathResolutionError + return errors.As(err, &target) && target.kind == pathResolutionErrorOperational +} + func resolvePathWithinRootFS(fsys FS, destRoot, candidate string) (string, error) { lexicalRoot, canonicalRoot, err := prepareRootPathsFS(fsys, destRoot) if err != nil { @@ -93,7 +150,7 @@ func normalizeAbsolutePathWithinRoot(lexicalRoot, canonicalRoot, candidateAbs st return filepath.Clean(filepath.Join(canonicalRoot, rel)), nil } - return "", fmt.Errorf("resolved path escapes destination: %s", candidateAbs) + return "", newPathSecurityError("resolved path escapes destination: %s", candidateAbs) } func resolvePathFromFilesystemRootFS(fsys FS, candidateAbs string, allowMissingTail bool, hopsRemaining int) (string, error) { @@ -123,7 +180,7 @@ func resolvePathWithinPreparedRootFS(fsys FS, lexicalRoot, canonicalRoot, candid return "", fmt.Errorf("cannot compute relative path: %w", err) } if !ok { - return "", fmt.Errorf("resolved path escapes destination: %s", candidateAbs) + return "", newPathSecurityError("resolved path escapes destination: %s", candidateAbs) } if rel == "." { return canonicalRoot, nil @@ -138,16 +195,16 @@ func resolvePathWithinPreparedRootFS(fsys FS, lexicalRoot, canonicalRoot, candid if allowMissingTail && os.IsNotExist(err) { return filepath.Clean(filepath.Join(current, filepath.Join(parts[idx:]...))), nil } - return "", fmt.Errorf("lstat %s: %w", next, err) + return "", wrapPathOperationalError(err, "lstat %s", next) } if info.Mode()&os.ModeSymlink != 0 { if hopsRemaining <= 0 { - return "", fmt.Errorf("too many symlink resolutions for %s", candidateAbs) + return "", newPathSecurityError("too many symlink resolutions for %s", candidateAbs) } target, err := fsys.Readlink(next) if err != nil { - return "", fmt.Errorf("readlink %s: %w", next, err) + return "", wrapPathOperationalError(err, "readlink %s", next) } var resolvedLink string @@ -164,7 +221,7 @@ func resolvePathWithinPreparedRootFS(fsys FS, lexicalRoot, canonicalRoot, candid if _, ok, err := relativePathWithinRoot(canonicalRoot, resolvedLink); err != nil { return "", fmt.Errorf("cannot compute relative path: %w", err) } else if !ok { - return "", fmt.Errorf("resolved path escapes destination: %s", resolvedLink) + return "", newPathSecurityError("resolved path escapes destination: %s", resolvedLink) } remainder := filepath.Join(parts[idx+1:]...) @@ -176,7 +233,7 @@ func resolvePathWithinPreparedRootFS(fsys FS, lexicalRoot, canonicalRoot, candid } if !info.IsDir() && idx < len(parts)-1 { - return "", fmt.Errorf("path component is not a directory: %s", next) + return "", newPathResolutionError(pathResolutionErrorOperational, nil, "path component is not a directory: %s", next) } current = next diff --git a/internal/orchestrator/path_security_test.go b/internal/orchestrator/path_security_test.go index c211000..a98873e 100644 --- a/internal/orchestrator/path_security_test.go +++ b/internal/orchestrator/path_security_test.go @@ -52,6 +52,48 @@ func TestResolvePathWithinRootFS_RejectsSymlinkLoop(t *testing.T) { } } +func TestResolvePathWithinRootFS_ClassifiesPathComponentNotDirectoryAsOperational(t *testing.T) { + root := t.TempDir() + blocker := filepath.Join(root, "deep") + if err := os.WriteFile(blocker, []byte("block"), 0o644); err != nil { + t.Fatalf("write blocker: %v", err) + } + + _, err := resolvePathWithinRootFS(osFS{}, root, filepath.Join("deep", "nested", "file.txt")) + if err == nil { + t.Fatal("expected error for non-directory path component") + } + if !isPathOperationalError(err) { + t.Fatalf("expected operational error classification, got %v", err) + } + if isPathSecurityError(err) { + t.Fatalf("expected non-security error classification, got %v", err) + } +} + +func TestResolvePathWithinRootFS_ClassifiesPermissionDeniedAsOperational(t *testing.T) { + fsys := NewFakeFS() + root := filepath.Join(string(os.PathSeparator), "restore-root") + if err := fsys.AddDir(root); err != nil { + t.Fatalf("add root dir: %v", err) + } + if err := fsys.AddDir(filepath.Join(root, "subdir")); err != nil { + t.Fatalf("add subdir: %v", err) + } + fsys.StatErrors[filepath.Clean(filepath.Join(root, "subdir", "file.txt"))] = os.ErrPermission + + _, err := resolvePathWithinRootFS(fsys, root, filepath.Join("subdir", "file.txt")) + if err == nil { + t.Fatal("expected permission error") + } + if !isPathOperationalError(err) { + t.Fatalf("expected operational error classification, got %v", err) + } + if isPathSecurityError(err) { + t.Fatalf("expected non-security error classification, got %v", err) + } +} + func TestResolvePathWithinRootFS_AllowsAbsoluteSymlinkTargetViaLexicalRoot(t *testing.T) { parent := t.TempDir() realRoot := filepath.Join(parent, "real-root") diff --git a/internal/orchestrator/restore.go b/internal/orchestrator/restore.go index 00a7a14..a390360 100644 --- a/internal/orchestrator/restore.go +++ b/internal/orchestrator/restore.go @@ -1495,7 +1495,12 @@ func sanitizeRestoreEntryTargetWithFS(fsys FS, destRoot, entryName string) (stri } if _, err := resolvePathWithinRootFS(fsys, absDestRoot, absTarget); err != nil { - return "", "", fmt.Errorf("illegal path: %s: %w", entryName, err) + if isPathSecurityError(err) { + return "", "", fmt.Errorf("illegal path: %s: %w", entryName, err) + } + if !isPathOperationalError(err) { + return "", "", fmt.Errorf("resolve extraction target: %w", err) + } } return absTarget, absDestRoot, nil diff --git a/internal/orchestrator/restore_test.go b/internal/orchestrator/restore_test.go index 9e80594..b5c6ae5 100644 --- a/internal/orchestrator/restore_test.go +++ b/internal/orchestrator/restore_test.go @@ -363,6 +363,31 @@ func TestExtractTarEntry_RejectsBrokenIntermediateSymlinkEscape(t *testing.T) { } } +func TestSanitizeRestoreEntryTargetWithFS_AllowsOperationalResolverErrorsWithinRoot(t *testing.T) { + fsys := NewFakeFS() + destRoot := filepath.Join(string(os.PathSeparator), "restore-root") + if err := fsys.AddDir(destRoot); err != nil { + t.Fatalf("add root dir: %v", err) + } + if err := fsys.AddDir(filepath.Join(destRoot, "subdir")); err != nil { + t.Fatalf("add subdir: %v", err) + } + fsys.StatErrors[filepath.Clean(filepath.Join(destRoot, "subdir", "file.txt"))] = os.ErrPermission + + target, cleanRoot, err := sanitizeRestoreEntryTargetWithFS(fsys, destRoot, filepath.Join("subdir", "file.txt")) + if err != nil { + t.Fatalf("sanitizeRestoreEntryTargetWithFS returned error: %v", err) + } + + wantTarget := filepath.Join(destRoot, "subdir", "file.txt") + if target != wantTarget { + t.Fatalf("target = %q, want %q", target, wantTarget) + } + if cleanRoot != filepath.Clean(destRoot) { + t.Fatalf("cleanRoot = %q, want %q", cleanRoot, filepath.Clean(destRoot)) + } +} + func TestExtractSymlink_RejectsBrokenIntermediateSymlinkEscape(t *testing.T) { logger := logging.New(types.LogLevelDebug, false) orig := restoreFS From 78dd8cb73b14372a06248b987308473e0be18ecd Mon Sep 17 00:00:00 2001 From: tis24dev Date: Tue, 10 Mar 2026 16:09:33 +0100 Subject: [PATCH 14/18] Separate PBS override scan roots from real datastore identities. - add explicit PBS datastore metadata for source, CLI identity and output key - derive stable path-based output keys for PBS_DATASTORE_PATH overrides - keep existing output names for real CLI-discovered PBS datastores - skip datastore show/status CLI calls for override-only entries - use filesystem-only namespace discovery for override paths - keep PXAR outputs isolated for colliding override basenames - preserve distinct override entries in PBS datastore inventory - exclude override-only inventory entries from datastore.cfg fallback restore - add regression coverage for basename collisions and restore safety - update PBS_DATASTORE_PATH documentation to reflect scan-root semantics --- docs/CONFIGURATION.md | 2 + docs/RESTORE_GUIDE.md | 3 +- internal/backup/collector_pbs.go | 11 +- .../collector_pbs_commands_coverage_test.go | 116 ++++++++++++++ internal/backup/collector_pbs_datastore.go | 147 +++++++++++++++--- .../collector_pbs_datastore_inventory.go | 135 +++++++++++++--- .../collector_pbs_datastore_inventory_test.go | 54 +++++++ internal/backup/collector_pbs_test.go | 133 ++++++++++++++++ internal/config/templates/backup.env | 2 +- internal/orchestrator/pbs_staged_apply.go | 9 ++ .../pbs_staged_apply_additional_test.go | 34 ++++ internal/pbs/namespaces.go | 6 + internal/pbs/namespaces_test.go | 15 ++ 13 files changed, 621 insertions(+), 46 deletions(-) diff --git a/docs/CONFIGURATION.md b/docs/CONFIGURATION.md index a3ef784..e0f3d82 100644 --- a/docs/CONFIGURATION.md +++ b/docs/CONFIGURATION.md @@ -1053,6 +1053,8 @@ VZDUMP_CONFIG_PATH=/etc/vzdump.conf # PBS datastore paths (comma/space separated) PBS_DATASTORE_PATH= # e.g., "/mnt/pbs1,/mnt/pbs2" +# Extra filesystem scan roots for datastore/PXAR discovery; these do not create +# real PBS datastore definitions and may use path-derived output keys. # System root override (testing/chroot) SYSTEM_ROOT_PREFIX= # Optional alternate root for system collection. Empty or "/" = real root. diff --git a/docs/RESTORE_GUIDE.md b/docs/RESTORE_GUIDE.md index 4c04c4c..53312fb 100644 --- a/docs/RESTORE_GUIDE.md +++ b/docs/RESTORE_GUIDE.md @@ -117,7 +117,7 @@ API apply is automatic for supported PBS staged categories; ProxSave may fall ba |----------|------|-------------|-------| | `pbs_config` | PBS Config Export | **Export-only** copy of /etc/proxmox-backup (never written to system) | `./etc/proxmox-backup/` | | `pbs_host` | PBS Host & Integrations | **Staged** node settings, ACME, proxy, metric servers and traffic control (API/file apply) | `./etc/proxmox-backup/node.cfg`
`./etc/proxmox-backup/proxy.cfg`
`./etc/proxmox-backup/acme/accounts.cfg`
`./etc/proxmox-backup/acme/plugins.cfg`
`./etc/proxmox-backup/metricserver.cfg`
`./etc/proxmox-backup/traffic-control.cfg`
`./var/lib/proxsave-info/commands/pbs/node_config.json`
`./var/lib/proxsave-info/commands/pbs/acme_accounts.json`
`./var/lib/proxsave-info/commands/pbs/acme_plugins.json`
`./var/lib/proxsave-info/commands/pbs/acme_account_*_info.json`
`./var/lib/proxsave-info/commands/pbs/acme_plugin_*_config.json`
`./var/lib/proxsave-info/commands/pbs/traffic_control.json` | -| `datastore_pbs` | PBS Datastore Configuration | **Staged** datastore definitions (incl. S3 endpoints) (API/file apply) | `./etc/proxmox-backup/datastore.cfg`
`./etc/proxmox-backup/s3.cfg`
`./var/lib/proxsave-info/commands/pbs/datastore_list.json`
`./var/lib/proxsave-info/commands/pbs/datastore_*_status.json`
`./var/lib/proxsave-info/commands/pbs/s3_endpoints.json`
`./var/lib/proxsave-info/commands/pbs/s3_endpoint_*_buckets.json`
`./var/lib/proxsave-info/commands/pbs/pbs_datastore_inventory.json` | +| `datastore_pbs` | PBS Datastore Configuration | **Staged** datastore definitions (incl. S3 endpoints) (API/file apply) | `./etc/proxmox-backup/datastore.cfg`
`./etc/proxmox-backup/s3.cfg`
`./var/lib/proxsave-info/commands/pbs/datastore_list.json`
`./var/lib/proxsave-info/commands/pbs/datastore_*_status.json`
`./var/lib/proxsave-info/commands/pbs/s3_endpoints.json`
`./var/lib/proxsave-info/commands/pbs/s3_endpoint_*_buckets.json`
`./var/lib/proxsave-info/commands/pbs/pbs_datastore_inventory.json`
Note: `PBS_DATASTORE_PATH` override scan roots are inventory context only and are not recreated as datastore definitions during restore. | | `maintenance_pbs` | PBS Maintenance | Maintenance settings | `./etc/proxmox-backup/maintenance.cfg` | | `pbs_jobs` | PBS Jobs | **Staged** sync/verify/prune jobs (API/file apply) | `./etc/proxmox-backup/sync.cfg`
`./etc/proxmox-backup/verification.cfg`
`./etc/proxmox-backup/prune.cfg`
`./var/lib/proxsave-info/commands/pbs/sync_jobs.json`
`./var/lib/proxsave-info/commands/pbs/verification_jobs.json`
`./var/lib/proxsave-info/commands/pbs/prune_jobs.json`
`./var/lib/proxsave-info/commands/pbs/gc_jobs.json` | | `pbs_remotes` | PBS Remotes | **Staged** remotes for sync/verify (may include credentials) (API/file apply) | `./etc/proxmox-backup/remote.cfg`
`./var/lib/proxsave-info/commands/pbs/remote_list.json` | @@ -2384,6 +2384,7 @@ systemctl restart proxmox-backup proxmox-backup-proxy **Restore behavior**: - ProxSave detects this condition during staged apply. - If `var/lib/proxsave-info/commands/pbs/pbs_datastore_inventory.json` is available in the backup, ProxSave will use its embedded snapshot of the original `datastore.cfg` to recover a valid configuration. +- Inventory entries that came only from `PBS_DATASTORE_PATH` scan roots are treated as diagnostic context and are excluded from regenerated `datastore.cfg`. - If recovery is not possible, ProxSave will **leave the existing** `/etc/proxmox-backup/datastore.cfg` unchanged to avoid breaking PBS. **Manual diagnosis**: diff --git a/internal/backup/collector_pbs.go b/internal/backup/collector_pbs.go index dc053b3..6be89e7 100644 --- a/internal/backup/collector_pbs.go +++ b/internal/backup/collector_pbs.go @@ -383,9 +383,18 @@ func (c *Collector) collectPBSCommands(ctx context.Context, datastores []pbsData // Datastore usage details if c.config.BackupDatastoreConfigs && len(datastores) > 0 { for _, ds := range datastores { + if ds.isOverride() { + c.logger.Debug("Skipping datastore status for %s (path=%s): no PBS datastore identity", ds.Name, ds.Path) + continue + } + cliName := ds.cliName() + if cliName == "" { + c.logger.Debug("Skipping datastore status for %s (path=%s): empty PBS datastore identity", ds.Name, ds.Path) + continue + } dsKey := ds.pathKey() c.safeCmdOutput(ctx, - fmt.Sprintf("proxmox-backup-manager datastore show %s --output-format=json", ds.Name), + fmt.Sprintf("proxmox-backup-manager datastore show %s --output-format=json", cliName), filepath.Join(commandsDir, fmt.Sprintf("datastore_%s_status.json", dsKey)), fmt.Sprintf("Datastore %s status", ds.Name), false) diff --git a/internal/backup/collector_pbs_commands_coverage_test.go b/internal/backup/collector_pbs_commands_coverage_test.go index b96c44f..65c37a1 100644 --- a/internal/backup/collector_pbs_commands_coverage_test.go +++ b/internal/backup/collector_pbs_commands_coverage_test.go @@ -92,6 +92,42 @@ func TestCollectPBSCommandsWritesExpectedOutputs(t *testing.T) { } } +func TestCollectPBSCommandsSkipsStatusForOverrideOnlyEntries(t *testing.T) { + pbsRoot := t.TempDir() + if err := os.WriteFile(filepath.Join(pbsRoot, "tape.cfg"), []byte("ok"), 0o640); err != nil { + t.Fatalf("write tape.cfg: %v", err) + } + + cfg := GetDefaultCollectorConfig() + cfg.PBSConfigPath = pbsRoot + + collector := NewCollectorWithDeps(newTestLogger(), cfg, t.TempDir(), types.ProxmoxBS, false, CollectorDeps{ + LookPath: func(name string) (string, error) { + return "/bin/" + name, nil + }, + RunCommand: func(ctx context.Context, name string, args ...string) ([]byte, error) { + return []byte(fmt.Sprintf("%s %s", name, strings.Join(args, " "))), nil + }, + }) + + override := pbsDatastore{ + Name: "backup", + Path: "/mnt/a/backup", + Source: pbsDatastoreSourceOverride, + NormalizedPath: normalizePBSDatastorePath("/mnt/a/backup"), + OutputKey: buildPBSOverrideOutputKey("/mnt/a/backup"), + } + if err := collector.collectPBSCommands(context.Background(), []pbsDatastore{override}); err != nil { + t.Fatalf("collectPBSCommands error: %v", err) + } + + commandsDir := filepath.Join(collector.tempDir, "var/lib/proxsave-info", "commands", "pbs") + statusPath := filepath.Join(commandsDir, fmt.Sprintf("datastore_%s_status.json", override.pathKey())) + if _, err := os.Stat(statusPath); !os.IsNotExist(err) { + t.Fatalf("override datastore status file should not exist (%s), got err=%v", statusPath, err) + } +} + func TestCollectPBSCommandsReturnsErrorWhenCriticalVersionFails(t *testing.T) { cfg := GetDefaultCollectorConfig() cfg.PBSConfigPath = t.TempDir() @@ -189,6 +225,50 @@ func TestCollectPBSPxarMetadataProcessesMultipleDatastores(t *testing.T) { } } +func TestCollectPBSPxarMetadataSeparatesOverrideBasenameCollisions(t *testing.T) { + tmp := t.TempDir() + cfg := GetDefaultCollectorConfig() + cfg.PxarDatastoreConcurrency = 2 + + collector := NewCollector(newTestLogger(), cfg, tmp, types.ProxmoxBS, false) + + makeOverride := func(path string) pbsDatastore { + for _, sub := range []string{"vm", "ct"} { + if err := os.MkdirAll(filepath.Join(path, sub), 0o755); err != nil { + t.Fatalf("mkdir %s: %v", sub, err) + } + } + if err := os.WriteFile(filepath.Join(path, "vm", "backup.pxar"), []byte("data"), 0o640); err != nil { + t.Fatalf("write pxar: %v", err) + } + return pbsDatastore{ + Name: "backup", + Path: path, + Comment: "configured via PBS_DATASTORE_PATH", + Source: pbsDatastoreSourceOverride, + NormalizedPath: normalizePBSDatastorePath(path), + OutputKey: buildPBSOverrideOutputKey(path), + } + } + + ds1 := makeOverride(filepath.Join(tmp, "mnt", "a", "backup")) + ds2 := makeOverride(filepath.Join(tmp, "srv", "b", "backup")) + if ds1.pathKey() == ds2.pathKey() { + t.Fatalf("expected distinct path keys, got %q", ds1.pathKey()) + } + + if err := collector.collectPBSPxarMetadata(context.Background(), []pbsDatastore{ds1, ds2}); err != nil { + t.Fatalf("collectPBSPxarMetadata error: %v", err) + } + + for _, ds := range []pbsDatastore{ds1, ds2} { + base := filepath.Join(tmp, "var/lib/proxsave-info", "pbs", "pxar", "metadata", ds.pathKey()) + if _, err := os.Stat(filepath.Join(base, "metadata.json")); err != nil { + t.Fatalf("expected metadata for %s: %v", ds.pathKey(), err) + } + } +} + func TestCollectPBSPxarMetadataReturnsErrorWhenTempVarIsFile(t *testing.T) { tmp := t.TempDir() if err := os.WriteFile(filepath.Join(tmp, "var"), []byte("not-a-dir"), 0o640); err != nil { @@ -234,6 +314,42 @@ func TestCollectDatastoreConfigsCreatesConfigAndNamespaceFiles(t *testing.T) { } } +func TestCollectDatastoreConfigsCreatesDistinctNamespaceFilesForOverrideCollisions(t *testing.T) { + cfg := GetDefaultCollectorConfig() + tmp := t.TempDir() + collector := NewCollectorWithDeps(newTestLogger(), cfg, tmp, types.ProxmoxBS, false, CollectorDeps{}) + + makeOverride := func(path string) pbsDatastore { + if err := os.MkdirAll(filepath.Join(path, "local", "vm"), 0o755); err != nil { + t.Fatalf("mkdir override namespace fixture: %v", err) + } + return pbsDatastore{ + Name: "backup", + Path: path, + Comment: "configured via PBS_DATASTORE_PATH", + Source: pbsDatastoreSourceOverride, + NormalizedPath: normalizePBSDatastorePath(path), + OutputKey: buildPBSOverrideOutputKey(path), + } + } + + ds1 := makeOverride(filepath.Join(tmp, "mnt", "a", "backup")) + ds2 := makeOverride(filepath.Join(tmp, "srv", "b", "backup")) + if err := collector.collectDatastoreConfigs(context.Background(), []pbsDatastore{ds1, ds2}); err != nil { + t.Fatalf("collectDatastoreConfigs error: %v", err) + } + + datastoreDir := filepath.Join(tmp, "var/lib/proxsave-info", "pbs", "datastores") + for _, ds := range []pbsDatastore{ds1, ds2} { + if _, err := os.Stat(filepath.Join(datastoreDir, fmt.Sprintf("%s_namespaces.json", ds.pathKey()))); err != nil { + t.Fatalf("expected namespaces file for %s: %v", ds.pathKey(), err) + } + if _, err := os.Stat(filepath.Join(datastoreDir, fmt.Sprintf("%s_config.json", ds.pathKey()))); !os.IsNotExist(err) { + t.Fatalf("override config file should not exist for %s, got err=%v", ds.pathKey(), err) + } + } +} + func TestCollectUserTokensSkipsInvalidUserListJSON(t *testing.T) { tmp := t.TempDir() collector := NewCollector(newTestLogger(), GetDefaultCollectorConfig(), tmp, types.ProxmoxBS, false) diff --git a/internal/backup/collector_pbs_datastore.go b/internal/backup/collector_pbs_datastore.go index 9ba87b1..62199a8 100644 --- a/internal/backup/collector_pbs_datastore.go +++ b/internal/backup/collector_pbs_datastore.go @@ -2,6 +2,8 @@ package backup import ( "context" + "crypto/sha256" + "encoding/hex" "encoding/json" "errors" "fmt" @@ -16,17 +18,91 @@ import ( "github.com/tis24dev/proxsave/internal/safefs" ) +const ( + pbsDatastoreSourceCLI = "cli" + pbsDatastoreSourceOverride = "override" + pbsDatastoreSourceConfig = "config" + pbsDatastoreOriginMerged = "merged" +) + +var ( + pbsDatastoreNamePattern = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`) + listNamespacesFunc = pbs.ListNamespaces + discoverNamespacesFunc = pbs.DiscoverNamespacesFromFilesystem +) + type pbsDatastore struct { - Name string - Path string - Comment string + Name string + Path string + Comment string + Source string + CLIName string + NormalizedPath string + OutputKey string +} + +func normalizePBSDatastorePath(path string) string { + trimmed := strings.TrimSpace(path) + if trimmed == "" { + return "" + } + return filepath.Clean(trimmed) +} + +func buildPBSOverrideDisplayName(path string, idx int) string { + name := filepath.Base(normalizePBSDatastorePath(path)) + if name == "" || name == "." || name == string(os.PathSeparator) || !pbsDatastoreNamePattern.MatchString(name) { + return fmt.Sprintf("datastore_%d", idx+1) + } + return name +} + +func buildPBSOverrideOutputKey(path string) string { + normalized := normalizePBSDatastorePath(path) + if normalized == "" { + return "entry" + } + + label := filepath.Base(normalized) + if label == "" || label == "." || label == string(os.PathSeparator) || !pbsDatastoreNamePattern.MatchString(label) { + label = "datastore" + } + + sum := sha256.Sum256([]byte(normalized)) + return fmt.Sprintf("%s_%s", sanitizeFilename(label), hex.EncodeToString(sum[:4])) +} + +func (ds pbsDatastore) normalizedPath() string { + if path := strings.TrimSpace(ds.NormalizedPath); path != "" { + return path + } + return normalizePBSDatastorePath(ds.Path) } func (ds pbsDatastore) pathKey() string { + if key := strings.TrimSpace(ds.OutputKey); key != "" { + return key + } return collectorPathKey(ds.Name) } -var listNamespacesFunc = pbs.ListNamespaces +func (ds pbsDatastore) cliName() string { + if name := strings.TrimSpace(ds.CLIName); name != "" { + return name + } + return strings.TrimSpace(ds.Name) +} + +func (ds pbsDatastore) isOverride() bool { + return strings.TrimSpace(ds.Source) == pbsDatastoreSourceOverride +} + +func (ds pbsDatastore) inventoryOrigin() string { + if origin := strings.TrimSpace(ds.Source); origin != "" { + return origin + } + return pbsDatastoreSourceCLI +} // collectDatastoreConfigs collects detailed datastore configurations func (c *Collector) collectDatastoreConfigs(ctx context.Context, datastores []pbsDatastore) error { @@ -44,12 +120,16 @@ func (c *Collector) collectDatastoreConfigs(ctx context.Context, datastores []pb for _, ds := range datastores { dsKey := ds.pathKey() - // Get datastore configuration details - c.safeCmdOutput(ctx, - fmt.Sprintf("proxmox-backup-manager datastore show %s --output-format=json", ds.Name), - filepath.Join(datastoreDir, fmt.Sprintf("%s_config.json", dsKey)), - fmt.Sprintf("Datastore %s configuration", ds.Name), - false) + if cliName := ds.cliName(); cliName != "" && !ds.isOverride() { + // Get datastore configuration details for CLI-backed datastores only. + c.safeCmdOutput(ctx, + fmt.Sprintf("proxmox-backup-manager datastore show %s --output-format=json", cliName), + filepath.Join(datastoreDir, fmt.Sprintf("%s_config.json", dsKey)), + fmt.Sprintf("Datastore %s configuration", ds.Name), + false) + } else { + c.logger.Debug("Skipping datastore CLI config for %s (path=%s): no PBS datastore identity", ds.Name, ds.Path) + } // Get namespace list using CLI/Filesystem fallback if err := c.collectDatastoreNamespaces(ctx, ds, datastoreDir); err != nil { @@ -77,7 +157,17 @@ func (c *Collector) collectDatastoreNamespaces(ctx context.Context, ds pbsDatast ioTimeout = time.Duration(c.config.FsIoTimeoutSeconds) * time.Second } - namespaces, fromFallback, err := listNamespacesFunc(ctx, ds.Name, ds.Path, ioTimeout) + var ( + namespaces []pbs.Namespace + fromFallback bool + err error + ) + if ds.isOverride() { + namespaces, err = discoverNamespacesFunc(ctx, ds.normalizedPath(), ioTimeout) + fromFallback = true + } else { + namespaces, fromFallback, err = listNamespacesFunc(ctx, ds.cliName(), ds.Path, ioTimeout) + } if err != nil { return err } @@ -458,10 +548,15 @@ func (c *Collector) getDatastoreList(ctx context.Context) ([]pbsDatastore, error for _, entry := range entries { name := strings.TrimSpace(entry.Name) if name != "" { + path := strings.TrimSpace(entry.Path) datastores = append(datastores, pbsDatastore{ - Name: name, - Path: strings.TrimSpace(entry.Path), - Comment: strings.TrimSpace(entry.Comment), + Name: name, + Path: path, + Comment: strings.TrimSpace(entry.Comment), + Source: pbsDatastoreSourceCLI, + CLIName: name, + NormalizedPath: normalizePBSDatastorePath(path), + OutputKey: collectorPathKey(name), }) } } @@ -469,27 +564,31 @@ func (c *Collector) getDatastoreList(ctx context.Context) ([]pbsDatastore, error if len(c.config.PBSDatastorePaths) > 0 { existing := make(map[string]struct{}, len(datastores)) for _, ds := range datastores { - if ds.Path != "" { - existing[ds.Path] = struct{}{} + if normalized := ds.normalizedPath(); normalized != "" { + existing[normalized] = struct{}{} } } - validName := regexp.MustCompile(`^[a-zA-Z0-9_-]+$`) for idx, override := range c.config.PBSDatastorePaths { override = strings.TrimSpace(override) if override == "" { continue } - if _, ok := existing[override]; ok { + normalized := normalizePBSDatastorePath(override) + if normalized == "" { continue } - name := filepath.Base(filepath.Clean(override)) - if name == "" || name == "." || name == string(os.PathSeparator) || !validName.MatchString(name) { - name = fmt.Sprintf("datastore_%d", idx+1) + if _, ok := existing[normalized]; ok { + continue } + existing[normalized] = struct{}{} + name := buildPBSOverrideDisplayName(normalized, idx) datastores = append(datastores, pbsDatastore{ - Name: name, - Path: override, - Comment: "configured via PBS_DATASTORE_PATH", + Name: name, + Path: override, + Comment: "configured via PBS_DATASTORE_PATH", + Source: pbsDatastoreSourceOverride, + NormalizedPath: normalized, + OutputKey: buildPBSOverrideOutputKey(normalized), }) } } diff --git a/internal/backup/collector_pbs_datastore_inventory.go b/internal/backup/collector_pbs_datastore_inventory.go index 72bbb02..7e334d0 100644 --- a/internal/backup/collector_pbs_datastore_inventory.go +++ b/internal/backup/collector_pbs_datastore_inventory.go @@ -66,6 +66,9 @@ type pbsDatastoreInventoryEntry struct { Path string `json:"path,omitempty"` Comment string `json:"comment,omitempty"` Sources []string `json:"sources,omitempty"` + Origin string `json:"origin,omitempty"` + CLIName string `json:"cli_name,omitempty"` + OutputKey string `json:"output_key,omitempty"` StatPath string `json:"stat_path,omitempty"` PathOK bool `json:"path_ok,omitempty"` PathIsDir bool `json:"path_is_dir,omitempty"` @@ -216,10 +219,13 @@ func (c *Collector) collectPBSDatastoreInventory(ctx context.Context, cliDatasto for _, def := range merged { entry := pbsDatastoreInventoryEntry{ - Name: def.Name, - Path: def.Path, - Comment: def.Comment, - Sources: append([]string(nil), def.Sources...), + Name: def.Name, + Path: def.Path, + Comment: def.Comment, + Sources: append([]string(nil), def.Sources...), + Origin: def.Origin, + CLIName: def.CLIName, + OutputKey: def.OutputKey, } statPath := def.Path @@ -447,34 +453,85 @@ func (c *Collector) captureInventoryCommand(ctx context.Context, pretty string, } type pbsDatastoreDefinition struct { - Name string - Path string - Comment string - Sources []string + Name string + Path string + Comment string + Sources []string + Origin string + CLIName string + OutputKey string } func mergePBSDatastoreDefinitions(cli, config []pbsDatastore) []pbsDatastoreDefinition { merged := make(map[string]*pbsDatastoreDefinition) - add := func(ds pbsDatastore, source string) { + defKey := func(ds pbsDatastore) string { + if ds.isOverride() { + if normalized := ds.normalizedPath(); normalized != "" { + return "override-path:" + normalized + } + if path := strings.TrimSpace(ds.Path); path != "" { + return "override-path:" + path + } + return "" + } + name := strings.TrimSpace(ds.Name) if name == "" { + return "" + } + return "name:" + name + } + + add := func(ds pbsDatastore, source string) { + key := defKey(ds) + if key == "" { return } - entry := merged[name] + name := strings.TrimSpace(ds.Name) + path := strings.TrimSpace(ds.Path) + comment := strings.TrimSpace(ds.Comment) + origin := ds.inventoryOrigin() + cliName := strings.TrimSpace(ds.cliName()) + entry := merged[key] if entry == nil { - entry = &pbsDatastoreDefinition{Name: name} - merged[name] = entry + entry = &pbsDatastoreDefinition{ + Name: name, + Origin: origin, + CLIName: cliName, + OutputKey: strings.TrimSpace(ds.pathKey()), + } + merged[key] = entry } entry.Sources = append(entry.Sources, source) - if entry.Path == "" && strings.TrimSpace(ds.Path) != "" { - entry.Path = strings.TrimSpace(ds.Path) + if entry.Name == "" && name != "" { + entry.Name = name + } + if entry.Path == "" && path != "" { + entry.Path = path + } + if entry.Comment == "" && comment != "" { + entry.Comment = comment + } + if entry.CLIName == "" && !ds.isOverride() && cliName != "" { + entry.CLIName = cliName } - if entry.Comment == "" && strings.TrimSpace(ds.Comment) != "" { - entry.Comment = strings.TrimSpace(ds.Comment) + if entry.OutputKey == "" { + entry.OutputKey = strings.TrimSpace(ds.pathKey()) + } + + switch { + case entry.Origin == "": + entry.Origin = origin + case entry.Origin == pbsDatastoreSourceOverride || origin == pbsDatastoreSourceOverride: + if entry.Origin == "" { + entry.Origin = pbsDatastoreSourceOverride + } + case entry.Origin != origin: + entry.Origin = pbsDatastoreOriginMerged } } @@ -482,7 +539,11 @@ func mergePBSDatastoreDefinitions(cli, config []pbsDatastore) []pbsDatastoreDefi add(ds, "datastore.cfg") } for _, ds := range cli { - add(ds, "cli") + source := "cli" + if ds.isOverride() { + source = "override" + } + add(ds, source) } out := make([]pbsDatastoreDefinition, 0, len(merged)) @@ -495,11 +556,38 @@ func mergePBSDatastoreDefinitions(cli, config []pbsDatastore) []pbsDatastoreDefi } sort.Slice(out, func(i, j int) bool { - return out[i].Name < out[j].Name + if out[i].Name != out[j].Name { + return out[i].Name < out[j].Name + } + if p1, p2 := pbsDatastoreOriginSortPriority(out[i].Origin), pbsDatastoreOriginSortPriority(out[j].Origin); p1 != p2 { + return p1 < p2 + } + if out[i].Path != out[j].Path { + return out[i].Path < out[j].Path + } + if out[i].OutputKey != out[j].OutputKey { + return out[i].OutputKey < out[j].OutputKey + } + return out[i].CLIName < out[j].CLIName }) return out } +func pbsDatastoreOriginSortPriority(origin string) int { + switch strings.TrimSpace(origin) { + case pbsDatastoreOriginMerged: + return 0 + case pbsDatastoreSourceConfig: + return 1 + case pbsDatastoreSourceCLI: + return 2 + case pbsDatastoreSourceOverride: + return 3 + default: + return 4 + } +} + func parsePBSDatastoreCfg(contents string) []pbsDatastore { contents = strings.TrimSpace(contents) if contents == "" { @@ -536,7 +624,12 @@ func parsePBSDatastoreCfg(contents string) []pbsDatastore { if name == "" { continue } - current = &pbsDatastore{Name: name} + current = &pbsDatastore{ + Name: name, + Source: pbsDatastoreSourceConfig, + CLIName: name, + OutputKey: collectorPathKey(name), + } continue } @@ -560,6 +653,10 @@ func parsePBSDatastoreCfg(contents string) []pbsDatastore { } flush() + for i := range out { + out[i].NormalizedPath = normalizePBSDatastorePath(out[i].Path) + } + return out } diff --git a/internal/backup/collector_pbs_datastore_inventory_test.go b/internal/backup/collector_pbs_datastore_inventory_test.go index ca11128..b2c2345 100644 --- a/internal/backup/collector_pbs_datastore_inventory_test.go +++ b/internal/backup/collector_pbs_datastore_inventory_test.go @@ -295,3 +295,57 @@ func TestCollectPBSDatastoreInventoryCapturesHostCommands(t *testing.T) { t.Fatalf("expected findmnt output to be captured") } } + +func TestMergePBSDatastoreDefinitionsKeepsOverridesSeparate(t *testing.T) { + config := []pbsDatastore{{ + Name: "backup", + Path: "/real/backup", + Comment: "primary", + Source: pbsDatastoreSourceConfig, + CLIName: "backup", + NormalizedPath: normalizePBSDatastorePath("/real/backup"), + OutputKey: collectorPathKey("backup"), + }} + cli := []pbsDatastore{ + { + Name: "backup", + Path: "/real/backup", + Comment: "runtime", + Source: pbsDatastoreSourceCLI, + CLIName: "backup", + NormalizedPath: normalizePBSDatastorePath("/real/backup"), + OutputKey: collectorPathKey("backup"), + }, + { + Name: "backup", + Path: "/mnt/a/backup", + Comment: "configured via PBS_DATASTORE_PATH", + Source: pbsDatastoreSourceOverride, + NormalizedPath: normalizePBSDatastorePath("/mnt/a/backup"), + OutputKey: buildPBSOverrideOutputKey("/mnt/a/backup"), + }, + { + Name: "backup", + Path: "/srv/b/backup", + Comment: "configured via PBS_DATASTORE_PATH", + Source: pbsDatastoreSourceOverride, + NormalizedPath: normalizePBSDatastorePath("/srv/b/backup"), + OutputKey: buildPBSOverrideOutputKey("/srv/b/backup"), + }, + } + + merged := mergePBSDatastoreDefinitions(cli, config) + if len(merged) != 3 { + t.Fatalf("expected 3 merged entries, got %d: %+v", len(merged), merged) + } + + if merged[0].Origin != pbsDatastoreOriginMerged || merged[0].Path != "/real/backup" { + t.Fatalf("expected real datastore entry first, got %+v", merged[0]) + } + if merged[1].Origin != pbsDatastoreSourceOverride || merged[2].Origin != pbsDatastoreSourceOverride { + t.Fatalf("expected override entries after real datastore, got %+v", merged) + } + if merged[1].OutputKey == merged[2].OutputKey { + t.Fatalf("override output keys should differ, got %+v", merged) + } +} diff --git a/internal/backup/collector_pbs_test.go b/internal/backup/collector_pbs_test.go index 3a96797..3abb34e 100644 --- a/internal/backup/collector_pbs_test.go +++ b/internal/backup/collector_pbs_test.go @@ -63,6 +63,47 @@ func TestGetDatastoreListSuccessWithOverrides(t *testing.T) { if datastores[2].Comment != "configured via PBS_DATASTORE_PATH" { t.Fatalf("expected override comment, got %q", datastores[2].Comment) } + if datastores[0].Source != pbsDatastoreSourceCLI || datastores[0].CLIName != "primary" || datastores[0].OutputKey != "primary" { + t.Fatalf("expected CLI datastore metadata, got %+v", datastores[0]) + } + if datastores[1].Source != pbsDatastoreSourceOverride || datastores[1].CLIName != "" || datastores[1].OutputKey == "" { + t.Fatalf("expected override datastore metadata, got %+v", datastores[1]) + } +} + +func TestGetDatastoreListOverrideCollisionsUseDistinctOutputKeys(t *testing.T) { + collector := newTestCollectorWithDeps(t, CollectorDeps{ + LookPath: func(cmd string) (string, error) { + return "/usr/bin/" + cmd, nil + }, + RunCommand: func(ctx context.Context, name string, args ...string) ([]byte, error) { + return []byte(`[{"name":"primary","path":"/data/primary/","comment":"main"}]`), nil + }, + }) + collector.config.PBSDatastorePaths = []string{ + "/mnt/a/backup", + "/srv/b/backup", + "/srv/b/backup/", + "/data/primary", + } + + datastores, err := collector.getDatastoreList(context.Background()) + if err != nil { + t.Fatalf("getDatastoreList failed: %v", err) + } + if len(datastores) != 3 { + t.Fatalf("expected 3 datastores after normalized dedupe, got %d: %+v", len(datastores), datastores) + } + + if datastores[1].Name != "backup" || datastores[2].Name != "backup" { + t.Fatalf("expected colliding override display names, got %+v", datastores) + } + if datastores[1].OutputKey == datastores[2].OutputKey { + t.Fatalf("override output keys should differ, got %q", datastores[1].OutputKey) + } + if datastores[1].NormalizedPath == datastores[2].NormalizedPath { + t.Fatalf("override normalized paths should differ, got %+v", datastores) + } } func TestGetDatastoreListContextCanceled(t *testing.T) { @@ -306,6 +347,50 @@ func TestCollectDatastoreNamespacesError(t *testing.T) { } } +func TestCollectDatastoreNamespacesOverrideUsesFilesystemOnly(t *testing.T) { + origList := listNamespacesFunc + t.Cleanup(func() { listNamespacesFunc = origList }) + listNamespacesFunc = func(context.Context, string, string, time.Duration) ([]pbs.Namespace, bool, error) { + t.Fatal("CLI namespace discovery should not be used for override paths") + return nil, false, nil + } + + collector := newTestCollectorWithDeps(t, CollectorDeps{}) + dsDir := filepath.Join(collector.tempDir, "datastores") + if err := os.MkdirAll(dsDir, 0o755); err != nil { + t.Fatalf("failed to create datastore dir: %v", err) + } + + dsPath := filepath.Join(collector.tempDir, "override") + if err := os.MkdirAll(filepath.Join(dsPath, "local", "vm"), 0o755); err != nil { + t.Fatalf("failed to create override namespace fixture: %v", err) + } + + ds := pbsDatastore{ + Name: "backup", + Path: dsPath, + Source: pbsDatastoreSourceOverride, + NormalizedPath: normalizePBSDatastorePath(dsPath), + OutputKey: buildPBSOverrideOutputKey(dsPath), + } + if err := collector.collectDatastoreNamespaces(context.Background(), ds, dsDir); err != nil { + t.Fatalf("collectDatastoreNamespaces failed: %v", err) + } + + data, err := os.ReadFile(filepath.Join(dsDir, fmt.Sprintf("%s_namespaces.json", ds.pathKey()))) + if err != nil { + t.Fatalf("namespaces file not created: %v", err) + } + + var namespaces []pbs.Namespace + if err := json.Unmarshal(data, &namespaces); err != nil { + t.Fatalf("failed to decode namespaces: %v", err) + } + if len(namespaces) != 2 || namespaces[1].Ns != "local" { + t.Fatalf("unexpected override namespaces: %+v", namespaces) + } +} + func TestCollectDatastoreConfigsDryRun(t *testing.T) { stubListNamespaces(t, func(context.Context, string, string, time.Duration) ([]pbs.Namespace, bool, error) { return []pbs.Namespace{{Ns: ""}}, false, nil @@ -389,6 +474,54 @@ func TestCollectDatastoreConfigs_UsesPathSafeKeyForUnsafeDatastoreName(t *testin } } +func TestCollectDatastoreConfigsSkipsCLIConfigForOverridePaths(t *testing.T) { + origList := listNamespacesFunc + t.Cleanup(func() { listNamespacesFunc = origList }) + listNamespacesFunc = func(context.Context, string, string, time.Duration) ([]pbs.Namespace, bool, error) { + t.Fatal("CLI namespace discovery should not be used for override paths") + return nil, false, nil + } + + dsPath := t.TempDir() + if err := os.MkdirAll(filepath.Join(dsPath, "vm"), 0o755); err != nil { + t.Fatalf("mkdir vm: %v", err) + } + + var runCalls int + collector := newTestCollectorWithDeps(t, CollectorDeps{ + LookPath: func(cmd string) (string, error) { + return "/usr/bin/" + cmd, nil + }, + RunCommand: func(_ context.Context, name string, args ...string) ([]byte, error) { + runCalls++ + return []byte(`{"ok":true}`), nil + }, + }) + + ds := pbsDatastore{ + Name: "backup", + Path: dsPath, + Comment: "configured via PBS_DATASTORE_PATH", + Source: pbsDatastoreSourceOverride, + NormalizedPath: normalizePBSDatastorePath(dsPath), + OutputKey: buildPBSOverrideOutputKey(dsPath), + } + if err := collector.collectDatastoreConfigs(context.Background(), []pbsDatastore{ds}); err != nil { + t.Fatalf("collectDatastoreConfigs failed: %v", err) + } + + datastoreDir := filepath.Join(collector.tempDir, "var", "lib", "proxsave-info", "pbs", "datastores") + if _, err := os.Stat(filepath.Join(datastoreDir, fmt.Sprintf("%s_namespaces.json", ds.pathKey()))); err != nil { + t.Fatalf("expected override namespaces file, got %v", err) + } + if _, err := os.Stat(filepath.Join(datastoreDir, fmt.Sprintf("%s_config.json", ds.pathKey()))); !os.IsNotExist(err) { + t.Fatalf("override config file should not exist, got err=%v", err) + } + if runCalls != 0 { + t.Fatalf("expected no CLI datastore show calls for override, got %d", runCalls) + } +} + func TestCollectPBSPxarMetadata_UsesPathSafeKeyForUnsafeDatastoreName(t *testing.T) { tmp := t.TempDir() cfg := GetDefaultCollectorConfig() diff --git a/internal/config/templates/backup.env b/internal/config/templates/backup.env index 2dbf810..ad5a9b0 100644 --- a/internal/config/templates/backup.env +++ b/internal/config/templates/backup.env @@ -318,7 +318,7 @@ PVE_CLUSTER_PATH=/var/lib/pve-cluster COROSYNC_CONFIG_PATH=${PVE_CONFIG_PATH}/corosync.conf VZDUMP_CONFIG_PATH=/etc/vzdump.conf PBS_CONFIG_PATH=/etc/proxmox-backup -PBS_DATASTORE_PATH= # Extra PBS paths separated by comma/space (e.g. /mnt/pbs1,/mnt/pbs2). Empty = auto-detect. +PBS_DATASTORE_PATH= # Extra PBS filesystem scan roots separated by comma/space (e.g. /mnt/pbs1,/mnt/pbs2). Not real datastore definitions; output keys may be path-derived. Empty = auto-detect. # System BACKUP_NETWORK_CONFIGS=true diff --git a/internal/orchestrator/pbs_staged_apply.go b/internal/orchestrator/pbs_staged_apply.go index 265910c..2bd91a1 100644 --- a/internal/orchestrator/pbs_staged_apply.go +++ b/internal/orchestrator/pbs_staged_apply.go @@ -359,6 +359,8 @@ type pbsDatastoreInventoryRestoreLite struct { Name string `json:"name"` Path string `json:"path"` Comment string `json:"comment"` + Origin string `json:"origin"` + CLIName string `json:"cli_name"` } `json:"datastores"` } @@ -387,7 +389,14 @@ func loadPBSDatastoreCfgFromInventory(stageRoot string) (string, string, error) // Fallback: generate a minimal datastore.cfg from the inventory's datastore list. var out strings.Builder for _, ds := range report.Datastores { + if strings.TrimSpace(ds.Origin) == "override" { + continue + } + name := strings.TrimSpace(ds.Name) + if name == "" { + name = strings.TrimSpace(ds.CLIName) + } path := strings.TrimSpace(ds.Path) if name == "" || path == "" { continue diff --git a/internal/orchestrator/pbs_staged_apply_additional_test.go b/internal/orchestrator/pbs_staged_apply_additional_test.go index 80cb253..5c0d1d2 100644 --- a/internal/orchestrator/pbs_staged_apply_additional_test.go +++ b/internal/orchestrator/pbs_staged_apply_additional_test.go @@ -280,6 +280,40 @@ func TestLoadPBSDatastoreCfgFromInventory_FallsBackToDatastoreList(t *testing.T) } } +func TestLoadPBSDatastoreCfgFromInventory_IgnoresOverrideEntries(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + stageRoot := "/stage" + inventory := `{"datastores":[{"name":"DS1","path":"/mnt/ds1","comment":"primary","origin":"merged"},{"name":"DS1","path":"/mnt/scan-root","comment":"configured via PBS_DATASTORE_PATH","origin":"override"}]}` + if err := fakeFS.WriteFile(stageRoot+"/var/lib/proxsave-info/commands/pbs/pbs_datastore_inventory.json", []byte(inventory), 0o640); err != nil { + t.Fatalf("write inventory: %v", err) + } + + content, src, err := loadPBSDatastoreCfgFromInventory(stageRoot) + if err != nil { + t.Fatalf("loadPBSDatastoreCfgFromInventory: %v", err) + } + if src != "pbs_datastore_inventory.json.datastores" { + t.Fatalf("src=%q", src) + } + if strings.Contains(content, "/mnt/scan-root") { + t.Fatalf("override entry should not be present in generated datastore.cfg: %s", content) + } + + blocks, err := parsePBSDatastoreCfgBlocks(content) + if err != nil { + t.Fatalf("parsePBSDatastoreCfgBlocks: %v", err) + } + if len(blocks) != 1 || blocks[0].Name != "DS1" || blocks[0].Path != "/mnt/ds1" { + t.Fatalf("unexpected blocks: %+v", blocks) + } +} + func TestLoadPBSDatastoreCfgFromInventory_PropagatesErrors(t *testing.T) { origFS := restoreFS t.Cleanup(func() { restoreFS = origFS }) diff --git a/internal/pbs/namespaces.go b/internal/pbs/namespaces.go index 345ac56..d168bef 100644 --- a/internal/pbs/namespaces.go +++ b/internal/pbs/namespaces.go @@ -46,6 +46,12 @@ func ListNamespaces(ctx context.Context, datastoreName, datastorePath string, io return namespaces, true, nil } +// DiscoverNamespacesFromFilesystem skips the PBS CLI and infers namespaces +// directly from the datastore filesystem layout. +func DiscoverNamespacesFromFilesystem(ctx context.Context, datastorePath string, ioTimeout time.Duration) ([]Namespace, error) { + return discoverNamespacesFromFilesystem(ctx, datastorePath, ioTimeout) +} + func listNamespacesViaCLI(ctx context.Context, datastore string) ([]Namespace, error) { if err := ctx.Err(); err != nil { return nil, err diff --git a/internal/pbs/namespaces_test.go b/internal/pbs/namespaces_test.go index ac358b9..f151cae 100644 --- a/internal/pbs/namespaces_test.go +++ b/internal/pbs/namespaces_test.go @@ -130,6 +130,21 @@ func TestDiscoverNamespacesFromFilesystem_Errors(t *testing.T) { } } +func TestDiscoverNamespacesFromFilesystemExportedHelper(t *testing.T) { + tmpDir := t.TempDir() + mustMkdirAll(t, filepath.Join(tmpDir, "prod", "vm")) + + namespaces, err := DiscoverNamespacesFromFilesystem(context.Background(), tmpDir, 0) + if err != nil { + t.Fatalf("DiscoverNamespacesFromFilesystem failed: %v", err) + } + + got := namespacesToMap(namespaces) + if _, ok := got["prod"]; !ok { + t.Fatalf("expected exported helper to discover namespace, got %+v", namespaces) + } +} + func TestListNamespaces_CLISuccess(t *testing.T) { setExecCommandStub(t, "cli-success") From 9e0fe52c6672d460ff4ea5d2c002678b7661327b Mon Sep 17 00:00:00 2001 From: tis24dev Date: Tue, 10 Mar 2026 16:57:26 +0100 Subject: [PATCH 15/18] fix(orchestrator): resolve symlinked base dirs before validating relative links resolvePathRelativeToBaseWithinRootFS now canonicalizes baseDir by resolving existing symlinks before validating relative link targets. This closes the bypass where an archive could create an apparently harmless parent symlink and then install a relative symlink that escapes destRoot once materialized by the kernel. --- internal/orchestrator/backup_safety_test.go | 101 ++++++++++++++++++++ internal/orchestrator/path_security.go | 17 +++- internal/orchestrator/path_security_test.go | 64 +++++++++++++ internal/orchestrator/restore_test.go | 66 +++++++++++++ 4 files changed, 247 insertions(+), 1 deletion(-) diff --git a/internal/orchestrator/backup_safety_test.go b/internal/orchestrator/backup_safety_test.go index 68c1eba..be0cc02 100644 --- a/internal/orchestrator/backup_safety_test.go +++ b/internal/orchestrator/backup_safety_test.go @@ -445,6 +445,107 @@ func TestRestoreSafetyBackup_RejectsBrokenIntermediateSymlinkEscape(t *testing.T } } +func TestRestoreSafetyBackup_RejectsEscapeWhenParentPathIsSymlink(t *testing.T) { + logger := logging.New(types.LogLevelInfo, false) + var logBuf bytes.Buffer + logger.SetOutput(&logBuf) + + tmpDir := t.TempDir() + backupPath := filepath.Join(tmpDir, "symlinked-parent-escape.tar.gz") + restoreDir := filepath.Join(tmpDir, "restore") + + var buf bytes.Buffer + gzw := gzip.NewWriter(&buf) + tw := tar.NewWriter(gzw) + + if err := tw.WriteHeader(&tar.Header{Name: "linkdir", Typeflag: tar.TypeSymlink, Linkname: "."}); err != nil { + t.Fatalf("write parent symlink header failed: %v", err) + } + if err := tw.WriteHeader(&tar.Header{Name: "linkdir/escape", Typeflag: tar.TypeSymlink, Linkname: "../outside"}); err != nil { + t.Fatalf("write escaping symlink header failed: %v", err) + } + + if err := tw.Close(); err != nil { + t.Fatalf("tar close failed: %v", err) + } + if err := gzw.Close(); err != nil { + t.Fatalf("gzip close failed: %v", err) + } + if err := os.WriteFile(backupPath, buf.Bytes(), 0o644); err != nil { + t.Fatalf("write archive failed: %v", err) + } + + if err := RestoreSafetyBackup(logger, backupPath, restoreDir); err != nil { + t.Fatalf("RestoreSafetyBackup failed: %v", err) + } + + linkTarget, err := os.Readlink(filepath.Join(restoreDir, "linkdir")) + if err != nil { + t.Fatalf("parent symlink should exist: %v", err) + } + if linkTarget != "." { + t.Fatalf("parent symlink target = %q, want %q", linkTarget, ".") + } + + if _, err := os.Lstat(filepath.Join(restoreDir, "escape")); !os.IsNotExist(err) { + t.Fatalf("escaping symlink should not be created, got err=%v", err) + } + if !strings.Contains(logBuf.String(), "Skipping symlink") { + t.Fatalf("expected symlink skip warning, logs=%q", logBuf.String()) + } +} + +func TestRestoreSafetyBackup_AllowsSafeTargetWhenParentPathIsSymlink(t *testing.T) { + logger := logging.New(types.LogLevelInfo, false) + + tmpDir := t.TempDir() + backupPath := filepath.Join(tmpDir, "symlinked-parent-safe.tar.gz") + restoreDir := filepath.Join(tmpDir, "restore") + + var buf bytes.Buffer + gzw := gzip.NewWriter(&buf) + tw := tar.NewWriter(gzw) + + if err := tw.WriteHeader(&tar.Header{Name: "subdir/", Mode: 0o755, Typeflag: tar.TypeDir}); err != nil { + t.Fatalf("write subdir header failed: %v", err) + } + content := []byte("ok") + if err := tw.WriteHeader(&tar.Header{Name: "subdir/file.txt", Mode: 0o644, Size: int64(len(content))}); err != nil { + t.Fatalf("write file header failed: %v", err) + } + if _, err := tw.Write(content); err != nil { + t.Fatalf("write file content failed: %v", err) + } + if err := tw.WriteHeader(&tar.Header{Name: "linkdir", Typeflag: tar.TypeSymlink, Linkname: "subdir"}); err != nil { + t.Fatalf("write parent symlink header failed: %v", err) + } + if err := tw.WriteHeader(&tar.Header{Name: "linkdir/ok", Typeflag: tar.TypeSymlink, Linkname: "file.txt"}); err != nil { + t.Fatalf("write safe symlink header failed: %v", err) + } + + if err := tw.Close(); err != nil { + t.Fatalf("tar close failed: %v", err) + } + if err := gzw.Close(); err != nil { + t.Fatalf("gzip close failed: %v", err) + } + if err := os.WriteFile(backupPath, buf.Bytes(), 0o644); err != nil { + t.Fatalf("write archive failed: %v", err) + } + + if err := RestoreSafetyBackup(logger, backupPath, restoreDir); err != nil { + t.Fatalf("RestoreSafetyBackup failed: %v", err) + } + + linkTarget, err := os.Readlink(filepath.Join(restoreDir, "subdir", "ok")) + if err != nil { + t.Fatalf("safe symlink should exist: %v", err) + } + if linkTarget != "file.txt" { + t.Fatalf("safe symlink target = %q, want %q", linkTarget, "file.txt") + } +} + func TestCleanupOldSafetyBackups(t *testing.T) { logger := logging.New(types.LogLevelInfo, false) diff --git a/internal/orchestrator/path_security.go b/internal/orchestrator/path_security.go index 53bd020..7b2a425 100644 --- a/internal/orchestrator/path_security.go +++ b/internal/orchestrator/path_security.go @@ -90,7 +90,22 @@ func resolvePathRelativeToBaseWithinRootFS(fsys FS, destRoot, baseDir, candidate if err != nil { return "", fmt.Errorf("resolve base directory: %w", err) } - canonicalBaseDir, err := normalizeAbsolutePathWithinRoot(lexicalRoot, canonicalRoot, baseAbs) + normalizedBaseDir, err := normalizeAbsolutePathWithinRoot(lexicalRoot, canonicalRoot, baseAbs) + if err != nil { + return "", err + } + + // Resolve symlinks already present in the base directory before validating + // a relative target against it. The kernel creates the new symlink under the + // resolved parent directory, not the lexical parent path. + canonicalBaseDir, err := resolvePathWithinPreparedRootFS( + fsys, + lexicalRoot, + canonicalRoot, + normalizedBaseDir, + true, + maxPathSecuritySymlinkHops, + ) if err != nil { return "", err } diff --git a/internal/orchestrator/path_security_test.go b/internal/orchestrator/path_security_test.go index a98873e..b3685b4 100644 --- a/internal/orchestrator/path_security_test.go +++ b/internal/orchestrator/path_security_test.go @@ -118,3 +118,67 @@ func TestResolvePathWithinRootFS_AllowsAbsoluteSymlinkTargetViaLexicalRoot(t *te t.Fatalf("resolved path = %q, want %q", resolved, want) } } + +func TestResolvePathRelativeToBaseWithinRootFS_RejectsEscapeViaSymlinkedBaseDir(t *testing.T) { + root := t.TempDir() + if err := os.Symlink(".", filepath.Join(root, "linkdir")); err != nil { + t.Fatalf("create base symlink: %v", err) + } + + _, err := resolvePathRelativeToBaseWithinRootFS(osFS{}, root, filepath.Join(root, "linkdir"), "../outside") + if err == nil || !strings.Contains(err.Error(), "escapes destination") { + t.Fatalf("expected escape rejection, got %v", err) + } +} + +func TestResolvePathRelativeToBaseWithinRootFS_AllowsRelativeTargetAfterSafeBaseResolution(t *testing.T) { + root := t.TempDir() + if err := os.MkdirAll(filepath.Join(root, "subdir"), 0o755); err != nil { + t.Fatalf("mkdir subdir: %v", err) + } + if err := os.Symlink("subdir", filepath.Join(root, "linkdir")); err != nil { + t.Fatalf("create base symlink: %v", err) + } + + resolved, err := resolvePathRelativeToBaseWithinRootFS(osFS{}, root, filepath.Join(root, "linkdir"), "file.txt") + if err != nil { + t.Fatalf("resolvePathRelativeToBaseWithinRootFS returned error: %v", err) + } + + want := filepath.Join(root, "subdir", "file.txt") + if resolved != want { + t.Fatalf("resolved path = %q, want %q", resolved, want) + } +} + +func TestResolvePathRelativeToBaseWithinRootFS_RejectsBaseDirSymlinkOutsideRoot(t *testing.T) { + root := t.TempDir() + outside := t.TempDir() + if err := os.Symlink(outside, filepath.Join(root, "linkdir")); err != nil { + t.Fatalf("create base symlink: %v", err) + } + + _, err := resolvePathRelativeToBaseWithinRootFS(osFS{}, root, filepath.Join(root, "linkdir"), "file.txt") + if err == nil || !strings.Contains(err.Error(), "escapes destination") { + t.Fatalf("expected escape rejection, got %v", err) + } +} + +func TestResolvePathRelativeToBaseWithinRootFS_PreservesAbsoluteCandidateBehavior(t *testing.T) { + root := t.TempDir() + if err := os.MkdirAll(filepath.Join(root, "subdir"), 0o755); err != nil { + t.Fatalf("mkdir subdir: %v", err) + } + if err := os.Symlink("subdir", filepath.Join(root, "linkdir")); err != nil { + t.Fatalf("create base symlink: %v", err) + } + + candidate := filepath.Join(root, "subdir", "file.txt") + resolved, err := resolvePathRelativeToBaseWithinRootFS(osFS{}, root, filepath.Join(root, "linkdir"), candidate) + if err != nil { + t.Fatalf("resolvePathRelativeToBaseWithinRootFS returned error: %v", err) + } + if resolved != candidate { + t.Fatalf("resolved path = %q, want %q", resolved, candidate) + } +} diff --git a/internal/orchestrator/restore_test.go b/internal/orchestrator/restore_test.go index b5c6ae5..d65e449 100644 --- a/internal/orchestrator/restore_test.go +++ b/internal/orchestrator/restore_test.go @@ -417,6 +417,72 @@ func TestExtractSymlink_RejectsBrokenIntermediateSymlinkEscape(t *testing.T) { } } +func TestExtractSymlink_RejectsEscapeWhenParentPathIsSymlink(t *testing.T) { + logger := logging.New(types.LogLevelDebug, false) + orig := restoreFS + restoreFS = osFS{} + t.Cleanup(func() { restoreFS = orig }) + + destRoot := t.TempDir() + if err := os.Symlink(".", filepath.Join(destRoot, "linkdir")); err != nil { + t.Fatalf("create parent symlink: %v", err) + } + + header := &tar.Header{ + Name: "linkdir/escape", + Typeflag: tar.TypeSymlink, + Linkname: "../outside", + } + target := filepath.Join(destRoot, header.Name) + + err := extractSymlink(target, header, destRoot, logger) + if err == nil || !strings.Contains(err.Error(), "escapes root") { + t.Fatalf("expected escapes root error, got %v", err) + } + + if _, err := os.Lstat(filepath.Join(destRoot, "escape")); !os.IsNotExist(err) { + t.Fatalf("escaping symlink should not be created, got err=%v", err) + } +} + +func TestExtractSymlink_AllowsSafeTargetWhenParentPathIsSymlink(t *testing.T) { + logger := logging.New(types.LogLevelDebug, false) + orig := restoreFS + restoreFS = osFS{} + t.Cleanup(func() { restoreFS = orig }) + + destRoot := t.TempDir() + if err := os.MkdirAll(filepath.Join(destRoot, "subdir"), 0o755); err != nil { + t.Fatalf("mkdir subdir: %v", err) + } + if err := os.Symlink("subdir", filepath.Join(destRoot, "linkdir")); err != nil { + t.Fatalf("create parent symlink: %v", err) + } + + header := &tar.Header{ + Name: "linkdir/ok", + Typeflag: tar.TypeSymlink, + Linkname: "file.txt", + } + target := filepath.Join(destRoot, header.Name) + + if err := extractSymlink(target, header, destRoot, logger); err != nil { + t.Fatalf("safe symlink should succeed: %v", err) + } + + linkTarget, err := os.Readlink(target) + if err != nil { + t.Fatalf("safe symlink should exist: %v", err) + } + if linkTarget != "file.txt" { + t.Fatalf("symlink target = %q, want %q", linkTarget, "file.txt") + } + + if _, err := os.Lstat(filepath.Join(destRoot, "subdir", "ok")); err != nil { + t.Fatalf("expected symlink at resolved parent path: %v", err) + } +} + func TestExtractTarEntry_DoesNotFollowExistingSymlinkTargetPath(t *testing.T) { logger := logging.New(types.LogLevelDebug, false) orig := restoreFS From 86958be07b776716750c5407550287761fbe331d Mon Sep 17 00:00:00 2001 From: tis24dev Date: Tue, 10 Mar 2026 17:50:27 +0100 Subject: [PATCH 16/18] fix(safefs): harden timeout test cleanup and limiter capture Make safefs timeout tests failure-safe by draining blocked workers from cleanup and separating worker cleanup from global state restore. Also capture limiter and fs operation hooks locally in runLimited/Stat/ReadDir/Statfs to avoid races with global resets while async work is still in flight. --- internal/safefs/safefs.go | 14 +++++++----- internal/safefs/safefs_test.go | 39 +++++++++++++++++----------------- 2 files changed, 29 insertions(+), 24 deletions(-) diff --git a/internal/safefs/safefs.go b/internal/safefs/safefs.go index d64cfa0..7bdd783 100644 --- a/internal/safefs/safefs.go +++ b/internal/safefs/safefs.go @@ -107,7 +107,8 @@ func runLimited[T any](ctx context.Context, timeout time.Duration, timeoutErr *T timer := time.NewTimer(timeout) defer timer.Stop() - if err := fsOpLimiter.acquire(ctx, timer.C); err != nil { + limiter := fsOpLimiter + if err := limiter.acquire(ctx, timer.C); err != nil { if errors.Is(err, ErrTimeout) { return zero, timeoutErr } @@ -120,7 +121,7 @@ func runLimited[T any](ctx context.Context, timeout time.Duration, timeoutErr *T } ch := make(chan result, 1) go func() { - defer fsOpLimiter.release() + defer limiter.release() value, err := run() ch <- result{value: value, err: err} }() @@ -136,21 +137,24 @@ func runLimited[T any](ctx context.Context, timeout time.Duration, timeoutErr *T } func Stat(ctx context.Context, path string, timeout time.Duration) (fs.FileInfo, error) { + stat := osStat return runLimited(ctx, timeout, &TimeoutError{Op: "stat", Path: path, Timeout: effectiveTimeout(ctx, timeout)}, func() (fs.FileInfo, error) { - return osStat(path) + return stat(path) }) } func ReadDir(ctx context.Context, path string, timeout time.Duration) ([]os.DirEntry, error) { + readDir := osReadDir return runLimited(ctx, timeout, &TimeoutError{Op: "readdir", Path: path, Timeout: effectiveTimeout(ctx, timeout)}, func() ([]os.DirEntry, error) { - return osReadDir(path) + return readDir(path) }) } func Statfs(ctx context.Context, path string, timeout time.Duration) (syscall.Statfs_t, error) { + statfs := syscallStatfs return runLimited(ctx, timeout, &TimeoutError{Op: "statfs", Path: path, Timeout: effectiveTimeout(ctx, timeout)}, func() (syscall.Statfs_t, error) { var stat syscall.Statfs_t - err := syscallStatfs(path, &stat) + err := statfs(path, &stat) return stat, err }) } diff --git a/internal/safefs/safefs_test.go b/internal/safefs/safefs_test.go index 19ce1e7..aa1b863 100644 --- a/internal/safefs/safefs_test.go +++ b/internal/safefs/safefs_test.go @@ -19,15 +19,23 @@ func waitForSignal(t *testing.T, ch <-chan struct{}, name string) { } } +func registerBlockedOpCleanup(t *testing.T, name string, unblock chan struct{}, finished <-chan struct{}, restore func()) { + t.Helper() + + t.Cleanup(restore) + t.Cleanup(func() { + close(unblock) + waitForSignal(t, finished, name) + }) +} + func TestStat_ReturnsTimeoutError(t *testing.T) { prev := osStat unblock := make(chan struct{}) finished := make(chan struct{}) - defer func() { - close(unblock) - waitForSignal(t, finished, "stat completion") + registerBlockedOpCleanup(t, "stat completion", unblock, finished, func() { osStat = prev - }() + }) osStat = func(string) (os.FileInfo, error) { <-unblock @@ -49,11 +57,9 @@ func TestReadDir_ReturnsTimeoutError(t *testing.T) { prev := osReadDir unblock := make(chan struct{}) finished := make(chan struct{}) - defer func() { - close(unblock) - waitForSignal(t, finished, "readdir completion") + registerBlockedOpCleanup(t, "readdir completion", unblock, finished, func() { osReadDir = prev - }() + }) osReadDir = func(string) ([]os.DirEntry, error) { <-unblock @@ -75,11 +81,9 @@ func TestStatfs_ReturnsTimeoutError(t *testing.T) { prev := syscallStatfs unblock := make(chan struct{}) finished := make(chan struct{}) - defer func() { - close(unblock) - waitForSignal(t, finished, "statfs completion") + registerBlockedOpCleanup(t, "statfs completion", unblock, finished, func() { syscallStatfs = prev - }() + }) syscallStatfs = func(string, *syscall.Statfs_t) error { <-unblock @@ -110,15 +114,15 @@ func TestStat_PropagatesContextCancellation(t *testing.T) { func TestStat_DoesNotSpawnPastLimiterCapacity(t *testing.T) { prevStat := osStat prevLimiter := fsOpLimiter - defer func() { + unblock := make(chan struct{}) + finished := make(chan struct{}) + registerBlockedOpCleanup(t, "limited stat completion", unblock, finished, func() { osStat = prevStat fsOpLimiter = prevLimiter - }() + }) fsOpLimiter = newOperationLimiter(1) - unblock := make(chan struct{}) - finished := make(chan struct{}) var calls atomic.Int32 osStat = func(string) (os.FileInfo, error) { calls.Add(1) @@ -145,7 +149,4 @@ func TestStat_DoesNotSpawnPastLimiterCapacity(t *testing.T) { if got := calls.Load(); got != 1 { t.Fatalf("calls after limiter saturation = %d; want 1", got) } - - close(unblock) - waitForSignal(t, finished, "limited stat completion") } From 36a84d6479b29f8647a427426886d68f35b76ac4 Mon Sep 17 00:00:00 2001 From: tis24dev Date: Tue, 10 Mar 2026 19:25:38 +0100 Subject: [PATCH 17/18] fix(input): preserve completed inflight reads across retries Keep completed in-flight line and password reads attached to their state until a caller consumes the buffered result, instead of clearing the inflight state from the producer goroutine. This fixes the race where a read could complete after a timeout, be cleaned up before the next retry started, and leave the completed input unreachable. Add deterministic tests for both retry-while-pending and completion-before-retry cases to lock the behavior down. --- internal/input/input.go | 30 ++--- internal/input/input_test.go | 235 ++++++++++++++++++++++++++++++++--- 2 files changed, 237 insertions(+), 28 deletions(-) diff --git a/internal/input/input.go b/internal/input/input.go index 6de043e..3df8280 100644 --- a/internal/input/input.go +++ b/internal/input/input.go @@ -27,7 +27,8 @@ type lineState struct { } type lineInflight struct { - done chan lineResult + done chan lineResult + completed chan struct{} } type passwordResult struct { @@ -41,7 +42,8 @@ type passwordState struct { } type passwordInflight struct { - done chan passwordResult + done chan passwordResult + completed chan struct{} } var ( @@ -107,6 +109,7 @@ func getPasswordState(fd int) *passwordState { // or stdin closure it returns ErrInputAborted. On ctx deadline it returns context.DeadlineExceeded. // Cancellation stops waiting but does not interrupt an already-started reader.ReadString call; // at most one in-flight read is kept per reader to avoid goroutine buildup across retries. +// A completed in-flight read remains attached to the reader until a later caller consumes it. func ReadLineWithContext(ctx context.Context, reader *bufio.Reader) (string, error) { if ctx == nil { ctx = context.Background() @@ -122,16 +125,15 @@ func ReadLineWithContext(ctx context.Context, reader *bufio.Reader) (string, err state.mu.Lock() defer state.mu.Unlock() if state.inflight == nil { - inflight := &lineInflight{done: make(chan lineResult, 1)} + inflight := &lineInflight{ + done: make(chan lineResult, 1), + completed: make(chan struct{}), + } state.inflight = inflight go func() { line, err := reader.ReadString('\n') inflight.done <- lineResult{line: line, err: MapInputError(err)} - state.mu.Lock() - if state.inflight == inflight { - state.inflight = nil - } - state.mu.Unlock() + close(inflight.completed) }() } inflight := state.inflight @@ -155,6 +157,7 @@ func ReadLineWithContext(ctx context.Context, reader *bufio.Reader) (string, err // context.DeadlineExceeded. // Cancellation stops waiting but does not interrupt an already-started password read; // at most one in-flight password read is kept per file descriptor to avoid goroutine buildup. +// A completed in-flight password read remains attached until a later caller consumes it. func ReadPasswordWithContext(ctx context.Context, readPassword func(int) ([]byte, error), fd int) ([]byte, error) { if ctx == nil { ctx = context.Background() @@ -170,16 +173,15 @@ func ReadPasswordWithContext(ctx context.Context, readPassword func(int) ([]byte state.mu.Lock() defer state.mu.Unlock() if state.inflight == nil { - inflight := &passwordInflight{done: make(chan passwordResult, 1)} + inflight := &passwordInflight{ + done: make(chan passwordResult, 1), + completed: make(chan struct{}), + } state.inflight = inflight go func() { b, err := readPassword(fd) inflight.done <- passwordResult{b: b, err: MapInputError(err)} - state.mu.Lock() - if state.inflight == inflight { - state.inflight = nil - } - state.mu.Unlock() + close(inflight.completed) }() } inflight := state.inflight diff --git a/internal/input/input_test.go b/internal/input/input_test.go index 16d4feb..747d3b7 100644 --- a/internal/input/input_test.go +++ b/internal/input/input_test.go @@ -13,22 +13,117 @@ import ( ) type blockingLineReader struct { - release chan struct{} - payload string - calls atomic.Int32 + release chan struct{} + finish chan struct{} + returned chan struct{} + payload string + calls atomic.Int32 } func (r *blockingLineReader) Read(p []byte) (int, error) { r.calls.Add(1) <-r.release + if r.finish != nil { + <-r.finish + } if r.payload == "" { + signalNonBlocking(r.returned) return 0, io.EOF } n := copy(p, r.payload) r.payload = r.payload[n:] + signalNonBlocking(r.returned) return n, nil } +type lineCallResult struct { + line string + err error +} + +type passwordCallResult struct { + b []byte + err error +} + +func signalNonBlocking(ch chan struct{}) { + if ch == nil { + return + } + select { + case ch <- struct{}{}: + default: + } +} + +func waitForSignal(t *testing.T, ch <-chan struct{}, name string) { + t.Helper() + select { + case <-ch: + case <-time.After(500 * time.Millisecond): + t.Fatalf("timed out waiting for %s", name) + } +} + +func waitForCondition(t *testing.T, name string, cond func() bool) { + t.Helper() + deadline := time.After(500 * time.Millisecond) + ticker := time.NewTicker(time.Millisecond) + defer ticker.Stop() + for { + if cond() { + return + } + select { + case <-deadline: + t.Fatalf("timed out waiting for %s", name) + case <-ticker.C: + } + } +} + +func currentLineInflight(t *testing.T, reader *bufio.Reader) *lineInflight { + t.Helper() + state := getLineState(reader) + state.mu.Lock() + defer state.mu.Unlock() + if state.inflight == nil { + t.Fatalf("expected line inflight state") + } + return state.inflight +} + +func currentPasswordInflight(t *testing.T, fd int) *passwordInflight { + t.Helper() + state := getPasswordState(fd) + state.mu.Lock() + defer state.mu.Unlock() + if state.inflight == nil { + t.Fatalf("expected password inflight state") + } + return state.inflight +} + +func assertSameLineInflight(t *testing.T, reader *bufio.Reader, want *lineInflight) { + t.Helper() + state := getLineState(reader) + state.mu.Lock() + defer state.mu.Unlock() + if state.inflight != want { + t.Fatalf("line inflight=%p; want %p", state.inflight, want) + } +} + +func assertSamePasswordInflight(t *testing.T, fd int, want *passwordInflight) { + t.Helper() + state := getPasswordState(fd) + state.mu.Lock() + defer state.mu.Unlock() + if state.inflight != want { + t.Fatalf("password inflight=%p; want %p", state.inflight, want) + } +} + func TestMapInputError(t *testing.T) { if MapInputError(nil) != nil { t.Fatalf("expected nil") @@ -227,10 +322,12 @@ func TestReadPasswordWithContext_DeadlineReturnsDeadlineExceeded(t *testing.T) { } } -func TestReadLineWithContext_ReusesInflightReadAfterTimeout(t *testing.T) { +func TestReadLineWithContext_ReusesInflightReadWhilePendingAfterTimeout(t *testing.T) { src := &blockingLineReader{ - release: make(chan struct{}), - payload: "hello\n", + release: make(chan struct{}), + finish: make(chan struct{}), + returned: make(chan struct{}, 1), + payload: "hello\n", } reader := bufio.NewReader(src) @@ -252,26 +349,80 @@ func TestReadLineWithContext_ReusesInflightReadAfterTimeout(t *testing.T) { t.Fatalf("underlying Read calls=%d; want 1", got) } + resultCh := make(chan lineCallResult, 1) + go func() { + line, err := ReadLineWithContext(context.Background(), reader) + resultCh <- lineCallResult{line: line, err: err} + }() + + state := getLineState(reader) + waitForCondition(t, "line retry to block on inflight read", func() bool { + if state.mu.TryLock() { + state.mu.Unlock() + return false + } + return true + }) + close(src.release) + close(src.finish) + waitForSignal(t, src.returned, "underlying line read completion") + + res := <-resultCh + if res.err != nil { + t.Fatalf("retry ReadLineWithContext error: %v", res.err) + } + if res.line != "hello\n" { + t.Fatalf("line=%q; want %q", res.line, "hello\n") + } + if got := src.calls.Load(); got != 1 { + t.Fatalf("underlying Read calls after pending retry=%d; want 1", got) + } +} + +func TestReadLineWithContext_PreservesCompletedReadForNextRetryAfterTimeout(t *testing.T) { + src := &blockingLineReader{ + release: make(chan struct{}), + returned: make(chan struct{}, 1), + payload: "hello\n", + } + reader := bufio.NewReader(src) + + ctx, cancel := context.WithTimeout(context.Background(), 25*time.Millisecond) + defer cancel() + _, err := ReadLineWithContext(ctx, reader) + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("first err=%v; want %v", err, context.DeadlineExceeded) + } + + inflight := currentLineInflight(t, reader) + close(src.release) + waitForSignal(t, src.returned, "underlying line read return") + waitForSignal(t, inflight.completed, "line inflight completion") + assertSameLineInflight(t, reader, inflight) line, err := ReadLineWithContext(context.Background(), reader) if err != nil { - t.Fatalf("third ReadLineWithContext error: %v", err) + t.Fatalf("retry ReadLineWithContext error: %v", err) } if line != "hello\n" { t.Fatalf("line=%q; want %q", line, "hello\n") } if got := src.calls.Load(); got != 1 { - t.Fatalf("underlying Read calls after release=%d; want 1", got) + t.Fatalf("underlying Read calls after completed retry=%d; want 1", got) } } -func TestReadPasswordWithContext_ReusesInflightReadAfterTimeout(t *testing.T) { - unblock := make(chan struct{}) +func TestReadPasswordWithContext_ReusesInflightReadWhilePendingAfterTimeout(t *testing.T) { + release := make(chan struct{}) + finish := make(chan struct{}) + returned := make(chan struct{}, 1) var calls atomic.Int32 readPassword := func(fd int) ([]byte, error) { calls.Add(1) - <-unblock + <-release + <-finish + signalNonBlocking(returned) return []byte("secret"), nil } @@ -299,16 +450,72 @@ func TestReadPasswordWithContext_ReusesInflightReadAfterTimeout(t *testing.T) { t.Fatalf("readPassword calls=%d; want 1", gotCalls) } - close(unblock) + resultCh := make(chan passwordCallResult, 1) + go func() { + got, err := ReadPasswordWithContext(context.Background(), readPassword, 42) + resultCh <- passwordCallResult{b: got, err: err} + }() + + state := getPasswordState(42) + waitForCondition(t, "password retry to block on inflight read", func() bool { + if state.mu.TryLock() { + state.mu.Unlock() + return false + } + return true + }) + + close(release) + close(finish) + waitForSignal(t, returned, "underlying password read completion") + + res := <-resultCh + if res.err != nil { + t.Fatalf("retry ReadPasswordWithContext error: %v", res.err) + } + if string(res.b) != "secret" { + t.Fatalf("got=%q; want %q", string(res.b), "secret") + } + if gotCalls := calls.Load(); gotCalls != 1 { + t.Fatalf("readPassword calls after pending retry=%d; want 1", gotCalls) + } +} + +func TestReadPasswordWithContext_PreservesCompletedReadForNextRetryAfterTimeout(t *testing.T) { + release := make(chan struct{}) + returned := make(chan struct{}, 1) + var calls atomic.Int32 + readPassword := func(fd int) ([]byte, error) { + calls.Add(1) + <-release + signalNonBlocking(returned) + return []byte("secret"), nil + } + + ctx, cancel := context.WithTimeout(context.Background(), 25*time.Millisecond) + defer cancel() + got, err := ReadPasswordWithContext(ctx, readPassword, 42) + if got != nil { + t.Fatalf("expected nil bytes on first deadline") + } + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("first err=%v; want %v", err, context.DeadlineExceeded) + } + + inflight := currentPasswordInflight(t, 42) + close(release) + waitForSignal(t, returned, "underlying password read return") + waitForSignal(t, inflight.completed, "password inflight completion") + assertSamePasswordInflight(t, 42, inflight) got, err = ReadPasswordWithContext(context.Background(), readPassword, 42) if err != nil { - t.Fatalf("third ReadPasswordWithContext error: %v", err) + t.Fatalf("retry ReadPasswordWithContext error: %v", err) } if string(got) != "secret" { t.Fatalf("got=%q; want %q", string(got), "secret") } if gotCalls := calls.Load(); gotCalls != 1 { - t.Fatalf("readPassword calls after release=%d; want 1", gotCalls) + t.Fatalf("readPassword calls after completed retry=%d; want 1", gotCalls) } } From fbec12fc308167710c098b5d1d5df149085d5d6f Mon Sep 17 00:00:00 2001 From: tis24dev Date: Tue, 10 Mar 2026 20:51:56 +0100 Subject: [PATCH 18/18] fix(pbs): disambiguate datastore output keys across CLI and overrides Add a PBS-specific output-key resolver to guarantee stable, unique datastore filenames and directories across auto-detected datastores and PBS_DATASTORE_PATH overrides. This fixes real collisions between CLI datastore names and path-derived override keys, makes pbsDatastore.pathKey() source-aware when OutputKey is unset, and keeps inventory output_key values aligned with the actual files written by PBS collectors. Includes regression tests for CLI-vs-override collisions, override fallback behavior, and inventory consistency. --- internal/backup/collector_pbs.go | 5 + internal/backup/collector_pbs_datastore.go | 202 +++++++++++++++++- .../collector_pbs_datastore_inventory.go | 17 +- .../collector_pbs_datastore_inventory_test.go | 53 +++++ internal/backup/collector_pbs_test.go | 110 ++++++++++ 5 files changed, 371 insertions(+), 16 deletions(-) diff --git a/internal/backup/collector_pbs.go b/internal/backup/collector_pbs.go index 6be89e7..8b23b7d 100644 --- a/internal/backup/collector_pbs.go +++ b/internal/backup/collector_pbs.go @@ -347,6 +347,11 @@ func (c *Collector) collectPBSDirectories(ctx context.Context, root string) erro // collectPBSCommands collects output from PBS commands func (c *Collector) collectPBSCommands(ctx context.Context, datastores []pbsDatastore) error { + if len(datastores) > 0 { + datastores = clonePBSDatastores(datastores) + assignUniquePBSDatastoreOutputKeys(datastores) + } + commandsDir := c.proxsaveCommandsDir("pbs") if err := c.ensureDir(commandsDir); err != nil { return fmt.Errorf("failed to create commands directory: %w", err) diff --git a/internal/backup/collector_pbs_datastore.go b/internal/backup/collector_pbs_datastore.go index 62199a8..49c4d3f 100644 --- a/internal/backup/collector_pbs_datastore.go +++ b/internal/backup/collector_pbs_datastore.go @@ -10,6 +10,7 @@ import ( "os" "path/filepath" "regexp" + "sort" "strings" "sync" "time" @@ -72,6 +73,199 @@ func buildPBSOverrideOutputKey(path string) string { return fmt.Sprintf("%s_%s", sanitizeFilename(label), hex.EncodeToString(sum[:4])) } +func pbsOutputKeyDigest(seed string) string { + sum := sha256.Sum256([]byte(seed)) + return hex.EncodeToString(sum[:4]) +} + +func pbsDatastoreIdentityKey(ds pbsDatastore) string { + if ds.isOverride() { + if normalized := ds.normalizedPath(); normalized != "" { + return "override-path:" + normalized + } + if path := strings.TrimSpace(ds.Path); path != "" { + return "override-path:" + path + } + return "" + } + + name := strings.TrimSpace(ds.Name) + if name == "" { + return "" + } + return "name:" + name +} + +func pbsDatastoreDefinitionIdentityKey(def pbsDatastoreDefinition) string { + if strings.TrimSpace(def.Origin) == pbsDatastoreSourceOverride { + if normalized := normalizePBSDatastorePath(def.Path); normalized != "" { + return "override-path:" + normalized + } + if path := strings.TrimSpace(def.Path); path != "" { + return "override-path:" + path + } + return "" + } + + name := strings.TrimSpace(def.Name) + if name == "" { + name = strings.TrimSpace(def.CLIName) + } + if name == "" { + return "" + } + return "name:" + name +} + +func pbsDatastoreCandidateOutputKey(ds pbsDatastore) string { + if ds.isOverride() { + if normalized := ds.normalizedPath(); normalized != "" { + return buildPBSOverrideOutputKey(normalized) + } + return buildPBSOverrideOutputKey(ds.Path) + } + return collectorPathKey(ds.Name) +} + +func pbsDatastoreDefinitionCandidateOutputKey(def pbsDatastoreDefinition) string { + if strings.TrimSpace(def.Origin) == pbsDatastoreSourceOverride { + if normalized := normalizePBSDatastorePath(def.Path); normalized != "" { + return buildPBSOverrideOutputKey(normalized) + } + return buildPBSOverrideOutputKey(def.Path) + } + + name := strings.TrimSpace(def.Name) + if name == "" { + name = strings.TrimSpace(def.CLIName) + } + return collectorPathKey(name) +} + +func pbsOutputKeyPriority(origin string) int { + if strings.TrimSpace(origin) == pbsDatastoreSourceOverride { + return 1 + } + return 0 +} + +type pbsOutputKeyAssignment struct { + Index int + Identity string + BaseKey string + Priority int +} + +func assignUniquePBSOutputKeys[T any](items []T, identityFn func(T) string, baseKeyFn func(T) string, priorityFn func(T) int, assignFn func(*T, string)) { + if len(items) == 0 { + return + } + + grouped := make(map[string][]pbsOutputKeyAssignment, len(items)) + baseKeys := make([]string, 0, len(items)) + for idx, item := range items { + baseKey := strings.TrimSpace(baseKeyFn(item)) + if baseKey == "" { + baseKey = "entry" + } + + identity := strings.TrimSpace(identityFn(item)) + if identity == "" { + identity = fmt.Sprintf("anonymous:%s:%d", baseKey, idx) + } + + if _, ok := grouped[baseKey]; !ok { + baseKeys = append(baseKeys, baseKey) + } + grouped[baseKey] = append(grouped[baseKey], pbsOutputKeyAssignment{ + Index: idx, + Identity: identity, + BaseKey: baseKey, + Priority: priorityFn(item), + }) + } + + sort.Strings(baseKeys) + + usedKeys := make(map[string]string, len(items)) + identityKeys := make(map[string]string, len(items)) + + for _, baseKey := range baseKeys { + assignments := grouped[baseKey] + sort.SliceStable(assignments, func(i, j int) bool { + if assignments[i].Priority != assignments[j].Priority { + return assignments[i].Priority < assignments[j].Priority + } + if assignments[i].Identity != assignments[j].Identity { + return assignments[i].Identity < assignments[j].Identity + } + return assignments[i].Index < assignments[j].Index + }) + + for pos, assignment := range assignments { + if existing := strings.TrimSpace(identityKeys[assignment.Identity]); existing != "" { + assignFn(&items[assignment.Index], existing) + continue + } + + preferBase := pos == 0 + for attempt := 0; ; attempt++ { + candidate := assignment.BaseKey + if !preferBase || attempt > 0 { + seed := assignment.Identity + if attempt > 0 { + seed = fmt.Sprintf("%s#%d", assignment.Identity, attempt) + } + candidate = fmt.Sprintf("%s_%s", assignment.BaseKey, pbsOutputKeyDigest(seed)) + } + + if owner, ok := usedKeys[candidate]; ok && owner != assignment.Identity { + continue + } + + usedKeys[candidate] = assignment.Identity + identityKeys[assignment.Identity] = candidate + assignFn(&items[assignment.Index], candidate) + break + } + } + } +} + +func assignUniquePBSDatastoreOutputKeys(datastores []pbsDatastore) { + assignUniquePBSOutputKeys(datastores, + pbsDatastoreIdentityKey, + pbsDatastoreCandidateOutputKey, + func(ds pbsDatastore) int { + return pbsOutputKeyPriority(ds.Source) + }, + func(ds *pbsDatastore, key string) { + ds.OutputKey = key + }) +} + +func assignUniquePBSDatastoreDefinitionOutputKeys(defs []pbsDatastoreDefinition) { + assignUniquePBSOutputKeys(defs, + pbsDatastoreDefinitionIdentityKey, + pbsDatastoreDefinitionCandidateOutputKey, + func(def pbsDatastoreDefinition) int { + return pbsOutputKeyPriority(def.Origin) + }, + func(def *pbsDatastoreDefinition, key string) { + def.OutputKey = key + }) +} + +func clonePBSDatastores(in []pbsDatastore) []pbsDatastore { + if len(in) == 0 { + return nil + } + + out := make([]pbsDatastore, len(in)) + copy(out, in) + return out +} + func (ds pbsDatastore) normalizedPath() string { if path := strings.TrimSpace(ds.NormalizedPath); path != "" { return path @@ -83,7 +277,7 @@ func (ds pbsDatastore) pathKey() string { if key := strings.TrimSpace(ds.OutputKey); key != "" { return key } - return collectorPathKey(ds.Name) + return pbsDatastoreCandidateOutputKey(ds) } func (ds pbsDatastore) cliName() string { @@ -110,6 +304,8 @@ func (c *Collector) collectDatastoreConfigs(ctx context.Context, datastores []pb c.logger.Debug("No datastores found") return nil } + datastores = clonePBSDatastores(datastores) + assignUniquePBSDatastoreOutputKeys(datastores) c.logger.Debug("Collecting datastore details for %d datastores", len(datastores)) datastoreDir := c.proxsaveInfoDir("pbs", "datastores") @@ -198,6 +394,8 @@ func (c *Collector) collectPBSPxarMetadata(ctx context.Context, datastores []pbs if len(datastores) == 0 { return nil } + datastores = clonePBSDatastores(datastores) + assignUniquePBSDatastoreOutputKeys(datastores) c.logger.Debug("Collecting PXAR metadata for %d datastores", len(datastores)) dsWorkers := c.config.PxarDatastoreConcurrency if dsWorkers <= 0 { @@ -593,6 +791,8 @@ func (c *Collector) getDatastoreList(ctx context.Context) ([]pbsDatastore, error } } + assignUniquePBSDatastoreOutputKeys(datastores) + c.logger.Debug("Detected %d configured datastores", len(datastores)) return datastores, nil } diff --git a/internal/backup/collector_pbs_datastore_inventory.go b/internal/backup/collector_pbs_datastore_inventory.go index 7e334d0..1b6ad85 100644 --- a/internal/backup/collector_pbs_datastore_inventory.go +++ b/internal/backup/collector_pbs_datastore_inventory.go @@ -466,21 +466,7 @@ func mergePBSDatastoreDefinitions(cli, config []pbsDatastore) []pbsDatastoreDefi merged := make(map[string]*pbsDatastoreDefinition) defKey := func(ds pbsDatastore) string { - if ds.isOverride() { - if normalized := ds.normalizedPath(); normalized != "" { - return "override-path:" + normalized - } - if path := strings.TrimSpace(ds.Path); path != "" { - return "override-path:" + path - } - return "" - } - - name := strings.TrimSpace(ds.Name) - if name == "" { - return "" - } - return "name:" + name + return pbsDatastoreIdentityKey(ds) } add := func(ds pbsDatastore, source string) { @@ -554,6 +540,7 @@ func mergePBSDatastoreDefinitions(cli, config []pbsDatastore) []pbsDatastoreDefi v.Sources = uniqueSortedStrings(v.Sources) out = append(out, *v) } + assignUniquePBSDatastoreDefinitionOutputKeys(out) sort.Slice(out, func(i, j int) bool { if out[i].Name != out[j].Name { diff --git a/internal/backup/collector_pbs_datastore_inventory_test.go b/internal/backup/collector_pbs_datastore_inventory_test.go index b2c2345..044c0b4 100644 --- a/internal/backup/collector_pbs_datastore_inventory_test.go +++ b/internal/backup/collector_pbs_datastore_inventory_test.go @@ -5,6 +5,7 @@ import ( "encoding/json" "os" "path/filepath" + "strings" "testing" "github.com/tis24dev/proxsave/internal/types" @@ -349,3 +350,55 @@ func TestMergePBSDatastoreDefinitionsKeepsOverridesSeparate(t *testing.T) { t.Fatalf("override output keys should differ, got %+v", merged) } } + +func TestMergePBSDatastoreDefinitionsDisambiguatesCLIAndOverrideOutputKeyCollisions(t *testing.T) { + overridePath := "/mnt/a/backup" + collidingKey := buildPBSOverrideOutputKey(overridePath) + + cli := []pbsDatastore{ + { + Name: collidingKey, + Path: "/real/runtime", + Comment: "runtime", + Source: pbsDatastoreSourceCLI, + CLIName: collidingKey, + NormalizedPath: normalizePBSDatastorePath("/real/runtime"), + OutputKey: collidingKey, + }, + { + Name: "backup", + Path: overridePath, + Comment: "configured via PBS_DATASTORE_PATH", + Source: pbsDatastoreSourceOverride, + NormalizedPath: normalizePBSDatastorePath(overridePath), + OutputKey: buildPBSOverrideOutputKey(overridePath), + }, + } + + merged := mergePBSDatastoreDefinitions(cli, nil) + if len(merged) != 2 { + t.Fatalf("expected 2 merged entries, got %d: %+v", len(merged), merged) + } + + var cliEntry, overrideEntry *pbsDatastoreDefinition + for i := range merged { + switch merged[i].Origin { + case pbsDatastoreSourceCLI: + cliEntry = &merged[i] + case pbsDatastoreSourceOverride: + overrideEntry = &merged[i] + } + } + if cliEntry == nil || overrideEntry == nil { + t.Fatalf("expected one CLI and one override entry, got %+v", merged) + } + if cliEntry.OutputKey != collidingKey { + t.Fatalf("CLI datastore should keep base key %q, got %+v", collidingKey, merged) + } + if overrideEntry.OutputKey == collidingKey { + t.Fatalf("override output key should be disambiguated, got %+v", merged) + } + if !strings.HasPrefix(overrideEntry.OutputKey, collidingKey+"_") { + t.Fatalf("override output key should extend colliding base key, got %+v", merged) + } +} diff --git a/internal/backup/collector_pbs_test.go b/internal/backup/collector_pbs_test.go index 3abb34e..0e2966a 100644 --- a/internal/backup/collector_pbs_test.go +++ b/internal/backup/collector_pbs_test.go @@ -106,6 +106,55 @@ func TestGetDatastoreListOverrideCollisionsUseDistinctOutputKeys(t *testing.T) { } } +func TestGetDatastoreListDisambiguatesCLIAndOverrideOutputKeyCollisions(t *testing.T) { + overridePath := "/mnt/a/backup" + collidingKey := buildPBSOverrideOutputKey(overridePath) + + collector := newTestCollectorWithDeps(t, CollectorDeps{ + LookPath: func(cmd string) (string, error) { + return "/usr/bin/" + cmd, nil + }, + RunCommand: func(ctx context.Context, name string, args ...string) ([]byte, error) { + return []byte(fmt.Sprintf(`[{"name":%q,"path":"/data/runtime","comment":"main"}]`, collidingKey)), nil + }, + }) + collector.config.PBSDatastorePaths = []string{overridePath} + + datastores, err := collector.getDatastoreList(context.Background()) + if err != nil { + t.Fatalf("getDatastoreList failed: %v", err) + } + if len(datastores) != 2 { + t.Fatalf("expected 2 datastores, got %d: %+v", len(datastores), datastores) + } + + cli := datastores[0] + override := datastores[1] + if cli.OutputKey != collidingKey { + t.Fatalf("CLI datastore should keep base key %q, got %+v", collidingKey, cli) + } + if override.OutputKey == collidingKey { + t.Fatalf("override key should be disambiguated away from %q, got %+v", collidingKey, override) + } + if override.OutputKey == cli.OutputKey { + t.Fatalf("datastore output keys should differ, got %+v", datastores) + } +} + +func TestPBSDatastorePathKeyUsesOverridePathFallback(t *testing.T) { + dsPath := "/mnt/a/backup" + ds := pbsDatastore{ + Name: "backup", + Path: dsPath, + Source: pbsDatastoreSourceOverride, + NormalizedPath: normalizePBSDatastorePath(dsPath), + } + + if got, want := ds.pathKey(), buildPBSOverrideOutputKey(dsPath); got != want { + t.Fatalf("override pathKey()=%q want %q", got, want) + } +} + func TestGetDatastoreListContextCanceled(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() @@ -522,6 +571,67 @@ func TestCollectDatastoreConfigsSkipsCLIConfigForOverridePaths(t *testing.T) { } } +func TestCollectDatastoreConfigsDisambiguatesManualCLIAndOverrideKeyCollisions(t *testing.T) { + overridePath := t.TempDir() + for _, sub := range []string{"vm", "ct"} { + if err := os.MkdirAll(filepath.Join(overridePath, sub), 0o755); err != nil { + t.Fatalf("mkdir %s: %v", sub, err) + } + } + + collidingKey := buildPBSOverrideOutputKey(overridePath) + stubListNamespaces(t, func(_ context.Context, name, path string, _ time.Duration) ([]pbs.Namespace, bool, error) { + return []pbs.Namespace{{Ns: name, Path: path}}, false, nil + }) + + origDiscover := discoverNamespacesFunc + t.Cleanup(func() { discoverNamespacesFunc = origDiscover }) + discoverNamespacesFunc = func(_ context.Context, path string, _ time.Duration) ([]pbs.Namespace, error) { + return []pbs.Namespace{{Ns: "override", Path: path}}, nil + } + + collector := newTestCollectorWithDeps(t, CollectorDeps{ + LookPath: func(cmd string) (string, error) { + return "/usr/bin/" + cmd, nil + }, + RunCommand: func(_ context.Context, name string, args ...string) ([]byte, error) { + return []byte(`{"ok":true}`), nil + }, + }) + + datastores := []pbsDatastore{ + {Name: collidingKey, Path: "/data/runtime"}, + { + Name: "backup", + Path: overridePath, + Source: pbsDatastoreSourceOverride, + NormalizedPath: normalizePBSDatastorePath(overridePath), + }, + } + if err := collector.collectDatastoreConfigs(context.Background(), datastores); err != nil { + t.Fatalf("collectDatastoreConfigs failed: %v", err) + } + + resolved := clonePBSDatastores(datastores) + assignUniquePBSDatastoreOutputKeys(resolved) + if resolved[0].OutputKey == resolved[1].OutputKey { + t.Fatalf("expected resolved keys to differ, got %+v", resolved) + } + + datastoreDir := filepath.Join(collector.tempDir, "var", "lib", "proxsave-info", "pbs", "datastores") + for _, suffix := range []string{"config.json", "namespaces.json"} { + if _, err := os.Stat(filepath.Join(datastoreDir, fmt.Sprintf("%s_%s", resolved[0].OutputKey, suffix))); err != nil { + t.Fatalf("expected CLI output for %s: %v", suffix, err) + } + } + if _, err := os.Stat(filepath.Join(datastoreDir, fmt.Sprintf("%s_namespaces.json", resolved[1].OutputKey))); err != nil { + t.Fatalf("expected override namespaces file: %v", err) + } + if _, err := os.Stat(filepath.Join(datastoreDir, fmt.Sprintf("%s_config.json", resolved[1].OutputKey))); !os.IsNotExist(err) { + t.Fatalf("override config file should not exist, got err=%v", err) + } +} + func TestCollectPBSPxarMetadata_UsesPathSafeKeyForUnsafeDatastoreName(t *testing.T) { tmp := t.TempDir() cfg := GetDefaultCollectorConfig()