diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 73f72f2..d4c1577 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -70,7 +70,7 @@ jobs: # GORELEASER ######################################## - name: Run GoReleaser - uses: goreleaser/goreleaser-action@v6 + uses: goreleaser/goreleaser-action@v7 with: version: latest workdir: ${{ github.workspace }} @@ -82,6 +82,6 @@ jobs: # ATTESTAZIONE PROVENIENZA BUILD ######################################## - name: Attest Build Provenance - uses: actions/attest-build-provenance@v3 + uses: actions/attest-build-provenance@v4 with: subject-path: build/proxsave_* diff --git a/Makefile b/Makefile index e3041d4..6d17a9d 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,8 @@ .PHONY: build test clean run build-release test-coverage lint fmt deps help coverage coverage-check COVERAGE_THRESHOLD ?= 50.0 +TOOLCHAIN_FROM_MOD := $(shell awk '/^toolchain /{print $$2}' go.mod 2>/dev/null) +COVER_GOTOOLCHAIN := $(if $(TOOLCHAIN_FROM_MOD),$(TOOLCHAIN_FROM_MOD)+auto,auto) # Build del progetto build: @@ -60,20 +62,20 @@ test: # Test con coverage test-coverage: - go test -coverprofile=coverage.out ./... - go tool cover -html=coverage.out + GOTOOLCHAIN=$(COVER_GOTOOLCHAIN) go test -coverprofile=coverage.out ./... + GOTOOLCHAIN=$(COVER_GOTOOLCHAIN) go tool cover -html=coverage.out # Full coverage report (all packages) coverage: @echo "Running coverage across all packages..." - @go test -coverpkg=./... -coverprofile=coverage.out ./... - @go tool cover -func=coverage.out | tail -n 1 + @GOTOOLCHAIN=$(COVER_GOTOOLCHAIN) go test -coverpkg=./... -coverprofile=coverage.out ./... + @GOTOOLCHAIN=$(COVER_GOTOOLCHAIN) go tool cover -func=coverage.out | tail -n 1 # Enforce minimum coverage threshold coverage-check: @echo "Running coverage check (threshold $(COVERAGE_THRESHOLD)% )..." - @go test -coverpkg=./... -coverprofile=coverage.out ./... - @total=$$(go tool cover -func=coverage.out | grep total: | awk '{print $$3}' | sed 's/%//'); \ + @GOTOOLCHAIN=$(COVER_GOTOOLCHAIN) go test -coverpkg=./... -coverprofile=coverage.out ./... + @total=$$(GOTOOLCHAIN=$(COVER_GOTOOLCHAIN) go tool cover -func=coverage.out | grep total: | awk '{print $$3}' | sed 's/%//'); \ echo "Total coverage: $$total%"; \ if awk -v total="$$total" -v threshold="$(COVERAGE_THRESHOLD)" 'BEGIN { exit !(total+0 >= threshold+0) }'; then \ echo "Coverage check passed."; \ diff --git a/docs/CONFIGURATION.md b/docs/CONFIGURATION.md index a3ef784..e0f3d82 100644 --- a/docs/CONFIGURATION.md +++ b/docs/CONFIGURATION.md @@ -1053,6 +1053,8 @@ VZDUMP_CONFIG_PATH=/etc/vzdump.conf # PBS datastore paths (comma/space separated) PBS_DATASTORE_PATH= # e.g., "/mnt/pbs1,/mnt/pbs2" +# Extra filesystem scan roots for datastore/PXAR discovery; these do not create +# real PBS datastore definitions and may use path-derived output keys. # System root override (testing/chroot) SYSTEM_ROOT_PREFIX= # Optional alternate root for system collection. Empty or "/" = real root. diff --git a/docs/RESTORE_GUIDE.md b/docs/RESTORE_GUIDE.md index 4c04c4c..53312fb 100644 --- a/docs/RESTORE_GUIDE.md +++ b/docs/RESTORE_GUIDE.md @@ -117,7 +117,7 @@ API apply is automatic for supported PBS staged categories; ProxSave may fall ba |----------|------|-------------|-------| | `pbs_config` | PBS Config Export | **Export-only** copy of /etc/proxmox-backup (never written to system) | `./etc/proxmox-backup/` | | `pbs_host` | PBS Host & Integrations | **Staged** node settings, ACME, proxy, metric servers and traffic control (API/file apply) | `./etc/proxmox-backup/node.cfg`
`./etc/proxmox-backup/proxy.cfg`
`./etc/proxmox-backup/acme/accounts.cfg`
`./etc/proxmox-backup/acme/plugins.cfg`
`./etc/proxmox-backup/metricserver.cfg`
`./etc/proxmox-backup/traffic-control.cfg`
`./var/lib/proxsave-info/commands/pbs/node_config.json`
`./var/lib/proxsave-info/commands/pbs/acme_accounts.json`
`./var/lib/proxsave-info/commands/pbs/acme_plugins.json`
`./var/lib/proxsave-info/commands/pbs/acme_account_*_info.json`
`./var/lib/proxsave-info/commands/pbs/acme_plugin_*_config.json`
`./var/lib/proxsave-info/commands/pbs/traffic_control.json` | -| `datastore_pbs` | PBS Datastore Configuration | **Staged** datastore definitions (incl. S3 endpoints) (API/file apply) | `./etc/proxmox-backup/datastore.cfg`
`./etc/proxmox-backup/s3.cfg`
`./var/lib/proxsave-info/commands/pbs/datastore_list.json`
`./var/lib/proxsave-info/commands/pbs/datastore_*_status.json`
`./var/lib/proxsave-info/commands/pbs/s3_endpoints.json`
`./var/lib/proxsave-info/commands/pbs/s3_endpoint_*_buckets.json`
`./var/lib/proxsave-info/commands/pbs/pbs_datastore_inventory.json` | +| `datastore_pbs` | PBS Datastore Configuration | **Staged** datastore definitions (incl. S3 endpoints) (API/file apply) | `./etc/proxmox-backup/datastore.cfg`
`./etc/proxmox-backup/s3.cfg`
`./var/lib/proxsave-info/commands/pbs/datastore_list.json`
`./var/lib/proxsave-info/commands/pbs/datastore_*_status.json`
`./var/lib/proxsave-info/commands/pbs/s3_endpoints.json`
`./var/lib/proxsave-info/commands/pbs/s3_endpoint_*_buckets.json`
`./var/lib/proxsave-info/commands/pbs/pbs_datastore_inventory.json`
Note: `PBS_DATASTORE_PATH` override scan roots are inventory context only and are not recreated as datastore definitions during restore. | | `maintenance_pbs` | PBS Maintenance | Maintenance settings | `./etc/proxmox-backup/maintenance.cfg` | | `pbs_jobs` | PBS Jobs | **Staged** sync/verify/prune jobs (API/file apply) | `./etc/proxmox-backup/sync.cfg`
`./etc/proxmox-backup/verification.cfg`
`./etc/proxmox-backup/prune.cfg`
`./var/lib/proxsave-info/commands/pbs/sync_jobs.json`
`./var/lib/proxsave-info/commands/pbs/verification_jobs.json`
`./var/lib/proxsave-info/commands/pbs/prune_jobs.json`
`./var/lib/proxsave-info/commands/pbs/gc_jobs.json` | | `pbs_remotes` | PBS Remotes | **Staged** remotes for sync/verify (may include credentials) (API/file apply) | `./etc/proxmox-backup/remote.cfg`
`./var/lib/proxsave-info/commands/pbs/remote_list.json` | @@ -2384,6 +2384,7 @@ systemctl restart proxmox-backup proxmox-backup-proxy **Restore behavior**: - ProxSave detects this condition during staged apply. - If `var/lib/proxsave-info/commands/pbs/pbs_datastore_inventory.json` is available in the backup, ProxSave will use its embedded snapshot of the original `datastore.cfg` to recover a valid configuration. +- Inventory entries that came only from `PBS_DATASTORE_PATH` scan roots are treated as diagnostic context and are excluded from regenerated `datastore.cfg`. - If recovery is not possible, ProxSave will **leave the existing** `/etc/proxmox-backup/datastore.cfg` unchanged to avoid breaking PBS. **Manual diagnosis**: diff --git a/go.mod b/go.mod index 4a36bb8..f99518e 100644 --- a/go.mod +++ b/go.mod @@ -2,7 +2,7 @@ module github.com/tis24dev/proxsave go 1.25 -toolchain go1.25.7 +toolchain go1.25.8 require ( filippo.io/age v1.3.1 diff --git a/internal/backup/checksum.go b/internal/backup/checksum.go index fc917ad..0e08b39 100644 --- a/internal/backup/checksum.go +++ b/internal/backup/checksum.go @@ -35,6 +35,30 @@ type Manifest struct { ClusterMode string `json:"cluster_mode,omitempty"` } +// NormalizeChecksum validates and normalizes a SHA256 checksum string. +func NormalizeChecksum(value string) (string, error) { + checksum := strings.ToLower(strings.TrimSpace(value)) + if checksum == "" { + return "", fmt.Errorf("checksum is empty") + } + if len(checksum) != sha256.Size*2 { + return "", fmt.Errorf("checksum must be %d hex characters, got %d", sha256.Size*2, len(checksum)) + } + if _, err := hex.DecodeString(checksum); err != nil { + return "", fmt.Errorf("checksum is not valid hex: %w", err) + } + return checksum, nil +} + +// ParseChecksumData extracts a SHA256 checksum from checksum file contents. +func ParseChecksumData(data []byte) (string, error) { + fields := strings.Fields(string(data)) + if len(fields) == 0 { + return "", fmt.Errorf("checksum file is empty") + } + return NormalizeChecksum(fields[0]) +} + // GenerateChecksum calculates SHA256 checksum of a file func GenerateChecksum(ctx context.Context, logger *logging.Logger, filePath string) (string, error) { logger.Debug("Generating SHA256 checksum for: %s", filePath) @@ -105,16 +129,21 @@ func CreateManifest(ctx context.Context, logger *logging.Logger, manifest *Manif func VerifyChecksum(ctx context.Context, logger *logging.Logger, filePath, expectedChecksum string) (bool, error) { logger.Debug("Verifying checksum for: %s", filePath) + normalizedExpected, err := NormalizeChecksum(expectedChecksum) + if err != nil { + return false, fmt.Errorf("invalid expected checksum: %w", err) + } + actualChecksum, err := GenerateChecksum(ctx, logger, filePath) if err != nil { return false, fmt.Errorf("failed to generate checksum: %w", err) } - matches := actualChecksum == expectedChecksum + matches := actualChecksum == normalizedExpected if matches { logger.Debug("Checksum verification passed") } else { - logger.Warning("Checksum mismatch! Expected: %s, Got: %s", expectedChecksum, actualChecksum) + logger.Warning("Checksum mismatch! Expected: %s, Got: %s", normalizedExpected, actualChecksum) } return matches, nil @@ -205,9 +234,8 @@ func parseLegacyMetadata(scanner *bufio.Scanner, legacy *Manifest) { func loadLegacyChecksum(archivePath string, legacy *Manifest) { // Attempt to load checksum from legacy .sha256 file if shaData, err := os.ReadFile(archivePath + ".sha256"); err == nil { - fields := strings.Fields(string(shaData)) - if len(fields) > 0 { - legacy.SHA256 = fields[0] + if checksum, parseErr := ParseChecksumData(shaData); parseErr == nil { + legacy.SHA256 = checksum } } } diff --git a/internal/backup/checksum_legacy_test.go b/internal/backup/checksum_legacy_test.go index 1013ae8..c6738cf 100644 --- a/internal/backup/checksum_legacy_test.go +++ b/internal/backup/checksum_legacy_test.go @@ -28,7 +28,8 @@ func TestLoadLegacyManifestWithShaAndFallbackEncryption(t *testing.T) { t.Fatalf("write metadata: %v", err) } - shaLine := "deadbeef " + filepath.Base(archive) + "\n" + expectedSHA := strings.Repeat("a", 64) + shaLine := expectedSHA + " " + filepath.Base(archive) + "\n" if err := os.WriteFile(archive+".sha256", []byte(shaLine), 0o640); err != nil { t.Fatalf("write sha256: %v", err) } @@ -50,8 +51,8 @@ func TestLoadLegacyManifestWithShaAndFallbackEncryption(t *testing.T) { if m.EncryptionMode != "plain" { t.Fatalf("expected fallback encryption mode plain, got %s", m.EncryptionMode) } - if m.SHA256 != "deadbeef" { - t.Fatalf("expected sha256 deadbeef, got %s", m.SHA256) + if m.SHA256 != expectedSHA { + t.Fatalf("expected sha256 %s, got %s", expectedSHA, m.SHA256) } if time.Since(m.CreatedAt) > time.Minute { t.Fatalf("unexpected CreatedAt too old: %v", m.CreatedAt) diff --git a/internal/backup/checksum_test.go b/internal/backup/checksum_test.go index bc12bd0..676506f 100644 --- a/internal/backup/checksum_test.go +++ b/internal/backup/checksum_test.go @@ -145,8 +145,9 @@ ENCRYPTION_MODE=age if err := os.WriteFile(metadataPath, []byte(metadata), 0644); err != nil { t.Fatalf("failed to write metadata: %v", err) } + expectedSHA := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" shaPath := archivePath + ".sha256" - if err := os.WriteFile(shaPath, []byte("deadbeef "+filepath.Base(archivePath)), 0644); err != nil { + if err := os.WriteFile(shaPath, []byte(expectedSHA+" "+filepath.Base(archivePath)), 0644); err != nil { t.Fatalf("failed to write sha file: %v", err) } @@ -163,7 +164,7 @@ ENCRYPTION_MODE=age if manifest.Hostname != "legacy-host" || manifest.ScriptVersion != "legacy-1.0" { t.Fatalf("legacy metadata not parsed correctly: %+v", manifest) } - if manifest.SHA256 != "deadbeef" { + if manifest.SHA256 != expectedSHA { t.Fatalf("expected SHA256 from sidecar, got %q", manifest.SHA256) } if manifest.EncryptionMode != "age" { diff --git a/internal/backup/collector.go b/internal/backup/collector.go index 466fed2..e3de49f 100644 --- a/internal/backup/collector.go +++ b/internal/backup/collector.go @@ -3,6 +3,8 @@ package backup import ( "bytes" "context" + "crypto/sha256" + "encoding/hex" "errors" "fmt" "io" @@ -1180,6 +1182,21 @@ func sanitizeFilename(name string) string { return clean } +func collectorPathKey(name string) string { + trimmed := strings.TrimSpace(name) + if trimmed == "" { + return "entry" + } + + safe := sanitizeFilename(trimmed) + if safe == trimmed { + return safe + } + + sum := sha256.Sum256([]byte(trimmed)) + return fmt.Sprintf("%s_%s", safe, hex.EncodeToString(sum[:4])) +} + // GetStats returns current collection statistics func (c *Collector) GetStats() *CollectionStats { c.statsMu.Lock() diff --git a/internal/backup/collector_helpers_extra_test.go b/internal/backup/collector_helpers_extra_test.go index ae58f2a..a0aa38b 100644 --- a/internal/backup/collector_helpers_extra_test.go +++ b/internal/backup/collector_helpers_extra_test.go @@ -1,6 +1,9 @@ package backup -import "testing" +import ( + "strings" + "testing" +) func TestSummarizeCommandOutputText(t *testing.T) { if got := summarizeCommandOutputText(""); got != "(no stdout/stderr)" { @@ -38,3 +41,25 @@ func TestSanitizeFilenameExtra(t *testing.T) { } } } + +func TestCollectorPathKey(t *testing.T) { + if got := collectorPathKey("store1"); got != "store1" { + t.Fatalf("collectorPathKey(store1)=%q want %q", got, "store1") + } + + unsafe := "../evil" + got := collectorPathKey(unsafe) + if got == unsafe { + t.Fatalf("collectorPathKey(%q) should not keep unsafe value", unsafe) + } + if got == sanitizeFilename(unsafe) { + t.Fatalf("collectorPathKey(%q) should add a disambiguating suffix", unsafe) + } + if !strings.HasPrefix(got, "__evil") { + t.Fatalf("collectorPathKey(%q)=%q should start with sanitized prefix", unsafe, got) + } + + if a, b := collectorPathKey("a/b"), collectorPathKey("a_b"); a == b { + t.Fatalf("collectorPathKey should avoid collisions: %q == %q", a, b) + } +} diff --git a/internal/backup/collector_pbs.go b/internal/backup/collector_pbs.go index 98212e9..8b23b7d 100644 --- a/internal/backup/collector_pbs.go +++ b/internal/backup/collector_pbs.go @@ -347,6 +347,11 @@ func (c *Collector) collectPBSDirectories(ctx context.Context, root string) erro // collectPBSCommands collects output from PBS commands func (c *Collector) collectPBSCommands(ctx context.Context, datastores []pbsDatastore) error { + if len(datastores) > 0 { + datastores = clonePBSDatastores(datastores) + assignUniquePBSDatastoreOutputKeys(datastores) + } + commandsDir := c.proxsaveCommandsDir("pbs") if err := c.ensureDir(commandsDir); err != nil { return fmt.Errorf("failed to create commands directory: %w", err) @@ -383,9 +388,19 @@ func (c *Collector) collectPBSCommands(ctx context.Context, datastores []pbsData // Datastore usage details if c.config.BackupDatastoreConfigs && len(datastores) > 0 { for _, ds := range datastores { + if ds.isOverride() { + c.logger.Debug("Skipping datastore status for %s (path=%s): no PBS datastore identity", ds.Name, ds.Path) + continue + } + cliName := ds.cliName() + if cliName == "" { + c.logger.Debug("Skipping datastore status for %s (path=%s): empty PBS datastore identity", ds.Name, ds.Path) + continue + } + dsKey := ds.pathKey() c.safeCmdOutput(ctx, - fmt.Sprintf("proxmox-backup-manager datastore show %s --output-format=json", ds.Name), - filepath.Join(commandsDir, fmt.Sprintf("datastore_%s_status.json", ds.Name)), + fmt.Sprintf("proxmox-backup-manager datastore show %s --output-format=json", cliName), + filepath.Join(commandsDir, fmt.Sprintf("datastore_%s_status.json", dsKey)), fmt.Sprintf("Datastore %s status", ds.Name), false) } diff --git a/internal/backup/collector_pbs_commands_coverage_test.go b/internal/backup/collector_pbs_commands_coverage_test.go index b96c44f..65c37a1 100644 --- a/internal/backup/collector_pbs_commands_coverage_test.go +++ b/internal/backup/collector_pbs_commands_coverage_test.go @@ -92,6 +92,42 @@ func TestCollectPBSCommandsWritesExpectedOutputs(t *testing.T) { } } +func TestCollectPBSCommandsSkipsStatusForOverrideOnlyEntries(t *testing.T) { + pbsRoot := t.TempDir() + if err := os.WriteFile(filepath.Join(pbsRoot, "tape.cfg"), []byte("ok"), 0o640); err != nil { + t.Fatalf("write tape.cfg: %v", err) + } + + cfg := GetDefaultCollectorConfig() + cfg.PBSConfigPath = pbsRoot + + collector := NewCollectorWithDeps(newTestLogger(), cfg, t.TempDir(), types.ProxmoxBS, false, CollectorDeps{ + LookPath: func(name string) (string, error) { + return "/bin/" + name, nil + }, + RunCommand: func(ctx context.Context, name string, args ...string) ([]byte, error) { + return []byte(fmt.Sprintf("%s %s", name, strings.Join(args, " "))), nil + }, + }) + + override := pbsDatastore{ + Name: "backup", + Path: "/mnt/a/backup", + Source: pbsDatastoreSourceOverride, + NormalizedPath: normalizePBSDatastorePath("/mnt/a/backup"), + OutputKey: buildPBSOverrideOutputKey("/mnt/a/backup"), + } + if err := collector.collectPBSCommands(context.Background(), []pbsDatastore{override}); err != nil { + t.Fatalf("collectPBSCommands error: %v", err) + } + + commandsDir := filepath.Join(collector.tempDir, "var/lib/proxsave-info", "commands", "pbs") + statusPath := filepath.Join(commandsDir, fmt.Sprintf("datastore_%s_status.json", override.pathKey())) + if _, err := os.Stat(statusPath); !os.IsNotExist(err) { + t.Fatalf("override datastore status file should not exist (%s), got err=%v", statusPath, err) + } +} + func TestCollectPBSCommandsReturnsErrorWhenCriticalVersionFails(t *testing.T) { cfg := GetDefaultCollectorConfig() cfg.PBSConfigPath = t.TempDir() @@ -189,6 +225,50 @@ func TestCollectPBSPxarMetadataProcessesMultipleDatastores(t *testing.T) { } } +func TestCollectPBSPxarMetadataSeparatesOverrideBasenameCollisions(t *testing.T) { + tmp := t.TempDir() + cfg := GetDefaultCollectorConfig() + cfg.PxarDatastoreConcurrency = 2 + + collector := NewCollector(newTestLogger(), cfg, tmp, types.ProxmoxBS, false) + + makeOverride := func(path string) pbsDatastore { + for _, sub := range []string{"vm", "ct"} { + if err := os.MkdirAll(filepath.Join(path, sub), 0o755); err != nil { + t.Fatalf("mkdir %s: %v", sub, err) + } + } + if err := os.WriteFile(filepath.Join(path, "vm", "backup.pxar"), []byte("data"), 0o640); err != nil { + t.Fatalf("write pxar: %v", err) + } + return pbsDatastore{ + Name: "backup", + Path: path, + Comment: "configured via PBS_DATASTORE_PATH", + Source: pbsDatastoreSourceOverride, + NormalizedPath: normalizePBSDatastorePath(path), + OutputKey: buildPBSOverrideOutputKey(path), + } + } + + ds1 := makeOverride(filepath.Join(tmp, "mnt", "a", "backup")) + ds2 := makeOverride(filepath.Join(tmp, "srv", "b", "backup")) + if ds1.pathKey() == ds2.pathKey() { + t.Fatalf("expected distinct path keys, got %q", ds1.pathKey()) + } + + if err := collector.collectPBSPxarMetadata(context.Background(), []pbsDatastore{ds1, ds2}); err != nil { + t.Fatalf("collectPBSPxarMetadata error: %v", err) + } + + for _, ds := range []pbsDatastore{ds1, ds2} { + base := filepath.Join(tmp, "var/lib/proxsave-info", "pbs", "pxar", "metadata", ds.pathKey()) + if _, err := os.Stat(filepath.Join(base, "metadata.json")); err != nil { + t.Fatalf("expected metadata for %s: %v", ds.pathKey(), err) + } + } +} + func TestCollectPBSPxarMetadataReturnsErrorWhenTempVarIsFile(t *testing.T) { tmp := t.TempDir() if err := os.WriteFile(filepath.Join(tmp, "var"), []byte("not-a-dir"), 0o640); err != nil { @@ -234,6 +314,42 @@ func TestCollectDatastoreConfigsCreatesConfigAndNamespaceFiles(t *testing.T) { } } +func TestCollectDatastoreConfigsCreatesDistinctNamespaceFilesForOverrideCollisions(t *testing.T) { + cfg := GetDefaultCollectorConfig() + tmp := t.TempDir() + collector := NewCollectorWithDeps(newTestLogger(), cfg, tmp, types.ProxmoxBS, false, CollectorDeps{}) + + makeOverride := func(path string) pbsDatastore { + if err := os.MkdirAll(filepath.Join(path, "local", "vm"), 0o755); err != nil { + t.Fatalf("mkdir override namespace fixture: %v", err) + } + return pbsDatastore{ + Name: "backup", + Path: path, + Comment: "configured via PBS_DATASTORE_PATH", + Source: pbsDatastoreSourceOverride, + NormalizedPath: normalizePBSDatastorePath(path), + OutputKey: buildPBSOverrideOutputKey(path), + } + } + + ds1 := makeOverride(filepath.Join(tmp, "mnt", "a", "backup")) + ds2 := makeOverride(filepath.Join(tmp, "srv", "b", "backup")) + if err := collector.collectDatastoreConfigs(context.Background(), []pbsDatastore{ds1, ds2}); err != nil { + t.Fatalf("collectDatastoreConfigs error: %v", err) + } + + datastoreDir := filepath.Join(tmp, "var/lib/proxsave-info", "pbs", "datastores") + for _, ds := range []pbsDatastore{ds1, ds2} { + if _, err := os.Stat(filepath.Join(datastoreDir, fmt.Sprintf("%s_namespaces.json", ds.pathKey()))); err != nil { + t.Fatalf("expected namespaces file for %s: %v", ds.pathKey(), err) + } + if _, err := os.Stat(filepath.Join(datastoreDir, fmt.Sprintf("%s_config.json", ds.pathKey()))); !os.IsNotExist(err) { + t.Fatalf("override config file should not exist for %s, got err=%v", ds.pathKey(), err) + } + } +} + func TestCollectUserTokensSkipsInvalidUserListJSON(t *testing.T) { tmp := t.TempDir() collector := NewCollector(newTestLogger(), GetDefaultCollectorConfig(), tmp, types.ProxmoxBS, false) diff --git a/internal/backup/collector_pbs_datastore.go b/internal/backup/collector_pbs_datastore.go index 0b5c2ba..472e169 100644 --- a/internal/backup/collector_pbs_datastore.go +++ b/internal/backup/collector_pbs_datastore.go @@ -2,12 +2,15 @@ package backup import ( "context" + "crypto/sha256" + "encoding/hex" "encoding/json" "errors" "fmt" "os" "path/filepath" "regexp" + "sort" "strings" "sync" "time" @@ -16,13 +19,284 @@ import ( "github.com/tis24dev/proxsave/internal/safefs" ) +const ( + pbsDatastoreSourceCLI = "cli" + pbsDatastoreSourceOverride = "override" + pbsDatastoreSourceConfig = "config" + pbsDatastoreOriginMerged = "merged" +) + +var ( + pbsDatastoreNamePattern = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`) + listNamespacesFunc = pbs.ListNamespaces + discoverNamespacesFunc = pbs.DiscoverNamespacesFromFilesystem +) + type pbsDatastore struct { - Name string - Path string - Comment string + Name string + Path string + Comment string + Source string + CLIName string + NormalizedPath string + OutputKey string +} + +func normalizePBSDatastorePath(path string) string { + trimmed := strings.TrimSpace(path) + if trimmed == "" { + return "" + } + return filepath.Clean(trimmed) +} + +func buildPBSOverrideDisplayName(path string, idx int) string { + name := filepath.Base(normalizePBSDatastorePath(path)) + if name == "" || name == "." || name == string(os.PathSeparator) || !pbsDatastoreNamePattern.MatchString(name) { + return fmt.Sprintf("datastore_%d", idx+1) + } + return name +} + +func buildPBSOverrideOutputKey(path string) string { + normalized := normalizePBSDatastorePath(path) + if normalized == "" { + return "entry" + } + + label := filepath.Base(normalized) + if label == "" || label == "." || label == string(os.PathSeparator) || !pbsDatastoreNamePattern.MatchString(label) { + label = "datastore" + } + + sum := sha256.Sum256([]byte(normalized)) + return fmt.Sprintf("%s_%s", sanitizeFilename(label), hex.EncodeToString(sum[:4])) +} + +func pbsOutputKeyDigest(seed string) string { + sum := sha256.Sum256([]byte(seed)) + return hex.EncodeToString(sum[:4]) +} + +func pbsDatastoreIdentityKey(ds pbsDatastore) string { + if ds.isOverride() { + if normalized := ds.normalizedPath(); normalized != "" { + return "override-path:" + normalized + } + if path := strings.TrimSpace(ds.Path); path != "" { + return "override-path:" + path + } + return "" + } + + name := strings.TrimSpace(ds.Name) + if name == "" { + return "" + } + return "name:" + name +} + +func pbsDatastoreDefinitionIdentityKey(def pbsDatastoreDefinition) string { + if strings.TrimSpace(def.Origin) == pbsDatastoreSourceOverride { + if normalized := normalizePBSDatastorePath(def.Path); normalized != "" { + return "override-path:" + normalized + } + if path := strings.TrimSpace(def.Path); path != "" { + return "override-path:" + path + } + return "" + } + + name := strings.TrimSpace(def.Name) + if name == "" { + name = strings.TrimSpace(def.CLIName) + } + if name == "" { + return "" + } + return "name:" + name +} + +func pbsDatastoreCandidateOutputKey(ds pbsDatastore) string { + if ds.isOverride() { + if normalized := ds.normalizedPath(); normalized != "" { + return buildPBSOverrideOutputKey(normalized) + } + return buildPBSOverrideOutputKey(ds.Path) + } + return collectorPathKey(ds.Name) +} + +func pbsDatastoreDefinitionCandidateOutputKey(def pbsDatastoreDefinition) string { + if strings.TrimSpace(def.Origin) == pbsDatastoreSourceOverride { + if normalized := normalizePBSDatastorePath(def.Path); normalized != "" { + return buildPBSOverrideOutputKey(normalized) + } + return buildPBSOverrideOutputKey(def.Path) + } + + name := strings.TrimSpace(def.Name) + if name == "" { + name = strings.TrimSpace(def.CLIName) + } + return collectorPathKey(name) +} + +func pbsOutputKeyPriority(origin string) int { + if strings.TrimSpace(origin) == pbsDatastoreSourceOverride { + return 1 + } + return 0 +} + +type pbsOutputKeyAssignment struct { + Index int + Identity string + BaseKey string + Priority int +} + +func assignUniquePBSOutputKeys[T any](items []T, identityFn func(T) string, baseKeyFn func(T) string, priorityFn func(T) int, assignFn func(*T, string)) { + if len(items) == 0 { + return + } + + grouped := make(map[string][]pbsOutputKeyAssignment, len(items)) + baseKeys := make([]string, 0, len(items)) + for idx, item := range items { + baseKey := strings.TrimSpace(baseKeyFn(item)) + if baseKey == "" { + baseKey = "entry" + } + + identity := strings.TrimSpace(identityFn(item)) + if identity == "" { + identity = fmt.Sprintf("anonymous:%s:%d", baseKey, idx) + } + + if _, ok := grouped[baseKey]; !ok { + baseKeys = append(baseKeys, baseKey) + } + grouped[baseKey] = append(grouped[baseKey], pbsOutputKeyAssignment{ + Index: idx, + Identity: identity, + BaseKey: baseKey, + Priority: priorityFn(item), + }) + } + + sort.Strings(baseKeys) + + usedKeys := make(map[string]string, len(items)) + identityKeys := make(map[string]string, len(items)) + + for _, baseKey := range baseKeys { + assignments := grouped[baseKey] + sort.SliceStable(assignments, func(i, j int) bool { + if assignments[i].Priority != assignments[j].Priority { + return assignments[i].Priority < assignments[j].Priority + } + if assignments[i].Identity != assignments[j].Identity { + return assignments[i].Identity < assignments[j].Identity + } + return assignments[i].Index < assignments[j].Index + }) + + for pos, assignment := range assignments { + if existing := strings.TrimSpace(identityKeys[assignment.Identity]); existing != "" { + assignFn(&items[assignment.Index], existing) + continue + } + + preferBase := pos == 0 + for attempt := 0; ; attempt++ { + candidate := assignment.BaseKey + if !preferBase || attempt > 0 { + seed := assignment.Identity + if attempt > 0 { + seed = fmt.Sprintf("%s#%d", assignment.Identity, attempt) + } + candidate = fmt.Sprintf("%s_%s", assignment.BaseKey, pbsOutputKeyDigest(seed)) + } + + if owner, ok := usedKeys[candidate]; ok && owner != assignment.Identity { + continue + } + + usedKeys[candidate] = assignment.Identity + identityKeys[assignment.Identity] = candidate + assignFn(&items[assignment.Index], candidate) + break + } + } + } +} + +func assignUniquePBSDatastoreOutputKeys(datastores []pbsDatastore) { + assignUniquePBSOutputKeys(datastores, + pbsDatastoreIdentityKey, + pbsDatastoreCandidateOutputKey, + func(ds pbsDatastore) int { + return pbsOutputKeyPriority(ds.Source) + }, + func(ds *pbsDatastore, key string) { + ds.OutputKey = key + }) } -var listNamespacesFunc = pbs.ListNamespaces +func assignUniquePBSDatastoreDefinitionOutputKeys(defs []pbsDatastoreDefinition) { + assignUniquePBSOutputKeys(defs, + pbsDatastoreDefinitionIdentityKey, + pbsDatastoreDefinitionCandidateOutputKey, + func(def pbsDatastoreDefinition) int { + return pbsOutputKeyPriority(def.Origin) + }, + func(def *pbsDatastoreDefinition, key string) { + def.OutputKey = key + }) +} + +func clonePBSDatastores(in []pbsDatastore) []pbsDatastore { + if len(in) == 0 { + return nil + } + + out := make([]pbsDatastore, len(in)) + copy(out, in) + return out +} + +func (ds pbsDatastore) normalizedPath() string { + if path := strings.TrimSpace(ds.NormalizedPath); path != "" { + return path + } + return normalizePBSDatastorePath(ds.Path) +} + +func (ds pbsDatastore) pathKey() string { + if key := strings.TrimSpace(ds.OutputKey); key != "" { + return key + } + return pbsDatastoreCandidateOutputKey(ds) +} + +func (ds pbsDatastore) cliName() string { + if name := strings.TrimSpace(ds.CLIName); name != "" { + return name + } + return strings.TrimSpace(ds.Name) +} + +func (ds pbsDatastore) isOverride() bool { + return strings.TrimSpace(ds.Source) == pbsDatastoreSourceOverride +} + +func (ds pbsDatastore) inventoryOrigin() string { + if origin := strings.TrimSpace(ds.Source); origin != "" { + return origin + } + return pbsDatastoreSourceCLI +} // collectDatastoreConfigs collects detailed datastore configurations func (c *Collector) collectDatastoreConfigs(ctx context.Context, datastores []pbsDatastore) error { @@ -30,6 +304,8 @@ func (c *Collector) collectDatastoreConfigs(ctx context.Context, datastores []pb c.logger.Debug("No datastores found") return nil } + datastores = clonePBSDatastores(datastores) + assignUniquePBSDatastoreOutputKeys(datastores) c.logger.Debug("Collecting datastore details for %d datastores", len(datastores)) datastoreDir := c.proxsaveInfoDir("pbs", "datastores") @@ -38,12 +314,18 @@ func (c *Collector) collectDatastoreConfigs(ctx context.Context, datastores []pb } for _, ds := range datastores { - // Get datastore configuration details - c.safeCmdOutput(ctx, - fmt.Sprintf("proxmox-backup-manager datastore show %s --output-format=json", ds.Name), - filepath.Join(datastoreDir, fmt.Sprintf("%s_config.json", ds.Name)), - fmt.Sprintf("Datastore %s configuration", ds.Name), - false) + dsKey := ds.pathKey() + + if cliName := ds.cliName(); cliName != "" && !ds.isOverride() { + // Get datastore configuration details for CLI-backed datastores only. + c.safeCmdOutput(ctx, + fmt.Sprintf("proxmox-backup-manager datastore show %s --output-format=json", cliName), + filepath.Join(datastoreDir, fmt.Sprintf("%s_config.json", dsKey)), + fmt.Sprintf("Datastore %s configuration", ds.Name), + false) + } else { + c.logger.Debug("Skipping datastore CLI config for %s (path=%s): no PBS datastore identity", ds.Name, ds.Path) + } // Get namespace list using CLI/Filesystem fallback if err := c.collectDatastoreNamespaces(ctx, ds, datastoreDir); err != nil { @@ -60,7 +342,7 @@ func (c *Collector) collectDatastoreConfigs(ctx context.Context, datastores []pb func (c *Collector) collectDatastoreNamespaces(ctx context.Context, ds pbsDatastore, datastoreDir string) error { c.logger.Debug("Collecting namespaces for datastore %s (path: %s)", ds.Name, ds.Path) // Write location is deterministic; if excluded, skip the whole operation. - outputPath := filepath.Join(datastoreDir, fmt.Sprintf("%s_namespaces.json", ds.Name)) + outputPath := filepath.Join(datastoreDir, fmt.Sprintf("%s_namespaces.json", ds.pathKey())) if c.shouldExclude(outputPath) { c.incFilesSkipped() return nil @@ -71,7 +353,17 @@ func (c *Collector) collectDatastoreNamespaces(ctx context.Context, ds pbsDatast ioTimeout = time.Duration(c.config.FsIoTimeoutSeconds) * time.Second } - namespaces, fromFallback, err := listNamespacesFunc(ctx, ds.Name, ds.Path, ioTimeout) + var ( + namespaces []pbs.Namespace + fromFallback bool + err error + ) + if ds.isOverride() { + namespaces, err = discoverNamespacesFunc(ctx, ds.normalizedPath(), ioTimeout) + fromFallback = true + } else { + namespaces, fromFallback, err = listNamespacesFunc(ctx, ds.cliName(), ds.Path, ioTimeout) + } if err != nil { return err } @@ -102,6 +394,8 @@ func (c *Collector) collectPBSPxarMetadata(ctx context.Context, datastores []pbs if len(datastores) == 0 { return nil } + datastores = clonePBSDatastores(datastores) + assignUniquePBSDatastoreOutputKeys(datastores) c.logger.Debug("Collecting PXAR metadata for %d datastores", len(datastores)) dsWorkers := c.config.PxarDatastoreConcurrency if dsWorkers <= 0 { @@ -214,16 +508,17 @@ func (c *Collector) processPxarDatastore(ctx context.Context, ds pbsDatastore, m start := time.Now() c.logger.Debug("PXAR: scanning datastore %s at %s", ds.Name, ds.Path) - dsDir := filepath.Join(metaRoot, ds.Name) + dsKey := ds.pathKey() + dsDir := filepath.Join(metaRoot, dsKey) if err := c.ensureDir(dsDir); err != nil { return fmt.Errorf("failed to create PXAR metadata directory for %s: %w", ds.Name, err) } for _, base := range []string{ - filepath.Join(selectedRoot, ds.Name, "vm"), - filepath.Join(selectedRoot, ds.Name, "ct"), - filepath.Join(smallRoot, ds.Name, "vm"), - filepath.Join(smallRoot, ds.Name, "ct"), + filepath.Join(selectedRoot, dsKey, "vm"), + filepath.Join(selectedRoot, dsKey, "ct"), + filepath.Join(smallRoot, dsKey, "vm"), + filepath.Join(smallRoot, dsKey, "ct"), } { if err := c.ensureDir(base); err != nil { c.logger.Debug("Failed to prepare PXAR directory %s: %v", base, err) @@ -278,7 +573,7 @@ func (c *Collector) processPxarDatastore(ctx context.Context, ds pbsDatastore, m return err } - if err := c.writePxarSubdirReport(ctx, filepath.Join(dsDir, fmt.Sprintf("%s_subdirs.txt", ds.Name)), ds, ioTimeout); err != nil { + if err := c.writePxarSubdirReport(ctx, filepath.Join(dsDir, fmt.Sprintf("%s_subdirs.txt", dsKey)), ds, ioTimeout); err != nil { if errors.Is(err, safefs.ErrTimeout) { c.logger.Warning("Skipping PXAR metadata for datastore %s (path=%s): subdir report timed out (%v)", ds.Name, ds.Path, err) return nil @@ -286,7 +581,7 @@ func (c *Collector) processPxarDatastore(ctx context.Context, ds pbsDatastore, m return err } - if err := c.writePxarListReport(ctx, filepath.Join(dsDir, fmt.Sprintf("%s_vm_pxar_list.txt", ds.Name)), ds, "vm", ioTimeout); err != nil { + if err := c.writePxarListReport(ctx, filepath.Join(dsDir, fmt.Sprintf("%s_vm_pxar_list.txt", dsKey)), ds, "vm", ioTimeout); err != nil { if errors.Is(err, safefs.ErrTimeout) { c.logger.Warning("Skipping PXAR metadata for datastore %s (path=%s): VM list report timed out (%v)", ds.Name, ds.Path, err) return nil @@ -294,7 +589,7 @@ func (c *Collector) processPxarDatastore(ctx context.Context, ds pbsDatastore, m return err } - if err := c.writePxarListReport(ctx, filepath.Join(dsDir, fmt.Sprintf("%s_ct_pxar_list.txt", ds.Name)), ds, "ct", ioTimeout); err != nil { + if err := c.writePxarListReport(ctx, filepath.Join(dsDir, fmt.Sprintf("%s_ct_pxar_list.txt", dsKey)), ds, "ct", ioTimeout); err != nil { if errors.Is(err, safefs.ErrTimeout) { c.logger.Warning("Skipping PXAR metadata for datastore %s (path=%s): CT list report timed out (%v)", ds.Name, ds.Path, err) return nil @@ -427,66 +722,86 @@ func (c *Collector) getDatastoreList(ctx context.Context) ([]pbsDatastore, error } c.logger.Debug("Enumerating PBS datastores via proxmox-backup-manager") - if _, err := c.depLookPath("proxmox-backup-manager"); err != nil { - return nil, nil - } - - output, err := c.depRunCommand(ctx, "proxmox-backup-manager", "datastore", "list", "--output-format=json") - if err != nil { - return nil, fmt.Errorf("proxmox-backup-manager datastore list failed: %w", err) - } - type datastoreEntry struct { Name string `json:"name"` Path string `json:"path"` Comment string `json:"comment"` } - var entries []datastoreEntry - if err := json.Unmarshal(output, &entries); err != nil { - return nil, fmt.Errorf("failed to parse datastore list JSON: %w", err) - } - - datastores := make([]pbsDatastore, 0, len(entries)) - for _, entry := range entries { - name := strings.TrimSpace(entry.Name) - if name != "" { - datastores = append(datastores, pbsDatastore{ - Name: name, - Path: strings.TrimSpace(entry.Path), - Comment: strings.TrimSpace(entry.Comment), - }) + datastores := make([]pbsDatastore, 0, len(c.config.PBSDatastorePaths)) + if _, err := c.depLookPath("proxmox-backup-manager"); err != nil { + c.logger.Debug("Skipping PBS datastore CLI enumeration: proxmox-backup-manager not available: %v", err) + } else { + output, err := c.depRunCommand(ctx, "proxmox-backup-manager", "datastore", "list", "--output-format=json") + if err != nil { + if ctxErr := ctx.Err(); ctxErr != nil { + return nil, ctxErr + } + c.logger.Debug("PBS datastore CLI enumeration failed: %v", err) + } else { + var entries []datastoreEntry + if err := json.Unmarshal(output, &entries); err != nil { + c.logger.Debug("Failed to parse PBS datastore list JSON: %v", err) + } else { + datastores = make([]pbsDatastore, 0, len(entries)+len(c.config.PBSDatastorePaths)) + for _, entry := range entries { + name := strings.TrimSpace(entry.Name) + if name == "" { + continue + } + path := strings.TrimSpace(entry.Path) + datastores = append(datastores, pbsDatastore{ + Name: name, + Path: path, + Comment: strings.TrimSpace(entry.Comment), + Source: pbsDatastoreSourceCLI, + CLIName: name, + NormalizedPath: normalizePBSDatastorePath(path), + OutputKey: collectorPathKey(name), + }) + } + } } } if len(c.config.PBSDatastorePaths) > 0 { existing := make(map[string]struct{}, len(datastores)) for _, ds := range datastores { - if ds.Path != "" { - existing[ds.Path] = struct{}{} + if normalized := ds.normalizedPath(); normalized != "" { + existing[normalized] = struct{}{} } } - validName := regexp.MustCompile(`^[a-zA-Z0-9_-]+$`) for idx, override := range c.config.PBSDatastorePaths { override = strings.TrimSpace(override) if override == "" { continue } - if _, ok := existing[override]; ok { + normalized := normalizePBSDatastorePath(override) + if normalized == "" { continue } - name := filepath.Base(filepath.Clean(override)) - if name == "" || name == "." || name == string(os.PathSeparator) || !validName.MatchString(name) { - name = fmt.Sprintf("datastore_%d", idx+1) + if !filepath.IsAbs(normalized) { + c.logger.Warning("Skipping PBS_DATASTORE_PATH override %q: path must be absolute", override) + continue } + if _, ok := existing[normalized]; ok { + continue + } + existing[normalized] = struct{}{} + name := buildPBSOverrideDisplayName(normalized, idx) datastores = append(datastores, pbsDatastore{ - Name: name, - Path: override, - Comment: "configured via PBS_DATASTORE_PATH", + Name: name, + Path: override, + Comment: "configured via PBS_DATASTORE_PATH", + Source: pbsDatastoreSourceOverride, + NormalizedPath: normalized, + OutputKey: buildPBSOverrideOutputKey(normalized), }) } } + assignUniquePBSDatastoreOutputKeys(datastores) + c.logger.Debug("Detected %d configured datastores", len(datastores)) return datastores, nil } diff --git a/internal/backup/collector_pbs_datastore_inventory.go b/internal/backup/collector_pbs_datastore_inventory.go index 72bbb02..b04b5fc 100644 --- a/internal/backup/collector_pbs_datastore_inventory.go +++ b/internal/backup/collector_pbs_datastore_inventory.go @@ -66,6 +66,9 @@ type pbsDatastoreInventoryEntry struct { Path string `json:"path,omitempty"` Comment string `json:"comment,omitempty"` Sources []string `json:"sources,omitempty"` + Origin string `json:"origin,omitempty"` + CLIName string `json:"cli_name,omitempty"` + OutputKey string `json:"output_key,omitempty"` StatPath string `json:"stat_path,omitempty"` PathOK bool `json:"path_ok,omitempty"` PathIsDir bool `json:"path_is_dir,omitempty"` @@ -216,10 +219,13 @@ func (c *Collector) collectPBSDatastoreInventory(ctx context.Context, cliDatasto for _, def := range merged { entry := pbsDatastoreInventoryEntry{ - Name: def.Name, - Path: def.Path, - Comment: def.Comment, - Sources: append([]string(nil), def.Sources...), + Name: def.Name, + Path: def.Path, + Comment: def.Comment, + Sources: append([]string(nil), def.Sources...), + Origin: def.Origin, + CLIName: def.CLIName, + OutputKey: def.OutputKey, } statPath := def.Path @@ -447,34 +453,71 @@ func (c *Collector) captureInventoryCommand(ctx context.Context, pretty string, } type pbsDatastoreDefinition struct { - Name string - Path string - Comment string - Sources []string + Name string + Path string + Comment string + Sources []string + Origin string + CLIName string + OutputKey string } func mergePBSDatastoreDefinitions(cli, config []pbsDatastore) []pbsDatastoreDefinition { merged := make(map[string]*pbsDatastoreDefinition) + defKey := func(ds pbsDatastore) string { + return pbsDatastoreIdentityKey(ds) + } + add := func(ds pbsDatastore, source string) { - name := strings.TrimSpace(ds.Name) - if name == "" { + key := defKey(ds) + if key == "" { return } - entry := merged[name] + name := strings.TrimSpace(ds.Name) + path := strings.TrimSpace(ds.Path) + comment := strings.TrimSpace(ds.Comment) + origin := ds.inventoryOrigin() + cliName := strings.TrimSpace(ds.cliName()) + entry := merged[key] if entry == nil { - entry = &pbsDatastoreDefinition{Name: name} - merged[name] = entry + entry = &pbsDatastoreDefinition{ + Name: name, + Origin: origin, + CLIName: cliName, + OutputKey: strings.TrimSpace(ds.pathKey()), + } + merged[key] = entry } entry.Sources = append(entry.Sources, source) - if entry.Path == "" && strings.TrimSpace(ds.Path) != "" { - entry.Path = strings.TrimSpace(ds.Path) + if entry.Name == "" && name != "" { + entry.Name = name + } + if path != "" && (entry.Path == "" || origin == pbsDatastoreSourceCLI) { + entry.Path = path } - if entry.Comment == "" && strings.TrimSpace(ds.Comment) != "" { - entry.Comment = strings.TrimSpace(ds.Comment) + if entry.Comment == "" && comment != "" { + entry.Comment = comment + } + if entry.CLIName == "" && !ds.isOverride() && cliName != "" { + entry.CLIName = cliName + } + if entry.OutputKey == "" { + entry.OutputKey = strings.TrimSpace(ds.pathKey()) + } + + switch { + case entry.Origin == "": + entry.Origin = origin + case entry.Origin == pbsDatastoreSourceOverride || origin == pbsDatastoreSourceOverride: + if entry.Origin == "" { + entry.Origin = pbsDatastoreSourceOverride + } + case entry.Origin != origin: + entry.Origin = pbsDatastoreOriginMerged } } @@ -482,7 +525,11 @@ func mergePBSDatastoreDefinitions(cli, config []pbsDatastore) []pbsDatastoreDefi add(ds, "datastore.cfg") } for _, ds := range cli { - add(ds, "cli") + source := "cli" + if ds.isOverride() { + source = "override" + } + add(ds, source) } out := make([]pbsDatastoreDefinition, 0, len(merged)) @@ -493,13 +540,41 @@ func mergePBSDatastoreDefinitions(cli, config []pbsDatastore) []pbsDatastoreDefi v.Sources = uniqueSortedStrings(v.Sources) out = append(out, *v) } + assignUniquePBSDatastoreDefinitionOutputKeys(out) sort.Slice(out, func(i, j int) bool { - return out[i].Name < out[j].Name + if out[i].Name != out[j].Name { + return out[i].Name < out[j].Name + } + if p1, p2 := pbsDatastoreOriginSortPriority(out[i].Origin), pbsDatastoreOriginSortPriority(out[j].Origin); p1 != p2 { + return p1 < p2 + } + if out[i].Path != out[j].Path { + return out[i].Path < out[j].Path + } + if out[i].OutputKey != out[j].OutputKey { + return out[i].OutputKey < out[j].OutputKey + } + return out[i].CLIName < out[j].CLIName }) return out } +func pbsDatastoreOriginSortPriority(origin string) int { + switch strings.TrimSpace(origin) { + case pbsDatastoreOriginMerged: + return 0 + case pbsDatastoreSourceConfig: + return 1 + case pbsDatastoreSourceCLI: + return 2 + case pbsDatastoreSourceOverride: + return 3 + default: + return 4 + } +} + func parsePBSDatastoreCfg(contents string) []pbsDatastore { contents = strings.TrimSpace(contents) if contents == "" { @@ -536,7 +611,12 @@ func parsePBSDatastoreCfg(contents string) []pbsDatastore { if name == "" { continue } - current = &pbsDatastore{Name: name} + current = &pbsDatastore{ + Name: name, + Source: pbsDatastoreSourceConfig, + CLIName: name, + OutputKey: collectorPathKey(name), + } continue } @@ -560,6 +640,10 @@ func parsePBSDatastoreCfg(contents string) []pbsDatastore { } flush() + for i := range out { + out[i].NormalizedPath = normalizePBSDatastorePath(out[i].Path) + } + return out } diff --git a/internal/backup/collector_pbs_datastore_inventory_test.go b/internal/backup/collector_pbs_datastore_inventory_test.go index ca11128..6b58724 100644 --- a/internal/backup/collector_pbs_datastore_inventory_test.go +++ b/internal/backup/collector_pbs_datastore_inventory_test.go @@ -5,6 +5,7 @@ import ( "encoding/json" "os" "path/filepath" + "strings" "testing" "github.com/tis24dev/proxsave/internal/types" @@ -295,3 +296,142 @@ func TestCollectPBSDatastoreInventoryCapturesHostCommands(t *testing.T) { t.Fatalf("expected findmnt output to be captured") } } + +func TestMergePBSDatastoreDefinitionsKeepsOverridesSeparate(t *testing.T) { + config := []pbsDatastore{{ + Name: "backup", + Path: "/real/backup", + Comment: "primary", + Source: pbsDatastoreSourceConfig, + CLIName: "backup", + NormalizedPath: normalizePBSDatastorePath("/real/backup"), + OutputKey: collectorPathKey("backup"), + }} + cli := []pbsDatastore{ + { + Name: "backup", + Path: "/real/backup", + Comment: "runtime", + Source: pbsDatastoreSourceCLI, + CLIName: "backup", + NormalizedPath: normalizePBSDatastorePath("/real/backup"), + OutputKey: collectorPathKey("backup"), + }, + { + Name: "backup", + Path: "/mnt/a/backup", + Comment: "configured via PBS_DATASTORE_PATH", + Source: pbsDatastoreSourceOverride, + NormalizedPath: normalizePBSDatastorePath("/mnt/a/backup"), + OutputKey: buildPBSOverrideOutputKey("/mnt/a/backup"), + }, + { + Name: "backup", + Path: "/srv/b/backup", + Comment: "configured via PBS_DATASTORE_PATH", + Source: pbsDatastoreSourceOverride, + NormalizedPath: normalizePBSDatastorePath("/srv/b/backup"), + OutputKey: buildPBSOverrideOutputKey("/srv/b/backup"), + }, + } + + merged := mergePBSDatastoreDefinitions(cli, config) + if len(merged) != 3 { + t.Fatalf("expected 3 merged entries, got %d: %+v", len(merged), merged) + } + + if merged[0].Origin != pbsDatastoreOriginMerged || merged[0].Path != "/real/backup" { + t.Fatalf("expected real datastore entry first, got %+v", merged[0]) + } + if merged[1].Origin != pbsDatastoreSourceOverride || merged[2].Origin != pbsDatastoreSourceOverride { + t.Fatalf("expected override entries after real datastore, got %+v", merged) + } + if merged[1].OutputKey == merged[2].OutputKey { + t.Fatalf("override output keys should differ, got %+v", merged) + } +} + +func TestMergePBSDatastoreDefinitionsPrefersCLIPathOverConfigPath(t *testing.T) { + config := []pbsDatastore{{ + Name: "backup", + Path: "/config/backup", + Comment: "from config", + Source: pbsDatastoreSourceConfig, + CLIName: "backup", + NormalizedPath: normalizePBSDatastorePath("/config/backup"), + OutputKey: collectorPathKey("backup"), + }} + cli := []pbsDatastore{{ + Name: "backup", + Path: "/runtime/backup", + Comment: "from cli", + Source: pbsDatastoreSourceCLI, + CLIName: "backup", + NormalizedPath: normalizePBSDatastorePath("/runtime/backup"), + OutputKey: collectorPathKey("backup"), + }} + + merged := mergePBSDatastoreDefinitions(cli, config) + if len(merged) != 1 { + t.Fatalf("expected 1 merged entry, got %d: %+v", len(merged), merged) + } + + if merged[0].Origin != pbsDatastoreOriginMerged { + t.Fatalf("expected merged origin, got %+v", merged[0]) + } + if merged[0].Path != "/runtime/backup" { + t.Fatalf("expected CLI path to win, got %+v", merged[0]) + } +} + +func TestMergePBSDatastoreDefinitionsDisambiguatesCLIAndOverrideOutputKeyCollisions(t *testing.T) { + overridePath := "/mnt/a/backup" + collidingKey := buildPBSOverrideOutputKey(overridePath) + + cli := []pbsDatastore{ + { + Name: collidingKey, + Path: "/real/runtime", + Comment: "runtime", + Source: pbsDatastoreSourceCLI, + CLIName: collidingKey, + NormalizedPath: normalizePBSDatastorePath("/real/runtime"), + OutputKey: collidingKey, + }, + { + Name: "backup", + Path: overridePath, + Comment: "configured via PBS_DATASTORE_PATH", + Source: pbsDatastoreSourceOverride, + NormalizedPath: normalizePBSDatastorePath(overridePath), + OutputKey: buildPBSOverrideOutputKey(overridePath), + }, + } + + merged := mergePBSDatastoreDefinitions(cli, nil) + if len(merged) != 2 { + t.Fatalf("expected 2 merged entries, got %d: %+v", len(merged), merged) + } + + var cliEntry, overrideEntry *pbsDatastoreDefinition + for i := range merged { + switch merged[i].Origin { + case pbsDatastoreSourceCLI: + cliEntry = &merged[i] + case pbsDatastoreSourceOverride: + overrideEntry = &merged[i] + } + } + if cliEntry == nil || overrideEntry == nil { + t.Fatalf("expected one CLI and one override entry, got %+v", merged) + } + if cliEntry.OutputKey != collidingKey { + t.Fatalf("CLI datastore should keep base key %q, got %+v", collidingKey, merged) + } + if overrideEntry.OutputKey == collidingKey { + t.Fatalf("override output key should be disambiguated, got %+v", merged) + } + if !strings.HasPrefix(overrideEntry.OutputKey, collidingKey+"_") { + t.Fatalf("override output key should extend colliding base key, got %+v", merged) + } +} diff --git a/internal/backup/collector_pbs_extra_test.go b/internal/backup/collector_pbs_extra_test.go index bdb822f..da250ac 100644 --- a/internal/backup/collector_pbs_extra_test.go +++ b/internal/backup/collector_pbs_extra_test.go @@ -13,20 +13,29 @@ import ( "github.com/tis24dev/proxsave/internal/types" ) -func TestGetDatastoreListNoBinary(t *testing.T) { - collector := NewCollectorWithDeps(newTestLogger(), GetDefaultCollectorConfig(), t.TempDir(), types.ProxmoxBS, false, CollectorDeps{ +func TestGetDatastoreListNoBinaryStillIncludesOverrides(t *testing.T) { + cfg := GetDefaultCollectorConfig() + cfg.PBSDatastorePaths = []string{"/override/no-cli"} + + collector := NewCollectorWithDeps(newTestLogger(), cfg, t.TempDir(), types.ProxmoxBS, false, CollectorDeps{ LookPath: func(string) (string, error) { return "", errors.New("not found") }, }) ds, err := collector.getDatastoreList(context.Background()) if err != nil { t.Fatalf("expected nil error, got %v", err) } - if len(ds) != 0 { - t.Fatalf("expected empty datastores when binary missing") + if len(ds) != 1 { + t.Fatalf("expected override datastore when binary missing, got %+v", ds) + } + if ds[0].Name != "no-cli" || ds[0].Path != "/override/no-cli" || ds[0].Source != pbsDatastoreSourceOverride { + t.Fatalf("unexpected override datastore when binary missing: %+v", ds[0]) } } -func TestGetDatastoreListCommandErrorAndParseError(t *testing.T) { +func TestGetDatastoreListCommandErrorAndParseErrorStillIncludeOverrides(t *testing.T) { + cfg := GetDefaultCollectorConfig() + cfg.PBSDatastorePaths = []string{"/override/fallback"} + deps := CollectorDeps{ LookPath: func(string) (string, error) { return "/bin/true", nil }, RunCommand: func(context.Context, string, ...string) ([]byte, error) { @@ -34,17 +43,25 @@ func TestGetDatastoreListCommandErrorAndParseError(t *testing.T) { }, } - c := NewCollectorWithDeps(newTestLogger(), GetDefaultCollectorConfig(), t.TempDir(), types.ProxmoxBS, false, deps) - if _, err := c.getDatastoreList(context.Background()); err == nil { - t.Fatalf("expected error when command fails") + c := NewCollectorWithDeps(newTestLogger(), cfg, t.TempDir(), types.ProxmoxBS, false, deps) + ds, err := c.getDatastoreList(context.Background()) + if err != nil { + t.Fatalf("expected nil error on command failure, got %v", err) + } + if len(ds) != 1 || ds[0].Path != "/override/fallback" || ds[0].Source != pbsDatastoreSourceOverride { + t.Fatalf("expected override fallback on command failure, got %+v", ds) } // Now simulate parse error c.deps.RunCommand = func(context.Context, string, ...string) ([]byte, error) { return []byte("{invalid"), nil } - if _, err := c.getDatastoreList(context.Background()); err == nil { - t.Fatalf("expected parse error for invalid JSON") + ds, err = c.getDatastoreList(context.Background()) + if err != nil { + t.Fatalf("expected nil error on parse failure, got %v", err) + } + if len(ds) != 1 || ds[0].Path != "/override/fallback" || ds[0].Source != pbsDatastoreSourceOverride { + t.Fatalf("expected override fallback on parse failure, got %+v", ds) } } diff --git a/internal/backup/collector_pbs_test.go b/internal/backup/collector_pbs_test.go index e44fddf..f7afb73 100644 --- a/internal/backup/collector_pbs_test.go +++ b/internal/backup/collector_pbs_test.go @@ -12,6 +12,7 @@ import ( "time" "github.com/tis24dev/proxsave/internal/pbs" + "github.com/tis24dev/proxsave/internal/types" ) func TestGetDatastoreListSuccessWithOverrides(t *testing.T) { @@ -62,6 +63,122 @@ func TestGetDatastoreListSuccessWithOverrides(t *testing.T) { if datastores[2].Comment != "configured via PBS_DATASTORE_PATH" { t.Fatalf("expected override comment, got %q", datastores[2].Comment) } + if datastores[0].Source != pbsDatastoreSourceCLI || datastores[0].CLIName != "primary" || datastores[0].OutputKey != "primary" { + t.Fatalf("expected CLI datastore metadata, got %+v", datastores[0]) + } + if datastores[1].Source != pbsDatastoreSourceOverride || datastores[1].CLIName != "" || datastores[1].OutputKey == "" { + t.Fatalf("expected override datastore metadata, got %+v", datastores[1]) + } +} + +func TestGetDatastoreListSkipsRelativeOverrides(t *testing.T) { + collector := newTestCollectorWithDeps(t, CollectorDeps{ + LookPath: func(cmd string) (string, error) { + return "/usr/bin/" + cmd, nil + }, + RunCommand: func(ctx context.Context, name string, args ...string) ([]byte, error) { + return []byte(`[]`), nil + }, + }) + collector.config.PBSDatastorePaths = []string{"relative/store", "./local/store", "/valid/store"} + + datastores, err := collector.getDatastoreList(context.Background()) + if err != nil { + t.Fatalf("getDatastoreList failed: %v", err) + } + if len(datastores) != 1 { + t.Fatalf("expected only absolute overrides, got %d: %+v", len(datastores), datastores) + } + if datastores[0].Path != "/valid/store" || datastores[0].NormalizedPath != "/valid/store" { + t.Fatalf("unexpected absolute override retained: %+v", datastores[0]) + } + if datastores[0].Source != pbsDatastoreSourceOverride { + t.Fatalf("expected override source, got %+v", datastores[0]) + } +} + +func TestGetDatastoreListOverrideCollisionsUseDistinctOutputKeys(t *testing.T) { + collector := newTestCollectorWithDeps(t, CollectorDeps{ + LookPath: func(cmd string) (string, error) { + return "/usr/bin/" + cmd, nil + }, + RunCommand: func(ctx context.Context, name string, args ...string) ([]byte, error) { + return []byte(`[{"name":"primary","path":"/data/primary/","comment":"main"}]`), nil + }, + }) + collector.config.PBSDatastorePaths = []string{ + "/mnt/a/backup", + "/srv/b/backup", + "/srv/b/backup/", + "/data/primary", + } + + datastores, err := collector.getDatastoreList(context.Background()) + if err != nil { + t.Fatalf("getDatastoreList failed: %v", err) + } + if len(datastores) != 3 { + t.Fatalf("expected 3 datastores after normalized dedupe, got %d: %+v", len(datastores), datastores) + } + + if datastores[1].Name != "backup" || datastores[2].Name != "backup" { + t.Fatalf("expected colliding override display names, got %+v", datastores) + } + if datastores[1].OutputKey == datastores[2].OutputKey { + t.Fatalf("override output keys should differ, got %q", datastores[1].OutputKey) + } + if datastores[1].NormalizedPath == datastores[2].NormalizedPath { + t.Fatalf("override normalized paths should differ, got %+v", datastores) + } +} + +func TestGetDatastoreListDisambiguatesCLIAndOverrideOutputKeyCollisions(t *testing.T) { + overridePath := "/mnt/a/backup" + collidingKey := buildPBSOverrideOutputKey(overridePath) + + collector := newTestCollectorWithDeps(t, CollectorDeps{ + LookPath: func(cmd string) (string, error) { + return "/usr/bin/" + cmd, nil + }, + RunCommand: func(ctx context.Context, name string, args ...string) ([]byte, error) { + return []byte(fmt.Sprintf(`[{"name":%q,"path":"/data/runtime","comment":"main"}]`, collidingKey)), nil + }, + }) + collector.config.PBSDatastorePaths = []string{overridePath} + + datastores, err := collector.getDatastoreList(context.Background()) + if err != nil { + t.Fatalf("getDatastoreList failed: %v", err) + } + if len(datastores) != 2 { + t.Fatalf("expected 2 datastores, got %d: %+v", len(datastores), datastores) + } + + cli := datastores[0] + override := datastores[1] + if cli.OutputKey != collidingKey { + t.Fatalf("CLI datastore should keep base key %q, got %+v", collidingKey, cli) + } + if override.OutputKey == collidingKey { + t.Fatalf("override key should be disambiguated away from %q, got %+v", collidingKey, override) + } + if override.OutputKey == cli.OutputKey { + t.Fatalf("datastore output keys should differ, got %+v", datastores) + } +} + +func TestPBSDatastorePathKeyUsesOverridePathFallback(t *testing.T) { + dsPath := "/mnt/a/backup" + ds := pbsDatastore{ + Name: "backup", + Path: dsPath, + Source: pbsDatastoreSourceOverride, + NormalizedPath: normalizePBSDatastorePath(dsPath), + } + + if got, want := ds.pathKey(), buildPBSOverrideOutputKey(dsPath); got != want { + t.Fatalf("override pathKey()=%q want %q", got, want) + } } func TestGetDatastoreListContextCanceled(t *testing.T) { @@ -98,10 +215,17 @@ func TestGetDatastoreListCommandError(t *testing.T) { return nil, fmt.Errorf("command failed") }, }) + collector.config.PBSDatastorePaths = []string{"/override/from-error"} - _, err := collector.getDatastoreList(context.Background()) - if err == nil || !strings.Contains(err.Error(), "datastore list failed") { - t.Fatalf("expected datastore list error, got %v", err) + datastores, err := collector.getDatastoreList(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(datastores) != 1 { + t.Fatalf("expected only override datastore, got %+v", datastores) + } + if datastores[0].Name != "from-error" || datastores[0].Path != "/override/from-error" || datastores[0].Source != pbsDatastoreSourceOverride { + t.Fatalf("unexpected override datastore after command failure: %+v", datastores[0]) } } @@ -114,10 +238,17 @@ func TestGetDatastoreListBadJSON(t *testing.T) { return []byte("not-json"), nil }, }) + collector.config.PBSDatastorePaths = []string{"/override/from-parse"} - _, err := collector.getDatastoreList(context.Background()) - if err == nil || !strings.Contains(err.Error(), "failed to parse datastore list JSON") { - t.Fatalf("expected parse error, got %v", err) + datastores, err := collector.getDatastoreList(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(datastores) != 1 { + t.Fatalf("expected only override datastore, got %+v", datastores) + } + if datastores[0].Name != "from-parse" || datastores[0].Path != "/override/from-parse" || datastores[0].Source != pbsDatastoreSourceOverride { + t.Fatalf("unexpected override datastore after parse failure: %+v", datastores[0]) } } @@ -305,6 +436,50 @@ func TestCollectDatastoreNamespacesError(t *testing.T) { } } +func TestCollectDatastoreNamespacesOverrideUsesFilesystemOnly(t *testing.T) { + origList := listNamespacesFunc + t.Cleanup(func() { listNamespacesFunc = origList }) + listNamespacesFunc = func(context.Context, string, string, time.Duration) ([]pbs.Namespace, bool, error) { + t.Fatal("CLI namespace discovery should not be used for override paths") + return nil, false, nil + } + + collector := newTestCollectorWithDeps(t, CollectorDeps{}) + dsDir := filepath.Join(collector.tempDir, "datastores") + if err := os.MkdirAll(dsDir, 0o755); err != nil { + t.Fatalf("failed to create datastore dir: %v", err) + } + + dsPath := filepath.Join(collector.tempDir, "override") + if err := os.MkdirAll(filepath.Join(dsPath, "local", "vm"), 0o755); err != nil { + t.Fatalf("failed to create override namespace fixture: %v", err) + } + + ds := pbsDatastore{ + Name: "backup", + Path: dsPath, + Source: pbsDatastoreSourceOverride, + NormalizedPath: normalizePBSDatastorePath(dsPath), + OutputKey: buildPBSOverrideOutputKey(dsPath), + } + if err := collector.collectDatastoreNamespaces(context.Background(), ds, dsDir); err != nil { + t.Fatalf("collectDatastoreNamespaces failed: %v", err) + } + + data, err := os.ReadFile(filepath.Join(dsDir, fmt.Sprintf("%s_namespaces.json", ds.pathKey()))) + if err != nil { + t.Fatalf("namespaces file not created: %v", err) + } + + var namespaces []pbs.Namespace + if err := json.Unmarshal(data, &namespaces); err != nil { + t.Fatalf("failed to decode namespaces: %v", err) + } + if len(namespaces) != 2 || namespaces[1].Ns != "local" { + t.Fatalf("unexpected override namespaces: %+v", namespaces) + } +} + func TestCollectDatastoreConfigsDryRun(t *testing.T) { stubListNamespaces(t, func(context.Context, string, string, time.Duration) ([]pbs.Namespace, bool, error) { return []pbs.Namespace{{Ns: ""}}, false, nil @@ -333,6 +508,348 @@ func TestCollectDatastoreConfigsDryRun(t *testing.T) { } } +func TestCollectDatastoreConfigs_UsesPathSafeKeyForUnsafeDatastoreName(t *testing.T) { + unsafeName := "../escape" + expectedKey := collectorPathKey(unsafeName) + + stubListNamespaces(t, func(_ context.Context, name, path string, _ time.Duration) ([]pbs.Namespace, bool, error) { + if name != unsafeName || path != "/fake" { + t.Fatalf("unexpected datastore args name=%q path=%q", name, path) + } + return []pbs.Namespace{{Ns: ""}}, false, nil + }) + + var seenArgs []string + collector := newTestCollectorWithDeps(t, CollectorDeps{ + LookPath: func(cmd string) (string, error) { + return "/usr/bin/" + cmd, nil + }, + RunCommand: func(_ context.Context, name string, args ...string) ([]byte, error) { + if name != "proxmox-backup-manager" { + return nil, fmt.Errorf("unexpected command %s", name) + } + seenArgs = append([]string(nil), args...) + return []byte(`{"ok":true}`), nil + }, + }) + + ds := pbsDatastore{Name: unsafeName, Path: "/fake"} + if err := collector.collectDatastoreConfigs(context.Background(), []pbsDatastore{ds}); err != nil { + t.Fatalf("collectDatastoreConfigs failed: %v", err) + } + + if len(seenArgs) < 3 || seenArgs[0] != "datastore" || seenArgs[1] != "show" || seenArgs[2] != unsafeName { + t.Fatalf("expected raw datastore name in command args, got %v", seenArgs) + } + + datastoreDir := filepath.Join(collector.tempDir, "var", "lib", "proxsave-info", "pbs", "datastores") + safeConfig := filepath.Join(datastoreDir, fmt.Sprintf("%s_config.json", expectedKey)) + safeNamespaces := filepath.Join(datastoreDir, fmt.Sprintf("%s_namespaces.json", expectedKey)) + for _, path := range []string{safeConfig, safeNamespaces} { + if _, err := os.Stat(path); err != nil { + t.Fatalf("expected safe output %s: %v", path, err) + } + } + + rawConfig := filepath.Join(datastoreDir, fmt.Sprintf("%s_config.json", unsafeName)) + rawNamespaces := filepath.Join(datastoreDir, fmt.Sprintf("%s_namespaces.json", unsafeName)) + for _, path := range []string{rawConfig, rawNamespaces} { + if path == safeConfig || path == safeNamespaces { + continue + } + if _, err := os.Stat(path); !os.IsNotExist(err) { + t.Fatalf("raw output path should not exist (%s), got err=%v", path, err) + } + } +} + +func TestCollectDatastoreConfigsSkipsCLIConfigForOverridePaths(t *testing.T) { + origList := listNamespacesFunc + t.Cleanup(func() { listNamespacesFunc = origList }) + listNamespacesFunc = func(context.Context, string, string, time.Duration) ([]pbs.Namespace, bool, error) { + t.Fatal("CLI namespace discovery should not be used for override paths") + return nil, false, nil + } + + dsPath := t.TempDir() + if err := os.MkdirAll(filepath.Join(dsPath, "vm"), 0o755); err != nil { + t.Fatalf("mkdir vm: %v", err) + } + + var runCalls int + collector := newTestCollectorWithDeps(t, CollectorDeps{ + LookPath: func(cmd string) (string, error) { + return "/usr/bin/" + cmd, nil + }, + RunCommand: func(_ context.Context, name string, args ...string) ([]byte, error) { + runCalls++ + return []byte(`{"ok":true}`), nil + }, + }) + + ds := pbsDatastore{ + Name: "backup", + Path: dsPath, + Comment: "configured via PBS_DATASTORE_PATH", + Source: pbsDatastoreSourceOverride, + NormalizedPath: normalizePBSDatastorePath(dsPath), + OutputKey: buildPBSOverrideOutputKey(dsPath), + } + if err := collector.collectDatastoreConfigs(context.Background(), []pbsDatastore{ds}); err != nil { + t.Fatalf("collectDatastoreConfigs failed: %v", err) + } + + datastoreDir := filepath.Join(collector.tempDir, "var", "lib", "proxsave-info", "pbs", "datastores") + if _, err := os.Stat(filepath.Join(datastoreDir, fmt.Sprintf("%s_namespaces.json", ds.pathKey()))); err != nil { + t.Fatalf("expected override namespaces file, got %v", err) + } + if _, err := os.Stat(filepath.Join(datastoreDir, fmt.Sprintf("%s_config.json", ds.pathKey()))); !os.IsNotExist(err) { + t.Fatalf("override config file should not exist, got err=%v", err) + } + if runCalls != 0 { + t.Fatalf("expected no CLI datastore show calls for override, got %d", runCalls) + } +} + +func TestCollectDatastoreConfigsDisambiguatesManualCLIAndOverrideKeyCollisions(t *testing.T) { + overridePath := t.TempDir() + for _, sub := range []string{"vm", "ct"} { + if err := os.MkdirAll(filepath.Join(overridePath, sub), 0o755); err != nil { + t.Fatalf("mkdir %s: %v", sub, err) + } + } + + collidingKey := buildPBSOverrideOutputKey(overridePath) + stubListNamespaces(t, func(_ context.Context, name, path string, _ time.Duration) ([]pbs.Namespace, bool, error) { + return []pbs.Namespace{{Ns: name, Path: path}}, false, nil + }) + + origDiscover := discoverNamespacesFunc + t.Cleanup(func() { discoverNamespacesFunc = origDiscover }) + discoverNamespacesFunc = func(_ context.Context, path string, _ time.Duration) ([]pbs.Namespace, error) { + return []pbs.Namespace{{Ns: "override", Path: path}}, nil + } + + collector := newTestCollectorWithDeps(t, CollectorDeps{ + LookPath: func(cmd string) (string, error) { + return "/usr/bin/" + cmd, nil + }, + RunCommand: func(_ context.Context, name string, args ...string) ([]byte, error) { + return []byte(`{"ok":true}`), nil + }, + }) + + datastores := []pbsDatastore{ + {Name: collidingKey, Path: "/data/runtime"}, + { + Name: "backup", + Path: overridePath, + Source: pbsDatastoreSourceOverride, + NormalizedPath: normalizePBSDatastorePath(overridePath), + }, + } + if err := collector.collectDatastoreConfigs(context.Background(), datastores); err != nil { + t.Fatalf("collectDatastoreConfigs failed: %v", err) + } + + resolved := clonePBSDatastores(datastores) + assignUniquePBSDatastoreOutputKeys(resolved) + if resolved[0].OutputKey == resolved[1].OutputKey { + t.Fatalf("expected resolved keys to differ, got %+v", resolved) + } + + datastoreDir := filepath.Join(collector.tempDir, "var", "lib", "proxsave-info", "pbs", "datastores") + for _, suffix := range []string{"config.json", "namespaces.json"} { + if _, err := os.Stat(filepath.Join(datastoreDir, fmt.Sprintf("%s_%s", resolved[0].OutputKey, suffix))); err != nil { + t.Fatalf("expected CLI output for %s: %v", suffix, err) + } + } + if _, err := os.Stat(filepath.Join(datastoreDir, fmt.Sprintf("%s_namespaces.json", resolved[1].OutputKey))); err != nil { + t.Fatalf("expected override namespaces file: %v", err) + } + if _, err := os.Stat(filepath.Join(datastoreDir, fmt.Sprintf("%s_config.json", resolved[1].OutputKey))); !os.IsNotExist(err) { + t.Fatalf("override config file should not exist, got err=%v", err) + } +} + +func TestCollectPBSPxarMetadata_UsesPathSafeKeyForUnsafeDatastoreName(t *testing.T) { + tmp := t.TempDir() + cfg := GetDefaultCollectorConfig() + collector := NewCollector(newTestLogger(), cfg, tmp, types.ProxmoxBS, false) + + dsPath := filepath.Join(tmp, "datastore") + for _, sub := range []string{"vm", "ct"} { + if err := os.MkdirAll(filepath.Join(dsPath, sub), 0o755); err != nil { + t.Fatalf("mkdir %s: %v", sub, err) + } + } + if err := os.WriteFile(filepath.Join(dsPath, "vm", "backup1.pxar"), []byte("data"), 0o640); err != nil { + t.Fatalf("write vm pxar: %v", err) + } + if err := os.WriteFile(filepath.Join(dsPath, "ct", "backup2.pxar"), []byte("data"), 0o640); err != nil { + t.Fatalf("write ct pxar: %v", err) + } + + ds := pbsDatastore{Name: "../escape", Path: dsPath, Comment: "unsafe"} + if err := collector.collectPBSPxarMetadata(context.Background(), []pbsDatastore{ds}); err != nil { + t.Fatalf("collectPBSPxarMetadata failed: %v", err) + } + + dsKey := collectorPathKey(ds.Name) + base := filepath.Join(tmp, "var/lib/proxsave-info", "pbs", "pxar", "metadata", dsKey) + for _, path := range []string{ + filepath.Join(base, "metadata.json"), + filepath.Join(base, fmt.Sprintf("%s_subdirs.txt", dsKey)), + filepath.Join(base, fmt.Sprintf("%s_vm_pxar_list.txt", dsKey)), + filepath.Join(base, fmt.Sprintf("%s_ct_pxar_list.txt", dsKey)), + } { + if _, err := os.Stat(path); err != nil { + t.Fatalf("expected safe PXAR output %s: %v", path, err) + } + } + + metaBytes, err := os.ReadFile(filepath.Join(base, "metadata.json")) + if err != nil { + t.Fatalf("read metadata.json: %v", err) + } + if !strings.Contains(string(metaBytes), ds.Name) { + t.Fatalf("metadata should keep raw datastore name, got %s", string(metaBytes)) + } + + selectedVM := filepath.Join(tmp, "var/lib/proxsave-info", "pbs", "pxar", "selected", dsKey, "vm") + smallVM := filepath.Join(tmp, "var/lib/proxsave-info", "pbs", "pxar", "small", dsKey, "vm") + for _, path := range []string{selectedVM, smallVM} { + if info, err := os.Stat(path); err != nil || !info.IsDir() { + t.Fatalf("expected safe PXAR directory %s, got err=%v", path, err) + } + } + + rawBase := filepath.Join(tmp, "var/lib/proxsave-info", "pbs", "pxar", "metadata", ds.Name) + if rawBase != base { + if _, err := os.Stat(rawBase); !os.IsNotExist(err) { + t.Fatalf("raw PXAR directory should not exist (%s), got err=%v", rawBase, err) + } + } +} + +func TestCollectPBSCommands_UsesPathSafeKeyForUnsafeDatastoreName(t *testing.T) { + pbsRoot := t.TempDir() + if err := os.WriteFile(filepath.Join(pbsRoot, "tape.cfg"), []byte("ok"), 0o640); err != nil { + t.Fatalf("write tape.cfg: %v", err) + } + + cfg := GetDefaultCollectorConfig() + cfg.PBSConfigPath = pbsRoot + + collector := NewCollectorWithDeps(newTestLogger(), cfg, t.TempDir(), types.ProxmoxBS, false, CollectorDeps{ + LookPath: func(name string) (string, error) { + return "/bin/" + name, nil + }, + RunCommand: func(_ context.Context, name string, args ...string) ([]byte, error) { + return []byte(fmt.Sprintf("%s %s", name, strings.Join(args, " "))), nil + }, + }) + + ds := pbsDatastore{Name: "../escape", Path: "/data/escape"} + if err := collector.collectPBSCommands(context.Background(), []pbsDatastore{ds}); err != nil { + t.Fatalf("collectPBSCommands error: %v", err) + } + + key := collectorPathKey(ds.Name) + commandsDir := filepath.Join(collector.tempDir, "var/lib/proxsave-info", "commands", "pbs") + safePath := filepath.Join(commandsDir, fmt.Sprintf("datastore_%s_status.json", key)) + if _, err := os.Stat(safePath); err != nil { + t.Fatalf("expected safe datastore status file: %v", err) + } + data, err := os.ReadFile(safePath) + if err != nil { + t.Fatalf("read datastore status file: %v", err) + } + if !strings.Contains(string(data), ds.Name) { + t.Fatalf("status file should reflect raw datastore name in command output, got %s", string(data)) + } + + rawPath := filepath.Join(commandsDir, fmt.Sprintf("datastore_%s_status.json", ds.Name)) + if rawPath != safePath { + if _, err := os.Stat(rawPath); !os.IsNotExist(err) { + t.Fatalf("raw datastore status path should not exist (%s), got err=%v", rawPath, err) + } + } +} + +func TestCollectPBSCommands_DisambiguatesStatusFilesForCollidingDatastoreKeys(t *testing.T) { + pbsRoot := t.TempDir() + if err := os.WriteFile(filepath.Join(pbsRoot, "tape.cfg"), []byte("ok"), 0o640); err != nil { + t.Fatalf("write tape.cfg: %v", err) + } + + cfg := GetDefaultCollectorConfig() + cfg.PBSConfigPath = pbsRoot + + collector := NewCollectorWithDeps(newTestLogger(), cfg, t.TempDir(), types.ProxmoxBS, false, CollectorDeps{ + LookPath: func(name string) (string, error) { + return "/bin/" + name, nil + }, + RunCommand: func(_ context.Context, name string, args ...string) ([]byte, error) { + return []byte(fmt.Sprintf("%s %s", name, strings.Join(args, " "))), nil + }, + }) + + unsafeName := "../escape" + baseKey := collectorPathKey(unsafeName) + datastores := []pbsDatastore{ + {Name: unsafeName, Path: "/data/unsafe"}, + {Name: baseKey, Path: "/data/colliding"}, + } + if got := pbsDatastoreCandidateOutputKey(datastores[1]); got != baseKey { + t.Fatalf("expected second datastore to preserve colliding base key %q, got %q", baseKey, got) + } + + resolved := clonePBSDatastores(datastores) + assignUniquePBSDatastoreOutputKeys(resolved) + if resolved[0].OutputKey == resolved[1].OutputKey { + t.Fatalf("expected distinct output keys for colliding datastores, got %+v", resolved) + } + + baseCount := 0 + suffixedCount := 0 + for _, ds := range resolved { + switch { + case ds.OutputKey == baseKey: + baseCount++ + case strings.HasPrefix(ds.OutputKey, baseKey+"_"): + suffixedCount++ + } + } + if baseCount != 1 || suffixedCount != 1 { + t.Fatalf("expected one base key and one suffixed key from collision, got %+v", resolved) + } + + if err := collector.collectPBSCommands(context.Background(), datastores); err != nil { + t.Fatalf("collectPBSCommands error: %v", err) + } + + commandsDir := filepath.Join(collector.tempDir, "var/lib/proxsave-info", "commands", "pbs") + statusFiles, err := filepath.Glob(filepath.Join(commandsDir, "datastore_*_status.json")) + if err != nil { + t.Fatalf("glob status files: %v", err) + } + if len(statusFiles) != 2 { + t.Fatalf("expected 2 datastore status files, got %d: %v", len(statusFiles), statusFiles) + } + + for _, ds := range resolved { + statusPath := filepath.Join(commandsDir, fmt.Sprintf("datastore_%s_status.json", ds.OutputKey)) + data, err := os.ReadFile(statusPath) + if err != nil { + t.Fatalf("read datastore status file %s: %v", statusPath, err) + } + if !strings.Contains(string(data), ds.Name) { + t.Fatalf("status file %s should contain datastore name %q, got %s", statusPath, ds.Name, string(data)) + } + } +} + func TestCollectUserConfigsWithTokens(t *testing.T) { collector := newTestCollectorWithDeps(t, CollectorDeps{ LookPath: func(cmd string) (string, error) { diff --git a/internal/backup/collector_pve.go b/internal/backup/collector_pve.go index ed0a8a6..8fe16eb 100644 --- a/internal/backup/collector_pve.go +++ b/internal/backup/collector_pve.go @@ -30,6 +30,10 @@ type pveStorageEntry struct { Status string } +func (s pveStorageEntry) pathKey() string { + return collectorPathKey(s.Name) +} + type pveRuntimeInfo struct { Nodes []string Storages []pveStorageEntry @@ -1029,7 +1033,7 @@ func (c *Collector) collectPVEStorageMetadata(ctx context.Context, storages []pv var summary strings.Builder summary.WriteString("# PVE datastores detected on ") summary.WriteString(time.Now().Format(time.RFC3339)) - summary.WriteString("\n# Format: NAME|PATH|TYPE|CONTENT\n\n") + summary.WriteString("\n# Format: TYPE|NAME|PATH|CONTENT\n\n") ioTimeout := time.Duration(0) if c.config != nil && c.config.FsIoTimeoutSeconds > 0 { @@ -1073,6 +1077,45 @@ func (c *Collector) collectPVEStorageMetadata(ctx context.Context, storages []pv return "" } + describeOptionalBool := func(v *bool) string { + if v == nil { + return "unknown" + } + if *v { + return "true" + } + return "false" + } + + logSkipDetails := func(storage pveStorageEntry, reason string, err error) { + if err != nil { + c.logger.Debug( + "PVE datastore skip details: name=%s path=%s type=%s content=%s active=%s enabled=%s status=%s reason=%s err=%v", + storage.Name, + storage.Path, + storage.Type, + storage.Content, + describeOptionalBool(storage.Active), + describeOptionalBool(storage.Enabled), + strings.TrimSpace(storage.Status), + reason, + err, + ) + return + } + c.logger.Debug( + "PVE datastore skip details: name=%s path=%s type=%s content=%s active=%s enabled=%s status=%s reason=%s", + storage.Name, + storage.Path, + storage.Type, + storage.Content, + describeOptionalBool(storage.Active), + describeOptionalBool(storage.Enabled), + strings.TrimSpace(storage.Status), + reason, + ) + } + processed := 0 for _, storage := range storages { if storage.Path == "" { @@ -1083,7 +1126,32 @@ func (c *Collector) collectPVEStorageMetadata(ctx context.Context, storages []pv } if reason := unavailableReason(storage); reason != "" { - c.logger.Warning("Skipping datastore %s (path=%s)%s: not available (%s)", storage.Name, storage.Path, formatRuntime(storage), reason) + status := strings.TrimPrefix(reason, "status=") + switch { + case reason == "enabled=false" || status == "disabled": + c.logger.Skip( + "PVE datastore %s ignored: disabled in Proxmox. Not scanning %s for vzdump backup files (PVE config backup continues).", + storage.Name, storage.Path, + ) + case reason == "active=false" || status == "inactive" || status == "unavailable": + c.logger.Warning( + "PVE datastore %s skipped: storage is offline. Not scanning %s for vzdump backup files. Mount/activate it, or disable it in Proxmox if unused.", + storage.Name, storage.Path, + ) + default: + if status != reason { + c.logger.Warning( + "PVE datastore %s skipped: Proxmox reports storage status %q. Not scanning %s for vzdump backup files. Fix storage status, or disable it in Proxmox if unused.", + storage.Name, status, storage.Path, + ) + } else { + c.logger.Warning( + "PVE datastore %s skipped: Proxmox reports the storage is not available. Not scanning %s for vzdump backup files. Fix storage status, or disable it in Proxmox if unused.", + storage.Name, storage.Path, + ) + } + } + logSkipDetails(storage, reason, nil) continue } @@ -1091,14 +1159,17 @@ func (c *Collector) collectPVEStorageMetadata(ctx context.Context, storages []pv stat, err := safefs.Stat(ctx, storage.Path, ioTimeout) if err != nil { if errors.Is(err, safefs.ErrTimeout) { - c.logger.Warning("Skipping datastore %s (path=%s)%s: filesystem probe timed out (%v)", storage.Name, storage.Path, formatRuntime(storage), err) + c.logger.Warning("PVE datastore %s skipped: filesystem probe timed out for %s (%v). Not scanning for vzdump backup files.", storage.Name, storage.Path, err) + logSkipDetails(storage, "filesystem_probe_timeout", err) } else { - c.logger.Debug("Skipping datastore %s (path not accessible: %s): %v", storage.Name, storage.Path, err) + c.logger.Warning("PVE datastore %s skipped: path %s not accessible (%v). Not scanning for vzdump backup files.", storage.Name, storage.Path, err) + logSkipDetails(storage, "path_not_accessible", err) } continue } if !stat.IsDir() { - c.logger.Debug("Skipping datastore %s (path not a directory: %s)", storage.Name, storage.Path) + c.logger.Warning("PVE datastore %s skipped: path %s is not a directory. Not scanning for vzdump backup files.", storage.Name, storage.Path) + logSkipDetails(storage, "path_not_directory", nil) continue } @@ -1109,7 +1180,7 @@ func (c *Collector) collectPVEStorageMetadata(ctx context.Context, storages []pv storage.Path, storage.Content)) - metaDir := filepath.Join(baseDir, storage.Name) + metaDir := filepath.Join(baseDir, storage.pathKey()) if err := c.ensureDir(metaDir); err != nil { c.logger.Warning("Failed to create metadata directory for %s: %v", storage.Name, err) continue @@ -1251,9 +1322,11 @@ func (c *Collector) collectDetailedPVEBackups(ctx context.Context, storage pveSt var totalFiles int64 var totalSize int64 + storageKey := storage.pathKey() + var smallDir string if c.config.BackupSmallPVEBackups && c.config.MaxPVEBackupSizeBytes > 0 { - smallDir = filepath.Join(c.tempDir, "var/lib/pve-cluster/small_backups", storage.Name) + smallDir = filepath.Join(c.tempDir, "var/lib/pve-cluster/small_backups", storageKey) if err := c.ensureDir(smallDir); err != nil { c.logger.Warning("Cannot create small backups directory %s: %v", smallDir, err) smallDir = "" @@ -1263,7 +1336,7 @@ func (c *Collector) collectDetailedPVEBackups(ctx context.Context, storage pveSt includePattern := strings.TrimSpace(c.config.PVEBackupIncludePattern) var includeDir string if includePattern != "" { - includeDir = filepath.Join(c.tempDir, "var/lib/pve-cluster/selected_backups", storage.Name) + includeDir = filepath.Join(c.tempDir, "var/lib/pve-cluster/selected_backups", storageKey) if err := c.ensureDir(includeDir); err != nil { c.logger.Warning("Cannot create selected backups directory %s: %v", includeDir, err) includeDir = "" @@ -1400,7 +1473,7 @@ type patternWriter struct { func newPatternWriter(storageName, storagePath, analysisDir, pattern string, dryRun bool) (*patternWriter, error) { clean := cleanPatternName(pattern) - filename := fmt.Sprintf("%s_%s_list.txt", storageName, clean) + filename := fmt.Sprintf("%s_%s_list.txt", collectorPathKey(storageName), clean) filePath := filepath.Join(analysisDir, filename) // In dry-run mode, create a writer without an actual file @@ -1526,7 +1599,7 @@ func (c *Collector) writePatternSummary(storage pveStorageEntry, analysisDir str return nil } - summaryPath := filepath.Join(analysisDir, fmt.Sprintf("%s_backup_summary.txt", storage.Name)) + summaryPath := filepath.Join(analysisDir, fmt.Sprintf("%s_backup_summary.txt", storage.pathKey())) file, err := os.OpenFile(summaryPath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0640) if err != nil { return err diff --git a/internal/backup/collector_pve_parse_test.go b/internal/backup/collector_pve_parse_test.go index aaea755..5f957e3 100644 --- a/internal/backup/collector_pve_parse_test.go +++ b/internal/backup/collector_pve_parse_test.go @@ -8,6 +8,11 @@ import ( // TestParseNodeStorageList tests parsing PVE storage entries from JSON func TestParseNodeStorageList(t *testing.T) { + boolPtr := func(v bool) *bool { + b := v + return &b + } + tests := []struct { name string input string @@ -26,6 +31,24 @@ func TestParseNodeStorageList(t *testing.T) { {Name: "local-lvm", Path: "", Type: "lvmthin", Content: "images,rootdir"}, }, }, + { + name: "storage with runtime flags", + input: `[ + {"storage": "HDD1", "path": "/mnt/pve/HDD1", "type": "dir", "content": "backup", "active": 0, "enabled": 0, "status": "disabled"}, + {"storage": "HDD2", "path": "/mnt/pve/HDD2", "type": "nfs", "content": "backup", "active": 0, "enabled": 1, "status": "unavailable"}, + {"storage": "HDD3", "path": "/mnt/pve/HDD3", "type": "dir", "content": "backup", "active": true, "enabled": true, "status": "ok"}, + {"storage": "HDD4", "path": "/mnt/pve/HDD4", "type": "dir", "content": "backup", "active": "0", "enabled": "1", "status": "unknown"}, + {"storage": "HDD5", "path": "/mnt/pve/HDD5", "type": "dir", "content": "backup", "active": null, "enabled": null, "status": ""} + ]`, + expectError: false, + expected: []pveStorageEntry{ + {Name: "HDD1", Path: "/mnt/pve/HDD1", Type: "dir", Content: "backup", Active: boolPtr(false), Enabled: boolPtr(false), Status: "disabled"}, + {Name: "HDD2", Path: "/mnt/pve/HDD2", Type: "nfs", Content: "backup", Active: boolPtr(false), Enabled: boolPtr(true), Status: "unavailable"}, + {Name: "HDD3", Path: "/mnt/pve/HDD3", Type: "dir", Content: "backup", Active: boolPtr(true), Enabled: boolPtr(true), Status: "ok"}, + {Name: "HDD4", Path: "/mnt/pve/HDD4", Type: "dir", Content: "backup", Active: boolPtr(false), Enabled: boolPtr(true), Status: "unknown"}, + {Name: "HDD5", Path: "/mnt/pve/HDD5", Type: "dir", Content: "backup", Active: nil, Enabled: nil, Status: ""}, + }, + }, { name: "storage with name field instead of storage", input: `[ @@ -145,6 +168,19 @@ func TestParseNodeStorageList(t *testing.T) { if entry.Content != tt.expected[i].Content { t.Errorf("entry[%d].Content = %q, want %q", i, entry.Content, tt.expected[i].Content) } + if (entry.Active == nil) != (tt.expected[i].Active == nil) { + t.Errorf("entry[%d].Active nilness = %v, want %v", i, entry.Active == nil, tt.expected[i].Active == nil) + } else if entry.Active != nil && tt.expected[i].Active != nil && *entry.Active != *tt.expected[i].Active { + t.Errorf("entry[%d].Active = %v, want %v", i, *entry.Active, *tt.expected[i].Active) + } + if (entry.Enabled == nil) != (tt.expected[i].Enabled == nil) { + t.Errorf("entry[%d].Enabled nilness = %v, want %v", i, entry.Enabled == nil, tt.expected[i].Enabled == nil) + } else if entry.Enabled != nil && tt.expected[i].Enabled != nil && *entry.Enabled != *tt.expected[i].Enabled { + t.Errorf("entry[%d].Enabled = %v, want %v", i, *entry.Enabled, *tt.expected[i].Enabled) + } + if entry.Status != tt.expected[i].Status { + t.Errorf("entry[%d].Status = %q, want %q", i, entry.Status, tt.expected[i].Status) + } } }) } diff --git a/internal/backup/collector_pve_test.go b/internal/backup/collector_pve_test.go index 6d4a808..1a5e127 100644 --- a/internal/backup/collector_pve_test.go +++ b/internal/backup/collector_pve_test.go @@ -1,11 +1,13 @@ package backup import ( + "bytes" "context" "errors" "fmt" "os" "path/filepath" + "strings" "testing" "github.com/tis24dev/proxsave/internal/logging" @@ -814,15 +816,148 @@ func TestCollectPVEStorageMetadata(t *testing.T) { collector := newPVECollector(t) ctx := context.Background() + storageDir := t.TempDir() storages := []pveStorageEntry{ - {Name: "local", Path: "/var/lib/vz", Type: "dir"}, + {Name: "local", Path: storageDir, Type: "dir"}, } - err := collector.collectPVEStorageMetadata(ctx, storages) + if err := collector.collectPVEStorageMetadata(ctx, storages); err != nil { + t.Fatalf("collectPVEStorageMetadata returned error: %v", err) + } + + metaPath := filepath.Join(collector.tempDir, "var/lib/pve-cluster/info/datastores", "local", "metadata.json") + if _, err := os.Stat(metaPath); err != nil { + t.Fatalf("expected %s created, got %v", metaPath, err) + } +} + +func TestCollectPVEStorageMetadata_UsesPathSafeKeyForUnsafeStorageName(t *testing.T) { + collector := newPVECollector(t) + collector.config.BackupSmallPVEBackups = true + collector.config.MaxPVEBackupSizeBytes = 1024 * 1024 + collector.config.PVEBackupIncludePattern = "vm-100" + + ctx := context.Background() + storageDir := t.TempDir() + if err := os.WriteFile(filepath.Join(storageDir, "vm-100-backup.vma"), []byte("content"), 0o644); err != nil { + t.Fatalf("write backup file: %v", err) + } + + storage := pveStorageEntry{Name: "../escape", Path: storageDir, Type: "dir"} + if err := collector.collectPVEStorageMetadata(ctx, []pveStorageEntry{storage}); err != nil { + t.Fatalf("collectPVEStorageMetadata returned error: %v", err) + } + + key := collectorPathKey(storage.Name) + baseDir := filepath.Join(collector.tempDir, "var/lib/pve-cluster/info/datastores", key) + for _, path := range []string{ + filepath.Join(baseDir, "metadata.json"), + filepath.Join(baseDir, "backup_analysis", fmt.Sprintf("%s_backup_summary.txt", key)), + filepath.Join(baseDir, "backup_analysis", fmt.Sprintf("%s__vma_list.txt", key)), + filepath.Join(collector.tempDir, "var/lib/pve-cluster/small_backups", key, "vm-100-backup.vma"), + filepath.Join(collector.tempDir, "var/lib/pve-cluster/selected_backups", key, "vm-100-backup.vma"), + } { + if _, err := os.Stat(path); err != nil { + t.Fatalf("expected safe output %s: %v", path, err) + } + } + + metaPath := filepath.Join(baseDir, "metadata.json") + metaBytes, err := os.ReadFile(metaPath) if err != nil { - t.Logf("collectPVEStorageMetadata returned error: %v", err) + t.Fatalf("read metadata.json: %v", err) + } + if !strings.Contains(string(metaBytes), storage.Name) { + t.Fatalf("metadata should keep raw storage name, got %s", string(metaBytes)) + } + + rawMetaPath := filepath.Join(collector.tempDir, "var/lib/pve-cluster/info/datastores", storage.Name, "metadata.json") + if rawMetaPath != metaPath { + if _, err := os.Stat(rawMetaPath); !os.IsNotExist(err) { + t.Fatalf("raw storage metadata path should not exist (%s), got err=%v", rawMetaPath, err) + } } } +func TestCollectPVEStorageMetadata_SkipReasonsAreUserFriendly(t *testing.T) { + boolPtr := func(v bool) *bool { + b := v + return &b + } + ctx := context.Background() + + t.Run("disabled storage uses SKIP with debug details", func(t *testing.T) { + collector := newPVECollector(t) + var out bytes.Buffer + collector.logger.SetOutput(&out) + + storages := []pveStorageEntry{ + {Name: "HDD1", Path: "/mnt/pve/HDD1", Active: boolPtr(false), Enabled: boolPtr(false)}, + } + if err := collector.collectPVEStorageMetadata(ctx, storages); err != nil { + t.Fatalf("collectPVEStorageMetadata returned error: %v", err) + } + + logText := out.String() + if collector.logger.WarningCount() != 0 { + t.Fatalf("expected 0 warnings, got %d", collector.logger.WarningCount()) + } + if !strings.Contains(logText, "SKIP") || !strings.Contains(logText, "disabled in Proxmox") { + t.Fatalf("expected user-friendly SKIP message for disabled storage, got logs:\n%s", logText) + } + if !strings.Contains(logText, "PVE datastore skip details:") { + t.Fatalf("expected debug details for skipped storage, got logs:\n%s", logText) + } + }) + + t.Run("offline storage uses WARNING with debug details", func(t *testing.T) { + collector := newPVECollector(t) + var out bytes.Buffer + collector.logger.SetOutput(&out) + + storages := []pveStorageEntry{ + {Name: "HDD2", Path: "/mnt/pve/HDD2", Active: boolPtr(false), Enabled: boolPtr(true)}, + } + if err := collector.collectPVEStorageMetadata(ctx, storages); err != nil { + t.Fatalf("collectPVEStorageMetadata returned error: %v", err) + } + + logText := out.String() + if collector.logger.WarningCount() != 1 { + t.Fatalf("expected 1 warning, got %d", collector.logger.WarningCount()) + } + if !strings.Contains(logText, "WARNING") || !strings.Contains(logText, "storage is offline") { + t.Fatalf("expected warning message for offline storage, got logs:\n%s", logText) + } + if !strings.Contains(logText, "PVE datastore skip details:") { + t.Fatalf("expected debug details for skipped storage, got logs:\n%s", logText) + } + }) + + t.Run("non-existent path uses WARNING with debug details", func(t *testing.T) { + collector := newPVECollector(t) + var out bytes.Buffer + collector.logger.SetOutput(&out) + + storages := []pveStorageEntry{ + {Name: "MISSING", Path: filepath.Join(t.TempDir(), "does-not-exist"), Type: "dir"}, + } + if err := collector.collectPVEStorageMetadata(ctx, storages); err != nil { + t.Fatalf("collectPVEStorageMetadata returned error: %v", err) + } + + logText := out.String() + if collector.logger.WarningCount() != 1 { + t.Fatalf("expected 1 warning, got %d", collector.logger.WarningCount()) + } + if !strings.Contains(logText, "not accessible") { + t.Fatalf("expected warning for missing path, got logs:\n%s", logText) + } + if !strings.Contains(logText, "reason=path_not_accessible") { + t.Fatalf("expected debug details including reason, got logs:\n%s", logText) + } + }) +} + // Test collectPVECephInfo function func TestCollectPVECephInfo(t *testing.T) { t.Run("no ceph configured", func(t *testing.T) { diff --git a/internal/config/templates/backup.env b/internal/config/templates/backup.env index 2dbf810..ad5a9b0 100644 --- a/internal/config/templates/backup.env +++ b/internal/config/templates/backup.env @@ -318,7 +318,7 @@ PVE_CLUSTER_PATH=/var/lib/pve-cluster COROSYNC_CONFIG_PATH=${PVE_CONFIG_PATH}/corosync.conf VZDUMP_CONFIG_PATH=/etc/vzdump.conf PBS_CONFIG_PATH=/etc/proxmox-backup -PBS_DATASTORE_PATH= # Extra PBS paths separated by comma/space (e.g. /mnt/pbs1,/mnt/pbs2). Empty = auto-detect. +PBS_DATASTORE_PATH= # Extra PBS filesystem scan roots separated by comma/space (e.g. /mnt/pbs1,/mnt/pbs2). Not real datastore definitions; output keys may be path-derived. Empty = auto-detect. # System BACKUP_NETWORK_CONFIGS=true diff --git a/internal/input/input.go b/internal/input/input.go index 15c87b1..9776279 100644 --- a/internal/input/input.go +++ b/internal/input/input.go @@ -7,6 +7,7 @@ import ( "io" "os" "strings" + "sync" ) // ErrInputAborted signals that interactive input was interrupted (typically via Ctrl+C @@ -15,6 +16,41 @@ import ( // Callers should translate this into the appropriate workflow-level abort error. var ErrInputAborted = errors.New("input aborted") +type lineResult struct { + line string + err error +} + +type lineState struct { + mu sync.Mutex + inflight *lineInflight +} + +type lineInflight struct { + done chan struct{} + result lineResult +} + +type passwordResult struct { + b []byte + err error +} + +type passwordState struct { + mu sync.Mutex + inflight *passwordInflight +} + +type passwordInflight struct { + done chan struct{} + result passwordResult +} + +var ( + lineStates sync.Map + passwordStates sync.Map +) + // IsAborted reports whether an operation was aborted by the user (typically via Ctrl+C), // by checking for ErrInputAborted and context cancellation. func IsAborted(err error) bool { @@ -41,28 +77,85 @@ func MapInputError(err error) error { return err } +func mapContextInputError(ctx context.Context) error { + if ctx == nil || ctx.Err() == nil { + return nil + } + if errors.Is(ctx.Err(), context.DeadlineExceeded) { + return context.DeadlineExceeded + } + return ErrInputAborted +} + +func getLineState(reader *bufio.Reader) *lineState { + if state, ok := lineStates.Load(reader); ok { + return state.(*lineState) + } + state := &lineState{} + actual, _ := lineStates.LoadOrStore(reader, state) + return actual.(*lineState) +} + +func getPasswordState(fd int) *passwordState { + if state, ok := passwordStates.Load(fd); ok { + return state.(*passwordState) + } + state := &passwordState{} + actual, _ := passwordStates.LoadOrStore(fd, state) + return actual.(*passwordState) +} + // ReadLineWithContext reads a single line and supports cancellation. On ctx cancellation // or stdin closure it returns ErrInputAborted. On ctx deadline it returns context.DeadlineExceeded. +// Cancellation stops waiting but does not interrupt an already-started reader.ReadString call; +// at most one in-flight read is kept per reader to avoid goroutine buildup across retries. +// A completed in-flight read remains attached to the reader until a later caller consumes it. func ReadLineWithContext(ctx context.Context, reader *bufio.Reader) (string, error) { if ctx == nil { ctx = context.Background() } - type result struct { - line string - err error + if reader == nil { + return "", errors.New("reader is nil") } - ch := make(chan result, 1) - go func() { - line, err := reader.ReadString('\n') - ch <- result{line: line, err: MapInputError(err)} - }() - select { - case <-ctx.Done(): - if errors.Is(ctx.Err(), context.DeadlineExceeded) { - return "", context.DeadlineExceeded + state := getLineState(reader) + + for { + if err := mapContextInputError(ctx); err != nil { + return "", err + } + + state.mu.Lock() + inflight := state.inflight + if inflight == nil { + inflight = &lineInflight{ + done: make(chan struct{}), + } + state.inflight = inflight + go func(inflight *lineInflight) { + line, err := reader.ReadString('\n') + inflight.result = lineResult{line: line, err: MapInputError(err)} + close(inflight.done) + }(inflight) + } + state.mu.Unlock() + + select { + case <-ctx.Done(): + if errors.Is(ctx.Err(), context.DeadlineExceeded) { + return "", context.DeadlineExceeded + } + return "", ErrInputAborted + case <-inflight.done: } - return "", ErrInputAborted - case res := <-ch: + + state.mu.Lock() + if state.inflight != inflight { + state.mu.Unlock() + continue + } + state.inflight = nil + res := inflight.result + state.mu.Unlock() return res.line, res.err } } @@ -70,6 +163,9 @@ func ReadLineWithContext(ctx context.Context, reader *bufio.Reader) (string, err // ReadPasswordWithContext reads a password (no echo) and supports cancellation. On ctx // cancellation or stdin closure it returns ErrInputAborted. On ctx deadline it returns // context.DeadlineExceeded. +// Cancellation stops waiting but does not interrupt an already-started password read; +// at most one in-flight password read is kept per file descriptor to avoid goroutine buildup. +// A completed in-flight password read remains attached until a later caller consumes it. func ReadPasswordWithContext(ctx context.Context, readPassword func(int) ([]byte, error), fd int) ([]byte, error) { if ctx == nil { ctx = context.Background() @@ -77,22 +173,45 @@ func ReadPasswordWithContext(ctx context.Context, readPassword func(int) ([]byte if readPassword == nil { return nil, errors.New("readPassword function is nil") } - type result struct { - b []byte - err error - } - ch := make(chan result, 1) - go func() { - b, err := readPassword(fd) - ch <- result{b: b, err: MapInputError(err)} - }() - select { - case <-ctx.Done(): - if errors.Is(ctx.Err(), context.DeadlineExceeded) { - return nil, context.DeadlineExceeded + state := getPasswordState(fd) + + for { + if err := mapContextInputError(ctx); err != nil { + return nil, err + } + + state.mu.Lock() + inflight := state.inflight + if inflight == nil { + inflight = &passwordInflight{ + done: make(chan struct{}), + } + state.inflight = inflight + go func(inflight *passwordInflight) { + b, err := readPassword(fd) + inflight.result = passwordResult{b: b, err: MapInputError(err)} + close(inflight.done) + }(inflight) + } + state.mu.Unlock() + + select { + case <-ctx.Done(): + if errors.Is(ctx.Err(), context.DeadlineExceeded) { + return nil, context.DeadlineExceeded + } + return nil, ErrInputAborted + case <-inflight.done: + } + + state.mu.Lock() + if state.inflight != inflight { + state.mu.Unlock() + continue } - return nil, ErrInputAborted - case res := <-ch: + state.inflight = nil + res := inflight.result + state.mu.Unlock() return res.b, res.err } } diff --git a/internal/input/input_test.go b/internal/input/input_test.go index 4113024..535f048 100644 --- a/internal/input/input_test.go +++ b/internal/input/input_test.go @@ -7,10 +7,161 @@ import ( "io" "os" "strings" + "sync/atomic" "testing" "time" ) +type blockingLineReader struct { + release chan struct{} + finish chan struct{} + returned chan struct{} + payload string + calls atomic.Int32 +} + +func (r *blockingLineReader) Read(p []byte) (int, error) { + r.calls.Add(1) + <-r.release + if r.finish != nil { + <-r.finish + } + if r.payload == "" { + signalNonBlocking(r.returned) + return 0, io.EOF + } + n := copy(p, r.payload) + r.payload = r.payload[n:] + signalNonBlocking(r.returned) + return n, nil +} + +type lineCallResult struct { + line string + err error +} + +type passwordCallResult struct { + b []byte + err error +} + +func assertNoLineResultYet(t *testing.T, ch <-chan lineCallResult) { + t.Helper() + select { + case res := <-ch: + t.Fatalf("unexpected line result before release: %+v", res) + default: + } +} + +func assertNoPasswordResultYet(t *testing.T, ch <-chan passwordCallResult) { + t.Helper() + select { + case res := <-ch: + t.Fatalf("unexpected password result before release: %+v", res) + default: + } +} + +func signalNonBlocking(ch chan struct{}) { + if ch == nil { + return + } + select { + case ch <- struct{}{}: + default: + } +} + +func waitForSignal(t *testing.T, ch <-chan struct{}, name string) { + t.Helper() + select { + case <-ch: + case <-time.After(500 * time.Millisecond): + t.Fatalf("timed out waiting for %s", name) + } +} + +func waitForCondition(t *testing.T, name string, cond func() bool) { + t.Helper() + deadline := time.After(500 * time.Millisecond) + ticker := time.NewTicker(time.Millisecond) + defer ticker.Stop() + for { + if cond() { + return + } + select { + case <-deadline: + t.Fatalf("timed out waiting for %s", name) + case <-ticker.C: + } + } +} + +func currentLineInflight(t *testing.T, reader *bufio.Reader) *lineInflight { + t.Helper() + state := getLineState(reader) + state.mu.Lock() + defer state.mu.Unlock() + if state.inflight == nil { + t.Fatalf("expected line inflight state") + } + return state.inflight +} + +func currentPasswordInflight(t *testing.T, fd int) *passwordInflight { + t.Helper() + state := getPasswordState(fd) + state.mu.Lock() + defer state.mu.Unlock() + if state.inflight == nil { + t.Fatalf("expected password inflight state") + } + return state.inflight +} + +func assertSameLineInflight(t *testing.T, reader *bufio.Reader, want *lineInflight) { + t.Helper() + state := getLineState(reader) + state.mu.Lock() + defer state.mu.Unlock() + if state.inflight != want { + t.Fatalf("line inflight=%p; want %p", state.inflight, want) + } +} + +func assertSamePasswordInflight(t *testing.T, fd int, want *passwordInflight) { + t.Helper() + state := getPasswordState(fd) + state.mu.Lock() + defer state.mu.Unlock() + if state.inflight != want { + t.Fatalf("password inflight=%p; want %p", state.inflight, want) + } +} + +func assertLineInflightCleared(t *testing.T, reader *bufio.Reader) { + t.Helper() + state := getLineState(reader) + state.mu.Lock() + defer state.mu.Unlock() + if state.inflight != nil { + t.Fatalf("line inflight=%p; want nil", state.inflight) + } +} + +func assertPasswordInflightCleared(t *testing.T, fd int) { + t.Helper() + state := getPasswordState(fd) + state.mu.Lock() + defer state.mu.Unlock() + if state.inflight != nil { + t.Fatalf("password inflight=%p; want nil", state.inflight) + } +} + func TestMapInputError(t *testing.T) { if MapInputError(nil) != nil { t.Fatalf("expected nil") @@ -133,6 +284,59 @@ func TestReadLineWithContext_DeadlineReturnsDeadlineExceeded(t *testing.T) { _ = pw.Close() } +func TestReadLineWithContext_ConcurrentCancelledCallerReturnsWhileReadPending(t *testing.T) { + src := &blockingLineReader{ + release: make(chan struct{}), + returned: make(chan struct{}, 1), + payload: "hello\n", + } + reader := bufio.NewReader(src) + + firstResultCh := make(chan lineCallResult, 1) + go func() { + line, err := ReadLineWithContext(context.Background(), reader) + firstResultCh <- lineCallResult{line: line, err: err} + }() + + waitForCondition(t, "underlying line read to start", func() bool { + return src.calls.Load() == 1 + }) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + secondResultCh := make(chan lineCallResult, 1) + go func() { + line, err := ReadLineWithContext(ctx, reader) + secondResultCh <- lineCallResult{line: line, err: err} + }() + + select { + case res := <-secondResultCh: + if res.line != "" { + t.Fatalf("cancelled line=%q; want empty", res.line) + } + if !errors.Is(res.err, ErrInputAborted) { + t.Fatalf("cancelled err=%v; want %v", res.err, ErrInputAborted) + } + case <-time.After(100 * time.Millisecond): + t.Fatalf("concurrent cancelled line read did not return while inflight read was pending") + } + + assertNoLineResultYet(t, firstResultCh) + + close(src.release) + waitForSignal(t, src.returned, "underlying line read completion") + + res := <-firstResultCh + if res.err != nil { + t.Fatalf("first ReadLineWithContext error: %v", res.err) + } + if res.line != "hello\n" { + t.Fatalf("first line=%q; want %q", res.line, "hello\n") + } +} + func TestReadPasswordWithContext_NilReadPasswordErrors(t *testing.T) { _, err := ReadPasswordWithContext(context.Background(), nil, 0) if err == nil { @@ -208,3 +412,314 @@ func TestReadPasswordWithContext_DeadlineReturnsDeadlineExceeded(t *testing.T) { t.Fatalf("err=%v; want %v", err, context.DeadlineExceeded) } } + +func TestReadPasswordWithContext_ConcurrentCancelledCallerReturnsWhileReadPending(t *testing.T) { + release := make(chan struct{}) + var calls atomic.Int32 + readPassword := func(fd int) ([]byte, error) { + calls.Add(1) + <-release + return []byte("secret"), nil + } + + firstResultCh := make(chan passwordCallResult, 1) + go func() { + got, err := ReadPasswordWithContext(context.Background(), readPassword, 42) + firstResultCh <- passwordCallResult{b: got, err: err} + }() + + waitForCondition(t, "underlying password read to start", func() bool { + return calls.Load() == 1 + }) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + secondResultCh := make(chan passwordCallResult, 1) + go func() { + got, err := ReadPasswordWithContext(ctx, readPassword, 42) + secondResultCh <- passwordCallResult{b: got, err: err} + }() + + select { + case res := <-secondResultCh: + if res.b != nil { + t.Fatalf("cancelled password=%q; want nil", string(res.b)) + } + if !errors.Is(res.err, ErrInputAborted) { + t.Fatalf("cancelled err=%v; want %v", res.err, ErrInputAborted) + } + case <-time.After(100 * time.Millisecond): + t.Fatalf("concurrent cancelled password read did not return while inflight read was pending") + } + + assertNoPasswordResultYet(t, firstResultCh) + + close(release) + + res := <-firstResultCh + if res.err != nil { + t.Fatalf("first ReadPasswordWithContext error: %v", res.err) + } + if string(res.b) != "secret" { + t.Fatalf("first password=%q; want %q", string(res.b), "secret") + } +} + +func TestReadLineWithContext_ReusesInflightReadWhilePendingAfterTimeout(t *testing.T) { + src := &blockingLineReader{ + release: make(chan struct{}), + finish: make(chan struct{}), + returned: make(chan struct{}, 1), + payload: "hello\n", + } + reader := bufio.NewReader(src) + + ctx1, cancel1 := context.WithTimeout(context.Background(), 25*time.Millisecond) + defer cancel1() + _, err := ReadLineWithContext(ctx1, reader) + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("first err=%v; want %v", err, context.DeadlineExceeded) + } + + ctx2, cancel2 := context.WithTimeout(context.Background(), 25*time.Millisecond) + defer cancel2() + _, err = ReadLineWithContext(ctx2, reader) + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("second err=%v; want %v", err, context.DeadlineExceeded) + } + + if got := src.calls.Load(); got != 1 { + t.Fatalf("underlying Read calls=%d; want 1", got) + } + + resultCh := make(chan lineCallResult, 1) + started := make(chan struct{}) + go func() { + close(started) + line, err := ReadLineWithContext(context.Background(), reader) + resultCh <- lineCallResult{line: line, err: err} + }() + waitForSignal(t, started, "line retry start") + assertNoLineResultYet(t, resultCh) + + close(src.release) + close(src.finish) + waitForSignal(t, src.returned, "underlying line read completion") + + res := <-resultCh + if res.err != nil { + t.Fatalf("retry ReadLineWithContext error: %v", res.err) + } + if res.line != "hello\n" { + t.Fatalf("line=%q; want %q", res.line, "hello\n") + } + if got := src.calls.Load(); got != 1 { + t.Fatalf("underlying Read calls after pending retry=%d; want 1", got) + } +} + +func TestReadLineWithContext_PreservesCompletedReadForNextRetryAfterTimeout(t *testing.T) { + src := &blockingLineReader{ + release: make(chan struct{}), + returned: make(chan struct{}, 1), + payload: "hello\n", + } + reader := bufio.NewReader(src) + + ctx, cancel := context.WithTimeout(context.Background(), 25*time.Millisecond) + defer cancel() + _, err := ReadLineWithContext(ctx, reader) + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("first err=%v; want %v", err, context.DeadlineExceeded) + } + + inflight := currentLineInflight(t, reader) + close(src.release) + waitForSignal(t, src.returned, "underlying line read return") + waitForSignal(t, inflight.done, "line inflight completion") + assertSameLineInflight(t, reader, inflight) + + line, err := ReadLineWithContext(context.Background(), reader) + if err != nil { + t.Fatalf("retry ReadLineWithContext error: %v", err) + } + if line != "hello\n" { + t.Fatalf("line=%q; want %q", line, "hello\n") + } + if got := src.calls.Load(); got != 1 { + t.Fatalf("underlying Read calls after completed retry=%d; want 1", got) + } +} + +func TestReadPasswordWithContext_ReusesInflightReadWhilePendingAfterTimeout(t *testing.T) { + release := make(chan struct{}) + finish := make(chan struct{}) + returned := make(chan struct{}, 1) + var calls atomic.Int32 + readPassword := func(fd int) ([]byte, error) { + calls.Add(1) + <-release + <-finish + signalNonBlocking(returned) + return []byte("secret"), nil + } + + ctx1, cancel1 := context.WithTimeout(context.Background(), 25*time.Millisecond) + defer cancel1() + got, err := ReadPasswordWithContext(ctx1, readPassword, 42) + if got != nil { + t.Fatalf("expected nil bytes on first deadline") + } + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("first err=%v; want %v", err, context.DeadlineExceeded) + } + + ctx2, cancel2 := context.WithTimeout(context.Background(), 25*time.Millisecond) + defer cancel2() + got, err = ReadPasswordWithContext(ctx2, readPassword, 42) + if got != nil { + t.Fatalf("expected nil bytes on second deadline") + } + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("second err=%v; want %v", err, context.DeadlineExceeded) + } + + if gotCalls := calls.Load(); gotCalls != 1 { + t.Fatalf("readPassword calls=%d; want 1", gotCalls) + } + + resultCh := make(chan passwordCallResult, 1) + started := make(chan struct{}) + go func() { + close(started) + got, err := ReadPasswordWithContext(context.Background(), readPassword, 42) + resultCh <- passwordCallResult{b: got, err: err} + }() + waitForSignal(t, started, "password retry start") + assertNoPasswordResultYet(t, resultCh) + + close(release) + close(finish) + waitForSignal(t, returned, "underlying password read completion") + + res := <-resultCh + if res.err != nil { + t.Fatalf("retry ReadPasswordWithContext error: %v", res.err) + } + if string(res.b) != "secret" { + t.Fatalf("got=%q; want %q", string(res.b), "secret") + } + if gotCalls := calls.Load(); gotCalls != 1 { + t.Fatalf("readPassword calls after pending retry=%d; want 1", gotCalls) + } +} + +func TestReadPasswordWithContext_PreservesCompletedReadForNextRetryAfterTimeout(t *testing.T) { + release := make(chan struct{}) + returned := make(chan struct{}, 1) + var calls atomic.Int32 + readPassword := func(fd int) ([]byte, error) { + calls.Add(1) + <-release + signalNonBlocking(returned) + return []byte("secret"), nil + } + + ctx, cancel := context.WithTimeout(context.Background(), 25*time.Millisecond) + defer cancel() + got, err := ReadPasswordWithContext(ctx, readPassword, 42) + if got != nil { + t.Fatalf("expected nil bytes on first deadline") + } + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("first err=%v; want %v", err, context.DeadlineExceeded) + } + + inflight := currentPasswordInflight(t, 42) + close(release) + waitForSignal(t, returned, "underlying password read return") + waitForSignal(t, inflight.done, "password inflight completion") + assertSamePasswordInflight(t, 42, inflight) + + got, err = ReadPasswordWithContext(context.Background(), readPassword, 42) + if err != nil { + t.Fatalf("retry ReadPasswordWithContext error: %v", err) + } + if string(got) != "secret" { + t.Fatalf("got=%q; want %q", string(got), "secret") + } + if gotCalls := calls.Load(); gotCalls != 1 { + t.Fatalf("readPassword calls after completed retry=%d; want 1", gotCalls) + } +} + +func TestReadLineWithContext_ClearsInflightAfterCompletedRetryConsumesResult(t *testing.T) { + src := &blockingLineReader{ + release: make(chan struct{}), + returned: make(chan struct{}, 1), + payload: "hello\n", + } + reader := bufio.NewReader(src) + + ctx, cancel := context.WithTimeout(context.Background(), 25*time.Millisecond) + defer cancel() + _, err := ReadLineWithContext(ctx, reader) + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("first err=%v; want %v", err, context.DeadlineExceeded) + } + + inflight := currentLineInflight(t, reader) + close(src.release) + waitForSignal(t, src.returned, "underlying line read return") + waitForSignal(t, inflight.done, "line inflight completion") + assertSameLineInflight(t, reader, inflight) + + line, err := ReadLineWithContext(context.Background(), reader) + if err != nil { + t.Fatalf("retry ReadLineWithContext error: %v", err) + } + if line != "hello\n" { + t.Fatalf("line=%q; want %q", line, "hello\n") + } + assertLineInflightCleared(t, reader) +} + +func TestReadPasswordWithContext_ClearsInflightAfterCompletedRetryConsumesResult(t *testing.T) { + release := make(chan struct{}) + returned := make(chan struct{}, 1) + var calls atomic.Int32 + readPassword := func(fd int) ([]byte, error) { + calls.Add(1) + <-release + signalNonBlocking(returned) + return []byte("secret"), nil + } + + const fd = 43 + + ctx, cancel := context.WithTimeout(context.Background(), 25*time.Millisecond) + defer cancel() + got, err := ReadPasswordWithContext(ctx, readPassword, fd) + if got != nil { + t.Fatalf("expected nil bytes on first deadline") + } + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("first err=%v; want %v", err, context.DeadlineExceeded) + } + + inflight := currentPasswordInflight(t, fd) + close(release) + waitForSignal(t, returned, "underlying password read return") + waitForSignal(t, inflight.done, "password inflight completion") + assertSamePasswordInflight(t, fd, inflight) + + got, err = ReadPasswordWithContext(context.Background(), readPassword, fd) + if err != nil { + t.Fatalf("retry ReadPasswordWithContext error: %v", err) + } + if string(got) != "secret" { + t.Fatalf("got=%q; want %q", string(got), "secret") + } + assertPasswordInflightCleared(t, fd) +} diff --git a/internal/orchestrator/backup_safety.go b/internal/orchestrator/backup_safety.go index 5402897..59ae25c 100644 --- a/internal/orchestrator/backup_safety.go +++ b/internal/orchestrator/backup_safety.go @@ -23,38 +23,8 @@ type safetyBackupSpec struct { WriteLocationFile bool } -// resolveAndCheckPath cleans and resolves symlinks for candidate extraction paths -// and verifies the resolved path is still within destRoot. func resolveAndCheckPath(destRoot, candidate string) (string, error) { - joined := candidate - if !filepath.IsAbs(candidate) { - joined = filepath.Join(destRoot, candidate) - } - - resolved, err := filepath.EvalSymlinks(joined) - if err != nil { - // If the path doesn't exist yet, EvalSymlinks will fail; fallback to the cleaned path. - resolved = filepath.Clean(joined) - } - - absDestRoot, err := filepath.Abs(destRoot) - if err != nil { - return "", fmt.Errorf("cannot resolve destination root: %w", err) - } - absResolved, err := filepath.Abs(resolved) - if err != nil { - return "", fmt.Errorf("cannot resolve candidate path: %w", err) - } - - rel, err := filepath.Rel(absDestRoot, absResolved) - if err != nil { - return "", fmt.Errorf("cannot compute relative path: %w", err) - } - if strings.HasPrefix(rel, ".."+string(os.PathSeparator)) || rel == ".." || filepath.IsAbs(rel) { - return "", fmt.Errorf("resolved path escapes destination: %s", absResolved) - } - - return absResolved, nil + return resolvePathWithinRootFS(safetyFS, destRoot, candidate) } // SafetyBackupResult contains information about the safety backup @@ -391,7 +361,7 @@ func RestoreSafetyBackup(logger *logging.Logger, backupPath string, destRoot str return fmt.Errorf("read tar entry: %w", err) } - target, _, err := sanitizeRestoreEntryTarget(absDestRoot, header.Name) + target, _, err := sanitizeRestoreEntryTargetWithFS(safetyFS, absDestRoot, header.Name) if err != nil { logger.Warning("Skipping archive entry %s: %v", header.Name, err) continue @@ -425,16 +395,12 @@ func RestoreSafetyBackup(logger *logging.Logger, backupPath string, destRoot str if header.Typeflag == tar.TypeSymlink { linkTarget := header.Linkname - // Resolve intended target relative to the sanitized symlink directory inside the archive - sanitizedDir := filepath.Dir(relTarget) - resolvedLinkPath := linkTarget - if !filepath.IsAbs(linkTarget) { - resolvedLinkPath = filepath.Join(sanitizedDir, linkTarget) - } - - if _, pathErr := resolveAndCheckPath(destRoot, resolvedLinkPath); pathErr != nil { - logger.Warning("Skipping symlink %s -> %s: target escapes root: %v", target, linkTarget, pathErr) - continue + if _, pathErr := resolvePathRelativeToBaseWithinRootFS(safetyFS, absDestRoot, filepath.Dir(target), linkTarget); pathErr != nil { + if isPathSecurityError(pathErr) { + logger.Warning("Skipping symlink %s -> %s: target escapes root: %v", target, linkTarget, pathErr) + continue + } + return fmt.Errorf("validate symlink %s -> %s before creation: %w", target, linkTarget, pathErr) } // Remove existing file/symlink before creating new one @@ -454,35 +420,14 @@ func RestoreSafetyBackup(logger *logging.Logger, backupPath string, destRoot str continue } - // Resolve the symlink target relative to the symlink's directory - symlinkTargetDir := filepath.Dir(target) - resolvedTarget := actualTarget - if !filepath.IsAbs(actualTarget) { - resolvedTarget = filepath.Join(symlinkTargetDir, actualTarget) - } - - // Validate the resolved target stays within destRoot - absDestRoot, err := filepath.Abs(destRoot) - if err != nil { - logger.Warning("Cannot resolve destination root: %v", err) + if _, err := resolvePathRelativeToBaseWithinRootFS(safetyFS, absDestRoot, filepath.Dir(target), actualTarget); err != nil { safetyFS.Remove(target) - continue - } - - absResolvedTarget, err := filepath.Abs(resolvedTarget) - if err != nil { - logger.Warning("Cannot resolve symlink target: %v", err) - safetyFS.Remove(target) - continue - } - - // Check if resolved target is within destRoot - rel, err := filepath.Rel(absDestRoot, absResolvedTarget) - if err != nil || strings.HasPrefix(rel, ".."+string(os.PathSeparator)) || rel == ".." { - logger.Warning("Removing symlink %s -> %s: target escapes root after creation (resolves to %s)", - target, actualTarget, absResolvedTarget) - safetyFS.Remove(target) - continue + if isPathSecurityError(err) { + logger.Warning("Removing symlink %s -> %s: target escapes root after creation: %v", + target, actualTarget, err) + continue + } + return fmt.Errorf("validate symlink %s -> %s after creation: %w", target, actualTarget, err) } logger.Debug("Created safe symlink: %s -> %s", header.Name, linkTarget) diff --git a/internal/orchestrator/backup_safety_test.go b/internal/orchestrator/backup_safety_test.go index 2529e2e..6e41abd 100644 --- a/internal/orchestrator/backup_safety_test.go +++ b/internal/orchestrator/backup_safety_test.go @@ -136,6 +136,36 @@ func (r *readDirMockFS) ReadDir(path string) ([]os.DirEntry, error) { return r.osFS.ReadDir(path) } +type restorePathValidationMockFS struct { + osFS + lstatErr map[string]error + symlinkHook func(oldname, newname string) +} + +func newRestorePathValidationMockFS() *restorePathValidationMockFS { + return &restorePathValidationMockFS{ + lstatErr: make(map[string]error), + } +} + +func (r *restorePathValidationMockFS) Lstat(path string) (os.FileInfo, error) { + cleanPath := filepath.Clean(path) + if err, ok := r.lstatErr[cleanPath]; ok { + return nil, err + } + return r.osFS.Lstat(path) +} + +func (r *restorePathValidationMockFS) Symlink(oldname, newname string) error { + if err := r.osFS.Symlink(oldname, newname); err != nil { + return err + } + if r.symlinkHook != nil { + r.symlinkHook(oldname, newname) + } + return nil +} + // fakeDirEntry implements os.DirEntry for testing. type fakeDirEntry struct { name string @@ -400,6 +430,152 @@ func TestRestoreSafetyBackup_DoesNotFollowExistingSymlinkTargetPath(t *testing.T } } +func TestRestoreSafetyBackup_RejectsBrokenIntermediateSymlinkEscape(t *testing.T) { + logger := logging.New(types.LogLevelInfo, false) + var logBuf bytes.Buffer + logger.SetOutput(&logBuf) + + tmpDir := t.TempDir() + backupPath := filepath.Join(tmpDir, "broken-escape.tar.gz") + restoreDir := filepath.Join(tmpDir, "restore") + outside := t.TempDir() + + if err := os.MkdirAll(restoreDir, 0o755); err != nil { + t.Fatalf("mkdir restore dir: %v", err) + } + if err := os.Symlink(outside, filepath.Join(restoreDir, "escape-link")); err != nil { + t.Fatalf("create escape symlink: %v", err) + } + + var buf bytes.Buffer + gzw := gzip.NewWriter(&buf) + tw := tar.NewWriter(gzw) + if err := tw.WriteHeader(&tar.Header{Name: "escape-link/missing/", Mode: 0o755, Typeflag: tar.TypeDir}); err != nil { + t.Fatalf("write dir header failed: %v", err) + } + if err := tw.Close(); err != nil { + t.Fatalf("tar close failed: %v", err) + } + if err := gzw.Close(); err != nil { + t.Fatalf("gzip close failed: %v", err) + } + if err := os.WriteFile(backupPath, buf.Bytes(), 0o644); err != nil { + t.Fatalf("write archive failed: %v", err) + } + + if err := RestoreSafetyBackup(logger, backupPath, restoreDir); err != nil { + t.Fatalf("RestoreSafetyBackup failed: %v", err) + } + + if _, err := os.Stat(filepath.Join(outside, "missing")); !os.IsNotExist(err) { + t.Fatalf("outside path should not be created, got err=%v", err) + } + if !strings.Contains(logBuf.String(), "Skipping archive entry") { + t.Fatalf("expected archive entry skip warning, logs=%q", logBuf.String()) + } +} + +func TestRestoreSafetyBackup_RejectsEscapeWhenParentPathIsSymlink(t *testing.T) { + logger := logging.New(types.LogLevelInfo, false) + var logBuf bytes.Buffer + logger.SetOutput(&logBuf) + + tmpDir := t.TempDir() + backupPath := filepath.Join(tmpDir, "symlinked-parent-escape.tar.gz") + restoreDir := filepath.Join(tmpDir, "restore") + + var buf bytes.Buffer + gzw := gzip.NewWriter(&buf) + tw := tar.NewWriter(gzw) + + if err := tw.WriteHeader(&tar.Header{Name: "linkdir", Typeflag: tar.TypeSymlink, Linkname: "."}); err != nil { + t.Fatalf("write parent symlink header failed: %v", err) + } + if err := tw.WriteHeader(&tar.Header{Name: "linkdir/escape", Typeflag: tar.TypeSymlink, Linkname: "../outside"}); err != nil { + t.Fatalf("write escaping symlink header failed: %v", err) + } + + if err := tw.Close(); err != nil { + t.Fatalf("tar close failed: %v", err) + } + if err := gzw.Close(); err != nil { + t.Fatalf("gzip close failed: %v", err) + } + if err := os.WriteFile(backupPath, buf.Bytes(), 0o644); err != nil { + t.Fatalf("write archive failed: %v", err) + } + + if err := RestoreSafetyBackup(logger, backupPath, restoreDir); err != nil { + t.Fatalf("RestoreSafetyBackup failed: %v", err) + } + + linkTarget, err := os.Readlink(filepath.Join(restoreDir, "linkdir")) + if err != nil { + t.Fatalf("parent symlink should exist: %v", err) + } + if linkTarget != "." { + t.Fatalf("parent symlink target = %q, want %q", linkTarget, ".") + } + + if _, err := os.Lstat(filepath.Join(restoreDir, "escape")); !os.IsNotExist(err) { + t.Fatalf("escaping symlink should not be created, got err=%v", err) + } + if !strings.Contains(logBuf.String(), "Skipping symlink") { + t.Fatalf("expected symlink skip warning, logs=%q", logBuf.String()) + } +} + +func TestRestoreSafetyBackup_AllowsSafeTargetWhenParentPathIsSymlink(t *testing.T) { + logger := logging.New(types.LogLevelInfo, false) + + tmpDir := t.TempDir() + backupPath := filepath.Join(tmpDir, "symlinked-parent-safe.tar.gz") + restoreDir := filepath.Join(tmpDir, "restore") + + var buf bytes.Buffer + gzw := gzip.NewWriter(&buf) + tw := tar.NewWriter(gzw) + + if err := tw.WriteHeader(&tar.Header{Name: "subdir/", Mode: 0o755, Typeflag: tar.TypeDir}); err != nil { + t.Fatalf("write subdir header failed: %v", err) + } + content := []byte("ok") + if err := tw.WriteHeader(&tar.Header{Name: "subdir/file.txt", Mode: 0o644, Size: int64(len(content))}); err != nil { + t.Fatalf("write file header failed: %v", err) + } + if _, err := tw.Write(content); err != nil { + t.Fatalf("write file content failed: %v", err) + } + if err := tw.WriteHeader(&tar.Header{Name: "linkdir", Typeflag: tar.TypeSymlink, Linkname: "subdir"}); err != nil { + t.Fatalf("write parent symlink header failed: %v", err) + } + if err := tw.WriteHeader(&tar.Header{Name: "linkdir/ok", Typeflag: tar.TypeSymlink, Linkname: "file.txt"}); err != nil { + t.Fatalf("write safe symlink header failed: %v", err) + } + + if err := tw.Close(); err != nil { + t.Fatalf("tar close failed: %v", err) + } + if err := gzw.Close(); err != nil { + t.Fatalf("gzip close failed: %v", err) + } + if err := os.WriteFile(backupPath, buf.Bytes(), 0o644); err != nil { + t.Fatalf("write archive failed: %v", err) + } + + if err := RestoreSafetyBackup(logger, backupPath, restoreDir); err != nil { + t.Fatalf("RestoreSafetyBackup failed: %v", err) + } + + linkTarget, err := os.Readlink(filepath.Join(restoreDir, "subdir", "ok")) + if err != nil { + t.Fatalf("safe symlink should exist: %v", err) + } + if linkTarget != "file.txt" { + t.Fatalf("safe symlink target = %q, want %q", linkTarget, "file.txt") + } +} + func TestCleanupOldSafetyBackups(t *testing.T) { logger := logging.New(types.LogLevelInfo, false) @@ -667,6 +843,18 @@ func TestResolveAndCheckPathRejectsSymlinkEscape(t *testing.T) { } } +func TestResolveAndCheckPathRejectsBrokenIntermediateSymlinkEscape(t *testing.T) { + root := t.TempDir() + outside := t.TempDir() + if err := os.Symlink(outside, filepath.Join(root, "escape-link")); err != nil { + t.Fatalf("create symlink: %v", err) + } + + if _, err := resolveAndCheckPath(root, filepath.Join("escape-link", "missing", "data.txt")); err == nil { + t.Fatalf("expected broken symlink escape to be rejected") + } +} + // ===================================== // walkFS / walkFSRecursive tests // ===================================== @@ -1410,6 +1598,108 @@ func TestRestoreSafetyBackup_SymlinkExistsRemoveAndCreate(t *testing.T) { } } +func TestRestoreSafetyBackup_SymlinkValidationOperationalErrorBeforeCreationFails(t *testing.T) { + fs := newRestorePathValidationMockFS() + origFS := safetyFS + safetyFS = fs + t.Cleanup(func() { safetyFS = origFS }) + + logger := logging.New(types.LogLevelInfo, false) + + tmpDir := t.TempDir() + backupPath := filepath.Join(tmpDir, "backup.tar.gz") + restoreDir := filepath.Join(tmpDir, "restore") + operationalErr := errors.New("permission denied") + + var buf bytes.Buffer + gzw := gzip.NewWriter(&buf) + tw := tar.NewWriter(gzw) + + if err := tw.WriteHeader(&tar.Header{ + Name: "nested/link", + Typeflag: tar.TypeSymlink, + Linkname: "target.txt", + }); err != nil { + t.Fatalf("write header: %v", err) + } + + if err := tw.Close(); err != nil { + t.Fatalf("tar close: %v", err) + } + if err := gzw.Close(); err != nil { + t.Fatalf("gzip close: %v", err) + } + if err := os.WriteFile(backupPath, buf.Bytes(), 0o644); err != nil { + t.Fatalf("write archive: %v", err) + } + + fs.lstatErr[filepath.Join(restoreDir, "nested")] = operationalErr + + err := RestoreSafetyBackup(logger, backupPath, restoreDir) + if err == nil { + t.Fatal("expected restore failure") + } + if !strings.Contains(err.Error(), "validate symlink") || !strings.Contains(err.Error(), "permission denied") { + t.Fatalf("unexpected error: %v", err) + } + + if _, statErr := os.Lstat(filepath.Join(restoreDir, "nested", "link")); !os.IsNotExist(statErr) { + t.Fatalf("symlink should not exist, got err=%v", statErr) + } +} + +func TestRestoreSafetyBackup_SymlinkValidationOperationalErrorAfterCreationFails(t *testing.T) { + fs := newRestorePathValidationMockFS() + origFS := safetyFS + safetyFS = fs + t.Cleanup(func() { safetyFS = origFS }) + + logger := logging.New(types.LogLevelInfo, false) + + tmpDir := t.TempDir() + backupPath := filepath.Join(tmpDir, "backup.tar.gz") + restoreDir := filepath.Join(tmpDir, "restore") + operationalErr := errors.New("permission denied") + + var buf bytes.Buffer + gzw := gzip.NewWriter(&buf) + tw := tar.NewWriter(gzw) + + if err := tw.WriteHeader(&tar.Header{ + Name: "link", + Typeflag: tar.TypeSymlink, + Linkname: "target.txt", + }); err != nil { + t.Fatalf("write header: %v", err) + } + + if err := tw.Close(); err != nil { + t.Fatalf("tar close: %v", err) + } + if err := gzw.Close(); err != nil { + t.Fatalf("gzip close: %v", err) + } + if err := os.WriteFile(backupPath, buf.Bytes(), 0o644); err != nil { + t.Fatalf("write archive: %v", err) + } + + fs.symlinkHook = func(_, _ string) { + fs.lstatErr[filepath.Join(restoreDir, "target.txt")] = operationalErr + } + + err := RestoreSafetyBackup(logger, backupPath, restoreDir) + if err == nil { + t.Fatal("expected restore failure") + } + if !strings.Contains(err.Error(), "validate symlink") || !strings.Contains(err.Error(), "permission denied") { + t.Fatalf("unexpected error: %v", err) + } + + if _, statErr := os.Lstat(filepath.Join(restoreDir, "link")); !os.IsNotExist(statErr) { + t.Fatalf("symlink should have been removed, got err=%v", statErr) + } +} + // ===================================== // backupDirectory additional tests // ===================================== diff --git a/internal/orchestrator/backup_sources.go b/internal/orchestrator/backup_sources.go index aac895e..c5def50 100644 --- a/internal/orchestrator/backup_sources.go +++ b/internal/orchestrator/backup_sources.go @@ -1,9 +1,12 @@ package orchestrator import ( + "bufio" + "bytes" "context" "errors" "fmt" + "io" "os/exec" "path" "path/filepath" @@ -130,6 +133,7 @@ func discoverRcloneBackups(ctx context.Context, cfg *config.Config, remotePath s emptyEntries := 0 nonCandidateEntries := 0 manifestErrors := 0 + integrityMissing := 0 logDebug(logger, "Cloud (rclone): scanned %d entries from rclone lsf output", totalEntries) snapshot := make(map[string]struct{}, len(lines)) @@ -211,10 +215,10 @@ func discoverRcloneBackups(ctx context.Context, cfg *config.Config, remotePath s report(fmt.Sprintf("Inspecting %d/%d: %s", idx+1, len(items), item.filename)) } - itemCtx, cancel := context.WithTimeout(ctx, timeout) switch item.kind { case sourceBundle: - manifest, perr := inspectRcloneBundleManifest(itemCtx, item.remoteBundle, logger) + bundleCtx, cancel := context.WithTimeout(ctx, timeout) + manifest, perr := inspectRcloneBundleManifest(bundleCtx, item.remoteBundle, logger) cancel() if perr != nil { if errors.Is(perr, context.DeadlineExceeded) { @@ -242,8 +246,9 @@ func discoverRcloneBackups(ctx context.Context, cfg *config.Config, remotePath s logDebug(logger, "Cloud (rclone): accepted backup bundle: %s", item.filename) case sourceRaw: - manifest, perr := inspectRcloneMetadataManifest(itemCtx, item.remoteMetadata, item.remoteArchive, logger) - cancel() + manifestCtx, manifestCancel := context.WithTimeout(ctx, timeout) + manifest, perr := inspectRcloneMetadataManifest(manifestCtx, item.remoteMetadata, item.remoteArchive, logger) + manifestCancel() if perr != nil { if errors.Is(perr, context.DeadlineExceeded) { return nil, fmt.Errorf("timed out while inspecting %s (timeout=%s). Increase RCLONE_TIMEOUT_CONNECTION if needed: %w", item.filename, timeout, perr) @@ -255,6 +260,35 @@ func discoverRcloneBackups(ctx context.Context, cfg *config.Config, remotePath s logWarning(logger, "Skipping rclone metadata %s: %v", item.filename, perr) continue } + manifestChecksum, perr := normalizeCandidateManifestChecksum(manifest) + if perr != nil { + integrityMissing++ + logWarning(logger, "Skipping rclone backup %s: invalid manifest checksum: %v", item.filename, perr) + continue + } + checksumFromFile := "" + if item.remoteChecksum != "" { + checksumCtx, checksumCancel := context.WithTimeout(ctx, timeout) + checksumFromFile, perr = inspectRcloneChecksumFile(checksumCtx, item.remoteChecksum, logger) + checksumCancel() + if perr != nil { + if errors.Is(perr, context.DeadlineExceeded) { + return nil, fmt.Errorf("timed out while inspecting %s checksum (timeout=%s). Increase RCLONE_TIMEOUT_CONNECTION if needed: %w", item.filename, timeout, perr) + } + if errors.Is(perr, context.Canceled) { + return nil, perr + } + integrityMissing++ + logWarning(logger, "Skipping rclone backup %s: invalid checksum file: %v", item.filename, perr) + continue + } + } + expectation, perr := resolveIntegrityExpectationValues(checksumFromFile, manifestChecksum) + if perr != nil { + integrityMissing++ + logWarning(logger, "Skipping rclone backup %s: %v", item.filename, perr) + continue + } displayBase := filepath.Base(manifest.ArchivePath) if strings.TrimSpace(displayBase) == "" { displayBase = filepath.Base(baseNameFromRemoteRef(item.remoteArchive)) @@ -265,11 +299,11 @@ func discoverRcloneBackups(ctx context.Context, cfg *config.Config, remotePath s RawArchivePath: item.remoteArchive, RawMetadataPath: item.remoteMetadata, RawChecksumPath: item.remoteChecksum, + Integrity: expectation, DisplayBase: displayBase, IsRclone: true, }) default: - cancel() continue } } @@ -292,11 +326,12 @@ func discoverRcloneBackups(ctx context.Context, cfg *config.Config, remotePath s logging.DebugStep( logger, "discover rclone backups", - "summary entries=%d empty=%d non_candidate=%d manifest_errors=%d accepted=%d elapsed=%s", + "summary entries=%d empty=%d non_candidate=%d manifest_errors=%d integrity_missing=%d accepted=%d elapsed=%s", totalEntries, emptyEntries, nonCandidateEntries, manifestErrors, + integrityMissing, len(candidates), time.Since(start), ) @@ -336,6 +371,7 @@ func discoverBackupCandidates(logger *logging.Logger, root string) (candidates [ metadataMissingArchive := 0 metadataManifestErrors := 0 checksumMissing := 0 + integrityUnavailable := 0 for _, entry := range entries { if entry.IsDir() { @@ -392,11 +428,26 @@ func discoverBackupCandidates(logger *logging.Logger, root string) (candidates [ logWarning(logger, "Skipping metadata %s: %v", name, err) continue } - logging.DebugStep(logger, "discover backup candidates", "raw candidate accepted: %s created_at=%s", name, manifest.CreatedAt.Format(time.RFC3339)) - - // If checksum is missing from both file and manifest, warn user - if !hasChecksum && manifest.SHA256 == "" { - logWarning(logger, "Backup %s has no checksum verification available", baseName) + manifestChecksum, err := normalizeCandidateManifestChecksum(manifest) + if err != nil { + integrityUnavailable++ + logWarning(logger, "Skipping backup %s: invalid manifest checksum: %v", baseName, err) + continue + } + checksumFromFile := "" + if hasChecksum { + checksumFromFile, err = parseLocalChecksumFile(checksumPath) + if err != nil { + integrityUnavailable++ + logWarning(logger, "Skipping backup %s: invalid checksum file: %v", baseName, err) + continue + } + } + expectation, err := resolveIntegrityExpectationValues(checksumFromFile, manifestChecksum) + if err != nil { + integrityUnavailable++ + logWarning(logger, "Skipping backup %s: %v", baseName, err) + continue } rawBases[baseName] = struct{}{} @@ -406,8 +457,10 @@ func discoverBackupCandidates(logger *logging.Logger, root string) (candidates [ RawArchivePath: archivePath, RawMetadataPath: fullPath, RawChecksumPath: checksumPath, + Integrity: expectation, DisplayBase: filepath.Base(manifest.ArchivePath), }) + logging.DebugStep(logger, "discover backup candidates", "raw candidate accepted: %s created_at=%s", name, manifest.CreatedAt.Format(time.RFC3339)) } } @@ -418,7 +471,7 @@ func discoverBackupCandidates(logger *logging.Logger, root string) (candidates [ logging.DebugStep( logger, "discover backup candidates", - "summary entries=%d files=%d dirs=%d bundles=%d bundle_manifest_errors=%d metadata=%d metadata_duplicate=%d metadata_missing_archive=%d metadata_manifest_errors=%d checksum_missing=%d candidates=%d", + "summary entries=%d files=%d dirs=%d bundles=%d bundle_manifest_errors=%d metadata=%d metadata_duplicate=%d metadata_missing_archive=%d metadata_manifest_errors=%d checksum_missing=%d integrity_unavailable=%d candidates=%d", len(entries), filesSeen, dirsSkipped, @@ -429,11 +482,101 @@ func discoverBackupCandidates(logger *logging.Logger, root string) (candidates [ metadataMissingArchive, metadataManifestErrors, checksumMissing, + integrityUnavailable, len(candidates), ) return candidates, nil } +func normalizeCandidateManifestChecksum(manifest *backup.Manifest) (string, error) { + if manifest == nil || strings.TrimSpace(manifest.SHA256) == "" { + return "", nil + } + normalized, err := backup.NormalizeChecksum(manifest.SHA256) + if err != nil { + return "", err + } + manifest.SHA256 = normalized + return normalized, nil +} + +const checksumFileReadLimit = 4 * 1024 + +func readBoundedChecksumLine(reader io.Reader) ([]byte, bool, error) { + limited := io.LimitReader(reader, checksumFileReadLimit+1) + line, err := bufio.NewReaderSize(limited, checksumFileReadLimit+1).ReadSlice('\n') + if err == nil { + return append([]byte(nil), line...), true, nil + } + if errors.Is(err, bufio.ErrBufferFull) || len(line) > checksumFileReadLimit { + return nil, true, fmt.Errorf("checksum file exceeds %d bytes before newline", checksumFileReadLimit) + } + if errors.Is(err, io.EOF) { + if len(line) == 0 { + return nil, false, fmt.Errorf("checksum file is empty") + } + return append([]byte(nil), line...), false, nil + } + return nil, false, err +} + +func parseLocalChecksumFile(checksumPath string) (string, error) { + file, err := restoreFS.Open(checksumPath) + if err != nil { + return "", fmt.Errorf("read checksum file %s: %w", checksumPath, err) + } + defer file.Close() + + data, _, err := readBoundedChecksumLine(file) + if err != nil { + return "", fmt.Errorf("read checksum file %s: %w", checksumPath, err) + } + checksum, err := backup.ParseChecksumData(data) + if err != nil { + return "", fmt.Errorf("parse checksum file %s: %w", checksumPath, err) + } + return checksum, nil +} + +func inspectRcloneChecksumFile(ctx context.Context, remotePath string, logger *logging.Logger) (checksum string, err error) { + done := logging.DebugStart(logger, "inspect rclone checksum", "remote=%s", remotePath) + defer func() { done(err) }() + logging.DebugStep(logger, "inspect rclone checksum", "executing: rclone cat %s", remotePath) + + cmd := exec.CommandContext(ctx, "rclone", "cat", remotePath) + stdout, err := cmd.StdoutPipe() + if err != nil { + return "", fmt.Errorf("start rclone cat %s: %w", remotePath, err) + } + var stderr bytes.Buffer + cmd.Stderr = &stderr + if err := cmd.Start(); err != nil { + return "", fmt.Errorf("start rclone cat %s: %w", remotePath, err) + } + + data, closedEarly, readErr := readBoundedChecksumLine(stdout) + if closedEarly { + _ = stdout.Close() + } + waitErr := cmd.Wait() + stderrOutput := strings.TrimSpace(stderr.String()) + ignoreWaitErr := closedEarly && stderrOutput == "" + if readErr != nil { + if waitErr != nil && !ignoreWaitErr { + return "", fmt.Errorf("rclone cat %s failed: %w (output: %s)", remotePath, waitErr, stderrOutput) + } + return "", fmt.Errorf("read checksum file %s: %w", remotePath, readErr) + } + if waitErr != nil && !ignoreWaitErr { + return "", fmt.Errorf("rclone cat %s failed: %w (output: %s)", remotePath, waitErr, stderrOutput) + } + checksum, err = backup.ParseChecksumData(data) + if err != nil { + return "", fmt.Errorf("parse checksum file %s: %w", remotePath, err) + } + return checksum, nil +} + // isLocalFilesystemPath returns true if the given value represents an absolute // local filesystem path (and not an rclone-style "remote:path" reference). func isLocalFilesystemPath(path string) bool { diff --git a/internal/orchestrator/backup_sources_test.go b/internal/orchestrator/backup_sources_test.go index 30b9f0d..ad27869 100644 --- a/internal/orchestrator/backup_sources_test.go +++ b/internal/orchestrator/backup_sources_test.go @@ -235,6 +235,7 @@ func TestDiscoverRcloneBackups_IncludesRawMetadata(t *testing.T) { ProxmoxVersion: "8.1", CreatedAt: time.Date(2025, 12, 5, 12, 0, 0, 0, time.UTC), EncryptionMode: "none", + SHA256: checksumHexForBytes([]byte("node-backup-20251205")), } metaBytes, err := json.Marshal(&manifest) if err != nil { @@ -317,6 +318,7 @@ func TestDiscoverRcloneBackups_MixedCandidatesSortedByCreatedAt(t *testing.T) { CreatedAt: time.Date(2025, 1, 3, 0, 0, 0, 0, time.UTC), EncryptionMode: "none", ProxmoxType: "pve", + SHA256: checksumHexForBytes([]byte("x")), } rawNewestData, _ := json.Marshal(&rawNewest) if err := os.WriteFile(rawNewestMeta, rawNewestData, 0o600); err != nil { @@ -354,6 +356,7 @@ func TestDiscoverRcloneBackups_MixedCandidatesSortedByCreatedAt(t *testing.T) { CreatedAt: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC), EncryptionMode: "none", ProxmoxType: "pve", + SHA256: checksumHexForBytes([]byte("x")), } rawOldData, _ := json.Marshal(&rawOld) if err := os.WriteFile(rawOldMeta, rawOldData, 0o600); err != nil { @@ -629,7 +632,7 @@ esac return manifest, cleanup } -func TestDiscoverBackupCandidates_NoLoggerStillCollectsRawArtifacts(t *testing.T) { +func TestDiscoverBackupCandidates_NoLoggerSkipsRawArtifactsWithoutChecksumVerification(t *testing.T) { tmpDir := t.TempDir() archivePath := filepath.Join(tmpDir, "config.tar.xz") if err := os.WriteFile(archivePath, []byte("dummy"), 0o600); err != nil { @@ -652,6 +655,44 @@ func TestDiscoverBackupCandidates_NoLoggerStillCollectsRawArtifacts(t *testing.T } // Intentionally skip checksum to exercise warning path with nil logger. + candidates, err := discoverBackupCandidates(nil, tmpDir) + if err != nil { + t.Fatalf("discoverBackupCandidates() error = %v", err) + } + if len(candidates) != 0 { + t.Fatalf("discoverBackupCandidates() returned %d candidates; want 0", len(candidates)) + } +} + +func TestDiscoverBackupCandidates_NormalizesAndStoresIntegrityExpectation(t *testing.T) { + tmpDir := t.TempDir() + archiveData := []byte("archive") + archivePath := filepath.Join(tmpDir, "config.tar.xz") + if err := os.WriteFile(archivePath, archiveData, 0o600); err != nil { + t.Fatalf("write archive: %v", err) + } + + checksum := checksumHexForBytes(archiveData) + manifest := backup.Manifest{ + ArchivePath: "/etc/pve/config.tar.xz", + ProxmoxType: "pve", + CreatedAt: time.Now().UTC(), + EncryptionMode: "none", + SHA256: " " + strings.ToUpper(checksum) + " ", + } + metaPath := archivePath + ".metadata" + data, err := json.Marshal(&manifest) + if err != nil { + t.Fatalf("marshal manifest: %v", err) + } + if err := os.WriteFile(metaPath, data, 0o600); err != nil { + t.Fatalf("write metadata: %v", err) + } + checksumLine := strings.ToUpper(checksum) + " " + filepath.Base(archivePath) + "\n" + if err := os.WriteFile(archivePath+".sha256", []byte(checksumLine), 0o600); err != nil { + t.Fatalf("write checksum: %v", err) + } + candidates, err := discoverBackupCandidates(nil, tmpDir) if err != nil { t.Fatalf("discoverBackupCandidates() error = %v", err) @@ -659,10 +700,504 @@ func TestDiscoverBackupCandidates_NoLoggerStillCollectsRawArtifacts(t *testing.T if len(candidates) != 1 { t.Fatalf("discoverBackupCandidates() returned %d candidates; want 1", len(candidates)) } - if candidates[0].RawArchivePath != archivePath { - t.Fatalf("RawArchivePath = %q; want %q", candidates[0].RawArchivePath, archivePath) + cand := candidates[0] + if cand.Manifest == nil { + t.Fatal("candidate Manifest is nil") + } + if cand.Manifest.SHA256 != checksum { + t.Fatalf("Manifest.SHA256 = %q; want %q", cand.Manifest.SHA256, checksum) + } + if cand.Integrity == nil { + t.Fatal("candidate Integrity is nil") + } + if cand.Integrity.Checksum != checksum { + t.Fatalf("Integrity.Checksum = %q; want %q", cand.Integrity.Checksum, checksum) + } + if cand.Integrity.Source != "checksum file and manifest" { + t.Fatalf("Integrity.Source = %q; want %q", cand.Integrity.Source, "checksum file and manifest") + } +} + +func TestDiscoverBackupCandidates_RejectsMalformedOrConflictingChecksums(t *testing.T) { + tests := []struct { + name string + manifestSHA256 string + checksumData string + }{ + { + name: "invalid manifest checksum", + manifestSHA256: "not-a-checksum", + checksumData: string(checksumLineForBytes("config.tar.xz", []byte("archive"))), + }, + { + name: "invalid checksum file", + manifestSHA256: checksumHexForBytes([]byte("archive")), + checksumData: "not-a-checksum config.tar.xz\n", + }, + { + name: "conflicting valid checksums", + manifestSHA256: checksumHexForBytes([]byte("archive")), + checksumData: string(checksumLineForBytes("config.tar.xz", []byte("different"))), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tmpDir := t.TempDir() + archivePath := filepath.Join(tmpDir, "config.tar.xz") + if err := os.WriteFile(archivePath, []byte("archive"), 0o600); err != nil { + t.Fatalf("write archive: %v", err) + } + + manifest := backup.Manifest{ + ArchivePath: "/etc/pve/config.tar.xz", + ProxmoxType: "pve", + CreatedAt: time.Now().UTC(), + EncryptionMode: "none", + SHA256: tt.manifestSHA256, + } + metaPath := archivePath + ".metadata" + data, err := json.Marshal(&manifest) + if err != nil { + t.Fatalf("marshal manifest: %v", err) + } + if err := os.WriteFile(metaPath, data, 0o600); err != nil { + t.Fatalf("write metadata: %v", err) + } + if err := os.WriteFile(archivePath+".sha256", []byte(tt.checksumData), 0o600); err != nil { + t.Fatalf("write checksum: %v", err) + } + + candidates, err := discoverBackupCandidates(nil, tmpDir) + if err != nil { + t.Fatalf("discoverBackupCandidates() error = %v", err) + } + if len(candidates) != 0 { + t.Fatalf("discoverBackupCandidates() returned %d candidates; want 0", len(candidates)) + } + }) + } +} + +func TestParseLocalChecksumFile_RejectsOversizedInput(t *testing.T) { + origFS := restoreFS + restoreFS = osFS{} + t.Cleanup(func() { restoreFS = origFS }) + + checksumPath := filepath.Join(t.TempDir(), "config.tar.xz.sha256") + oversized := strings.Repeat("a", checksumFileReadLimit+1) + "\n" + if err := os.WriteFile(checksumPath, []byte(oversized), 0o600); err != nil { + t.Fatalf("write checksum: %v", err) + } + + _, err := parseLocalChecksumFile(checksumPath) + if err == nil { + t.Fatal("parseLocalChecksumFile() error = nil; want oversize error") + } + if !strings.Contains(err.Error(), "exceeds") { + t.Fatalf("parseLocalChecksumFile() error = %v; want oversize error", err) + } +} + +func TestParseLocalChecksumFile_AcceptsBoundedInputWithoutTrailingNewline(t *testing.T) { + origFS := restoreFS + restoreFS = osFS{} + t.Cleanup(func() { restoreFS = origFS }) + + checksumPath := filepath.Join(t.TempDir(), "config.tar.xz.sha256") + want := checksumHexForBytes([]byte("archive")) + if err := os.WriteFile(checksumPath, checksumLineForBytes("config.tar.xz", []byte("archive")), 0o600); err != nil { + t.Fatalf("write checksum: %v", err) + } + + got, err := parseLocalChecksumFile(checksumPath) + if err != nil { + t.Fatalf("parseLocalChecksumFile() error = %v", err) + } + if got != want { + t.Fatalf("parseLocalChecksumFile() = %q; want %q", got, want) + } +} + +func TestDiscoverRcloneBackups_NormalizesAndStoresIntegrityExpectation(t *testing.T) { + tmpDir := t.TempDir() + + archiveData := []byte("archive") + checksum := checksumHexForBytes(archiveData) + manifest := backup.Manifest{ + ArchivePath: "/var/backups/node-backup.tar.xz", + ProxmoxType: "pve", + CreatedAt: time.Date(2025, 12, 5, 12, 0, 0, 0, time.UTC), + EncryptionMode: "none", + SHA256: " " + strings.ToUpper(checksum) + " ", + } + metaBytes, err := json.Marshal(&manifest) + if err != nil { + t.Fatalf("marshal manifest: %v", err) + } + metadataPath := filepath.Join(tmpDir, "node-backup.tar.xz.metadata") + if err := os.WriteFile(metadataPath, metaBytes, 0o600); err != nil { + t.Fatalf("write metadata: %v", err) + } + checksumPath := filepath.Join(tmpDir, "node-backup.tar.xz.sha256") + checksumLine := strings.ToUpper(checksum) + " node-backup.tar.xz\n" + if err := os.WriteFile(checksumPath, []byte(checksumLine), 0o600); err != nil { + t.Fatalf("write checksum: %v", err) + } + + scriptPath := filepath.Join(tmpDir, "rclone") + script := `#!/bin/sh +subcmd="$1" +target="$2" +case "$subcmd" in + lsf) + printf 'node-backup.tar.xz\n' + printf 'node-backup.tar.xz.metadata\n' + printf 'node-backup.tar.xz.sha256\n' + ;; + cat) + case "$target" in + *node-backup.tar.xz.metadata) cat "$METADATA_PATH" ;; + *node-backup.tar.xz.sha256) cat "$CHECKSUM_PATH" ;; + *) echo "unexpected cat target: $target" >&2; exit 1 ;; + esac + ;; + *) + echo "unexpected subcommand: $subcmd" >&2 + exit 1 + ;; +esac +` + if err := os.WriteFile(scriptPath, []byte(script), 0o755); err != nil { + t.Fatalf("write fake rclone: %v", err) + } + + oldPath := os.Getenv("PATH") + if err := os.Setenv("PATH", tmpDir+string(os.PathListSeparator)+oldPath); err != nil { + t.Fatalf("set PATH: %v", err) + } + defer os.Setenv("PATH", oldPath) + if err := os.Setenv("METADATA_PATH", metadataPath); err != nil { + t.Fatalf("set METADATA_PATH: %v", err) + } + defer os.Unsetenv("METADATA_PATH") + if err := os.Setenv("CHECKSUM_PATH", checksumPath); err != nil { + t.Fatalf("set CHECKSUM_PATH: %v", err) + } + defer os.Unsetenv("CHECKSUM_PATH") + + candidates, err := discoverRcloneBackups(context.Background(), nil, "gdrive:pbs-backups/server1", nil, nil) + if err != nil { + t.Fatalf("discoverRcloneBackups() error = %v", err) + } + if len(candidates) != 1 { + t.Fatalf("discoverRcloneBackups() returned %d candidates; want 1", len(candidates)) + } + cand := candidates[0] + if cand.Manifest == nil { + t.Fatal("candidate Manifest is nil") + } + if cand.Manifest.SHA256 != checksum { + t.Fatalf("Manifest.SHA256 = %q; want %q", cand.Manifest.SHA256, checksum) + } + if cand.Integrity == nil { + t.Fatal("candidate Integrity is nil") + } + if cand.Integrity.Checksum != checksum { + t.Fatalf("Integrity.Checksum = %q; want %q", cand.Integrity.Checksum, checksum) + } + if cand.Integrity.Source != "checksum file and manifest" { + t.Fatalf("Integrity.Source = %q; want %q", cand.Integrity.Source, "checksum file and manifest") + } +} + +func TestDiscoverRcloneBackups_UsesFreshTimeoutForChecksumFetch(t *testing.T) { + tmpDir := t.TempDir() + + archiveData := []byte("archive") + checksum := checksumHexForBytes(archiveData) + manifest := backup.Manifest{ + ArchivePath: "/var/backups/node-backup.tar.xz", + ProxmoxType: "pve", + CreatedAt: time.Date(2025, 12, 5, 12, 0, 0, 0, time.UTC), + EncryptionMode: "none", + SHA256: checksum, + } + metaBytes, err := json.Marshal(&manifest) + if err != nil { + t.Fatalf("marshal manifest: %v", err) + } + metadataPath := filepath.Join(tmpDir, "node-backup.tar.xz.metadata") + if err := os.WriteFile(metadataPath, metaBytes, 0o600); err != nil { + t.Fatalf("write metadata: %v", err) + } + checksumPath := filepath.Join(tmpDir, "node-backup.tar.xz.sha256") + checksumLine := checksum + " node-backup.tar.xz\n" + if err := os.WriteFile(checksumPath, []byte(checksumLine), 0o600); err != nil { + t.Fatalf("write checksum: %v", err) + } + + scriptPath := filepath.Join(tmpDir, "rclone") + script := `#!/bin/sh +subcmd="$1" +target="$2" +case "$subcmd" in + lsf) + printf 'node-backup.tar.xz\n' + printf 'node-backup.tar.xz.metadata\n' + printf 'node-backup.tar.xz.sha256\n' + ;; + cat) + case "$target" in + *node-backup.tar.xz.metadata) + sleep 2 + cat "$METADATA_PATH" + ;; + *node-backup.tar.xz.sha256) + sleep 2 + cat "$CHECKSUM_PATH" + ;; + *) + echo "unexpected cat target: $target" >&2 + exit 1 + ;; + esac + ;; + *) + echo "unexpected subcommand: $subcmd" >&2 + exit 1 + ;; +esac +` + if err := os.WriteFile(scriptPath, []byte(script), 0o755); err != nil { + t.Fatalf("write fake rclone: %v", err) + } + + oldPath := os.Getenv("PATH") + if err := os.Setenv("PATH", tmpDir+string(os.PathListSeparator)+oldPath); err != nil { + t.Fatalf("set PATH: %v", err) + } + defer os.Setenv("PATH", oldPath) + if err := os.Setenv("METADATA_PATH", metadataPath); err != nil { + t.Fatalf("set METADATA_PATH: %v", err) + } + defer os.Unsetenv("METADATA_PATH") + if err := os.Setenv("CHECKSUM_PATH", checksumPath); err != nil { + t.Fatalf("set CHECKSUM_PATH: %v", err) + } + defer os.Unsetenv("CHECKSUM_PATH") + + cfg := &config.Config{RcloneTimeoutConnection: 3} + candidates, err := discoverRcloneBackups(context.Background(), cfg, "gdrive:pbs-backups/server1", nil, nil) + if err != nil { + t.Fatalf("discoverRcloneBackups() error = %v", err) + } + if len(candidates) != 1 { + t.Fatalf("discoverRcloneBackups() returned %d candidates; want 1", len(candidates)) + } + cand := candidates[0] + if cand.Integrity == nil { + t.Fatal("candidate Integrity is nil") + } + if cand.Integrity.Checksum != checksum { + t.Fatalf("Integrity.Checksum = %q; want %q", cand.Integrity.Checksum, checksum) + } +} + +func TestDiscoverRcloneBackups_RejectsMalformedOrConflictingChecksums(t *testing.T) { + tests := []struct { + name string + manifestSHA256 string + checksumData string + }{ + { + name: "invalid manifest checksum", + manifestSHA256: "not-a-checksum", + checksumData: string(checksumLineForBytes("node-backup.tar.xz", []byte("archive"))), + }, + { + name: "invalid checksum file", + manifestSHA256: checksumHexForBytes([]byte("archive")), + checksumData: "not-a-checksum node-backup.tar.xz\n", + }, + { + name: "conflicting valid checksums", + manifestSHA256: checksumHexForBytes([]byte("archive")), + checksumData: string(checksumLineForBytes("node-backup.tar.xz", []byte("different"))), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tmpDir := t.TempDir() + manifest := backup.Manifest{ + ArchivePath: "/var/backups/node-backup.tar.xz", + ProxmoxType: "pve", + CreatedAt: time.Date(2025, 12, 5, 12, 0, 0, 0, time.UTC), + EncryptionMode: "none", + SHA256: tt.manifestSHA256, + } + metaBytes, err := json.Marshal(&manifest) + if err != nil { + t.Fatalf("marshal manifest: %v", err) + } + metadataPath := filepath.Join(tmpDir, "node-backup.tar.xz.metadata") + if err := os.WriteFile(metadataPath, metaBytes, 0o600); err != nil { + t.Fatalf("write metadata: %v", err) + } + checksumPath := filepath.Join(tmpDir, "node-backup.tar.xz.sha256") + if err := os.WriteFile(checksumPath, []byte(tt.checksumData), 0o600); err != nil { + t.Fatalf("write checksum: %v", err) + } + + scriptPath := filepath.Join(tmpDir, "rclone") + script := `#!/bin/sh +subcmd="$1" +target="$2" +case "$subcmd" in + lsf) + printf 'node-backup.tar.xz\n' + printf 'node-backup.tar.xz.metadata\n' + printf 'node-backup.tar.xz.sha256\n' + ;; + cat) + case "$target" in + *node-backup.tar.xz.metadata) cat "$METADATA_PATH" ;; + *node-backup.tar.xz.sha256) cat "$CHECKSUM_PATH" ;; + *) echo "unexpected cat target: $target" >&2; exit 1 ;; + esac + ;; + *) + echo "unexpected subcommand: $subcmd" >&2 + exit 1 + ;; +esac +` + if err := os.WriteFile(scriptPath, []byte(script), 0o755); err != nil { + t.Fatalf("write fake rclone: %v", err) + } + + oldPath := os.Getenv("PATH") + if err := os.Setenv("PATH", tmpDir+string(os.PathListSeparator)+oldPath); err != nil { + t.Fatalf("set PATH: %v", err) + } + defer os.Setenv("PATH", oldPath) + if err := os.Setenv("METADATA_PATH", metadataPath); err != nil { + t.Fatalf("set METADATA_PATH: %v", err) + } + defer os.Unsetenv("METADATA_PATH") + if err := os.Setenv("CHECKSUM_PATH", checksumPath); err != nil { + t.Fatalf("set CHECKSUM_PATH: %v", err) + } + defer os.Unsetenv("CHECKSUM_PATH") + + candidates, err := discoverRcloneBackups(context.Background(), nil, "gdrive:pbs-backups/server1", nil, nil) + if err != nil { + t.Fatalf("discoverRcloneBackups() error = %v", err) + } + if len(candidates) != 0 { + t.Fatalf("discoverRcloneBackups() returned %d candidates; want 0", len(candidates)) + } + }) + } +} + +func TestInspectRcloneChecksumFile_RejectsOversizedInput(t *testing.T) { + tmpDir := t.TempDir() + checksumPath := filepath.Join(tmpDir, "node-backup.tar.xz.sha256") + oversized := strings.Repeat("a", checksumFileReadLimit+1) + "\n" + if err := os.WriteFile(checksumPath, []byte(oversized), 0o600); err != nil { + t.Fatalf("write checksum: %v", err) + } + + scriptPath := filepath.Join(tmpDir, "rclone") + script := `#!/bin/sh +if [ "$1" != "cat" ]; then + echo "unexpected subcommand: $1" >&2 + exit 1 +fi +cat "$CHECKSUM_PATH" +` + if err := os.WriteFile(scriptPath, []byte(script), 0o755); err != nil { + t.Fatalf("write fake rclone: %v", err) + } + + t.Setenv("PATH", tmpDir+string(os.PathListSeparator)+os.Getenv("PATH")) + t.Setenv("CHECKSUM_PATH", checksumPath) + + _, err := inspectRcloneChecksumFile(context.Background(), "gdrive:node-backup.tar.xz.sha256", nil) + if err == nil { + t.Fatal("inspectRcloneChecksumFile() error = nil; want oversize error") + } + if !strings.Contains(err.Error(), "exceeds") { + t.Fatalf("inspectRcloneChecksumFile() error = %v; want oversize error", err) + } +} + +func TestInspectRcloneChecksumFile_AcceptsBoundedInputWithoutTrailingNewline(t *testing.T) { + tmpDir := t.TempDir() + checksumPath := filepath.Join(tmpDir, "node-backup.tar.xz.sha256") + want := checksumHexForBytes([]byte("archive")) + if err := os.WriteFile(checksumPath, checksumLineForBytes("node-backup.tar.xz", []byte("archive")), 0o600); err != nil { + t.Fatalf("write checksum: %v", err) + } + + scriptPath := filepath.Join(tmpDir, "rclone") + script := `#!/bin/sh +if [ "$1" != "cat" ]; then + echo "unexpected subcommand: $1" >&2 + exit 1 +fi +cat "$CHECKSUM_PATH" +` + if err := os.WriteFile(scriptPath, []byte(script), 0o755); err != nil { + t.Fatalf("write fake rclone: %v", err) + } + + t.Setenv("PATH", tmpDir+string(os.PathListSeparator)+os.Getenv("PATH")) + t.Setenv("CHECKSUM_PATH", checksumPath) + + got, err := inspectRcloneChecksumFile(context.Background(), "gdrive:node-backup.tar.xz.sha256", nil) + if err != nil { + t.Fatalf("inspectRcloneChecksumFile() error = %v", err) + } + if got != want { + t.Fatalf("inspectRcloneChecksumFile() = %q; want %q", got, want) + } +} + +func TestInspectRcloneChecksumFile_SurfacesRcloneFailureAfterValidFirstLine(t *testing.T) { + tmpDir := t.TempDir() + checksumPath := filepath.Join(tmpDir, "node-backup.tar.xz.sha256") + if err := os.WriteFile(checksumPath, append(checksumLineForBytes("node-backup.tar.xz", []byte("archive")), '\n'), 0o600); err != nil { + t.Fatalf("write checksum: %v", err) + } + + scriptPath := filepath.Join(tmpDir, "rclone") + script := `#!/bin/sh +if [ "$1" != "cat" ]; then + echo "unexpected subcommand: $1" >&2 + exit 1 +fi +cat "$CHECKSUM_PATH" +echo "simulated rclone failure" >&2 +exit 1 +` + if err := os.WriteFile(scriptPath, []byte(script), 0o755); err != nil { + t.Fatalf("write fake rclone: %v", err) + } + + t.Setenv("PATH", tmpDir+string(os.PathListSeparator)+os.Getenv("PATH")) + t.Setenv("CHECKSUM_PATH", checksumPath) + + _, err := inspectRcloneChecksumFile(context.Background(), "gdrive:node-backup.tar.xz.sha256", nil) + if err == nil { + t.Fatal("inspectRcloneChecksumFile() error = nil; want rclone failure") + } + if !strings.Contains(err.Error(), "rclone cat gdrive:node-backup.tar.xz.sha256 failed") { + t.Fatalf("inspectRcloneChecksumFile() error = %v; want rclone failure", err) } - if candidates[0].RawChecksumPath != "" { - t.Fatalf("RawChecksumPath should be empty when checksum missing; got %q", candidates[0].RawChecksumPath) + if !strings.Contains(err.Error(), "simulated rclone failure") { + t.Fatalf("inspectRcloneChecksumFile() error = %v; want stderr output", err) } } diff --git a/internal/orchestrator/compatibility.go b/internal/orchestrator/compatibility.go index cc2eb3f..bde7503 100644 --- a/internal/orchestrator/compatibility.go +++ b/internal/orchestrator/compatibility.go @@ -41,12 +41,8 @@ func DetectBackupType(manifest *backup.Manifest) SystemType { // Check ProxmoxType field if present if manifest.ProxmoxType != "" { - proxmoxType := strings.ToLower(manifest.ProxmoxType) - if strings.Contains(proxmoxType, "pve") || strings.Contains(proxmoxType, "proxmox-ve") { - return SystemTypePVE - } - if strings.Contains(proxmoxType, "pbs") || strings.Contains(proxmoxType, "proxmox-backup") { - return SystemTypePBS + if backupType := parseSystemTypeString(manifest.ProxmoxType); backupType != SystemTypeUnknown { + return backupType } } @@ -65,12 +61,25 @@ func DetectBackupType(manifest *backup.Manifest) SystemType { return SystemTypeUnknown } -// ValidateCompatibility checks if a backup is compatible with the current system -func ValidateCompatibility(manifest *backup.Manifest) error { - currentSystem := DetectCurrentSystem() - backupType := DetectBackupType(manifest) +func parseSystemTypeString(value string) SystemType { + normalized := strings.ToLower(strings.TrimSpace(value)) + switch { + case strings.Contains(normalized, "pve"), + strings.Contains(normalized, "proxmox-ve"), + strings.Contains(normalized, "proxmox ve"): + return SystemTypePVE + case strings.Contains(normalized, "pbs"), + strings.Contains(normalized, "proxmox-backup"), + strings.Contains(normalized, "proxmox backup"), + strings.Contains(normalized, "proxmox backup server"): + return SystemTypePBS + default: + return SystemTypeUnknown + } +} - // If we can't detect either, issue a warning but allow +// ValidateCompatibility checks if a backup is compatible with the current system. +func ValidateCompatibility(currentSystem, backupType SystemType) error { if currentSystem == SystemTypeUnknown { return fmt.Errorf("warning: cannot detect current system type - restoration may fail") } diff --git a/internal/orchestrator/compatibility_test.go b/internal/orchestrator/compatibility_test.go index 5c7ef98..70bafb1 100644 --- a/internal/orchestrator/compatibility_test.go +++ b/internal/orchestrator/compatibility_test.go @@ -3,8 +3,6 @@ package orchestrator import ( "os" "testing" - - "github.com/tis24dev/proxsave/internal/backup" ) func TestValidateCompatibility_Mismatch(t *testing.T) { @@ -18,8 +16,7 @@ func TestValidateCompatibility_Mismatch(t *testing.T) { t.Fatalf("mkdir: %v", err) } - manifest := &backup.Manifest{ProxmoxType: "pbs"} - if err := ValidateCompatibility(manifest); err == nil { + if err := ValidateCompatibility(SystemTypePVE, SystemTypePBS); err == nil { t.Fatalf("expected incompatibility error") } } @@ -36,6 +33,38 @@ func TestDetectCurrentSystem_Unknown(t *testing.T) { } } +func TestParseSystemTypeString_AcceptsFullNames(t *testing.T) { + tests := []struct { + name string + input string + want SystemType + }{ + { + name: "pve full name with space", + input: "Proxmox VE", + want: SystemTypePVE, + }, + { + name: "pbs generic full name", + input: "Proxmox Backup", + want: SystemTypePBS, + }, + { + name: "pbs full server name", + input: "Proxmox Backup Server", + want: SystemTypePBS, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := parseSystemTypeString(tt.input); got != tt.want { + t.Fatalf("parseSystemTypeString(%q) = %v; want %v", tt.input, got, tt.want) + } + }) + } +} + func TestGetSystemInfoDetectsPVE(t *testing.T) { orig := compatFS defer func() { compatFS = orig }() diff --git a/internal/orchestrator/decrypt.go b/internal/orchestrator/decrypt.go index c995530..83b44f7 100644 --- a/internal/orchestrator/decrypt.go +++ b/internal/orchestrator/decrypt.go @@ -43,6 +43,7 @@ type decryptCandidate struct { RawArchivePath string RawMetadataPath string RawChecksumPath string + Integrity *stagedIntegrityExpectation DisplayBase string IsRclone bool } @@ -54,10 +55,11 @@ type stagedFiles struct { } type preparedBundle struct { - ArchivePath string - Manifest backup.Manifest - Checksum string - cleanup func() + ArchivePath string + Manifest backup.Manifest + Checksum string + SourceChecksum string + cleanup func() } func (p *preparedBundle) Cleanup() { diff --git a/internal/orchestrator/decrypt_integrity.go b/internal/orchestrator/decrypt_integrity.go new file mode 100644 index 0000000..aba5784 --- /dev/null +++ b/internal/orchestrator/decrypt_integrity.go @@ -0,0 +1,120 @@ +package orchestrator + +import ( + "context" + "fmt" + "strings" + + "github.com/tis24dev/proxsave/internal/backup" + "github.com/tis24dev/proxsave/internal/logging" +) + +type stagedIntegrityExpectation struct { + Checksum string + Source string +} + +func resolveIntegrityExpectationValues(checksumFromFile, checksumFromManifest string) (*stagedIntegrityExpectation, error) { + switch { + case checksumFromFile != "" && checksumFromManifest != "" && checksumFromFile != checksumFromManifest: + return nil, fmt.Errorf("checksum mismatch between checksum file and manifest") + case checksumFromFile != "" && checksumFromManifest != "": + return &stagedIntegrityExpectation{Checksum: checksumFromFile, Source: "checksum file and manifest"}, nil + case checksumFromFile != "": + return &stagedIntegrityExpectation{Checksum: checksumFromFile, Source: "checksum file"}, nil + case checksumFromManifest != "": + return &stagedIntegrityExpectation{Checksum: checksumFromManifest, Source: "manifest"}, nil + default: + return nil, fmt.Errorf("backup has no checksum verification available") + } +} + +func resolveStagedIntegrityExpectation(staged stagedFiles, manifest *backup.Manifest) (*stagedIntegrityExpectation, error) { + var ( + checksumFromFile string + checksumFromManifest string + ) + + if strings.TrimSpace(staged.ChecksumPath) != "" { + data, err := restoreFS.ReadFile(staged.ChecksumPath) + if err != nil { + return nil, fmt.Errorf("read checksum file: %w", err) + } + checksumFromFile, err = backup.ParseChecksumData(data) + if err != nil { + return nil, fmt.Errorf("parse checksum file %s: %w", staged.ChecksumPath, err) + } + } + + if manifest != nil && strings.TrimSpace(manifest.SHA256) != "" { + normalized, err := backup.NormalizeChecksum(manifest.SHA256) + if err != nil { + return nil, fmt.Errorf("parse manifest checksum: %w", err) + } + checksumFromManifest = normalized + } + + return resolveIntegrityExpectationValues(checksumFromFile, checksumFromManifest) +} + +func resolveCandidateIntegrityExpectation(staged stagedFiles, cand *decryptCandidate) (*stagedIntegrityExpectation, error) { + if cand != nil && cand.Integrity != nil && strings.TrimSpace(cand.Integrity.Checksum) != "" { + normalized, err := backup.NormalizeChecksum(cand.Integrity.Checksum) + if err != nil { + return nil, fmt.Errorf("parse candidate checksum: %w", err) + } + expectation := &stagedIntegrityExpectation{ + Checksum: normalized, + Source: strings.TrimSpace(cand.Integrity.Source), + } + if expectation.Source == "" { + expectation.Source = "candidate" + } + + if strings.TrimSpace(staged.ChecksumPath) != "" { + data, err := restoreFS.ReadFile(staged.ChecksumPath) + if err != nil { + return nil, fmt.Errorf("read checksum file: %w", err) + } + checksumFromFile, err := backup.ParseChecksumData(data) + if err != nil { + return nil, fmt.Errorf("parse checksum file %s: %w", staged.ChecksumPath, err) + } + if checksumFromFile != expectation.Checksum { + return nil, fmt.Errorf("checksum mismatch between checksum file and selected candidate") + } + } + + return expectation, nil + } + + var manifest *backup.Manifest + if cand != nil { + manifest = cand.Manifest + } + return resolveStagedIntegrityExpectation(staged, manifest) +} + +func verifyStagedArchiveIntegrity(ctx context.Context, logger *logging.Logger, staged stagedFiles, cand *decryptCandidate) (string, error) { + if staged.ArchivePath == "" { + return "", fmt.Errorf("staged archive path is empty") + } + if logger == nil { + logger = logging.GetDefaultLogger() + } + + expectation, err := resolveCandidateIntegrityExpectation(staged, cand) + if err != nil { + return "", err + } + + logger.Info("Verifying staged archive integrity using %s", expectation.Source) + ok, err := backup.VerifyChecksum(ctx, logger, staged.ArchivePath, expectation.Checksum) + if err != nil { + return "", fmt.Errorf("verify staged archive: %w", err) + } + if !ok { + return "", fmt.Errorf("staged archive checksum mismatch") + } + return expectation.Checksum, nil +} diff --git a/internal/orchestrator/decrypt_integrity_test.go b/internal/orchestrator/decrypt_integrity_test.go new file mode 100644 index 0000000..a77b685 --- /dev/null +++ b/internal/orchestrator/decrypt_integrity_test.go @@ -0,0 +1,160 @@ +package orchestrator + +import ( + "bufio" + "context" + "encoding/json" + "os" + "strings" + "testing" + "time" + + "github.com/tis24dev/proxsave/internal/backup" + "github.com/tis24dev/proxsave/internal/logging" + "github.com/tis24dev/proxsave/internal/types" +) + +func TestResolveStagedIntegrityExpectation_RejectsConflictingSources(t *testing.T) { + origFS := restoreFS + restoreFS = osFS{} + t.Cleanup(func() { restoreFS = origFS }) + + dir := t.TempDir() + checksumPath := dir + "/backup.tar.sha256" + if err := os.WriteFile(checksumPath, checksumLineForBytes("backup.tar", []byte("archive")), 0o640); err != nil { + t.Fatalf("write checksum: %v", err) + } + + _, err := resolveStagedIntegrityExpectation(stagedFiles{ChecksumPath: checksumPath}, &backup.Manifest{ + SHA256: checksumHexForBytes([]byte("different")), + }) + if err == nil { + t.Fatalf("expected conflict error") + } + if !strings.Contains(err.Error(), "checksum mismatch") { + t.Fatalf("expected checksum mismatch error, got %v", err) + } +} + +func TestPreparePlainBundle_RejectsMissingChecksumVerification(t *testing.T) { + origFS := restoreFS + restoreFS = osFS{} + t.Cleanup(func() { restoreFS = origFS }) + + dir := t.TempDir() + archivePath := dir + "/backup.tar" + if err := os.WriteFile(archivePath, []byte("archive"), 0o640); err != nil { + t.Fatalf("write archive: %v", err) + } + metaPath := archivePath + ".metadata" + manifest := &backup.Manifest{ + ArchivePath: archivePath, + CreatedAt: time.Now(), + EncryptionMode: "none", + } + data, err := json.Marshal(manifest) + if err != nil { + t.Fatalf("marshal metadata: %v", err) + } + if err := os.WriteFile(metaPath, data, 0o640); err != nil { + t.Fatalf("write metadata: %v", err) + } + + cand := &decryptCandidate{ + Manifest: manifest, + Source: sourceRaw, + RawArchivePath: archivePath, + RawMetadataPath: metaPath, + DisplayBase: "backup.tar", + } + + reader := bufio.NewReader(strings.NewReader("")) + logger := logging.New(types.LogLevelError, false) + + _, err = preparePlainBundle(context.Background(), reader, cand, "", logger) + if err == nil { + t.Fatalf("expected missing checksum verification error") + } + if !strings.Contains(err.Error(), "no checksum verification available") { + t.Fatalf("expected checksum availability error, got %v", err) + } +} + +func TestPreparePlainBundle_RejectsChecksumMismatch(t *testing.T) { + origFS := restoreFS + restoreFS = osFS{} + t.Cleanup(func() { restoreFS = origFS }) + + dir := t.TempDir() + archiveData := []byte("archive") + archivePath := dir + "/backup.tar" + if err := os.WriteFile(archivePath, archiveData, 0o640); err != nil { + t.Fatalf("write archive: %v", err) + } + metaPath := archivePath + ".metadata" + manifest := &backup.Manifest{ + ArchivePath: archivePath, + CreatedAt: time.Now(), + EncryptionMode: "none", + } + data, err := json.Marshal(manifest) + if err != nil { + t.Fatalf("marshal metadata: %v", err) + } + if err := os.WriteFile(metaPath, data, 0o640); err != nil { + t.Fatalf("write metadata: %v", err) + } + checksumPath := archivePath + ".sha256" + if err := os.WriteFile(checksumPath, checksumLineForBytes("backup.tar", []byte("tampered")), 0o640); err != nil { + t.Fatalf("write checksum: %v", err) + } + + cand := &decryptCandidate{ + Manifest: manifest, + Source: sourceRaw, + RawArchivePath: archivePath, + RawMetadataPath: metaPath, + RawChecksumPath: checksumPath, + DisplayBase: "backup.tar", + } + + reader := bufio.NewReader(strings.NewReader("")) + logger := logging.New(types.LogLevelError, false) + + _, err = preparePlainBundle(context.Background(), reader, cand, "", logger) + if err == nil { + t.Fatalf("expected checksum mismatch error") + } + if !strings.Contains(err.Error(), "checksum mismatch") { + t.Fatalf("expected checksum mismatch error, got %v", err) + } +} + +func TestVerifyStagedArchiveIntegrity_UsesCandidateIntegrityExpectation(t *testing.T) { + origFS := restoreFS + restoreFS = osFS{} + t.Cleanup(func() { restoreFS = origFS }) + + dir := t.TempDir() + archiveData := []byte("archive") + archivePath := dir + "/backup.tar" + if err := os.WriteFile(archivePath, archiveData, 0o640); err != nil { + t.Fatalf("write archive: %v", err) + } + + got, err := verifyStagedArchiveIntegrity(context.Background(), logging.New(types.LogLevelError, false), stagedFiles{ + ArchivePath: archivePath, + }, &decryptCandidate{ + Integrity: &stagedIntegrityExpectation{ + Checksum: strings.ToUpper(checksumHexForBytes(archiveData)), + Source: "checksum file", + }, + }) + if err != nil { + t.Fatalf("verifyStagedArchiveIntegrity() error = %v", err) + } + want := checksumHexForBytes(archiveData) + if got != want { + t.Fatalf("verifyStagedArchiveIntegrity() = %q; want %q", got, want) + } +} diff --git a/internal/orchestrator/decrypt_integrity_test_helpers_test.go b/internal/orchestrator/decrypt_integrity_test_helpers_test.go new file mode 100644 index 0000000..f8f022a --- /dev/null +++ b/internal/orchestrator/decrypt_integrity_test_helpers_test.go @@ -0,0 +1,15 @@ +package orchestrator + +import ( + "crypto/sha256" + "fmt" +) + +func checksumHexForBytes(data []byte) string { + sum := sha256.Sum256(data) + return fmt.Sprintf("%x", sum) +} + +func checksumLineForBytes(filename string, data []byte) []byte { + return []byte(fmt.Sprintf("%s %s", checksumHexForBytes(data), filename)) +} diff --git a/internal/orchestrator/decrypt_prepare_common.go b/internal/orchestrator/decrypt_prepare_common.go new file mode 100644 index 0000000..20b85ad --- /dev/null +++ b/internal/orchestrator/decrypt_prepare_common.go @@ -0,0 +1,186 @@ +package orchestrator + +import ( + "context" + "fmt" + "path/filepath" + "strings" + + "github.com/tis24dev/proxsave/internal/backup" + "github.com/tis24dev/proxsave/internal/logging" +) + +type archiveDecryptFunc func(ctx context.Context, encryptedPath, outputPath, displayName string) error + +func createUniquePreparedArchivePath(workDir, baseName string) (string, error) { + pattern := strings.TrimSpace(baseName) + if pattern == "" { + pattern = "archive" + } + tempFile, err := restoreFS.CreateTemp(workDir, pattern+".decrypted-*") + if err != nil { + return "", fmt.Errorf("create archive output path: %w", err) + } + path := tempFile.Name() + if err := tempFile.Close(); err != nil { + return "", fmt.Errorf("close archive output path: %w", err) + } + return path, nil +} + +func resolvePreparedArchivePath(workDir, stagedArchivePath, currentEncryption string) (string, error) { + archiveBase := filepath.Base(stagedArchivePath) + if archiveBase == "." || archiveBase == string(filepath.Separator) || strings.TrimSpace(archiveBase) == "" { + return "", fmt.Errorf("invalid staged archive path %s", stagedArchivePath) + } + + if currentEncryption == "age" { + if !strings.HasSuffix(archiveBase, ".age") { + return "", fmt.Errorf("encrypted archive %s is missing .age suffix", stagedArchivePath) + } + + plainArchiveName := strings.TrimSuffix(archiveBase, ".age") + if strings.TrimSpace(plainArchiveName) == "" { + return createUniquePreparedArchivePath(workDir, archiveBase) + } + + plainArchivePath := filepath.Join(workDir, plainArchiveName) + if plainArchivePath == stagedArchivePath { + return createUniquePreparedArchivePath(workDir, plainArchiveName) + } + return plainArchivePath, nil + } + + if strings.HasSuffix(archiveBase, ".age") { + mode := currentEncryption + if mode == "" { + mode = "plain" + } + return "", fmt.Errorf("archive %s has .age suffix but encryption mode is %s", stagedArchivePath, mode) + } + + return filepath.Join(workDir, archiveBase), nil +} + +func preparePlainBundleCommon(ctx context.Context, cand *decryptCandidate, version string, logger *logging.Logger, decryptArchive archiveDecryptFunc) (bundle *preparedBundle, err error) { + if cand == nil || cand.Manifest == nil { + return nil, fmt.Errorf("invalid backup candidate") + } + if logger == nil { + logger = logging.GetDefaultLogger() + } + + var rcloneCleanup func() + if cand.IsRclone && cand.Source == sourceBundle { + logger.Debug("Detected rclone backup, downloading...") + localPath, cleanup, err := downloadRcloneBackup(ctx, cand.BundlePath, logger) + if err != nil { + return nil, fmt.Errorf("failed to download rclone backup: %w", err) + } + rcloneCleanup = cleanup + cand.BundlePath = localPath + } + + tempRoot := filepath.Join("/tmp", "proxsave") + if err := restoreFS.MkdirAll(tempRoot, 0o755); err != nil { + if rcloneCleanup != nil { + rcloneCleanup() + } + return nil, fmt.Errorf("create temp root: %w", err) + } + + workDir, err := restoreFS.MkdirTemp(tempRoot, "proxmox-decrypt-*") + if err != nil { + if rcloneCleanup != nil { + rcloneCleanup() + } + return nil, fmt.Errorf("create temp dir: %w", err) + } + + cleanup := func() { + _ = restoreFS.RemoveAll(workDir) + if rcloneCleanup != nil { + rcloneCleanup() + } + } + + var staged stagedFiles + switch cand.Source { + case sourceBundle: + logger.Info("Extracting bundle %s", filepath.Base(cand.BundlePath)) + staged, err = extractBundleToWorkdirWithLogger(cand.BundlePath, workDir, logger) + case sourceRaw: + logger.Info("Staging raw artifacts for %s", filepath.Base(cand.RawArchivePath)) + staged, err = copyRawArtifactsToWorkdirWithLogger(ctx, cand, workDir, logger) + default: + err = fmt.Errorf("unsupported candidate source") + } + if err != nil { + cleanup() + return nil, err + } + + sourceChecksum, err := verifyStagedArchiveIntegrity(ctx, logger, staged, cand) + if err != nil { + cleanup() + return nil, err + } + + manifestCopy := *cand.Manifest + currentEncryption := strings.ToLower(strings.TrimSpace(manifestCopy.EncryptionMode)) + logger.Info("Preparing archive %s for decryption (mode: %s)", manifestCopy.ArchivePath, statusFromManifest(&manifestCopy)) + + plainArchivePath, err := resolvePreparedArchivePath(workDir, staged.ArchivePath, currentEncryption) + if err != nil { + cleanup() + return nil, err + } + + if currentEncryption == "age" { + if decryptArchive == nil { + cleanup() + return nil, fmt.Errorf("decrypt function not available") + } + displayName := cand.DisplayBase + if strings.TrimSpace(displayName) == "" { + displayName = filepath.Base(manifestCopy.ArchivePath) + } + if err := decryptArchive(ctx, staged.ArchivePath, plainArchivePath, displayName); err != nil { + cleanup() + return nil, err + } + } else if staged.ArchivePath != plainArchivePath { + if err := copyFile(restoreFS, staged.ArchivePath, plainArchivePath); err != nil { + cleanup() + return nil, fmt.Errorf("copy archive: %w", err) + } + } + + archiveInfo, err := restoreFS.Stat(plainArchivePath) + if err != nil { + cleanup() + return nil, fmt.Errorf("stat decrypted archive: %w", err) + } + + plainChecksum, err := backup.GenerateChecksum(ctx, logger, plainArchivePath) + if err != nil { + cleanup() + return nil, fmt.Errorf("generate checksum: %w", err) + } + + manifestCopy.ArchivePath = plainArchivePath + manifestCopy.ArchiveSize = archiveInfo.Size() + manifestCopy.SHA256 = plainChecksum + manifestCopy.EncryptionMode = "none" + if version != "" { + manifestCopy.ScriptVersion = version + } + + return &preparedBundle{ + ArchivePath: plainArchivePath, + Manifest: manifestCopy, + Checksum: plainChecksum, + SourceChecksum: sourceChecksum, + cleanup: cleanup, + }, nil +} diff --git a/internal/orchestrator/decrypt_test.go b/internal/orchestrator/decrypt_test.go index 84fcee1..4c795df 100644 --- a/internal/orchestrator/decrypt_test.go +++ b/internal/orchestrator/decrypt_test.go @@ -1575,14 +1575,15 @@ func TestPreparePlainBundle_SourceBundleSuccess(t *testing.T) { dir := t.TempDir() // Create bundle with required files + archiveData := []byte("archive data") manifestData, _ := json.Marshal(&backup.Manifest{ ArchivePath: filepath.Join(dir, "archive.tar.xz"), EncryptionMode: "none", }) bundlePath := createTestBundle(t, []bundleEntry{ - {name: "archive.tar.xz", data: []byte("archive data")}, + {name: "archive.tar.xz", data: archiveData}, {name: "backup.metadata", data: manifestData}, - {name: "backup.sha256", data: []byte("abc123 archive.tar.xz")}, + {name: "backup.sha256", data: checksumLineForBytes("archive.tar.xz", archiveData)}, }) cand := &decryptCandidate{ @@ -2819,7 +2820,7 @@ func TestPreparePlainBundle_CopyFileSamePath(t *testing.T) { t.Fatalf("write metadata: %v", err) } checksumPath := archivePath + ".sha256" - if err := os.WriteFile(checksumPath, []byte("abc123 backup.tar.xz"), 0o644); err != nil { + if err := os.WriteFile(checksumPath, checksumLineForBytes("backup.tar.xz", []byte("archive content")), 0o644); err != nil { t.Fatalf("write checksum: %v", err) } @@ -2890,7 +2891,7 @@ func TestPreparePlainBundle_AgeDecryptionWithRclone(t *testing.T) { tw.Write(manifestData) // Add checksum - checksumData := []byte("abc123 backup.tar.xz.age") + checksumData := checksumLineForBytes("backup.tar.xz.age", archiveContent) tw.WriteHeader(&tar.Header{Name: "backup.sha256", Size: int64(len(checksumData)), Mode: 0o600}) tw.Write(checksumData) @@ -2939,6 +2940,196 @@ esac } } +func TestPreparePlainBundleCommon_TrimmedAgeEncryptionTriggersDecrypt(t *testing.T) { + origFS := restoreFS + restoreFS = osFS{} + t.Cleanup(func() { restoreFS = origFS }) + + dir := t.TempDir() + workArchive := filepath.Join(dir, "backup.tar.xz.age") + if err := os.WriteFile(workArchive, []byte("ciphertext"), 0o600); err != nil { + t.Fatalf("write archive: %v", err) + } + manifestPath := filepath.Join(dir, "backup.metadata") + if err := os.WriteFile(manifestPath, []byte(`{"encryption_mode":" age "}`), 0o600); err != nil { + t.Fatalf("write manifest: %v", err) + } + checksumPath := filepath.Join(dir, "backup.sha256") + checksumLine := checksumLineForBytes(filepath.Base(workArchive), []byte("ciphertext")) + if err := os.WriteFile(checksumPath, []byte(checksumLine), 0o600); err != nil { + t.Fatalf("write checksum: %v", err) + } + + cand := &decryptCandidate{ + Manifest: &backup.Manifest{ + ArchivePath: workArchive, + EncryptionMode: " age ", + }, + Source: sourceRaw, + RawArchivePath: workArchive, + RawMetadataPath: manifestPath, + RawChecksumPath: checksumPath, + DisplayBase: "backup.tar.xz.age", + } + + logger := logging.New(types.LogLevelError, false) + logger.SetOutput(io.Discard) + + decryptCalled := false + prepared, err := preparePlainBundleCommon(context.Background(), cand, "1.0.0", logger, func(ctx context.Context, encryptedPath, outputPath, displayName string) error { + decryptCalled = true + if encryptedPath == outputPath { + t.Fatalf("decrypt callback received identical input/output path %q", encryptedPath) + } + if displayName != "backup.tar.xz.age" { + t.Fatalf("displayName=%q; want %q", displayName, "backup.tar.xz.age") + } + return os.WriteFile(outputPath, []byte("plaintext"), 0o600) + }) + if err != nil { + t.Fatalf("preparePlainBundleCommon error: %v", err) + } + defer prepared.Cleanup() + + if !decryptCalled { + t.Fatal("expected decrypt callback to be invoked for trimmed age encryption mode") + } + if prepared.Manifest.EncryptionMode != "none" { + t.Fatalf("prepared manifest EncryptionMode=%q; want %q", prepared.Manifest.EncryptionMode, "none") + } + data, err := os.ReadFile(prepared.ArchivePath) + if err != nil { + t.Fatalf("read prepared archive: %v", err) + } + if string(data) != "plaintext" { + t.Fatalf("prepared archive content=%q; want %q", string(data), "plaintext") + } +} + +func TestPreparePlainBundleCommon_AgeModeRequiresAgeSuffix(t *testing.T) { + origFS := restoreFS + restoreFS = osFS{} + t.Cleanup(func() { restoreFS = origFS }) + + dir := t.TempDir() + workArchive := filepath.Join(dir, "backup.tar.xz") + if err := os.WriteFile(workArchive, []byte("ciphertext"), 0o600); err != nil { + t.Fatalf("write archive: %v", err) + } + manifestPath := filepath.Join(dir, "backup.metadata") + if err := os.WriteFile(manifestPath, []byte(`{"encryption_mode":"age"}`), 0o600); err != nil { + t.Fatalf("write manifest: %v", err) + } + checksumPath := filepath.Join(dir, "backup.sha256") + checksumLine := checksumLineForBytes(filepath.Base(workArchive), []byte("ciphertext")) + if err := os.WriteFile(checksumPath, []byte(checksumLine), 0o600); err != nil { + t.Fatalf("write checksum: %v", err) + } + + cand := &decryptCandidate{ + Manifest: &backup.Manifest{ + ArchivePath: workArchive, + EncryptionMode: "age", + }, + Source: sourceRaw, + RawArchivePath: workArchive, + RawMetadataPath: manifestPath, + RawChecksumPath: checksumPath, + DisplayBase: "backup.tar.xz", + } + + logger := logging.New(types.LogLevelError, false) + logger.SetOutput(io.Discard) + + decryptCalled := false + _, err := preparePlainBundleCommon(context.Background(), cand, "1.0.0", logger, func(ctx context.Context, encryptedPath, outputPath, displayName string) error { + decryptCalled = true + return nil + }) + if err == nil { + t.Fatal("preparePlainBundleCommon error = nil; want missing .age suffix error") + } + if !strings.Contains(err.Error(), "missing .age suffix") { + t.Fatalf("preparePlainBundleCommon error = %v; want missing .age suffix error", err) + } + if decryptCalled { + t.Fatal("decrypt callback was called for archive without .age suffix") + } +} + +func TestPreparePlainBundleCommon_NonAgeRejectsAgeSuffix(t *testing.T) { + origFS := restoreFS + restoreFS = osFS{} + t.Cleanup(func() { restoreFS = origFS }) + + dir := t.TempDir() + workArchive := filepath.Join(dir, "backup.tar.xz.age") + if err := os.WriteFile(workArchive, []byte("ciphertext"), 0o600); err != nil { + t.Fatalf("write archive: %v", err) + } + manifestPath := filepath.Join(dir, "backup.metadata") + if err := os.WriteFile(manifestPath, []byte(`{"encryption_mode":"none"}`), 0o600); err != nil { + t.Fatalf("write manifest: %v", err) + } + checksumPath := filepath.Join(dir, "backup.sha256") + checksumLine := checksumLineForBytes(filepath.Base(workArchive), []byte("ciphertext")) + if err := os.WriteFile(checksumPath, []byte(checksumLine), 0o600); err != nil { + t.Fatalf("write checksum: %v", err) + } + + cand := &decryptCandidate{ + Manifest: &backup.Manifest{ + ArchivePath: workArchive, + EncryptionMode: "none", + }, + Source: sourceRaw, + RawArchivePath: workArchive, + RawMetadataPath: manifestPath, + RawChecksumPath: checksumPath, + DisplayBase: "backup.tar.xz.age", + } + + logger := logging.New(types.LogLevelError, false) + logger.SetOutput(io.Discard) + + _, err := preparePlainBundleCommon(context.Background(), cand, "1.0.0", logger, func(ctx context.Context, encryptedPath, outputPath, displayName string) error { + t.Fatal("decrypt callback should not be called for non-age archive handling") + return nil + }) + if err == nil { + t.Fatal("preparePlainBundleCommon error = nil; want .age suffix mismatch error") + } + if !strings.Contains(err.Error(), "has .age suffix but encryption mode is none") { + t.Fatalf("preparePlainBundleCommon error = %v; want .age suffix mismatch error", err) + } +} + +func TestResolvePreparedArchivePath_AgeFallbackUsesUniqueOutput(t *testing.T) { + origFS := restoreFS + restoreFS = osFS{} + t.Cleanup(func() { restoreFS = origFS }) + + workDir := t.TempDir() + stagedArchivePath := filepath.Join(workDir, ".age") + + got, err := resolvePreparedArchivePath(workDir, stagedArchivePath, "age") + if err != nil { + t.Fatalf("resolvePreparedArchivePath error: %v", err) + } + if got == stagedArchivePath { + t.Fatalf("resolvePreparedArchivePath() = %q; want unique output path", got) + } + if got == workDir { + t.Fatalf("resolvePreparedArchivePath() = %q; want file path inside workdir", got) + } + if filepath.Dir(got) != workDir { + t.Fatalf("resolvePreparedArchivePath() dir = %q; want %q", filepath.Dir(got), workDir) + } + if !strings.HasPrefix(filepath.Base(got), ".age.decrypted-") { + t.Fatalf("resolvePreparedArchivePath() base = %q; want .age.decrypted-*", filepath.Base(got)) + } +} + // ===================================== // extractBundleToWorkdirWithLogger coverage tests // ===================================== @@ -3017,7 +3208,7 @@ func TestPreparePlainBundle_SourceBundleAdditional(t *testing.T) { tw.WriteHeader(&tar.Header{Name: "backup.metadata", Size: int64(len(manifestData)), Mode: 0o600}) tw.Write(manifestData) - checksumData := []byte("abc123 backup.tar.xz") + checksumData := checksumLineForBytes("backup.tar.xz", archiveData) tw.WriteHeader(&tar.Header{Name: "backup.sha256", Size: int64(len(checksumData)), Mode: 0o600}) tw.Write(checksumData) @@ -3393,7 +3584,7 @@ func TestExtractBundleToWorkdir_OpenFileErrorOnExtract(t *testing.T) { } // Add checksum - checksum := []byte("checksum backup.tar.xz\n") + checksum := checksumLineForBytes("backup.tar.xz", archiveData) if err := tw.WriteHeader(&tar.Header{Name: "backup.sha256", Size: int64(len(checksum)), Mode: 0o640}); err != nil { t.Fatalf("write checksum header: %v", err) } @@ -3608,7 +3799,7 @@ func TestSelectDecryptCandidate_RequireEncryptedAllPlain(t *testing.T) { tw.Write(metaJSON) // Add checksum - checksum := []byte("abc123 backup.tar.xz\n") + checksum := checksumLineForBytes("backup.tar.xz", archiveData) tw.WriteHeader(&tar.Header{Name: "backup.sha256", Size: int64(len(checksum)), Mode: 0o640}) tw.Write(checksum) tw.Close() @@ -3720,7 +3911,7 @@ exit 1 tw.WriteHeader(&tar.Header{Name: "backup.metadata", Size: int64(len(metaJSON)), Mode: 0o640}) tw.Write(metaJSON) - checksum := []byte("abc123 backup.tar.xz\n") + checksum := checksumLineForBytes("backup.tar.xz", archiveData) tw.WriteHeader(&tar.Header{Name: "backup.sha256", Size: int64(len(checksum)), Mode: 0o640}) tw.Write(checksum) tw.Close() @@ -3771,7 +3962,7 @@ func TestPreparePlainBundle_StatErrorAfterExtract(t *testing.T) { tw.WriteHeader(&tar.Header{Name: "backup.metadata", Size: int64(len(metaJSON)), Mode: 0o640}) tw.Write(metaJSON) - checksum := []byte("abc123 backup.tar.xz\n") + checksum := checksumLineForBytes("backup.tar.xz", archiveData) tw.WriteHeader(&tar.Header{Name: "backup.sha256", Size: int64(len(checksum)), Mode: 0o640}) tw.Write(checksum) tw.Close() @@ -3869,8 +4060,9 @@ func TestPreparePlainBundle_MkdirTempErrorWithRcloneCleanup(t *testing.T) { metaJSON, _ := json.Marshal(backup.Manifest{EncryptionMode: "none", ArchivePath: "backup.tar.xz"}) tw.WriteHeader(&tar.Header{Name: "backup.metadata", Size: int64(len(metaJSON)), Mode: 0o640}) tw.Write(metaJSON) - tw.WriteHeader(&tar.Header{Name: "backup.sha256", Size: 5, Mode: 0o640}) - tw.Write([]byte("hash\n")) + checksum := checksumLineForBytes("backup.tar.xz", archiveData) + tw.WriteHeader(&tar.Header{Name: "backup.sha256", Size: int64(len(checksum)), Mode: 0o640}) + tw.Write(checksum) tw.Close() bundleFile.Close() @@ -4066,7 +4258,7 @@ func TestPreparePlainBundle_CopyFileError(t *testing.T) { tw.WriteHeader(&tar.Header{Name: "backup.metadata", Size: int64(len(metaJSON)), Mode: 0o640}) tw.Write(metaJSON) - checksum := []byte("abc123 backup.tar.xz\n") + checksum := checksumLineForBytes("backup.tar.xz", archiveData) tw.WriteHeader(&tar.Header{Name: "backup.sha256", Size: int64(len(checksum)), Mode: 0o640}) tw.Write(checksum) tw.Close() @@ -4175,7 +4367,7 @@ func TestPreparePlainBundle_StatErrorOnPlainArchive(t *testing.T) { tw.WriteHeader(&tar.Header{Name: "backup.metadata", Size: int64(len(metaJSON)), Mode: 0o640}) tw.Write(metaJSON) - checksum := []byte("abc123 backup.tar.xz\n") + checksum := checksumLineForBytes("backup.tar.xz", archiveData) tw.WriteHeader(&tar.Header{Name: "backup.sha256", Size: int64(len(checksum)), Mode: 0o640}) tw.Write(checksum) tw.Close() @@ -4311,7 +4503,7 @@ func TestPreparePlainBundle_GenerateChecksumErrorPath(t *testing.T) { tw.WriteHeader(&tar.Header{Name: "backup.metadata", Size: int64(len(metaJSON)), Mode: 0o640}) tw.Write(metaJSON) - checksum := []byte("abc123 backup.tar.xz\n") + checksum := checksumLineForBytes("backup.tar.xz", archiveData) tw.WriteHeader(&tar.Header{Name: "backup.sha256", Size: int64(len(checksum)), Mode: 0o640}) tw.Write(checksum) tw.Close() diff --git a/internal/orchestrator/decrypt_tui.go b/internal/orchestrator/decrypt_tui.go index 655f7f5..7602d8a 100644 --- a/internal/orchestrator/decrypt_tui.go +++ b/internal/orchestrator/decrypt_tui.go @@ -233,109 +233,9 @@ func promptNewPathInput(defaultPath, configPath, buildSig string) (string, error } func preparePlainBundleTUI(ctx context.Context, cand *decryptCandidate, version string, logger *logging.Logger, configPath, buildSig string) (*preparedBundle, error) { - if cand == nil || cand.Manifest == nil { - return nil, fmt.Errorf("invalid backup candidate") - } - - // If this is an rclone-backed bundle, download it first into the local temp area. - var rcloneCleanup func() - if cand.IsRclone && cand.Source == sourceBundle { - logger.Debug("Detected rclone backup, downloading for TUI workflow...") - localPath, cleanupFn, err := downloadRcloneBackup(ctx, cand.BundlePath, logger) - if err != nil { - return nil, fmt.Errorf("failed to download rclone backup: %w", err) - } - rcloneCleanup = cleanupFn - cand.BundlePath = localPath - } - - tempRoot := filepath.Join("/tmp", "proxsave") - if err := restoreFS.MkdirAll(tempRoot, 0o755); err != nil { - if rcloneCleanup != nil { - rcloneCleanup() - } - return nil, fmt.Errorf("create temp root: %w", err) - } - workDir, err := restoreFS.MkdirTemp(tempRoot, "proxmox-decrypt-*") - if err != nil { - if rcloneCleanup != nil { - rcloneCleanup() - } - return nil, fmt.Errorf("create temp dir: %w", err) - } - cleanup := func() { - _ = restoreFS.RemoveAll(workDir) - if rcloneCleanup != nil { - rcloneCleanup() - } - } - - var staged stagedFiles - switch cand.Source { - case sourceBundle: - logger.Debug("Extracting bundle %s", filepath.Base(cand.BundlePath)) - staged, err = extractBundleToWorkdirWithLogger(cand.BundlePath, workDir, logger) - case sourceRaw: - logger.Debug("Staging raw artifacts for %s", filepath.Base(cand.RawArchivePath)) - staged, err = copyRawArtifactsToWorkdirWithLogger(ctx, cand, workDir, logger) - default: - err = fmt.Errorf("unsupported candidate source") - } - if err != nil { - cleanup() - return nil, err - } - - manifestCopy := *cand.Manifest - currentEncryption := strings.ToLower(manifestCopy.EncryptionMode) - - logger.Debug("Preparing archive %s for decryption (mode: %s)", filepath.Base(manifestCopy.ArchivePath), statusFromManifest(&manifestCopy)) - - plainArchiveName := strings.TrimSuffix(filepath.Base(staged.ArchivePath), ".age") - plainArchivePath := filepath.Join(workDir, plainArchiveName) - - if currentEncryption == "age" { - displayName := cand.DisplayBase - if displayName == "" { - displayName = filepath.Base(manifestCopy.ArchivePath) - } - if err := decryptArchiveWithTUIPrompts(ctx, staged.ArchivePath, plainArchivePath, displayName, configPath, buildSig, logger); err != nil { - cleanup() - return nil, err - } - } else if staged.ArchivePath != plainArchivePath { - if err := copyFile(restoreFS, staged.ArchivePath, plainArchivePath); err != nil { - cleanup() - return nil, fmt.Errorf("copy archive: %w", err) - } - } - - archiveInfo, err := restoreFS.Stat(plainArchivePath) - if err != nil { - cleanup() - return nil, fmt.Errorf("stat decrypted archive: %w", err) - } - - checksum, err := backup.GenerateChecksum(ctx, logger, plainArchivePath) - if err != nil { - cleanup() - return nil, fmt.Errorf("generate checksum: %w", err) - } - - manifestCopy.ArchivePath = plainArchivePath - manifestCopy.ArchiveSize = archiveInfo.Size() - manifestCopy.SHA256 = checksum - manifestCopy.EncryptionMode = "none" - if version != "" { - manifestCopy.ScriptVersion = version - } - - return &preparedBundle{ - ArchivePath: plainArchivePath, - Manifest: manifestCopy, - Checksum: checksum, - cleanup: cleanup, - }, nil + return preparePlainBundleCommon(ctx, cand, version, logger, func(ctx context.Context, encryptedPath, outputPath, displayName string) error { + return decryptArchiveWithTUIPrompts(ctx, encryptedPath, outputPath, displayName, configPath, buildSig, logger) + }) } func decryptArchiveWithTUIPrompts(ctx context.Context, encryptedPath, outputPath, displayName, configPath, buildSig string, logger *logging.Logger) error { diff --git a/internal/orchestrator/decrypt_tui_test.go b/internal/orchestrator/decrypt_tui_test.go index 6404b76..f08f9a0 100644 --- a/internal/orchestrator/decrypt_tui_test.go +++ b/internal/orchestrator/decrypt_tui_test.go @@ -204,7 +204,7 @@ func TestPreparePlainBundleTUICopiesRawArtifacts(t *testing.T) { if err := os.WriteFile(rawMetadata, []byte(`{"manifest":true}`), 0o640); err != nil { t.Fatalf("write metadata: %v", err) } - if err := os.WriteFile(rawChecksum, []byte("checksum backup.tar\n"), 0o640); err != nil { + if err := os.WriteFile(rawChecksum, checksumLineForBytes("backup.tar", []byte("payload-data")), 0o640); err != nil { t.Fatalf("write checksum: %v", err) } diff --git a/internal/orchestrator/decrypt_workflow_test.go b/internal/orchestrator/decrypt_workflow_test.go index 3f68fa4..e145b7d 100644 --- a/internal/orchestrator/decrypt_workflow_test.go +++ b/internal/orchestrator/decrypt_workflow_test.go @@ -18,7 +18,8 @@ import ( func writeRawBackup(t *testing.T, dir, name string) *backup.Manifest { t.Helper() archive := filepath.Join(dir, name) - if err := os.WriteFile(archive, []byte("data"), 0o640); err != nil { + archiveData := []byte("data") + if err := os.WriteFile(archive, archiveData, 0o640); err != nil { t.Fatalf("write archive: %v", err) } manifest := &backup.Manifest{ @@ -31,7 +32,7 @@ func writeRawBackup(t *testing.T, dir, name string) *backup.Manifest { if err := os.WriteFile(manifestPath, data, 0o640); err != nil { t.Fatalf("write metadata: %v", err) } - if err := os.WriteFile(archive+".sha256", []byte("checksum file"), 0o640); err != nil { + if err := os.WriteFile(archive+".sha256", checksumLineForBytes(filepath.Base(archive), archiveData), 0o640); err != nil { t.Fatalf("write checksum: %v", err) } return manifest @@ -70,21 +71,23 @@ func TestRunDecryptWorkflow_BundleNotFound(t *testing.T) { func TestPreparePlainBundle_AllowsMissingRawChecksumSidecar(t *testing.T) { dir := t.TempDir() archive := filepath.Join(dir, "bad.bundle.tar") - if err := os.WriteFile(archive, []byte("data"), 0o640); err != nil { + archiveData := []byte("data") + if err := os.WriteFile(archive, archiveData, 0o640); err != nil { t.Fatalf("write archive: %v", err) } manifest := &backup.Manifest{ ArchivePath: archive, CreatedAt: time.Now(), Hostname: "host", + SHA256: checksumHexForBytes(archiveData), } metaPath := archive + ".metadata" data, _ := json.Marshal(manifest) if err := os.WriteFile(metaPath, data, 0o640); err != nil { t.Fatalf("write metadata: %v", err) } - // No checksum file: ProxSave should still allow restore/decrypt to proceed - // (it re-computes checksums on the staged/plain archive anyway). + // No checksum sidecar: restore/decrypt should still proceed when the manifest + // already carries the expected archive checksum. cand := &decryptCandidate{ Manifest: manifest, @@ -100,7 +103,7 @@ func TestPreparePlainBundle_AllowsMissingRawChecksumSidecar(t *testing.T) { t.Cleanup(func() { restoreFS = osFS{} }) if _, err := preparePlainBundle(context.Background(), reader, cand, "", logging.New(types.LogLevelInfo, false)); err != nil { - t.Fatalf("expected missing checksum to be tolerated, got error: %v", err) + t.Fatalf("expected manifest checksum to cover missing sidecar, got error: %v", err) } } diff --git a/internal/orchestrator/decrypt_workflow_ui.go b/internal/orchestrator/decrypt_workflow_ui.go index 2ae37d7..7de3f69 100644 --- a/internal/orchestrator/decrypt_workflow_ui.go +++ b/internal/orchestrator/decrypt_workflow_ui.go @@ -179,113 +179,9 @@ func preparePlainBundleWithUI(ctx context.Context, cand *decryptCandidate, versi }) (bundle *preparedBundle, err error) { done := logging.DebugStart(logger, "prepare plain bundle (ui)", "source=%v rclone=%v", cand.Source, cand.IsRclone) defer func() { done(err) }() - - if cand == nil || cand.Manifest == nil { - return nil, fmt.Errorf("invalid backup candidate") - } - - var rcloneCleanup func() - if cand.IsRclone && cand.Source == sourceBundle { - logger.Debug("Detected rclone backup, downloading...") - localPath, cleanup, err := downloadRcloneBackup(ctx, cand.BundlePath, logger) - if err != nil { - return nil, fmt.Errorf("failed to download rclone backup: %w", err) - } - rcloneCleanup = cleanup - cand.BundlePath = localPath - } - - tempRoot := filepath.Join("/tmp", "proxsave") - if err := restoreFS.MkdirAll(tempRoot, 0o755); err != nil { - if rcloneCleanup != nil { - rcloneCleanup() - } - return nil, fmt.Errorf("create temp root: %w", err) - } - - workDir, err := restoreFS.MkdirTemp(tempRoot, "proxmox-decrypt-*") - if err != nil { - if rcloneCleanup != nil { - rcloneCleanup() - } - return nil, fmt.Errorf("create temp dir: %w", err) - } - - cleanup := func() { - _ = restoreFS.RemoveAll(workDir) - if rcloneCleanup != nil { - rcloneCleanup() - } - } - - var staged stagedFiles - switch cand.Source { - case sourceBundle: - logger.Info("Extracting bundle %s", filepath.Base(cand.BundlePath)) - staged, err = extractBundleToWorkdirWithLogger(cand.BundlePath, workDir, logger) - case sourceRaw: - logger.Info("Staging raw artifacts for %s", filepath.Base(cand.RawArchivePath)) - staged, err = copyRawArtifactsToWorkdirWithLogger(ctx, cand, workDir, logger) - default: - err = fmt.Errorf("unsupported candidate source") - } - if err != nil { - cleanup() - return nil, err - } - - manifestCopy := *cand.Manifest - currentEncryption := strings.ToLower(manifestCopy.EncryptionMode) - logger.Info("Preparing archive %s for decryption (mode: %s)", manifestCopy.ArchivePath, statusFromManifest(&manifestCopy)) - - plainArchiveName := strings.TrimSuffix(filepath.Base(staged.ArchivePath), ".age") - plainArchivePath := filepath.Join(workDir, plainArchiveName) - - if currentEncryption == "age" { - displayName := cand.DisplayBase - if strings.TrimSpace(displayName) == "" { - displayName = filepath.Base(manifestCopy.ArchivePath) - } - if err := decryptArchiveWithSecretPrompt(ctx, staged.ArchivePath, plainArchivePath, displayName, ui.PromptDecryptSecret); err != nil { - cleanup() - return nil, err - } - } else { - if staged.ArchivePath != plainArchivePath { - if err := copyFile(restoreFS, staged.ArchivePath, plainArchivePath); err != nil { - cleanup() - return nil, fmt.Errorf("copy archive: %w", err) - } - } - } - - archiveInfo, err := restoreFS.Stat(plainArchivePath) - if err != nil { - cleanup() - return nil, fmt.Errorf("stat decrypted archive: %w", err) - } - - checksum, err := backup.GenerateChecksum(ctx, logger, plainArchivePath) - if err != nil { - cleanup() - return nil, fmt.Errorf("generate checksum: %w", err) - } - - manifestCopy.ArchivePath = plainArchivePath - manifestCopy.ArchiveSize = archiveInfo.Size() - manifestCopy.SHA256 = checksum - manifestCopy.EncryptionMode = "none" - if version != "" { - manifestCopy.ScriptVersion = version - } - - bundle = &preparedBundle{ - ArchivePath: plainArchivePath, - Manifest: manifestCopy, - Checksum: checksum, - cleanup: cleanup, - } - return bundle, nil + return preparePlainBundleCommon(ctx, cand, version, logger, func(ctx context.Context, encryptedPath, outputPath, displayName string) error { + return decryptArchiveWithSecretPrompt(ctx, encryptedPath, outputPath, displayName, ui.PromptDecryptSecret) + }) } func runDecryptWorkflowWithUI(ctx context.Context, cfg *config.Config, logger *logging.Logger, version string, ui DecryptWorkflowUI) (err error) { diff --git a/internal/orchestrator/deps.go b/internal/orchestrator/deps.go index 648e20b..6530c30 100644 --- a/internal/orchestrator/deps.go +++ b/internal/orchestrator/deps.go @@ -16,6 +16,7 @@ import ( // FS abstracts filesystem operations to simplify testing. type FS interface { Stat(path string) (os.FileInfo, error) + Lstat(path string) (os.FileInfo, error) ReadFile(path string) ([]byte, error) Open(path string) (*os.File, error) OpenFile(path string, flag int, perm fs.FileMode) (*os.File, error) @@ -69,9 +70,10 @@ type Deps struct { type osFS struct{} -func (osFS) Stat(path string) (os.FileInfo, error) { return os.Stat(path) } -func (osFS) ReadFile(path string) ([]byte, error) { return os.ReadFile(path) } -func (osFS) Open(path string) (*os.File, error) { return os.Open(path) } +func (osFS) Stat(path string) (os.FileInfo, error) { return os.Stat(path) } +func (osFS) Lstat(path string) (os.FileInfo, error) { return os.Lstat(path) } +func (osFS) ReadFile(path string) ([]byte, error) { return os.ReadFile(path) } +func (osFS) Open(path string) (*os.File, error) { return os.Open(path) } func (osFS) OpenFile(path string, flag int, perm fs.FileMode) (*os.File, error) { return os.OpenFile(path, flag, perm) } diff --git a/internal/orchestrator/deps_test.go b/internal/orchestrator/deps_test.go index ba4d404..986051f 100644 --- a/internal/orchestrator/deps_test.go +++ b/internal/orchestrator/deps_test.go @@ -60,6 +60,16 @@ func (f *FakeFS) Stat(path string) (os.FileInfo, error) { return os.Stat(f.onDisk(path)) } +func (f *FakeFS) Lstat(path string) (os.FileInfo, error) { + if err, ok := f.StatErr[filepath.Clean(path)]; ok { + return nil, err + } + if err, ok := f.StatErrors[filepath.Clean(path)]; ok { + return nil, err + } + return os.Lstat(f.onDisk(path)) +} + func (f *FakeFS) ReadFile(path string) ([]byte, error) { return os.ReadFile(f.onDisk(path)) } diff --git a/internal/orchestrator/network_apply_additional_test.go b/internal/orchestrator/network_apply_additional_test.go new file mode 100644 index 0000000..b3e2780 --- /dev/null +++ b/internal/orchestrator/network_apply_additional_test.go @@ -0,0 +1,1791 @@ +package orchestrator + +import ( + "archive/tar" + "bufio" + "context" + "errors" + "fmt" + "io" + "io/fs" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/tis24dev/proxsave/internal/logging" + "github.com/tis24dev/proxsave/internal/types" +) + +type writeFailFS struct { + FS + failPath string + err error +} + +func (f writeFailFS) WriteFile(path string, data []byte, perm fs.FileMode) error { + if filepath.Clean(path) == filepath.Clean(f.failPath) { + return f.err + } + return f.FS.WriteFile(path, data, perm) +} + +func newDiscardLogger() *logging.Logger { + logger := logging.New(types.LogLevelDebug, false) + logger.SetOutput(io.Discard) + return logger +} + +func TestNetworkApplyNotCommittedError_ErrorAndUnwrap(t *testing.T) { + var err *NetworkApplyNotCommittedError + if got := err.Error(); got != ErrNetworkApplyNotCommitted.Error() { + t.Fatalf("Error()=%q want %q", got, ErrNetworkApplyNotCommitted.Error()) + } + if got := err.Unwrap(); !errors.Is(got, ErrNetworkApplyNotCommitted) { + t.Fatalf("Unwrap()=%v want %v", got, ErrNetworkApplyNotCommitted) + } + + err = &NetworkApplyNotCommittedError{RollbackLog: "/tmp/log"} + if got := err.Error(); got != ErrNetworkApplyNotCommitted.Error() { + t.Fatalf("Error()=%q want %q", got, ErrNetworkApplyNotCommitted.Error()) + } + if got := err.Unwrap(); got != ErrNetworkApplyNotCommitted { + t.Fatalf("Unwrap()=%v want %v", got, ErrNetworkApplyNotCommitted) + } +} + +func TestNetworkRollbackHandleRemaining(t *testing.T) { + var h *networkRollbackHandle + if got := h.remaining(time.Now()); got != 0 { + t.Fatalf("nil remaining=%s want 0", got) + } + + h = &networkRollbackHandle{ + armedAt: time.Date(2026, 2, 1, 1, 2, 3, 0, time.UTC), + timeout: 10 * time.Second, + } + if got := h.remaining(h.armedAt); got != 10*time.Second { + t.Fatalf("remaining=%s want %s", got, 10*time.Second) + } + if got := h.remaining(h.armedAt.Add(4 * time.Second)); got != 6*time.Second { + t.Fatalf("remaining=%s want %s", got, 6*time.Second) + } + if got := h.remaining(h.armedAt.Add(20 * time.Second)); got != 0 { + t.Fatalf("remaining=%s want 0", got) + } +} + +func TestShouldAttemptNetworkApply(t *testing.T) { + if shouldAttemptNetworkApply(nil) { + t.Fatalf("expected false for nil plan") + } + if shouldAttemptNetworkApply(&RestorePlan{NormalCategories: []Category{{ID: "storage_pve"}}}) { + t.Fatalf("expected false when network category not present") + } + if !shouldAttemptNetworkApply(&RestorePlan{NormalCategories: []Category{{ID: "network"}}}) { + t.Fatalf("expected true when network category present") + } +} + +func TestExtractIPFromSnapshot_EmptyArgsReturnUnknown(t *testing.T) { + if got := extractIPFromSnapshot("", "vmbr0"); got != "unknown" { + t.Fatalf("got %q want unknown", got) + } + if got := extractIPFromSnapshot("/snap.txt", ""); got != "unknown" { + t.Fatalf("got %q want unknown", got) + } +} + +func TestExtractIPFromSnapshot_IgnoresLinesOutsideAddrSection(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + snapshot := strings.Join([]string{ + "$ ip -br addr", + "lo UNKNOWN 127.0.0.1/8", + "$ ip route show", + "vmbr0 UP 192.0.2.10/24", + "", + }, "\n") + if err := fakeFS.WriteFile("/snap.txt", []byte(snapshot), 0o600); err != nil { + t.Fatalf("write snapshot: %v", err) + } + if got := extractIPFromSnapshot("/snap.txt", "vmbr0"); got != "unknown" { + t.Fatalf("got %q want unknown", got) + } +} + +func TestExtractIPFromSnapshot_SkipsInvalidTokensButReturnsFirstValid(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + snapshot := strings.Join([]string{ + "$ ip -br addr", + "vmbr0 UP not-an-ip 2a01:db8::2/64", + "", + }, "\n") + if err := fakeFS.WriteFile("/snap.txt", []byte(snapshot), 0o600); err != nil { + t.Fatalf("write snapshot: %v", err) + } + if got := extractIPFromSnapshot("/snap.txt", "vmbr0"); got != "2a01:db8::2/64" { + t.Fatalf("got %q want %q", got, "2a01:db8::2/64") + } +} + +func TestExtractIPFromSnapshot_SkipsErrorLinesAndParsesNext(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + snapshot := strings.Join([]string{ + "$ ip -br addr", + "ERROR: ip failed", + "vmbr0 UP 192.0.2.55/24", + "", + }, "\n") + if err := fakeFS.WriteFile("/snap.txt", []byte(snapshot), 0o600); err != nil { + t.Fatalf("write snapshot: %v", err) + } + if got := extractIPFromSnapshot("/snap.txt", "vmbr0"); got != "192.0.2.55/24" { + t.Fatalf("got %q want %q", got, "192.0.2.55/24") + } +} + +func TestExtractIPFromSnapshot_ReturnsUnknownWhenNoValidAddressTokens(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + snapshot := strings.Join([]string{ + "$ ip -br addr", + "vmbr0 UP not-an-ip also-bad", + "", + }, "\n") + if err := fakeFS.WriteFile("/snap.txt", []byte(snapshot), 0o600); err != nil { + t.Fatalf("write snapshot: %v", err) + } + if got := extractIPFromSnapshot("/snap.txt", "vmbr0"); got != "unknown" { + t.Fatalf("got %q want unknown", got) + } +} + +func TestExtractIPFromSnapshot_IgnoresCommandsBeforeAddrSection(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + snapshot := strings.Join([]string{ + "$ ip route show", + "default via 192.0.2.1 dev vmbr0", + "$ ip -br addr", + "vmbr0 UP 192.0.2.99/24", + "", + }, "\n") + if err := fakeFS.WriteFile("/snap.txt", []byte(snapshot), 0o600); err != nil { + t.Fatalf("write snapshot: %v", err) + } + if got := extractIPFromSnapshot("/snap.txt", "vmbr0"); got != "192.0.2.99/24" { + t.Fatalf("got %q want %q", got, "192.0.2.99/24") + } +} + +func TestBuildNetworkApplyNotCommittedError_HandleNilIfaceEmpty(t *testing.T) { + origFS := restoreFS + origCmd := restoreCmd + t.Cleanup(func() { + restoreFS = origFS + restoreCmd = origCmd + }) + + restoreFS = NewFakeFS() + restoreCmd = &FakeCommandRunner{} + + logger := newDiscardLogger() + + got := buildNetworkApplyNotCommittedError(context.Background(), logger, "", nil) + if got.RollbackArmed { + t.Fatalf("RollbackArmed=true want false") + } + if got.RollbackLog != "" || got.RollbackMarker != "" { + t.Fatalf("expected empty rollback paths, got log=%q marker=%q", got.RollbackLog, got.RollbackMarker) + } + if got.RestoredIP != "unknown" || got.OriginalIP != "unknown" { + t.Fatalf("restored=%q original=%q want unknown/unknown", got.RestoredIP, got.OriginalIP) + } + if !got.RollbackDeadline.IsZero() { + t.Fatalf("RollbackDeadline=%s want zero", got.RollbackDeadline) + } +} + +func TestBuildNetworkApplyNotCommittedError_ArmedWithMarkerAndSnapshots(t *testing.T) { + origFS := restoreFS + origCmd := restoreCmd + t.Cleanup(func() { + restoreFS = origFS + restoreCmd = origCmd + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + fakeCmd := &FakeCommandRunner{ + Outputs: map[string][]byte{ + "ip -o addr show dev vmbr0 scope global": []byte(strings.Join([]string{ + "2: vmbr0 inet 192.0.2.10/24 brd 192.0.2.255 scope global vmbr0", + "2: vmbr0 inet6 2001:db8::1/64 scope global", + }, "\n")), + "ip route show default": []byte("default via 192.0.2.1 dev vmbr0\n"), + }, + } + restoreCmd = fakeCmd + + logger := newDiscardLogger() + + handle := &networkRollbackHandle{ + workDir: "/work", + markerPath: "/work/marker", + logPath: "/work/log", + armedAt: time.Date(2026, 2, 1, 12, 0, 0, 0, time.UTC), + timeout: 3 * time.Minute, + } + + if err := fakeFS.WriteFile(handle.markerPath, []byte("pending\n"), 0o600); err != nil { + t.Fatalf("write marker: %v", err) + } + beforeSnapshot := strings.Join([]string{ + "$ ip -br addr", + "vmbr0 UP 10.0.0.2/24", + "", + }, "\n") + if err := fakeFS.WriteFile("/work/before.txt", []byte(beforeSnapshot), 0o600); err != nil { + t.Fatalf("write before snapshot: %v", err) + } + + got := buildNetworkApplyNotCommittedError(context.Background(), logger, "vmbr0", handle) + if !got.RollbackArmed { + t.Fatalf("RollbackArmed=false want true") + } + if got.RollbackLog != "/work/log" || got.RollbackMarker != "/work/marker" { + t.Fatalf("paths log=%q marker=%q want /work/log /work/marker", got.RollbackLog, got.RollbackMarker) + } + if got.RestoredIP != "192.0.2.10/24, 2001:db8::1/64" { + t.Fatalf("RestoredIP=%q", got.RestoredIP) + } + if got.OriginalIP != "10.0.0.2/24" { + t.Fatalf("OriginalIP=%q want %q", got.OriginalIP, "10.0.0.2/24") + } + wantDeadline := handle.armedAt.Add(handle.timeout) + if !got.RollbackDeadline.Equal(wantDeadline) { + t.Fatalf("RollbackDeadline=%s want %s", got.RollbackDeadline, wantDeadline) + } +} + +func TestBuildNetworkApplyNotCommittedError_MarkerMissingAndIPQueryFails(t *testing.T) { + origFS := restoreFS + origCmd := restoreCmd + t.Cleanup(func() { + restoreFS = origFS + restoreCmd = origCmd + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + restoreCmd = &FakeCommandRunner{ + Errors: map[string]error{ + "ip -o addr show dev vmbr0 scope global": errors.New("boom"), + }, + } + + logger := newDiscardLogger() + handle := &networkRollbackHandle{ + workDir: "/work", + markerPath: "/work/missing", + armedAt: time.Date(2026, 2, 1, 12, 0, 0, 0, time.UTC), + timeout: 1 * time.Minute, + } + + got := buildNetworkApplyNotCommittedError(context.Background(), logger, "vmbr0", handle) + if got.RollbackArmed { + t.Fatalf("RollbackArmed=true want false when marker missing") + } + if got.RestoredIP != "unknown" { + t.Fatalf("RestoredIP=%q want unknown", got.RestoredIP) + } + if got.OriginalIP != "unknown" { + t.Fatalf("OriginalIP=%q want unknown", got.OriginalIP) + } +} + +func TestRollbackAlreadyRunning(t *testing.T) { + origCmd := restoreCmd + t.Cleanup(func() { restoreCmd = origCmd }) + + logger := newDiscardLogger() + pathDir := t.TempDir() + + t.Run("skip when handle nil", func(t *testing.T) { + t.Setenv("PATH", pathDir) + restoreCmd = &FakeCommandRunner{} + if rollbackAlreadyRunning(context.Background(), logger, nil) { + t.Fatalf("expected false for nil handle") + } + }) + + t.Run("skip when unit name empty", func(t *testing.T) { + t.Setenv("PATH", pathDir) + restoreCmd = &FakeCommandRunner{} + if rollbackAlreadyRunning(context.Background(), logger, &networkRollbackHandle{unitName: ""}) { + t.Fatalf("expected false when unitName empty") + } + }) + + t.Run("skip when systemctl missing", func(t *testing.T) { + t.Setenv("PATH", pathDir) + restoreCmd = &FakeCommandRunner{} + if rollbackAlreadyRunning(context.Background(), logger, &networkRollbackHandle{unitName: "x"}) { + t.Fatalf("expected false when systemctl not available") + } + }) + + t.Run("running when active or activating", func(t *testing.T) { + writeExecutable(t, pathDir, "systemctl") + t.Setenv("PATH", pathDir) + + for _, tc := range []struct { + state string + want bool + }{ + {state: "active\n", want: true}, + {state: "activating\n", want: true}, + {state: "inactive\n", want: false}, + } { + restoreCmd = &FakeCommandRunner{ + Outputs: map[string][]byte{ + "systemctl is-active unit.service": []byte(tc.state), + }, + } + if got := rollbackAlreadyRunning(context.Background(), logger, &networkRollbackHandle{unitName: "unit"}); got != tc.want { + t.Fatalf("state=%q got=%v want=%v", tc.state, got, tc.want) + } + } + }) + + t.Run("not running when systemctl errors", func(t *testing.T) { + writeExecutable(t, pathDir, "systemctl") + t.Setenv("PATH", pathDir) + restoreCmd = &FakeCommandRunner{ + Errors: map[string]error{ + "systemctl is-active unit.service": errors.New("boom"), + }, + } + if rollbackAlreadyRunning(context.Background(), logger, &networkRollbackHandle{unitName: "unit"}) { + t.Fatalf("expected false when systemctl is-active errors") + } + }) +} + +func TestArmNetworkRollback_ValidationErrors(t *testing.T) { + if _, err := armNetworkRollback(context.Background(), newDiscardLogger(), "", 10*time.Second, ""); err == nil { + t.Fatalf("expected error for empty backup path") + } + if _, err := armNetworkRollback(context.Background(), newDiscardLogger(), "/backup.tar", 0, ""); err == nil { + t.Fatalf("expected error for invalid timeout") + } +} + +func TestArmNetworkRollback_CreateRollbackDirFailureReturnsError(t *testing.T) { + origFS := restoreFS + origTime := restoreTime + t.Cleanup(func() { + restoreFS = origFS + restoreTime = origTime + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + fakeFS.MkdirAllErr = errors.New("boom") + restoreFS = fakeFS + restoreTime = &FakeTime{Current: time.Date(2026, 2, 1, 12, 34, 56, 0, time.UTC)} + + if _, err := armNetworkRollback(context.Background(), newDiscardLogger(), "/backup.tar", 30*time.Second, "/secure"); err == nil || !strings.Contains(err.Error(), "create rollback directory") { + t.Fatalf("err=%v want create rollback directory error", err) + } +} + +func TestArmNetworkRollback_MarkerWriteFailureReturnsError(t *testing.T) { + origFS := restoreFS + origTime := restoreTime + t.Cleanup(func() { + restoreFS = origFS + restoreTime = origTime + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + fakeFS.WriteErr = errors.New("boom") + restoreFS = fakeFS + restoreTime = &FakeTime{Current: time.Date(2026, 2, 1, 12, 34, 56, 0, time.UTC)} + + if _, err := armNetworkRollback(context.Background(), newDiscardLogger(), "/backup.tar", 30*time.Second, "/secure"); err == nil || !strings.Contains(err.Error(), "write rollback marker") { + t.Fatalf("err=%v want write rollback marker error", err) + } +} + +func TestArmNetworkRollback_ScriptWriteFailureReturnsError(t *testing.T) { + origFS := restoreFS + origTime := restoreTime + t.Cleanup(func() { + restoreFS = origFS + restoreTime = origTime + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreTime = &FakeTime{Current: time.Date(2026, 2, 1, 12, 34, 56, 0, time.UTC)} + + scriptPath := "/secure/network_rollback_20260201_123456.sh" + restoreFS = writeFailFS{FS: fakeFS, failPath: scriptPath, err: errors.New("boom")} + + if _, err := armNetworkRollback(context.Background(), newDiscardLogger(), "/backup.tar", 30*time.Second, "/secure"); err == nil || !strings.Contains(err.Error(), "write rollback script") { + t.Fatalf("err=%v want write rollback script error", err) + } +} + +func TestArmNetworkRollback_UsesSystemdRunWhenAvailable(t *testing.T) { + origFS := restoreFS + origCmd := restoreCmd + origTime := restoreTime + t.Cleanup(func() { + restoreFS = origFS + restoreCmd = origCmd + restoreTime = origTime + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + restoreTime = &FakeTime{Current: time.Date(2026, 2, 1, 12, 34, 56, 0, time.UTC)} + + pathDir := t.TempDir() + writeExecutable(t, pathDir, "systemd-run") + t.Setenv("PATH", pathDir) + + fakeCmd := &FakeCommandRunner{ + Outputs: map[string][]byte{}, + } + restoreCmd = fakeCmd + + logger := newDiscardLogger() + handle, err := armNetworkRollback(context.Background(), logger, "/backup.tar", 30*time.Second, "") + if err != nil { + t.Fatalf("armNetworkRollback error: %v", err) + } + if handle == nil || handle.unitName == "" { + t.Fatalf("expected handle with unitName set") + } + if !strings.Contains(handle.unitName, "20260201_123456") { + t.Fatalf("unitName=%q want timestamp", handle.unitName) + } + + data, err := fakeFS.ReadFile(handle.scriptPath) + if err != nil { + t.Fatalf("read script: %v", err) + } + if !strings.Contains(string(data), "Restart networking after rollback") { + t.Fatalf("expected restartNetworking block in script") + } + + foundSystemdRun := false + for _, call := range fakeCmd.CallsList() { + if strings.HasPrefix(call, "systemd-run --unit=") { + foundSystemdRun = true + } + if strings.HasPrefix(call, "sh -c ") { + t.Fatalf("unexpected fallback sh -c call: %s", call) + } + } + if !foundSystemdRun { + t.Fatalf("expected systemd-run to be called; calls=%v", fakeCmd.CallsList()) + } +} + +func TestArmNetworkRollback_CustomWorkDirAndSystemdRunOutput(t *testing.T) { + origFS := restoreFS + origCmd := restoreCmd + origTime := restoreTime + t.Cleanup(func() { + restoreFS = origFS + restoreCmd = origCmd + restoreTime = origTime + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + restoreTime = &FakeTime{Current: time.Date(2026, 2, 1, 12, 34, 56, 0, time.UTC)} + + pathDir := t.TempDir() + writeExecutable(t, pathDir, "systemd-run") + t.Setenv("PATH", pathDir) + + expectedSystemdRun := "systemd-run --unit=proxsave-network-rollback-20260201_123456 --on-active=2s /bin/sh /secure/network_rollback_20260201_123456.sh" + restoreCmd = &FakeCommandRunner{ + Outputs: map[string][]byte{ + expectedSystemdRun: []byte("unit started\n"), + }, + } + + handle, err := armNetworkRollback(context.Background(), newDiscardLogger(), "/backup.tar", 2*time.Second, "/secure") + if err != nil { + t.Fatalf("armNetworkRollback error: %v", err) + } + if handle == nil || handle.workDir != "/secure" { + t.Fatalf("handle=%#v want workDir=/secure", handle) + } +} + +func TestArmNetworkRollback_SystemdRunFailureFallsBackToNohup(t *testing.T) { + origFS := restoreFS + origCmd := restoreCmd + origTime := restoreTime + t.Cleanup(func() { + restoreFS = origFS + restoreCmd = origCmd + restoreTime = origTime + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + restoreTime = &FakeTime{Current: time.Date(2026, 2, 1, 12, 34, 56, 0, time.UTC)} + + pathDir := t.TempDir() + writeExecutable(t, pathDir, "systemd-run") + t.Setenv("PATH", pathDir) + + fakeCmd := &FakeCommandRunner{ + Errors: map[string]error{}, + } + restoreCmd = fakeCmd + + logger := newDiscardLogger() + expectedSystemdRun := "systemd-run --unit=proxsave-network-rollback-20260201_123456 --on-active=30s /bin/sh /tmp/proxsave/network_rollback_20260201_123456.sh" + fakeCmd.Errors[expectedSystemdRun] = errors.New("boom") + + handle, err := armNetworkRollback(context.Background(), logger, "/backup.tar", 30*time.Second, "") + if err != nil { + t.Fatalf("armNetworkRollback error: %v", err) + } + if handle == nil { + t.Fatalf("expected handle") + } + if handle.unitName != "" { + t.Fatalf("unitName=%q want empty after systemd-run failure", handle.unitName) + } + + foundSystemdRun := false + foundFallback := false + for _, call := range fakeCmd.CallsList() { + if strings.HasPrefix(call, "systemd-run ") { + foundSystemdRun = true + } + if strings.HasPrefix(call, "sh -c nohup sh -c 'sleep ") { + foundFallback = true + } + } + if !foundSystemdRun || !foundFallback { + t.Fatalf("expected both systemd-run and fallback; calls=%v", fakeCmd.CallsList()) + } +} + +func TestArmNetworkRollback_WithoutSystemdRunUsesNohup(t *testing.T) { + origFS := restoreFS + origCmd := restoreCmd + origTime := restoreTime + t.Cleanup(func() { + restoreFS = origFS + restoreCmd = origCmd + restoreTime = origTime + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + restoreTime = &FakeTime{Current: time.Date(2026, 2, 1, 12, 34, 56, 0, time.UTC)} + + pathDir := t.TempDir() + t.Setenv("PATH", pathDir) + + fakeCmd := &FakeCommandRunner{} + restoreCmd = fakeCmd + + logger := newDiscardLogger() + handle, err := armNetworkRollback(context.Background(), logger, "/backup.tar", 1*time.Second, "") + if err != nil { + t.Fatalf("armNetworkRollback error: %v", err) + } + if handle == nil { + t.Fatalf("expected handle") + } + if handle.unitName != "" { + t.Fatalf("unitName=%q want empty in nohup mode", handle.unitName) + } + + foundFallback := false + for _, call := range fakeCmd.CallsList() { + if strings.HasPrefix(call, "sh -c nohup sh -c 'sleep ") { + foundFallback = true + } + } + if !foundFallback { + t.Fatalf("expected fallback sh -c call; calls=%v", fakeCmd.CallsList()) + } +} + +func TestArmNetworkRollback_SubSecondTimeoutArmsAtLeastOneSecond(t *testing.T) { + origFS := restoreFS + origCmd := restoreCmd + origTime := restoreTime + t.Cleanup(func() { + restoreFS = origFS + restoreCmd = origCmd + restoreTime = origTime + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + restoreTime = &FakeTime{Current: time.Date(2026, 2, 1, 12, 34, 56, 0, time.UTC)} + + pathDir := t.TempDir() + t.Setenv("PATH", pathDir) + + fakeCmd := &FakeCommandRunner{} + restoreCmd = fakeCmd + + handle, err := armNetworkRollback(context.Background(), newDiscardLogger(), "/backup.tar", 500*time.Millisecond, "") + if err != nil { + t.Fatalf("armNetworkRollback error: %v", err) + } + if handle == nil { + t.Fatalf("expected handle") + } + + foundSleep1 := false + for _, call := range fakeCmd.CallsList() { + if strings.Contains(call, "sleep 1;") { + foundSleep1 = true + } + } + if !foundSleep1 { + t.Fatalf("expected sleep 1 in nohup command; calls=%v", fakeCmd.CallsList()) + } +} + +func TestArmNetworkRollback_FallbackCommandFailureReturnsError(t *testing.T) { + origFS := restoreFS + origCmd := restoreCmd + origTime := restoreTime + t.Cleanup(func() { + restoreFS = origFS + restoreCmd = origCmd + restoreTime = origTime + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + restoreTime = &FakeTime{Current: time.Date(2026, 2, 1, 12, 34, 56, 0, time.UTC)} + + pathDir := t.TempDir() + t.Setenv("PATH", pathDir) + + restoreCmd = &FakeCommandRunner{ + Errors: map[string]error{ + "sh -c nohup sh -c 'sleep 1; /bin/sh /tmp/proxsave/network_rollback_20260201_123456.sh' >/dev/null 2>&1 &": errors.New("boom"), + }, + } + + _, err := armNetworkRollback(context.Background(), newDiscardLogger(), "/backup.tar", 1*time.Second, "") + if err == nil || !strings.Contains(err.Error(), "failed to arm rollback timer") { + t.Fatalf("err=%v want failed to arm rollback timer", err) + } +} + +func TestDisarmNetworkRollback_RemovesMarkerAndStopsTimer(t *testing.T) { + origFS := restoreFS + origCmd := restoreCmd + t.Cleanup(func() { + restoreFS = origFS + restoreCmd = origCmd + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + pathDir := t.TempDir() + writeExecutable(t, pathDir, "systemctl") + t.Setenv("PATH", pathDir) + + fakeCmd := &FakeCommandRunner{} + restoreCmd = fakeCmd + + handle := &networkRollbackHandle{ + markerPath: "/tmp/marker", + unitName: "proxsave-network-rollback-test", + } + if err := fakeFS.WriteFile(handle.markerPath, []byte("pending\n"), 0o600); err != nil { + t.Fatalf("write marker: %v", err) + } + + disarmNetworkRollback(context.Background(), newDiscardLogger(), handle) + if _, err := fakeFS.Stat(handle.markerPath); !errors.Is(err, os.ErrNotExist) { + t.Fatalf("expected marker to be removed, stat err=%v", err) + } + + calls := strings.Join(fakeCmd.CallsList(), "\n") + if !strings.Contains(calls, "systemctl stop "+handle.unitName+".timer") { + t.Fatalf("expected systemctl stop call; calls=%v", fakeCmd.CallsList()) + } + if !strings.Contains(calls, "systemctl reset-failed "+handle.unitName+".service "+handle.unitName+".timer") { + t.Fatalf("expected systemctl reset-failed call; calls=%v", fakeCmd.CallsList()) + } +} + +func TestDisarmNetworkRollback_NilHandleNoop(t *testing.T) { + disarmNetworkRollback(context.Background(), newDiscardLogger(), nil) +} + +func TestDisarmNetworkRollback_MarkerRemoveFailureAndStopError(t *testing.T) { + origFS := restoreFS + origCmd := restoreCmd + t.Cleanup(func() { + restoreFS = origFS + restoreCmd = origCmd + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + + pathDir := t.TempDir() + writeExecutable(t, pathDir, "systemctl") + t.Setenv("PATH", pathDir) + + fakeCmd := &FakeCommandRunner{ + Errors: map[string]error{ + "systemctl stop unit.timer": errors.New("boom"), + }, + } + restoreCmd = fakeCmd + + handle := &networkRollbackHandle{ + markerPath: "/tmp/marker", + unitName: "unit", + } + if err := fakeFS.WriteFile(handle.markerPath, []byte("pending\n"), 0o600); err != nil { + t.Fatalf("write marker: %v", err) + } + + restoreFS = removeFailFS{FS: fakeFS, failPath: handle.markerPath, err: errors.New("remove boom")} + disarmNetworkRollback(context.Background(), newDiscardLogger(), handle) + + calls := strings.Join(fakeCmd.CallsList(), "\n") + if !strings.Contains(calls, "systemctl stop unit.timer") { + t.Fatalf("expected systemctl stop call; calls=%v", fakeCmd.CallsList()) + } + if !strings.Contains(calls, "systemctl reset-failed unit.service unit.timer") { + t.Fatalf("expected systemctl reset-failed call; calls=%v", fakeCmd.CallsList()) + } +} + +func TestMaybeRepairNICNamesCLI_SkippedWhenArchiveMissing(t *testing.T) { + origTime := restoreTime + t.Cleanup(func() { restoreTime = origTime }) + restoreTime = &FakeTime{Current: time.Date(2026, 2, 2, 3, 4, 5, 0, time.UTC)} + + reader := bufio.NewReader(strings.NewReader("")) + result := maybeRepairNICNamesCLI(context.Background(), reader, newDiscardLogger(), "") + if result == nil { + t.Fatalf("expected result") + } + if !strings.Contains(result.SkippedReason, "backup archive not available") { + t.Fatalf("SkippedReason=%q", result.SkippedReason) + } + if !result.AppliedAt.Equal(nowRestore()) { + t.Fatalf("AppliedAt=%s want %s", result.AppliedAt, nowRestore()) + } +} + +func TestMaybeRepairNICNamesCLI_ReturnsNilOnPlanError(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + if err := fakeFS.WriteFile("/backup.zip", []byte("not a tar"), 0o600); err != nil { + t.Fatalf("write archive: %v", err) + } + + reader := bufio.NewReader(strings.NewReader("")) + if got := maybeRepairNICNamesCLI(context.Background(), reader, newDiscardLogger(), "/backup.zip"); got != nil { + t.Fatalf("expected nil on plan error, got %#v", got) + } +} + +func TestMaybeRepairNICNamesCLI_AppliesMappingWithoutConflicts(t *testing.T) { + origFS := restoreFS + origTime := restoreTime + origSysNet := sysClassNetPath + t.Cleanup(func() { + restoreFS = origFS + restoreTime = origTime + sysClassNetPath = origSysNet + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + restoreTime = &FakeTime{Current: time.Date(2026, 2, 2, 3, 4, 5, 0, time.UTC)} + + pathDir := t.TempDir() + t.Setenv("PATH", pathDir) + + sysDir := t.TempDir() + sysClassNetPath = sysDir + if err := os.MkdirAll(filepath.Join(sysDir, "enp3s0"), 0o755); err != nil { + t.Fatalf("mkdir enp3s0: %v", err) + } + if err := os.WriteFile(filepath.Join(sysDir, "enp3s0", "address"), []byte("aa:bb:cc:dd:ee:ff\n"), 0o644); err != nil { + t.Fatalf("write address: %v", err) + } + + invJSON := []byte(`{"interfaces":[{"name":"eno1","mac":"aa:bb:cc:dd:ee:ff"}]}`) + writeTarToFakeFS(t, fakeFS, "/backup.tar", []tarEntry{ + { + Name: "./var/lib/proxsave-info/commands/system/network_inventory.json", + Typeflag: tar.TypeReg, + Mode: 0o644, + Data: invJSON, + }, + }) + + if err := fakeFS.WriteFile("/etc/network/interfaces", []byte("auto eno1\niface eno1 inet manual\n"), 0o644); err != nil { + t.Fatalf("write interfaces: %v", err) + } + + reader := bufio.NewReader(strings.NewReader("")) + result := maybeRepairNICNamesCLI(context.Background(), reader, newDiscardLogger(), "/backup.tar") + if result == nil || !result.Applied() { + t.Fatalf("expected applied result, got %#v", result) + } + + updated, err := fakeFS.ReadFile("/etc/network/interfaces") + if err != nil { + t.Fatalf("read updated interfaces: %v", err) + } + if string(updated) != "auto enp3s0\niface enp3s0 inet manual\n" { + t.Fatalf("updated=%q", string(updated)) + } +} + +func TestMaybeRepairNICNamesCLI_SkipsWhenNamingOverridesDetectedAndUserConfirms(t *testing.T) { + origFS := restoreFS + origTime := restoreTime + origSysNet := sysClassNetPath + t.Cleanup(func() { + restoreFS = origFS + restoreTime = origTime + sysClassNetPath = origSysNet + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + restoreTime = &FakeTime{Current: time.Date(2026, 2, 2, 3, 4, 5, 0, time.UTC)} + + pathDir := t.TempDir() + t.Setenv("PATH", pathDir) + + sysDir := t.TempDir() + sysClassNetPath = sysDir + if err := os.MkdirAll(filepath.Join(sysDir, "enp3s0"), 0o755); err != nil { + t.Fatalf("mkdir enp3s0: %v", err) + } + if err := os.WriteFile(filepath.Join(sysDir, "enp3s0", "address"), []byte("aa:bb:cc:dd:ee:ff\n"), 0o644); err != nil { + t.Fatalf("write address: %v", err) + } + + invJSON := []byte(`{"interfaces":[{"name":"eno1","mac":"aa:bb:cc:dd:ee:ff"}]}`) + writeTarToFakeFS(t, fakeFS, "/backup.tar", []tarEntry{ + { + Name: "./var/lib/proxsave-info/commands/system/network_inventory.json", + Typeflag: tar.TypeReg, + Mode: 0o644, + Data: invJSON, + }, + }) + + if err := fakeFS.MkdirAll("/etc/udev/rules.d", 0o755); err != nil { + t.Fatalf("mkdir udev: %v", err) + } + rule := `SUBSYSTEM=="net", ATTR{address}=="aa:bb:cc:dd:ee:ff", NAME="eth0"` + if err := fakeFS.WriteFile("/etc/udev/rules.d/70-persistent-net.rules", []byte(rule+"\n"), 0o644); err != nil { + t.Fatalf("write rule: %v", err) + } + + reader := bufio.NewReader(strings.NewReader("y\n")) + result := maybeRepairNICNamesCLI(context.Background(), reader, newDiscardLogger(), "/backup.tar") + if result == nil { + t.Fatalf("expected result") + } + if !strings.Contains(result.SkippedReason, "persistent NIC naming rules") { + t.Fatalf("SkippedReason=%q", result.SkippedReason) + } +} + +func TestMaybeRepairNICNamesCLI_OverridesDetectedUserChoosesProceed(t *testing.T) { + origFS := restoreFS + origTime := restoreTime + origSysNet := sysClassNetPath + t.Cleanup(func() { + restoreFS = origFS + restoreTime = origTime + sysClassNetPath = origSysNet + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + restoreTime = &FakeTime{Current: time.Date(2026, 2, 2, 3, 4, 5, 0, time.UTC)} + + pathDir := t.TempDir() + t.Setenv("PATH", pathDir) + + sysDir := t.TempDir() + sysClassNetPath = sysDir + if err := os.MkdirAll(filepath.Join(sysDir, "enp3s0"), 0o755); err != nil { + t.Fatalf("mkdir enp3s0: %v", err) + } + if err := os.WriteFile(filepath.Join(sysDir, "enp3s0", "address"), []byte("aa:bb:cc:dd:ee:ff\n"), 0o644); err != nil { + t.Fatalf("write address: %v", err) + } + + invJSON := []byte(`{"interfaces":[{"name":"eno1","mac":"aa:bb:cc:dd:ee:ff"}]}`) + writeTarToFakeFS(t, fakeFS, "/backup.tar", []tarEntry{ + { + Name: "./var/lib/proxsave-info/commands/system/network_inventory.json", + Typeflag: tar.TypeReg, + Mode: 0o644, + Data: invJSON, + }, + }) + + if err := fakeFS.MkdirAll("/etc/udev/rules.d", 0o755); err != nil { + t.Fatalf("mkdir udev: %v", err) + } + rule := `SUBSYSTEM=="net", ATTR{address}=="aa:bb:cc:dd:ee:ff", NAME="eth0"` + if err := fakeFS.WriteFile("/etc/udev/rules.d/70-persistent-net.rules", []byte(rule+"\n"), 0o644); err != nil { + t.Fatalf("write rule: %v", err) + } + + if err := fakeFS.WriteFile("/etc/network/interfaces", []byte("auto eno1\niface eno1 inet manual\n"), 0o644); err != nil { + t.Fatalf("write interfaces: %v", err) + } + + reader := bufio.NewReader(strings.NewReader("n\n")) + result := maybeRepairNICNamesCLI(context.Background(), reader, newDiscardLogger(), "/backup.tar") + if result == nil || !result.Applied() { + t.Fatalf("expected applied result, got %#v", result) + } +} + +func TestMaybeRepairNICNamesCLI_OverridesDetectionErrorStillApplies(t *testing.T) { + origFS := restoreFS + origTime := restoreTime + origSysNet := sysClassNetPath + t.Cleanup(func() { + restoreFS = origFS + restoreTime = origTime + sysClassNetPath = origSysNet + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + restoreTime = &FakeTime{Current: time.Date(2026, 2, 2, 3, 4, 5, 0, time.UTC)} + + pathDir := t.TempDir() + t.Setenv("PATH", pathDir) + + sysDir := t.TempDir() + sysClassNetPath = sysDir + if err := os.MkdirAll(filepath.Join(sysDir, "enp3s0"), 0o755); err != nil { + t.Fatalf("mkdir enp3s0: %v", err) + } + if err := os.WriteFile(filepath.Join(sysDir, "enp3s0", "address"), []byte("aa:bb:cc:dd:ee:ff\n"), 0o644); err != nil { + t.Fatalf("write address: %v", err) + } + + invJSON := []byte(`{"interfaces":[{"name":"eno1","mac":"aa:bb:cc:dd:ee:ff"}]}`) + writeTarToFakeFS(t, fakeFS, "/backup.tar", []tarEntry{ + { + Name: "./var/lib/proxsave-info/commands/system/network_inventory.json", + Typeflag: tar.TypeReg, + Mode: 0o644, + Data: invJSON, + }, + }) + + if err := fakeFS.WriteFile("/etc/network/interfaces", []byte("auto eno1\niface eno1 inet manual\n"), 0o644); err != nil { + t.Fatalf("write interfaces: %v", err) + } + + restoreFS = readDirFailFS{FS: fakeFS, failPath: "/etc/udev/rules.d", err: errors.New("boom")} + + reader := bufio.NewReader(strings.NewReader("")) + result := maybeRepairNICNamesCLI(context.Background(), reader, newDiscardLogger(), "/backup.tar") + if result == nil || !result.Applied() { + t.Fatalf("expected applied result, got %#v", result) + } +} + +func TestMaybeRepairNICNamesCLI_ConflictsPromptAndSkip(t *testing.T) { + origFS := restoreFS + origTime := restoreTime + origSysNet := sysClassNetPath + t.Cleanup(func() { + restoreFS = origFS + restoreTime = origTime + sysClassNetPath = origSysNet + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + restoreTime = &FakeTime{Current: time.Date(2026, 2, 2, 3, 4, 5, 0, time.UTC)} + + pathDir := t.TempDir() + t.Setenv("PATH", pathDir) + + sysDir := t.TempDir() + sysClassNetPath = sysDir + if err := os.MkdirAll(filepath.Join(sysDir, "enp3s0"), 0o755); err != nil { + t.Fatalf("mkdir enp3s0: %v", err) + } + if err := os.MkdirAll(filepath.Join(sysDir, "eno1"), 0o755); err != nil { + t.Fatalf("mkdir eno1: %v", err) + } + if err := os.WriteFile(filepath.Join(sysDir, "enp3s0", "address"), []byte("aa:bb:cc:dd:ee:ff\n"), 0o644); err != nil { + t.Fatalf("write enp3s0 address: %v", err) + } + if err := os.WriteFile(filepath.Join(sysDir, "eno1", "address"), []byte("11:22:33:44:55:66\n"), 0o644); err != nil { + t.Fatalf("write eno1 address: %v", err) + } + + invJSON := []byte(`{"interfaces":[{"name":"eno1","mac":"aa:bb:cc:dd:ee:ff"}]}`) + writeTarToFakeFS(t, fakeFS, "/backup.tar", []tarEntry{ + { + Name: "./var/lib/proxsave-info/commands/system/network_inventory.json", + Typeflag: tar.TypeReg, + Mode: 0o644, + Data: invJSON, + }, + }) + + reader := bufio.NewReader(strings.NewReader("n\n")) + result := maybeRepairNICNamesCLI(context.Background(), reader, newDiscardLogger(), "/backup.tar") + if result == nil { + t.Fatalf("expected result") + } + if result.Applied() { + t.Fatalf("expected no changes when conflicts skipped, got %#v", result) + } + if !strings.Contains(result.SkippedReason, "conflicting NIC mappings") { + t.Fatalf("SkippedReason=%q", result.SkippedReason) + } +} + +func TestMaybeRepairNICNamesCLI_ConflictsPromptAndApply(t *testing.T) { + origFS := restoreFS + origTime := restoreTime + origSysNet := sysClassNetPath + t.Cleanup(func() { + restoreFS = origFS + restoreTime = origTime + sysClassNetPath = origSysNet + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + restoreTime = &FakeTime{Current: time.Date(2026, 2, 2, 3, 4, 5, 0, time.UTC)} + + pathDir := t.TempDir() + t.Setenv("PATH", pathDir) + + sysDir := t.TempDir() + sysClassNetPath = sysDir + if err := os.MkdirAll(filepath.Join(sysDir, "enp3s0"), 0o755); err != nil { + t.Fatalf("mkdir enp3s0: %v", err) + } + if err := os.MkdirAll(filepath.Join(sysDir, "eno1"), 0o755); err != nil { + t.Fatalf("mkdir eno1: %v", err) + } + if err := os.WriteFile(filepath.Join(sysDir, "enp3s0", "address"), []byte("aa:bb:cc:dd:ee:ff\n"), 0o644); err != nil { + t.Fatalf("write enp3s0 address: %v", err) + } + if err := os.WriteFile(filepath.Join(sysDir, "eno1", "address"), []byte("11:22:33:44:55:66\n"), 0o644); err != nil { + t.Fatalf("write eno1 address: %v", err) + } + + invJSON := []byte(`{"interfaces":[{"name":"eno1","mac":"aa:bb:cc:dd:ee:ff"}]}`) + writeTarToFakeFS(t, fakeFS, "/backup.tar", []tarEntry{ + { + Name: "./var/lib/proxsave-info/commands/system/network_inventory.json", + Typeflag: tar.TypeReg, + Mode: 0o644, + Data: invJSON, + }, + }) + + if err := fakeFS.WriteFile("/etc/network/interfaces", []byte("auto eno1\niface eno1 inet manual\n"), 0o644); err != nil { + t.Fatalf("write interfaces: %v", err) + } + + reader := bufio.NewReader(strings.NewReader("y\n")) + result := maybeRepairNICNamesCLI(context.Background(), reader, newDiscardLogger(), "/backup.tar") + if result == nil || !result.Applied() { + t.Fatalf("expected applied result, got %#v", result) + } +} + +func TestMaybeRepairNICNamesCLI_OverridesPromptErrorContinues(t *testing.T) { + origFS := restoreFS + origTime := restoreTime + origSysNet := sysClassNetPath + t.Cleanup(func() { + restoreFS = origFS + restoreTime = origTime + sysClassNetPath = origSysNet + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + restoreTime = &FakeTime{Current: time.Date(2026, 2, 2, 3, 4, 5, 0, time.UTC)} + + pathDir := t.TempDir() + t.Setenv("PATH", pathDir) + + sysDir := t.TempDir() + sysClassNetPath = sysDir + if err := os.MkdirAll(filepath.Join(sysDir, "enp3s0"), 0o755); err != nil { + t.Fatalf("mkdir enp3s0: %v", err) + } + if err := os.WriteFile(filepath.Join(sysDir, "enp3s0", "address"), []byte("aa:bb:cc:dd:ee:ff\n"), 0o644); err != nil { + t.Fatalf("write address: %v", err) + } + + invJSON := []byte(`{"interfaces":[{"name":"eno1","mac":"aa:bb:cc:dd:ee:ff"}]}`) + writeTarToFakeFS(t, fakeFS, "/backup.tar", []tarEntry{ + { + Name: "./var/lib/proxsave-info/commands/system/network_inventory.json", + Typeflag: tar.TypeReg, + Mode: 0o644, + Data: invJSON, + }, + }) + + if err := fakeFS.MkdirAll("/etc/udev/rules.d", 0o755); err != nil { + t.Fatalf("mkdir udev: %v", err) + } + rule := `SUBSYSTEM=="net", ATTR{address}=="aa:bb:cc:dd:ee:ff", NAME="eth0"` + if err := fakeFS.WriteFile("/etc/udev/rules.d/70-persistent-net.rules", []byte(rule+"\n"), 0o644); err != nil { + t.Fatalf("write rule: %v", err) + } + + if err := fakeFS.WriteFile("/etc/network/interfaces", []byte("auto eno1\niface eno1 inet manual\n"), 0o644); err != nil { + t.Fatalf("write interfaces: %v", err) + } + + reader := bufio.NewReader(strings.NewReader("")) + result := maybeRepairNICNamesCLI(context.Background(), reader, newDiscardLogger(), "/backup.tar") + if result == nil || !result.Applied() { + t.Fatalf("expected applied result despite prompt error, got %#v", result) + } +} + +func TestMaybeRepairNICNamesCLI_ConflictPromptErrorLeavesConflictsExcluded(t *testing.T) { + origFS := restoreFS + origTime := restoreTime + origSysNet := sysClassNetPath + t.Cleanup(func() { + restoreFS = origFS + restoreTime = origTime + sysClassNetPath = origSysNet + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + restoreTime = &FakeTime{Current: time.Date(2026, 2, 2, 3, 4, 5, 0, time.UTC)} + + pathDir := t.TempDir() + t.Setenv("PATH", pathDir) + + sysDir := t.TempDir() + sysClassNetPath = sysDir + if err := os.MkdirAll(filepath.Join(sysDir, "enp3s0"), 0o755); err != nil { + t.Fatalf("mkdir enp3s0: %v", err) + } + if err := os.MkdirAll(filepath.Join(sysDir, "eno1"), 0o755); err != nil { + t.Fatalf("mkdir eno1: %v", err) + } + if err := os.WriteFile(filepath.Join(sysDir, "enp3s0", "address"), []byte("aa:bb:cc:dd:ee:ff\n"), 0o644); err != nil { + t.Fatalf("write enp3s0 address: %v", err) + } + if err := os.WriteFile(filepath.Join(sysDir, "eno1", "address"), []byte("11:22:33:44:55:66\n"), 0o644); err != nil { + t.Fatalf("write eno1 address: %v", err) + } + + invJSON := []byte(`{"interfaces":[{"name":"eno1","mac":"aa:bb:cc:dd:ee:ff"}]}`) + writeTarToFakeFS(t, fakeFS, "/backup.tar", []tarEntry{ + { + Name: "./var/lib/proxsave-info/commands/system/network_inventory.json", + Typeflag: tar.TypeReg, + Mode: 0o644, + Data: invJSON, + }, + }) + + reader := bufio.NewReader(strings.NewReader("")) + result := maybeRepairNICNamesCLI(context.Background(), reader, newDiscardLogger(), "/backup.tar") + if result == nil { + t.Fatalf("expected result") + } + if result.Applied() { + t.Fatalf("expected conflicts excluded due to prompt error, got %#v", result) + } + if !strings.Contains(result.SkippedReason, "conflicting NIC mappings") { + t.Fatalf("SkippedReason=%q", result.SkippedReason) + } +} + +func TestMaybeRepairNICNamesCLI_MoreThan32ConflictsTriggersTruncationBranch(t *testing.T) { + origFS := restoreFS + origTime := restoreTime + origSysNet := sysClassNetPath + t.Cleanup(func() { + restoreFS = origFS + restoreTime = origTime + sysClassNetPath = origSysNet + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + restoreTime = &FakeTime{Current: time.Date(2026, 2, 2, 3, 4, 5, 0, time.UTC)} + + pathDir := t.TempDir() + t.Setenv("PATH", pathDir) + + sysDir := t.TempDir() + sysClassNetPath = sysDir + + var ifaceJSON []string + for i := 0; i < 33; i++ { + oldName := fmt.Sprintf("eno%d", i) + newName := fmt.Sprintf("enp%d", i) + mac := fmt.Sprintf("aa:bb:cc:dd:ee:%02x", i) + ifaceJSON = append(ifaceJSON, fmt.Sprintf(`{"name":"%s","mac":"%s"}`, oldName, mac)) + + if err := os.MkdirAll(filepath.Join(sysDir, newName), 0o755); err != nil { + t.Fatalf("mkdir %s: %v", newName, err) + } + if err := os.WriteFile(filepath.Join(sysDir, newName, "address"), []byte(mac+"\n"), 0o644); err != nil { + t.Fatalf("write %s address: %v", newName, err) + } + + if err := os.MkdirAll(filepath.Join(sysDir, oldName), 0o755); err != nil { + t.Fatalf("mkdir %s: %v", oldName, err) + } + conflictMAC := fmt.Sprintf("11:22:33:44:55:%02x", i) + if err := os.WriteFile(filepath.Join(sysDir, oldName, "address"), []byte(conflictMAC+"\n"), 0o644); err != nil { + t.Fatalf("write %s address: %v", oldName, err) + } + } + + invJSON := []byte(`{"interfaces":[` + strings.Join(ifaceJSON, ",") + `]}`) + writeTarToFakeFS(t, fakeFS, "/backup.tar", []tarEntry{ + { + Name: "./var/lib/proxsave-info/commands/system/network_inventory.json", + Typeflag: tar.TypeReg, + Mode: 0o644, + Data: invJSON, + }, + }) + + reader := bufio.NewReader(strings.NewReader("n\n")) + result := maybeRepairNICNamesCLI(context.Background(), reader, newDiscardLogger(), "/backup.tar") + if result == nil { + t.Fatalf("expected result") + } + if result.Applied() { + t.Fatalf("expected no changes when conflicts skipped, got %#v", result) + } +} + +func TestMaybeRepairNICNamesCLI_ReturnsNilOnApplyError(t *testing.T) { + origFS := restoreFS + origTime := restoreTime + origSysNet := sysClassNetPath + t.Cleanup(func() { + restoreFS = origFS + restoreTime = origTime + sysClassNetPath = origSysNet + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + restoreTime = &FakeTime{Current: time.Date(2026, 2, 2, 3, 4, 5, 0, time.UTC)} + + pathDir := t.TempDir() + t.Setenv("PATH", pathDir) + + sysDir := t.TempDir() + sysClassNetPath = sysDir + if err := os.MkdirAll(filepath.Join(sysDir, "enp3s0"), 0o755); err != nil { + t.Fatalf("mkdir enp3s0: %v", err) + } + if err := os.WriteFile(filepath.Join(sysDir, "enp3s0", "address"), []byte("aa:bb:cc:dd:ee:ff\n"), 0o644); err != nil { + t.Fatalf("write address: %v", err) + } + + invJSON := []byte(`{"interfaces":[{"name":"eno1","mac":"aa:bb:cc:dd:ee:ff"}]}`) + writeTarToFakeFS(t, fakeFS, "/backup.tar", []tarEntry{ + { + Name: "./var/lib/proxsave-info/commands/system/network_inventory.json", + Typeflag: tar.TypeReg, + Mode: 0o644, + Data: invJSON, + }, + }) + + if err := fakeFS.WriteFile("/etc/network/interfaces", []byte("auto eno1\niface eno1 inet manual\n"), 0o644); err != nil { + t.Fatalf("write interfaces: %v", err) + } + + restoreFS = mkdirAllFailFS{FS: fakeFS, failPath: "/tmp/proxsave", err: errors.New("boom")} + + reader := bufio.NewReader(strings.NewReader("")) + if got := maybeRepairNICNamesCLI(context.Background(), reader, newDiscardLogger(), "/backup.tar"); got != nil { + t.Fatalf("expected nil on apply error, got %#v", got) + } +} + +func TestApplyNetworkConfig_SelectsAvailableCommand(t *testing.T) { + origCmd := restoreCmd + t.Cleanup(func() { restoreCmd = origCmd }) + + logger := newDiscardLogger() + + t.Run("ifreload", func(t *testing.T) { + pathDir := t.TempDir() + writeExecutable(t, pathDir, "ifreload") + writeExecutable(t, pathDir, "systemctl") + writeExecutable(t, pathDir, "ifup") + t.Setenv("PATH", pathDir) + + fakeCmd := &FakeCommandRunner{} + restoreCmd = fakeCmd + + if err := applyNetworkConfig(context.Background(), logger); err != nil { + t.Fatalf("applyNetworkConfig error: %v", err) + } + if calls := fakeCmd.CallsList(); len(calls) != 1 || calls[0] != "ifreload -a" { + t.Fatalf("calls=%v want [ifreload -a]", calls) + } + }) + + t.Run("systemctl", func(t *testing.T) { + pathDir := t.TempDir() + writeExecutable(t, pathDir, "systemctl") + t.Setenv("PATH", pathDir) + + fakeCmd := &FakeCommandRunner{} + restoreCmd = fakeCmd + + if err := applyNetworkConfig(context.Background(), logger); err != nil { + t.Fatalf("applyNetworkConfig error: %v", err) + } + if calls := fakeCmd.CallsList(); len(calls) != 1 || calls[0] != "systemctl restart networking" { + t.Fatalf("calls=%v want [systemctl restart networking]", calls) + } + }) + + t.Run("ifup", func(t *testing.T) { + pathDir := t.TempDir() + writeExecutable(t, pathDir, "ifup") + t.Setenv("PATH", pathDir) + + fakeCmd := &FakeCommandRunner{} + restoreCmd = fakeCmd + + if err := applyNetworkConfig(context.Background(), logger); err != nil { + t.Fatalf("applyNetworkConfig error: %v", err) + } + if calls := fakeCmd.CallsList(); len(calls) != 1 || calls[0] != "ifup -a" { + t.Fatalf("calls=%v want [ifup -a]", calls) + } + }) + + t.Run("none", func(t *testing.T) { + pathDir := t.TempDir() + t.Setenv("PATH", pathDir) + + restoreCmd = &FakeCommandRunner{} + if err := applyNetworkConfig(context.Background(), logger); err == nil || !strings.Contains(err.Error(), "no supported network reload command") { + t.Fatalf("err=%v want supported network reload command error", err) + } + }) +} + +func TestParseSSHClientIP(t *testing.T) { + t.Run("SSH_CONNECTION", func(t *testing.T) { + t.Setenv("SSH_CONNECTION", "203.0.113.9 1234 10.0.0.1 22") + t.Setenv("SSH_CLIENT", "") + if got := parseSSHClientIP(); got != "203.0.113.9" { + t.Fatalf("got %q want %q", got, "203.0.113.9") + } + }) + + t.Run("SSH_CLIENT", func(t *testing.T) { + t.Setenv("SSH_CONNECTION", "") + t.Setenv("SSH_CLIENT", "203.0.113.8 2222 22") + if got := parseSSHClientIP(); got != "203.0.113.8" { + t.Fatalf("got %q want %q", got, "203.0.113.8") + } + }) + + t.Run("none", func(t *testing.T) { + t.Setenv("SSH_CONNECTION", "") + t.Setenv("SSH_CLIENT", "") + if got := parseSSHClientIP(); got != "" { + t.Fatalf("got %q want empty", got) + } + }) +} + +func TestDetectManagementInterface(t *testing.T) { + origCmd := restoreCmd + t.Cleanup(func() { restoreCmd = origCmd }) + + logger := newDiscardLogger() + + t.Run("ssh route", func(t *testing.T) { + t.Setenv("SSH_CONNECTION", "203.0.113.9 1234 10.0.0.1 22") + t.Setenv("SSH_CLIENT", "") + restoreCmd = &FakeCommandRunner{ + Outputs: map[string][]byte{ + "ip route get 203.0.113.9": []byte("203.0.113.9 dev vmbr0 src 192.0.2.2\n"), + }, + } + iface, src := detectManagementInterface(context.Background(), logger) + if iface != "vmbr0" || src != "ssh" { + t.Fatalf("iface=%q src=%q want vmbr0 ssh", iface, src) + } + }) + + t.Run("default route fallback", func(t *testing.T) { + t.Setenv("SSH_CONNECTION", "203.0.113.9 1234 10.0.0.1 22") + t.Setenv("SSH_CLIENT", "") + restoreCmd = &FakeCommandRunner{ + Errors: map[string]error{ + "ip route get 203.0.113.9": errors.New("boom"), + }, + Outputs: map[string][]byte{ + "ip route show default": []byte("default via 192.0.2.1 dev nic1\n"), + }, + } + iface, src := detectManagementInterface(context.Background(), logger) + if iface != "nic1" || src != "default-route" { + t.Fatalf("iface=%q src=%q want nic1 default-route", iface, src) + } + }) + + t.Run("none", func(t *testing.T) { + t.Setenv("SSH_CONNECTION", "") + t.Setenv("SSH_CLIENT", "") + restoreCmd = &FakeCommandRunner{ + Errors: map[string]error{ + "ip route show default": errors.New("boom"), + }, + } + iface, src := detectManagementInterface(context.Background(), logger) + if iface != "" || src != "" { + t.Fatalf("iface=%q src=%q want empty", iface, src) + } + }) +} + +func TestRouteInterfaceForIPAndDefaultRouteInterface_ErrorCases(t *testing.T) { + origCmd := restoreCmd + t.Cleanup(func() { restoreCmd = origCmd }) + + restoreCmd = &FakeCommandRunner{ + Errors: map[string]error{ + "ip route get 203.0.113.9": errors.New("boom"), + "ip route show default": errors.New("boom"), + "ip route show default -x": errors.New("boom"), + "ip route show default --y": errors.New("boom"), + }, + } + if got := routeInterfaceForIP(context.Background(), "203.0.113.9"); got != "" { + t.Fatalf("routeInterfaceForIP=%q want empty on error", got) + } + if got := defaultRouteInterface(context.Background()); got != "" { + t.Fatalf("defaultRouteInterface=%q want empty on error", got) + } +} + +func TestDefaultRouteInterface_EmptyOutputReturnsEmpty(t *testing.T) { + origCmd := restoreCmd + t.Cleanup(func() { restoreCmd = origCmd }) + + restoreCmd = &FakeCommandRunner{ + Outputs: map[string][]byte{ + "ip route show default": []byte(""), + }, + } + if got := defaultRouteInterface(context.Background()); got != "" { + t.Fatalf("defaultRouteInterface=%q want empty", got) + } +} + +func TestDefaultNetworkPortChecks(t *testing.T) { + if got := defaultNetworkPortChecks(SystemTypePVE); len(got) != 1 || got[0].Port != 8006 { + t.Fatalf("PVE checks=%v", got) + } + if got := defaultNetworkPortChecks(SystemTypePBS); len(got) != 1 || got[0].Port != 8007 { + t.Fatalf("PBS checks=%v", got) + } + if got := defaultNetworkPortChecks(SystemTypeUnknown); got != nil { + t.Fatalf("unknown checks=%v want nil", got) + } +} + +func TestPromptNetworkCommitWithCountdown_InputAborted(t *testing.T) { + reader := bufio.NewReader(strings.NewReader("")) + logger := newDiscardLogger() + + committed, err := promptNetworkCommitWithCountdown(context.Background(), reader, logger, 2*time.Second) + if committed { + t.Fatalf("expected committed=false") + } + if err == nil { + t.Fatalf("expected error") + } +} + +func TestRollbackNetworkFilesNow_ErrorCasesAndScriptFailure(t *testing.T) { + origFS := restoreFS + origCmd := restoreCmd + origTime := restoreTime + t.Cleanup(func() { + restoreFS = origFS + restoreCmd = origCmd + restoreTime = origTime + }) + + logger := newDiscardLogger() + + if _, err := rollbackNetworkFilesNow(context.Background(), logger, "", ""); err == nil { + t.Fatalf("expected error for empty backup path") + } + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreTime = &FakeTime{Current: time.Date(2026, 2, 1, 12, 34, 56, 0, time.UTC)} + + t.Run("mkdir error", func(t *testing.T) { + restoreFS = &FakeFS{Root: fakeFS.Root, MkdirAllErr: errors.New("boom"), StatErr: make(map[string]error), StatErrors: make(map[string]error), OpenFileErr: make(map[string]error)} + restoreCmd = &FakeCommandRunner{} + if _, err := rollbackNetworkFilesNow(context.Background(), logger, "/backup.tar", "/work"); err == nil { + t.Fatalf("expected error for mkdir failure") + } + }) + + t.Run("marker write error", func(t *testing.T) { + failFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(failFS.Root) }) + failFS.WriteErr = errors.New("boom") + restoreFS = failFS + restoreCmd = &FakeCommandRunner{} + + if _, err := rollbackNetworkFilesNow(context.Background(), logger, "/backup.tar", "/work"); err == nil || !strings.Contains(err.Error(), "write rollback marker") { + t.Fatalf("err=%v want write rollback marker error", err) + } + }) + + t.Run("script write error removes marker", func(t *testing.T) { + restoreFS = fakeFS + restoreCmd = &FakeCommandRunner{} + + scriptPath := "/work/network_rollback_now_20260201_123456.sh" + restoreFS = writeFailFS{FS: fakeFS, failPath: scriptPath, err: errors.New("boom")} + + _, err := rollbackNetworkFilesNow(context.Background(), logger, "/backup.tar", "/work") + if err == nil || !strings.Contains(err.Error(), "write rollback script") { + t.Fatalf("err=%v want write rollback script error", err) + } + + markerPath := "/work/network_rollback_now_pending_20260201_123456" + if _, statErr := fakeFS.Stat(markerPath); !errors.Is(statErr, os.ErrNotExist) { + t.Fatalf("expected marker to be removed on script write failure, stat err=%v", statErr) + } + }) + + t.Run("marker remove error is non-fatal", func(t *testing.T) { + restoreCmd = &FakeCommandRunner{ + Outputs: map[string][]byte{ + "sh /work/network_rollback_now_20260201_123456.sh": []byte("ok\n"), + }, + } + + markerPath := "/work/network_rollback_now_pending_20260201_123456" + restoreFS = removeFailFS{FS: fakeFS, failPath: markerPath, err: errors.New("remove boom")} + + logPath, err := rollbackNetworkFilesNow(context.Background(), logger, "/backup.tar", "/work") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if logPath != "/work/network_rollback_now_20260201_123456.log" { + t.Fatalf("logPath=%q", logPath) + } + }) + + t.Run("script run error returns log path", func(t *testing.T) { + restoreFS = fakeFS + + restoreCmd = &FakeCommandRunner{ + Errors: map[string]error{ + "sh /work/network_rollback_now_20260201_123456.sh": errors.New("boom"), + }, + } + + logPath, err := rollbackNetworkFilesNow(context.Background(), logger, "/backup.tar", "/work") + if err == nil || !strings.Contains(err.Error(), "rollback script failed") { + t.Fatalf("err=%v want rollback script failed", err) + } + if logPath != "/work/network_rollback_now_20260201_123456.log" { + t.Fatalf("logPath=%q", logPath) + } + }) +} + +func TestRollbackNetworkFilesNow_DefaultWorkDirUsesTmpProxsave(t *testing.T) { + origFS := restoreFS + origCmd := restoreCmd + origTime := restoreTime + t.Cleanup(func() { + restoreFS = origFS + restoreCmd = origCmd + restoreTime = origTime + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + restoreTime = &FakeTime{Current: time.Date(2026, 2, 1, 12, 34, 56, 0, time.UTC)} + + restoreCmd = &FakeCommandRunner{ + Outputs: map[string][]byte{ + "sh /tmp/proxsave/network_rollback_now_20260201_123456.sh": []byte("ok\n"), + }, + } + + logPath, err := rollbackNetworkFilesNow(context.Background(), newDiscardLogger(), "/backup.tar", "") + if err != nil { + t.Fatalf("rollbackNetworkFilesNow error: %v", err) + } + if logPath != "/tmp/proxsave/network_rollback_now_20260201_123456.log" { + t.Fatalf("logPath=%q", logPath) + } +} + +func TestBuildRollbackScript_IncludesRestartBlockWhenEnabled(t *testing.T) { + script := buildRollbackScript("/marker", "/backup with spaces.tar", "/tmp/log file.log", true) + if !strings.Contains(script, "Restart networking after rollback") { + t.Fatalf("expected restartNetworking block") + } + if !strings.Contains(script, "LOG='/tmp/log file.log'") { + t.Fatalf("expected quoted LOG, got:\n%s", script) + } + if !strings.Contains(script, "BACKUP='/backup with spaces.tar'") { + t.Fatalf("expected quoted BACKUP, got:\n%s", script) + } +} + +func TestShellQuote(t *testing.T) { + if got := shellQuote(""); got != "''" { + t.Fatalf("shellQuote empty=%q", got) + } + if got := shellQuote("simple"); got != "simple" { + t.Fatalf("shellQuote simple=%q", got) + } + want := `'a b '\''c'\'''` + if got := shellQuote("a b 'c'"); got != want { + t.Fatalf("shellQuote=%q want %q", got, want) + } +} + +func TestRunCommandLogged_SuccessAndFailure(t *testing.T) { + origCmd := restoreCmd + t.Cleanup(func() { restoreCmd = origCmd }) + + logger := newDiscardLogger() + + restoreCmd = &FakeCommandRunner{ + Outputs: map[string][]byte{ + "echo hi": []byte("hi\n"), + }, + } + if err := runCommandLogged(context.Background(), logger, "echo", "hi"); err != nil { + t.Fatalf("runCommandLogged error: %v", err) + } + + restoreCmd = &FakeCommandRunner{ + Errors: map[string]error{ + "false x": errors.New("boom"), + }, + } + err := runCommandLogged(context.Background(), logger, "false", "x") + if err == nil || !strings.Contains(err.Error(), "false") || !strings.Contains(err.Error(), "failed") { + t.Fatalf("err=%v want wrapped failure", err) + } +} diff --git a/internal/orchestrator/network_staged_apply.go b/internal/orchestrator/network_staged_apply.go index 3412d15..da93844 100644 --- a/internal/orchestrator/network_staged_apply.go +++ b/internal/orchestrator/network_staged_apply.go @@ -59,13 +59,20 @@ func applyNetworkFilesFromStage(logger *logging.Logger, stageRoot string) (appli } func copyDirOverlay(srcDir, destDir string) ([]string, error) { - info, err := restoreFS.Stat(srcDir) + return copyDirOverlayWithinRoot(srcDir, destDir, destDir) +} + +func copyDirOverlayWithinRoot(srcDir, destDir, destRoot string) ([]string, error) { + info, err := restoreFS.Lstat(srcDir) if err != nil { if os.IsNotExist(err) { return nil, nil } return nil, fmt.Errorf("stat %s: %w", srcDir, err) } + if info.Mode()&os.ModeSymlink != 0 { + return nil, fmt.Errorf("staged directory must not be a symlink: %s", srcDir) + } if !info.IsDir() { return nil, nil } @@ -91,8 +98,27 @@ func copyDirOverlay(srcDir, destDir string) ([]string, error) { src := filepath.Join(srcDir, name) dest := filepath.Join(destDir, name) - if entry.IsDir() { - paths, err := copyDirOverlay(src, dest) + info, err := restoreFS.Lstat(src) + if err != nil { + if os.IsNotExist(err) { + continue + } + return applied, fmt.Errorf("stat %s: %w", src, err) + } + + if info.Mode()&os.ModeSymlink != 0 { + ok, err := copySymlinkOverlayWithinRoot(src, dest, destRoot) + if err != nil { + return applied, err + } + if ok { + applied = append(applied, dest) + } + continue + } + + if info.IsDir() { + paths, err := copyDirOverlayWithinRoot(src, dest, destRoot) if err != nil { return applied, err } @@ -100,7 +126,7 @@ func copyDirOverlay(srcDir, destDir string) ([]string, error) { continue } - ok, err := copyFileOverlay(src, dest) + ok, err := copyFileOverlayWithinRoot(src, dest, destRoot) if err != nil { return applied, err } @@ -113,16 +139,26 @@ func copyDirOverlay(srcDir, destDir string) ([]string, error) { } func copyFileOverlay(src, dest string) (bool, error) { - info, err := restoreFS.Stat(src) + return copyFileOverlayWithinRoot(src, dest, filepath.Dir(dest)) +} + +func copyFileOverlayWithinRoot(src, dest, destRoot string) (bool, error) { + info, err := restoreFS.Lstat(src) if err != nil { if os.IsNotExist(err) { return false, nil } return false, fmt.Errorf("stat %s: %w", src, err) } + if info.Mode()&os.ModeSymlink != 0 { + return copySymlinkOverlayWithinRoot(src, dest, destRoot) + } if info.IsDir() { return false, nil } + if !info.Mode().IsRegular() { + return false, fmt.Errorf("unsupported staged file type %s (mode=%s)", src, info.Mode()) + } data, err := restoreFS.ReadFile(src) if err != nil { @@ -141,3 +177,102 @@ func copyFileOverlay(src, dest string) (bool, error) { } return true, nil } + +func copySymlinkOverlay(src, dest string) (bool, error) { + return copySymlinkOverlayWithinRoot(src, dest, filepath.Dir(dest)) +} + +func copySymlinkOverlayWithinRoot(src, dest, destRoot string) (bool, error) { + info, err := restoreFS.Lstat(src) + if err != nil { + if os.IsNotExist(err) { + return false, nil + } + return false, fmt.Errorf("stat %s: %w", src, err) + } + if info.Mode()&os.ModeSymlink == 0 { + return false, fmt.Errorf("source is not a symlink: %s", src) + } + + target, err := restoreFS.Readlink(src) + if err != nil { + if os.IsNotExist(err) { + return false, nil + } + return false, fmt.Errorf("readlink %s: %w", src, err) + } + + validatedTarget, err := validateOverlaySymlinkTargetWithinRoot(destRoot, dest, target) + if err != nil { + return false, fmt.Errorf("unsafe symlink target %s -> %s: %w", dest, target, err) + } + + if err := ensureDirExistsWithInheritedMeta(filepath.Dir(dest)); err != nil { + return false, fmt.Errorf("ensure %s: %w", filepath.Dir(dest), err) + } + + if existing, err := restoreFS.Lstat(dest); err == nil { + if existing.IsDir() { + return false, fmt.Errorf("destination exists as directory: %s", dest) + } + if err := restoreFS.Remove(dest); err != nil && !os.IsNotExist(err) { + return false, fmt.Errorf("remove %s: %w", dest, err) + } + } else if !os.IsNotExist(err) { + return false, fmt.Errorf("stat %s: %w", dest, err) + } + + if err := restoreFS.Symlink(validatedTarget, dest); err != nil { + return false, fmt.Errorf("symlink %s -> %s: %w", dest, validatedTarget, err) + } + if err := verifyCreatedSymlinkWithinRootFS(restoreFS, destRoot, dest); err != nil { + return false, err + } + return true, nil +} + +// verifyCreatedSymlinkWithinRootFS re-reads a newly created symlink and ensures +// its effective target still resolves within destRoot. On validation failure it +// removes the symlink, mirroring extractSymlink's cleanup-on-failure behavior. +func verifyCreatedSymlinkWithinRootFS(fsys FS, destRoot, dest string) error { + destRoot = filepath.Clean(strings.TrimSpace(destRoot)) + dest = filepath.Clean(strings.TrimSpace(dest)) + + actualTarget, err := fsys.Readlink(dest) + if err != nil { + _ = fsys.Remove(dest) + return fmt.Errorf("read created symlink %s: %w", dest, err) + } + + if _, err := resolvePathRelativeToBaseWithinRootFS(fsys, destRoot, filepath.Dir(dest), actualTarget); err != nil { + _ = fsys.Remove(dest) + return fmt.Errorf("symlink target escapes root after creation: %s -> %s: %w", dest, actualTarget, err) + } + + return nil +} + +func validateOverlaySymlinkTargetWithinRoot(destRoot, dest, target string) (string, error) { + destRoot = filepath.Clean(strings.TrimSpace(destRoot)) + dest = filepath.Clean(strings.TrimSpace(dest)) + + resolved, err := resolvePathRelativeToBaseWithinRootFS(restoreFS, destRoot, filepath.Dir(dest), target) + if err != nil { + return "", err + } + + if !filepath.IsAbs(target) { + return target, nil + } + + resolvedParent, err := resolvePathRelativeToBaseWithinRootFS(restoreFS, destRoot, filepath.Dir(dest), ".") + if err != nil { + return "", err + } + + rewrittenTarget, err := filepath.Rel(resolvedParent, resolved) + if err != nil { + return "", fmt.Errorf("rewrite symlink target %s -> %s: %w", dest, resolved, err) + } + return filepath.Clean(rewrittenTarget), nil +} diff --git a/internal/orchestrator/network_staged_apply_test.go b/internal/orchestrator/network_staged_apply_test.go new file mode 100644 index 0000000..bc61f30 --- /dev/null +++ b/internal/orchestrator/network_staged_apply_test.go @@ -0,0 +1,240 @@ +package orchestrator + +import ( + "errors" + "os" + "path/filepath" + "strings" + "testing" +) + +type preservingSymlinkFS struct { + *FakeFS +} + +func newPreservingSymlinkFS() *preservingSymlinkFS { + return &preservingSymlinkFS{FakeFS: NewFakeFS()} +} + +func (f *preservingSymlinkFS) Symlink(oldname, newname string) error { + if err := os.MkdirAll(filepath.Dir(f.onDisk(newname)), 0o755); err != nil { + return err + } + return os.Symlink(oldname, f.onDisk(newname)) +} + +type tamperSymlinkFS struct { + FS + destPath string + target string +} + +func (f tamperSymlinkFS) Symlink(oldname, newname string) error { + if filepath.Clean(newname) == filepath.Clean(f.destPath) { + return f.FS.Symlink(f.target, newname) + } + return f.FS.Symlink(oldname, newname) +} + +func TestApplyNetworkFilesFromStage_RewritesSafeAbsoluteSymlinkTargetsWithinDestinationRoot(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := newPreservingSymlinkFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + if err := fakeFS.MkdirAll("/stage/etc/network", 0o755); err != nil { + t.Fatalf("create stage dir: %v", err) + } + if err := fakeFS.Symlink("/etc/network/interfaces.real", "/stage/etc/network/interfaces"); err != nil { + t.Fatalf("create staged symlink: %v", err) + } + + applied, err := applyNetworkFilesFromStage(newTestLogger(), "/stage") + if err != nil { + t.Fatalf("applyNetworkFilesFromStage: %v", err) + } + + info, err := fakeFS.Lstat("/etc/network/interfaces") + if err != nil { + t.Fatalf("lstat dest symlink: %v", err) + } + if info.Mode()&os.ModeSymlink == 0 { + t.Fatalf("expected symlink mode, got %s", info.Mode()) + } + + gotTarget, err := fakeFS.Readlink("/etc/network/interfaces") + if err != nil { + t.Fatalf("readlink dest symlink: %v", err) + } + if gotTarget != "interfaces.real" { + t.Fatalf("symlink target=%q, want %q", gotTarget, "interfaces.real") + } + + found := false + for _, path := range applied { + if path == "/etc/network/interfaces" { + found = true + break + } + } + if !found { + t.Fatalf("expected applied paths to include /etc/network/interfaces, got %#v", applied) + } +} + +func TestApplyNetworkFilesFromStage_RejectsEscapingSymlinkTargets(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := newPreservingSymlinkFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + if err := fakeFS.MkdirAll("/stage/etc/network", 0o755); err != nil { + t.Fatalf("create stage dir: %v", err) + } + if err := fakeFS.Symlink("/stage/etc/network/interfaces.real", "/stage/etc/network/interfaces"); err != nil { + t.Fatalf("create staged symlink: %v", err) + } + + _, err := applyNetworkFilesFromStage(newTestLogger(), "/stage") + if err == nil || !strings.Contains(err.Error(), "unsafe symlink target") { + t.Fatalf("expected unsafe symlink target error, got %v", err) + } +} + +func TestApplyNetworkFilesFromStage_RejectsRelativeSymlinkTargetsThatEscapeDestinationRoot(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := newPreservingSymlinkFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + if err := fakeFS.MkdirAll("/stage/etc/network/interfaces.d", 0o755); err != nil { + t.Fatalf("create stage dir: %v", err) + } + if err := fakeFS.Symlink("../../shadow", "/stage/etc/network/interfaces.d/uplink"); err != nil { + t.Fatalf("create staged symlink: %v", err) + } + + _, err := applyNetworkFilesFromStage(newTestLogger(), "/stage") + if err == nil || !strings.Contains(err.Error(), "unsafe symlink target") { + t.Fatalf("expected unsafe symlink target error, got %v", err) + } +} + +func TestApplyNetworkFilesFromStage_RejectsSymlinkStageDirectory(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := newPreservingSymlinkFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + if err := fakeFS.MkdirAll("/stage/etc", 0o755); err != nil { + t.Fatalf("create stage parent dir: %v", err) + } + if err := fakeFS.Symlink("/outside/network", "/stage/etc/network"); err != nil { + t.Fatalf("create staged dir symlink: %v", err) + } + + _, err := applyNetworkFilesFromStage(newTestLogger(), "/stage") + if err == nil || !strings.Contains(err.Error(), "must not be a symlink") { + t.Fatalf("expected staged directory symlink error, got %v", err) + } +} + +func TestValidateOverlaySymlinkTargetWithinRoot_RewritesAbsoluteTargetFromResolvedParent(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := newPreservingSymlinkFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + if err := fakeFS.MkdirAll("/dest/materialized/network", 0o755); err != nil { + t.Fatalf("create materialized dir: %v", err) + } + if err := fakeFS.Symlink("/dest/materialized", "/dest/etc"); err != nil { + t.Fatalf("create parent symlink: %v", err) + } + + rewritten, err := validateOverlaySymlinkTargetWithinRoot( + "/dest", + "/dest/etc/network/interfaces", + "/dest/materialized/network/interfaces.real", + ) + if err != nil { + t.Fatalf("validateOverlaySymlinkTargetWithinRoot error: %v", err) + } + if rewritten != "interfaces.real" { + t.Fatalf("rewritten target = %q, want %q", rewritten, "interfaces.real") + } +} + +func TestCopySymlinkOverlayWithinRoot_CleansUpWhenCreatedSymlinkReadbackFails(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := newPreservingSymlinkFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + + if err := fakeFS.MkdirAll("/stage", 0o755); err != nil { + t.Fatalf("create stage dir: %v", err) + } + if err := fakeFS.Symlink("/dest/target", "/stage/link"); err != nil { + t.Fatalf("create staged symlink: %v", err) + } + + restoreFS = readlinkFailFS{ + FS: fakeFS, + failPath: "/dest/link", + err: errors.New("boom"), + } + + ok, err := copySymlinkOverlayWithinRoot("/stage/link", "/dest/link", "/dest") + if ok { + t.Fatal("copySymlinkOverlayWithinRoot reported success; want failure") + } + if err == nil || !strings.Contains(err.Error(), "read created symlink /dest/link") { + t.Fatalf("expected readback validation error, got %v", err) + } + if _, statErr := fakeFS.Lstat("/dest/link"); !os.IsNotExist(statErr) { + t.Fatalf("expected created symlink cleanup, lstat err = %v", statErr) + } +} + +func TestCopySymlinkOverlayWithinRoot_CleansUpWhenCreatedSymlinkEscapesAfterCreation(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := newPreservingSymlinkFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + + if err := fakeFS.MkdirAll("/stage", 0o755); err != nil { + t.Fatalf("create stage dir: %v", err) + } + if err := fakeFS.Symlink("/dest/target", "/stage/link"); err != nil { + t.Fatalf("create staged symlink: %v", err) + } + + restoreFS = tamperSymlinkFS{ + FS: fakeFS, + destPath: "/dest/link", + target: "/outside/evil", + } + + ok, err := copySymlinkOverlayWithinRoot("/stage/link", "/dest/link", "/dest") + if ok { + t.Fatal("copySymlinkOverlayWithinRoot reported success; want failure") + } + if err == nil || !strings.Contains(err.Error(), "escapes root after creation") { + t.Fatalf("expected post-create escape error, got %v", err) + } + if _, statErr := fakeFS.Lstat("/dest/link"); !os.IsNotExist(statErr) { + t.Fatalf("expected created symlink cleanup, lstat err = %v", statErr) + } +} diff --git a/internal/orchestrator/path_security.go b/internal/orchestrator/path_security.go new file mode 100644 index 0000000..7b2a425 --- /dev/null +++ b/internal/orchestrator/path_security.go @@ -0,0 +1,286 @@ +package orchestrator + +import ( + "errors" + "fmt" + "os" + "path/filepath" + "strings" +) + +const maxPathSecuritySymlinkHops = 40 + +type pathResolutionErrorKind string + +const ( + pathResolutionErrorSecurity pathResolutionErrorKind = "security" + pathResolutionErrorOperational pathResolutionErrorKind = "operational" +) + +type pathResolutionError struct { + kind pathResolutionErrorKind + msg string + cause error +} + +func (e *pathResolutionError) Error() string { + if e == nil { + return "" + } + if e.cause != nil { + return fmt.Sprintf("%s: %v", e.msg, e.cause) + } + return e.msg +} + +func (e *pathResolutionError) Unwrap() error { + if e == nil { + return nil + } + return e.cause +} + +func newPathResolutionError(kind pathResolutionErrorKind, cause error, format string, args ...any) error { + return &pathResolutionError{ + kind: kind, + msg: fmt.Sprintf(format, args...), + cause: cause, + } +} + +func newPathSecurityError(format string, args ...any) error { + return newPathResolutionError(pathResolutionErrorSecurity, nil, format, args...) +} + +func wrapPathOperationalError(cause error, format string, args ...any) error { + return newPathResolutionError(pathResolutionErrorOperational, cause, format, args...) +} + +func isPathSecurityError(err error) bool { + var target *pathResolutionError + return errors.As(err, &target) && target.kind == pathResolutionErrorSecurity +} + +func isPathOperationalError(err error) bool { + var target *pathResolutionError + return errors.As(err, &target) && target.kind == pathResolutionErrorOperational +} + +func resolvePathWithinRootFS(fsys FS, destRoot, candidate string) (string, error) { + lexicalRoot, canonicalRoot, err := prepareRootPathsFS(fsys, destRoot) + if err != nil { + return "", err + } + + candidateAbs, err := normalizeCandidateWithinRoot(lexicalRoot, canonicalRoot, candidate) + if err != nil { + return "", err + } + + return resolvePathWithinPreparedRootFS(fsys, lexicalRoot, canonicalRoot, candidateAbs, true, maxPathSecuritySymlinkHops) +} + +func resolvePathRelativeToBaseWithinRootFS(fsys FS, destRoot, baseDir, candidate string) (string, error) { + lexicalRoot, canonicalRoot, err := prepareRootPathsFS(fsys, destRoot) + if err != nil { + return "", err + } + + baseAbs, err := filepath.Abs(filepath.Clean(baseDir)) + if err != nil { + return "", fmt.Errorf("resolve base directory: %w", err) + } + normalizedBaseDir, err := normalizeAbsolutePathWithinRoot(lexicalRoot, canonicalRoot, baseAbs) + if err != nil { + return "", err + } + + // Resolve symlinks already present in the base directory before validating + // a relative target against it. The kernel creates the new symlink under the + // resolved parent directory, not the lexical parent path. + canonicalBaseDir, err := resolvePathWithinPreparedRootFS( + fsys, + lexicalRoot, + canonicalRoot, + normalizedBaseDir, + true, + maxPathSecuritySymlinkHops, + ) + if err != nil { + return "", err + } + + var candidateAbs string + if filepath.IsAbs(candidate) { + candidateAbs, err = normalizeAbsolutePathWithinRoot(lexicalRoot, canonicalRoot, filepath.Clean(candidate)) + if err != nil { + return "", err + } + } else { + candidateAbs = filepath.Clean(filepath.Join(canonicalBaseDir, candidate)) + } + + return resolvePathWithinPreparedRootFS(fsys, lexicalRoot, canonicalRoot, candidateAbs, true, maxPathSecuritySymlinkHops) +} + +func prepareRootPathsFS(fsys FS, destRoot string) (string, string, error) { + fsys = pathSecurityFS(fsys) + + cleanRoot := filepath.Clean(destRoot) + absRoot, err := filepath.Abs(cleanRoot) + if err != nil { + return "", "", fmt.Errorf("resolve destination root: %w", err) + } + + canonicalRoot, err := resolvePathFromFilesystemRootFS(fsys, absRoot, true, maxPathSecuritySymlinkHops) + if err != nil { + return "", "", fmt.Errorf("resolve destination root: %w", err) + } + + return absRoot, canonicalRoot, nil +} + +func normalizeCandidateWithinRoot(lexicalRoot, canonicalRoot, candidate string) (string, error) { + if filepath.IsAbs(candidate) { + return normalizeAbsolutePathWithinRoot(lexicalRoot, canonicalRoot, filepath.Clean(candidate)) + } + + return filepath.Clean(filepath.Join(canonicalRoot, candidate)), nil +} + +func normalizeAbsolutePathWithinRoot(lexicalRoot, canonicalRoot, candidateAbs string) (string, error) { + rel, ok, err := relativePathWithinRoot(lexicalRoot, candidateAbs) + if err != nil { + return "", fmt.Errorf("cannot compute relative path: %w", err) + } + if ok { + return filepath.Clean(filepath.Join(canonicalRoot, rel)), nil + } + + rel, ok, err = relativePathWithinRoot(canonicalRoot, candidateAbs) + if err != nil { + return "", fmt.Errorf("cannot compute relative path: %w", err) + } + if ok { + return filepath.Clean(filepath.Join(canonicalRoot, rel)), nil + } + + return "", newPathSecurityError("resolved path escapes destination: %s", candidateAbs) +} + +func resolvePathFromFilesystemRootFS(fsys FS, candidateAbs string, allowMissingTail bool, hopsRemaining int) (string, error) { + root := string(os.PathSeparator) + return resolvePathWithinPreparedRootFS(fsys, root, root, candidateAbs, allowMissingTail, hopsRemaining) +} + +func resolvePathWithinPreparedRootFS(fsys FS, lexicalRoot, canonicalRoot, candidateAbs string, allowMissingTail bool, hopsRemaining int) (string, error) { + fsys = pathSecurityFS(fsys) + + lexicalRoot = filepath.Clean(lexicalRoot) + canonicalRoot = filepath.Clean(canonicalRoot) + candidateAbs = filepath.Clean(candidateAbs) + + if !filepath.IsAbs(lexicalRoot) { + return "", fmt.Errorf("destination root must be absolute: %s", lexicalRoot) + } + if !filepath.IsAbs(canonicalRoot) { + return "", fmt.Errorf("destination root must be absolute: %s", canonicalRoot) + } + if !filepath.IsAbs(candidateAbs) { + return "", fmt.Errorf("candidate path must be absolute: %s", candidateAbs) + } + + rel, ok, err := relativePathWithinRoot(canonicalRoot, candidateAbs) + if err != nil { + return "", fmt.Errorf("cannot compute relative path: %w", err) + } + if !ok { + return "", newPathSecurityError("resolved path escapes destination: %s", candidateAbs) + } + if rel == "." { + return canonicalRoot, nil + } + + current := canonicalRoot + parts := splitRelativePath(rel) + for idx, part := range parts { + next := filepath.Join(current, part) + info, err := fsys.Lstat(next) + if err != nil { + if allowMissingTail && os.IsNotExist(err) { + return filepath.Clean(filepath.Join(current, filepath.Join(parts[idx:]...))), nil + } + return "", wrapPathOperationalError(err, "lstat %s", next) + } + + if info.Mode()&os.ModeSymlink != 0 { + if hopsRemaining <= 0 { + return "", newPathSecurityError("too many symlink resolutions for %s", candidateAbs) + } + target, err := fsys.Readlink(next) + if err != nil { + return "", wrapPathOperationalError(err, "readlink %s", next) + } + + var resolvedLink string + if filepath.IsAbs(target) { + resolvedLink, err = normalizeAbsolutePathWithinRoot(lexicalRoot, canonicalRoot, filepath.Clean(target)) + if err != nil { + return "", err + } + } else { + resolvedLink = filepath.Join(current, target) + } + resolvedLink = filepath.Clean(resolvedLink) + + if _, ok, err := relativePathWithinRoot(canonicalRoot, resolvedLink); err != nil { + return "", fmt.Errorf("cannot compute relative path: %w", err) + } else if !ok { + return "", newPathSecurityError("resolved path escapes destination: %s", resolvedLink) + } + + remainder := filepath.Join(parts[idx+1:]...) + if remainder != "" { + resolvedLink = filepath.Join(resolvedLink, remainder) + } + + return resolvePathWithinPreparedRootFS(fsys, lexicalRoot, canonicalRoot, resolvedLink, allowMissingTail, hopsRemaining-1) + } + + if !info.IsDir() && idx < len(parts)-1 { + return "", newPathResolutionError(pathResolutionErrorOperational, nil, "path component is not a directory: %s", next) + } + + current = next + } + + return current, nil +} + +func relativePathWithinRoot(root, candidate string) (string, bool, error) { + rel, err := filepath.Rel(root, candidate) + if err != nil { + return "", false, err + } + if rel == "." { + return rel, true, nil + } + if rel == ".." || strings.HasPrefix(rel, ".."+string(os.PathSeparator)) || filepath.IsAbs(rel) { + return rel, false, nil + } + return rel, true, nil +} + +func splitRelativePath(rel string) []string { + if rel == "" || rel == "." { + return nil + } + return strings.Split(rel, string(os.PathSeparator)) +} + +func pathSecurityFS(fsys FS) FS { + if fsys == nil { + return osFS{} + } + return fsys +} diff --git a/internal/orchestrator/path_security_test.go b/internal/orchestrator/path_security_test.go new file mode 100644 index 0000000..b3685b4 --- /dev/null +++ b/internal/orchestrator/path_security_test.go @@ -0,0 +1,184 @@ +package orchestrator + +import ( + "os" + "path/filepath" + "strings" + "testing" +) + +func TestResolvePathWithinRootFS_AllowsMissingTailAfterSafeSymlink(t *testing.T) { + root := t.TempDir() + if err := os.MkdirAll(filepath.Join(root, "safe-target"), 0o755); err != nil { + t.Fatalf("mkdir safe-target: %v", err) + } + if err := os.Symlink("safe-target", filepath.Join(root, "safe-link")); err != nil { + t.Fatalf("create safe symlink: %v", err) + } + + resolved, err := resolvePathWithinRootFS(osFS{}, root, filepath.Join("safe-link", "missing", "file.txt")) + if err != nil { + t.Fatalf("resolvePathWithinRootFS returned error: %v", err) + } + + want := filepath.Join(root, "safe-target", "missing", "file.txt") + if resolved != want { + t.Fatalf("resolved path = %q, want %q", resolved, want) + } +} + +func TestResolvePathWithinRootFS_RejectsBrokenIntermediateSymlinkEscape(t *testing.T) { + root := t.TempDir() + outside := t.TempDir() + if err := os.Symlink(outside, filepath.Join(root, "escape-link")); err != nil { + t.Fatalf("create escape symlink: %v", err) + } + + _, err := resolvePathWithinRootFS(osFS{}, root, filepath.Join("escape-link", "missing", "file.txt")) + if err == nil || !strings.Contains(err.Error(), "escapes destination") { + t.Fatalf("expected escape rejection, got %v", err) + } +} + +func TestResolvePathWithinRootFS_RejectsSymlinkLoop(t *testing.T) { + root := t.TempDir() + if err := os.Symlink("loop", filepath.Join(root, "loop")); err != nil { + t.Fatalf("create loop symlink: %v", err) + } + + _, err := resolvePathWithinRootFS(osFS{}, root, filepath.Join("loop", "file.txt")) + if err == nil || !strings.Contains(err.Error(), "too many symlink resolutions") { + t.Fatalf("expected symlink loop rejection, got %v", err) + } +} + +func TestResolvePathWithinRootFS_ClassifiesPathComponentNotDirectoryAsOperational(t *testing.T) { + root := t.TempDir() + blocker := filepath.Join(root, "deep") + if err := os.WriteFile(blocker, []byte("block"), 0o644); err != nil { + t.Fatalf("write blocker: %v", err) + } + + _, err := resolvePathWithinRootFS(osFS{}, root, filepath.Join("deep", "nested", "file.txt")) + if err == nil { + t.Fatal("expected error for non-directory path component") + } + if !isPathOperationalError(err) { + t.Fatalf("expected operational error classification, got %v", err) + } + if isPathSecurityError(err) { + t.Fatalf("expected non-security error classification, got %v", err) + } +} + +func TestResolvePathWithinRootFS_ClassifiesPermissionDeniedAsOperational(t *testing.T) { + fsys := NewFakeFS() + root := filepath.Join(string(os.PathSeparator), "restore-root") + if err := fsys.AddDir(root); err != nil { + t.Fatalf("add root dir: %v", err) + } + if err := fsys.AddDir(filepath.Join(root, "subdir")); err != nil { + t.Fatalf("add subdir: %v", err) + } + fsys.StatErrors[filepath.Clean(filepath.Join(root, "subdir", "file.txt"))] = os.ErrPermission + + _, err := resolvePathWithinRootFS(fsys, root, filepath.Join("subdir", "file.txt")) + if err == nil { + t.Fatal("expected permission error") + } + if !isPathOperationalError(err) { + t.Fatalf("expected operational error classification, got %v", err) + } + if isPathSecurityError(err) { + t.Fatalf("expected non-security error classification, got %v", err) + } +} + +func TestResolvePathWithinRootFS_AllowsAbsoluteSymlinkTargetViaLexicalRoot(t *testing.T) { + parent := t.TempDir() + realRoot := filepath.Join(parent, "real-root") + linkRoot := filepath.Join(parent, "link-root") + if err := os.MkdirAll(filepath.Join(realRoot, "etc-real"), 0o755); err != nil { + t.Fatalf("mkdir real root: %v", err) + } + if err := os.Symlink(realRoot, linkRoot); err != nil { + t.Fatalf("create root symlink: %v", err) + } + if err := os.Symlink(filepath.Join(linkRoot, "etc-real"), filepath.Join(realRoot, "etc")); err != nil { + t.Fatalf("create nested symlink: %v", err) + } + + resolved, err := resolvePathWithinRootFS(osFS{}, linkRoot, filepath.Join("etc", "missing", "file.txt")) + if err != nil { + t.Fatalf("resolvePathWithinRootFS returned error: %v", err) + } + + want := filepath.Join(realRoot, "etc-real", "missing", "file.txt") + if resolved != want { + t.Fatalf("resolved path = %q, want %q", resolved, want) + } +} + +func TestResolvePathRelativeToBaseWithinRootFS_RejectsEscapeViaSymlinkedBaseDir(t *testing.T) { + root := t.TempDir() + if err := os.Symlink(".", filepath.Join(root, "linkdir")); err != nil { + t.Fatalf("create base symlink: %v", err) + } + + _, err := resolvePathRelativeToBaseWithinRootFS(osFS{}, root, filepath.Join(root, "linkdir"), "../outside") + if err == nil || !strings.Contains(err.Error(), "escapes destination") { + t.Fatalf("expected escape rejection, got %v", err) + } +} + +func TestResolvePathRelativeToBaseWithinRootFS_AllowsRelativeTargetAfterSafeBaseResolution(t *testing.T) { + root := t.TempDir() + if err := os.MkdirAll(filepath.Join(root, "subdir"), 0o755); err != nil { + t.Fatalf("mkdir subdir: %v", err) + } + if err := os.Symlink("subdir", filepath.Join(root, "linkdir")); err != nil { + t.Fatalf("create base symlink: %v", err) + } + + resolved, err := resolvePathRelativeToBaseWithinRootFS(osFS{}, root, filepath.Join(root, "linkdir"), "file.txt") + if err != nil { + t.Fatalf("resolvePathRelativeToBaseWithinRootFS returned error: %v", err) + } + + want := filepath.Join(root, "subdir", "file.txt") + if resolved != want { + t.Fatalf("resolved path = %q, want %q", resolved, want) + } +} + +func TestResolvePathRelativeToBaseWithinRootFS_RejectsBaseDirSymlinkOutsideRoot(t *testing.T) { + root := t.TempDir() + outside := t.TempDir() + if err := os.Symlink(outside, filepath.Join(root, "linkdir")); err != nil { + t.Fatalf("create base symlink: %v", err) + } + + _, err := resolvePathRelativeToBaseWithinRootFS(osFS{}, root, filepath.Join(root, "linkdir"), "file.txt") + if err == nil || !strings.Contains(err.Error(), "escapes destination") { + t.Fatalf("expected escape rejection, got %v", err) + } +} + +func TestResolvePathRelativeToBaseWithinRootFS_PreservesAbsoluteCandidateBehavior(t *testing.T) { + root := t.TempDir() + if err := os.MkdirAll(filepath.Join(root, "subdir"), 0o755); err != nil { + t.Fatalf("mkdir subdir: %v", err) + } + if err := os.Symlink("subdir", filepath.Join(root, "linkdir")); err != nil { + t.Fatalf("create base symlink: %v", err) + } + + candidate := filepath.Join(root, "subdir", "file.txt") + resolved, err := resolvePathRelativeToBaseWithinRootFS(osFS{}, root, filepath.Join(root, "linkdir"), candidate) + if err != nil { + t.Fatalf("resolvePathRelativeToBaseWithinRootFS returned error: %v", err) + } + if resolved != candidate { + t.Fatalf("resolved path = %q, want %q", resolved, candidate) + } +} diff --git a/internal/orchestrator/pbs_staged_apply.go b/internal/orchestrator/pbs_staged_apply.go index 265910c..2bd91a1 100644 --- a/internal/orchestrator/pbs_staged_apply.go +++ b/internal/orchestrator/pbs_staged_apply.go @@ -359,6 +359,8 @@ type pbsDatastoreInventoryRestoreLite struct { Name string `json:"name"` Path string `json:"path"` Comment string `json:"comment"` + Origin string `json:"origin"` + CLIName string `json:"cli_name"` } `json:"datastores"` } @@ -387,7 +389,14 @@ func loadPBSDatastoreCfgFromInventory(stageRoot string) (string, string, error) // Fallback: generate a minimal datastore.cfg from the inventory's datastore list. var out strings.Builder for _, ds := range report.Datastores { + if strings.TrimSpace(ds.Origin) == "override" { + continue + } + name := strings.TrimSpace(ds.Name) + if name == "" { + name = strings.TrimSpace(ds.CLIName) + } path := strings.TrimSpace(ds.Path) if name == "" || path == "" { continue diff --git a/internal/orchestrator/pbs_staged_apply_additional_test.go b/internal/orchestrator/pbs_staged_apply_additional_test.go index 80cb253..5c0d1d2 100644 --- a/internal/orchestrator/pbs_staged_apply_additional_test.go +++ b/internal/orchestrator/pbs_staged_apply_additional_test.go @@ -280,6 +280,40 @@ func TestLoadPBSDatastoreCfgFromInventory_FallsBackToDatastoreList(t *testing.T) } } +func TestLoadPBSDatastoreCfgFromInventory_IgnoresOverrideEntries(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + stageRoot := "/stage" + inventory := `{"datastores":[{"name":"DS1","path":"/mnt/ds1","comment":"primary","origin":"merged"},{"name":"DS1","path":"/mnt/scan-root","comment":"configured via PBS_DATASTORE_PATH","origin":"override"}]}` + if err := fakeFS.WriteFile(stageRoot+"/var/lib/proxsave-info/commands/pbs/pbs_datastore_inventory.json", []byte(inventory), 0o640); err != nil { + t.Fatalf("write inventory: %v", err) + } + + content, src, err := loadPBSDatastoreCfgFromInventory(stageRoot) + if err != nil { + t.Fatalf("loadPBSDatastoreCfgFromInventory: %v", err) + } + if src != "pbs_datastore_inventory.json.datastores" { + t.Fatalf("src=%q", src) + } + if strings.Contains(content, "/mnt/scan-root") { + t.Fatalf("override entry should not be present in generated datastore.cfg: %s", content) + } + + blocks, err := parsePBSDatastoreCfgBlocks(content) + if err != nil { + t.Fatalf("parsePBSDatastoreCfgBlocks: %v", err) + } + if len(blocks) != 1 || blocks[0].Name != "DS1" || blocks[0].Path != "/mnt/ds1" { + t.Fatalf("unexpected blocks: %+v", blocks) + } +} + func TestLoadPBSDatastoreCfgFromInventory_PropagatesErrors(t *testing.T) { origFS := restoreFS t.Cleanup(func() { restoreFS = origFS }) diff --git a/internal/orchestrator/restore.go b/internal/orchestrator/restore.go index 2f14d20..a390360 100644 --- a/internal/orchestrator/restore.go +++ b/internal/orchestrator/restore.go @@ -1448,6 +1448,10 @@ func runRestoreCommandStream(ctx context.Context, name string, stdin io.Reader, } func sanitizeRestoreEntryTarget(destRoot, entryName string) (string, string, error) { + return sanitizeRestoreEntryTargetWithFS(restoreFS, destRoot, entryName) +} + +func sanitizeRestoreEntryTargetWithFS(fsys FS, destRoot, entryName string) (string, string, error) { cleanDestRoot := filepath.Clean(destRoot) if cleanDestRoot == "" { cleanDestRoot = string(os.PathSeparator) @@ -1490,23 +1494,13 @@ func sanitizeRestoreEntryTarget(destRoot, entryName string) (string, string, err return "", "", fmt.Errorf("illegal path: %s", entryName) } - parentDir := filepath.Dir(absTarget) - resolvedParentDir, err := filepath.EvalSymlinks(parentDir) - if err != nil { - // If the path doesn't exist yet, EvalSymlinks will fail; fallback to the cleaned path. - resolvedParentDir = filepath.Clean(parentDir) - } - absResolvedParentDir, err := filepath.Abs(resolvedParentDir) - if err != nil { - return "", "", fmt.Errorf("resolve extraction parent directory: %w", err) - } - - relParent, err := filepath.Rel(absDestRoot, absResolvedParentDir) - if err != nil { - return "", "", fmt.Errorf("illegal path: %s: %w", entryName, err) - } - if strings.HasPrefix(relParent, ".."+string(os.PathSeparator)) || relParent == ".." || filepath.IsAbs(relParent) { - return "", "", fmt.Errorf("illegal path: %s", entryName) + if _, err := resolvePathWithinRootFS(fsys, absDestRoot, absTarget); err != nil { + if isPathSecurityError(err) { + return "", "", fmt.Errorf("illegal path: %s: %w", entryName, err) + } + if !isPathOperationalError(err) { + return "", "", fmt.Errorf("resolve extraction target: %w", err) + } } return absTarget, absDestRoot, nil @@ -1537,7 +1531,7 @@ func shouldSkipProxmoxSystemRestore(relTarget string) (bool, string) { // extractTarEntry extracts a single TAR entry, preserving all attributes including atime/ctime func extractTarEntry(tarReader *tar.Reader, header *tar.Header, destRoot string, logger *logging.Logger) error { - target, cleanDestRoot, err := sanitizeRestoreEntryTarget(destRoot, header.Name) + target, cleanDestRoot, err := sanitizeRestoreEntryTargetWithFS(restoreFS, destRoot, header.Name) if err != nil { return err } @@ -1647,26 +1641,8 @@ func extractRegularFile(tarReader *tar.Reader, target string, header *tar.Header func extractSymlink(target string, header *tar.Header, destRoot string, logger *logging.Logger) error { linkTarget := header.Linkname - // Pre-validation: ensure the resolved target would stay within destRoot before creating the symlink - relativeTarget, err := filepath.Rel(destRoot, target) - if err != nil { - return fmt.Errorf("determine relative path for symlink %s: %w", target, err) - } - if strings.HasPrefix(relativeTarget, ".."+string(os.PathSeparator)) || relativeTarget == ".." { - return fmt.Errorf("sanitized symlink path escapes root: %s", target) - } - - symlinkArchivePath := path.Clean(filepath.ToSlash(relativeTarget)) - symlinkArchiveDir := path.Dir(symlinkArchivePath) - if symlinkArchiveDir == "." { - symlinkArchiveDir = "" - } - potentialTarget := linkTarget - if !filepath.IsAbs(linkTarget) { - potentialTarget = filepath.FromSlash(path.Join(symlinkArchiveDir, linkTarget)) - } - - if _, err := resolveAndCheckPath(destRoot, potentialTarget); err != nil { + // Pre-validation: ensure the symlink target resolves within destRoot before creation. + if _, err := resolvePathRelativeToBaseWithinRootFS(restoreFS, destRoot, filepath.Dir(target), linkTarget); err != nil { return fmt.Errorf("symlink target escapes root before creation: %s -> %s: %w", header.Name, linkTarget, err) } @@ -1685,32 +1661,9 @@ func extractSymlink(target string, header *tar.Header, destRoot string, logger * return fmt.Errorf("read created symlink %s: %w", target, err) } - // Resolve the symlink target relative to the symlink's directory - symlinkDir := filepath.Dir(target) - resolvedTarget := actualTarget - if !filepath.IsAbs(actualTarget) { - resolvedTarget = filepath.Join(symlinkDir, actualTarget) - } - - // Validate the resolved target stays within destRoot using absolute paths - absDestRoot, err := filepath.Abs(destRoot) - if err != nil { - restoreFS.Remove(target) - return fmt.Errorf("resolve destination root: %w", err) - } - - absResolvedTarget, err := filepath.Abs(resolvedTarget) - if err != nil { - restoreFS.Remove(target) - return fmt.Errorf("resolve symlink target: %w", err) - } - - // Check if resolved target is within destRoot - rel, err := filepath.Rel(absDestRoot, absResolvedTarget) - if err != nil || strings.HasPrefix(rel, ".."+string(os.PathSeparator)) || rel == ".." { + if _, err := resolvePathRelativeToBaseWithinRootFS(restoreFS, destRoot, filepath.Dir(target), actualTarget); err != nil { restoreFS.Remove(target) - return fmt.Errorf("symlink target escapes root after creation: %s -> %s (resolves to %s)", - header.Name, linkTarget, absResolvedTarget) + return fmt.Errorf("symlink target escapes root after creation: %s -> %s: %w", header.Name, actualTarget, err) } // Set ownership (on the symlink itself, not the target) @@ -1733,7 +1686,7 @@ func extractHardlink(target string, header *tar.Header, destRoot string) error { } // Validate the hard link target stays within extraction root - if _, err := resolveAndCheckPath(destRoot, linkName); err != nil { + if _, err := resolvePathWithinRootFS(restoreFS, destRoot, linkName); err != nil { return fmt.Errorf("hardlink target escapes root: %s -> %s: %w", header.Name, linkName, err) } diff --git a/internal/orchestrator/restore_decision.go b/internal/orchestrator/restore_decision.go new file mode 100644 index 0000000..9630114 --- /dev/null +++ b/internal/orchestrator/restore_decision.go @@ -0,0 +1,305 @@ +package orchestrator + +import ( + "archive/tar" + "bufio" + "bytes" + "context" + "fmt" + "io" + "path" + "strings" + + "github.com/tis24dev/proxsave/internal/logging" +) + +// RestoreDecisionSource describes where trusted restore facts were derived from. +type RestoreDecisionSource string + +const ( + RestoreDecisionSourceUnknown RestoreDecisionSource = "unknown" + RestoreDecisionSourceInternalMetadata RestoreDecisionSource = "internal_metadata" + RestoreDecisionSourceCategories RestoreDecisionSource = "categories" + RestoreDecisionSourceAmbiguous RestoreDecisionSource = "ambiguous" +) + +// RestoreDecisionInfo contains the archive-derived facts used for restore decisions. +type RestoreDecisionInfo struct { + BackupType SystemType + ClusterPayload bool + BackupHostname string + Source RestoreDecisionSource +} + +type restoreDecisionMetadata struct { + BackupType SystemType + ClusterMode string + Hostname string +} + +type restoreArchiveInspection struct { + AvailableCategories []Category + Decision *RestoreDecisionInfo +} + +const ( + restoreDecisionMetadataPath = "var/lib/proxsave-info/backup_metadata.txt" + restoreDecisionMetadataMaxBytes = 8 * 1024 +) + +// AnalyzeRestoreArchive inspects the archive once and derives trusted restore facts +// from archive contents plus internal backup metadata when present. +func AnalyzeRestoreArchive(archivePath string, logger *logging.Logger) (categories []Category, decision *RestoreDecisionInfo, err error) { + if logger == nil { + logger = logging.GetDefaultLogger() + } + + done := logging.DebugStart(logger, "analyze restore archive", "archive=%s", archivePath) + defer func() { done(err) }() + logger.Info("Analyzing backup contents...") + + inspection, err := inspectRestoreArchiveContents(archivePath, logger) + if err != nil { + return nil, nil, err + } + + for _, cat := range inspection.AvailableCategories { + logger.Debug("Category available: %s (%s)", cat.ID, cat.Name) + } + logger.Info("Detected %d available categories", len(inspection.AvailableCategories)) + if inspection.Decision != nil { + logger.Debug( + "Restore decision facts: backup_type=%s cluster_payload=%v hostname=%q source=%s", + inspection.Decision.BackupType, + inspection.Decision.ClusterPayload, + inspection.Decision.BackupHostname, + inspection.Decision.Source, + ) + } + + return inspection.AvailableCategories, inspection.Decision, nil +} + +func inspectRestoreArchiveContents(archivePath string, logger *logging.Logger) (inspection *restoreArchiveInspection, err error) { + file, err := restoreFS.Open(archivePath) + if err != nil { + return nil, fmt.Errorf("open archive: %w", err) + } + defer file.Close() + + reader, err := createDecompressionReader(context.Background(), file, archivePath) + if err != nil { + return nil, err + } + if closer, ok := reader.(interface{ Close() error }); ok { + defer func() { + if closeErr := closer.Close(); closeErr != nil && err == nil { + inspection = nil + err = fmt.Errorf("inspect archive: %w", closeErr) + } + }() + } + + tarReader := tar.NewReader(reader) + archivePaths, metadata, metadataErr, collectErr := collectRestoreArchiveFacts(tarReader) + if collectErr != nil { + return nil, fmt.Errorf("inspect archive: %w", collectErr) + } + if metadataErr != nil { + logger.Warning("Could not parse internal backup metadata: %v", metadataErr) + } + + logger.Debug("Found %d entries in archive", len(archivePaths)) + availableCategories := AnalyzeArchivePaths(archivePaths, GetAllCategories()) + + decision := buildRestoreDecisionInfo(metadata, availableCategories, logger) + inspection = &restoreArchiveInspection{ + AvailableCategories: availableCategories, + Decision: decision, + } + return inspection, nil +} + +func collectRestoreArchiveFacts(tarReader *tar.Reader) ([]string, *restoreDecisionMetadata, error, error) { + var ( + archivePaths []string + metadata *restoreDecisionMetadata + metadataErr error + ) + + for { + header, err := tarReader.Next() + if err == io.EOF { + break + } + if err != nil { + return nil, nil, nil, err + } + + archivePaths = append(archivePaths, header.Name) + if metadata != nil || header.FileInfo().IsDir() { + continue + } + if !isRestoreDecisionMetadataEntry(header.Name) { + continue + } + + data, readErr := readRestoreDecisionMetadata(tarReader, header) + if readErr != nil { + metadataErr = readErr + continue + } + parsed, parseErr := parseRestoreDecisionMetadata(data) + if parseErr != nil { + metadataErr = parseErr + continue + } + metadata = parsed + } + + return archivePaths, metadata, metadataErr, nil +} + +func readRestoreDecisionMetadata(tarReader *tar.Reader, header *tar.Header) ([]byte, error) { + if header == nil { + return nil, fmt.Errorf("restore metadata entry is missing a tar header") + } + if header.Typeflag != tar.TypeReg && header.Typeflag != tar.TypeRegA { + return nil, fmt.Errorf("archive entry %s is not a regular file", header.Name) + } + + limited := io.LimitReader(tarReader, restoreDecisionMetadataMaxBytes+1) + data, err := io.ReadAll(limited) + if err != nil { + return nil, err + } + if int64(len(data)) > restoreDecisionMetadataMaxBytes { + size := header.Size + if size <= 0 { + size = int64(len(data)) + } + return nil, fmt.Errorf("archive entry %s too large (%d bytes)", header.Name, size) + } + return data, nil +} + +func isRestoreDecisionMetadataEntry(entryName string) bool { + return normalizeRestoreEntryPath(entryName) == restoreDecisionMetadataPath +} + +func normalizeRestoreEntryPath(entryName string) string { + clean := strings.TrimSpace(strings.ReplaceAll(entryName, "\\", "/")) + if clean == "" { + return "" + } + + clean = path.Clean(clean) + clean = strings.TrimPrefix(clean, "./") + clean = strings.TrimPrefix(clean, "/") + if clean == "." { + return "" + } + return clean +} + +func parseRestoreDecisionMetadata(data []byte) (*restoreDecisionMetadata, error) { + meta := &restoreDecisionMetadata{} + scanner := bufio.NewScanner(bytes.NewReader(data)) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" || strings.HasPrefix(line, "#") { + continue + } + + parts := strings.SplitN(line, "=", 2) + if len(parts) != 2 { + continue + } + + key := strings.TrimSpace(parts[0]) + value := strings.TrimSpace(parts[1]) + switch key { + case "BACKUP_TYPE": + meta.BackupType = parseSystemTypeString(value) + case "PVE_CLUSTER_MODE", "CLUSTER_MODE": + meta.ClusterMode = value + case "HOSTNAME": + meta.Hostname = value + } + } + if err := scanner.Err(); err != nil { + return nil, err + } + return meta, nil +} + +func buildRestoreDecisionInfo(metadata *restoreDecisionMetadata, categories []Category, logger *logging.Logger) *RestoreDecisionInfo { + if logger == nil { + logger = logging.GetDefaultLogger() + } + + info := &RestoreDecisionInfo{ + BackupType: SystemTypeUnknown, + ClusterPayload: hasCategoryID(categories, "pve_cluster"), + Source: RestoreDecisionSourceUnknown, + } + + if metadata != nil { + info.BackupHostname = strings.TrimSpace(metadata.Hostname) + } + + categoryType, ambiguousType := detectBackupTypeFromCategories(categories) + switch { + case ambiguousType: + logger.Warning("Archive contains both PVE and PBS-specific payloads; treating backup type as unknown for compatibility checks") + info.Source = RestoreDecisionSourceAmbiguous + case metadata != nil && metadata.BackupType != SystemTypeUnknown && categoryType == SystemTypeUnknown: + info.BackupType = metadata.BackupType + info.Source = RestoreDecisionSourceInternalMetadata + case metadata != nil && metadata.BackupType != SystemTypeUnknown && categoryType != SystemTypeUnknown && metadata.BackupType != categoryType: + logger.Warning("Internal backup metadata and archive payload disagree on backup type; using archive-derived type %s", strings.ToUpper(string(categoryType))) + info.BackupType = categoryType + info.Source = RestoreDecisionSourceCategories + case categoryType != SystemTypeUnknown: + info.BackupType = categoryType + info.Source = RestoreDecisionSourceCategories + case metadata != nil && metadata.BackupType != SystemTypeUnknown: + info.BackupType = metadata.BackupType + info.Source = RestoreDecisionSourceInternalMetadata + } + + if metadata != nil { + metadataCluster := strings.EqualFold(strings.TrimSpace(metadata.ClusterMode), "cluster") + switch { + case metadataCluster && !info.ClusterPayload: + logger.Warning("Internal backup metadata reports cluster mode, but no pve_cluster payload was found; guarded cluster restore remains disabled") + case !metadataCluster && info.ClusterPayload: + logger.Warning("Cluster payload detected in archive despite metadata reporting non-cluster backup; guarded cluster restore remains enabled") + } + } + + return info +} + +func detectBackupTypeFromCategories(categories []Category) (SystemType, bool) { + var hasPVE, hasPBS bool + for _, cat := range categories { + switch cat.Type { + case CategoryTypePVE: + hasPVE = true + case CategoryTypePBS: + hasPBS = true + } + } + + switch { + case hasPVE && hasPBS: + return SystemTypeUnknown, true + case hasPVE: + return SystemTypePVE, false + case hasPBS: + return SystemTypePBS, false + default: + return SystemTypeUnknown, false + } +} diff --git a/internal/orchestrator/restore_decision_test.go b/internal/orchestrator/restore_decision_test.go new file mode 100644 index 0000000..d049f33 --- /dev/null +++ b/internal/orchestrator/restore_decision_test.go @@ -0,0 +1,281 @@ +package orchestrator + +import ( + "archive/tar" + "bytes" + "context" + "errors" + "io" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/tis24dev/proxsave/internal/logging" +) + +type closeErrorReadCloser struct { + *bytes.Reader + closeErr error +} + +func (r *closeErrorReadCloser) Close() error { + return r.closeErr +} + +type streamCommandRunner struct { + stream io.ReadCloser + calls []string +} + +func (r *streamCommandRunner) Run(context.Context, string, ...string) ([]byte, error) { + return nil, nil +} + +func (r *streamCommandRunner) RunStream(_ context.Context, name string, _ io.Reader, args ...string) (io.ReadCloser, error) { + r.calls = append(r.calls, commandKey(name, args)) + return r.stream, nil +} + +func tarBytes(t *testing.T, files map[string]string) []byte { + t.Helper() + + var buf bytes.Buffer + tw := tar.NewWriter(&buf) + for name, content := range files { + data := []byte(content) + hdr := &tar.Header{ + Name: name, + Mode: 0o640, + Size: int64(len(data)), + } + if err := tw.WriteHeader(hdr); err != nil { + t.Fatalf("WriteHeader(%s): %v", name, err) + } + if _, err := tw.Write(data); err != nil { + t.Fatalf("Write(%s): %v", name, err) + } + } + if err := tw.Close(); err != nil { + t.Fatalf("Close tar writer: %v", err) + } + return buf.Bytes() +} + +func TestAnalyzeRestoreArchive_UsesInternalMetadataWhenCategoriesAreCommonOnly(t *testing.T) { + origRestoreFS := restoreFS + t.Cleanup(func() { restoreFS = origRestoreFS }) + restoreFS = osFS{} + + archivePath := filepath.Join(t.TempDir(), "backup.tar") + if err := writeTarFile(archivePath, map[string]string{ + "etc/hosts": "127.0.0.1 localhost\n", + "var/lib/proxsave-info/backup_metadata.txt": "# ProxSave Metadata\nBACKUP_TYPE=pbs\nHOSTNAME=pbs-node\nPVE_CLUSTER_MODE=cluster\n", + }); err != nil { + t.Fatalf("writeTarFile: %v", err) + } + + logger := logging.New(logging.GetDefaultLogger().GetLevel(), false) + categories, decision, err := AnalyzeRestoreArchive(archivePath, logger) + if err != nil { + t.Fatalf("AnalyzeRestoreArchive() error: %v", err) + } + if backupType, ambiguous := detectBackupTypeFromCategories(categories); backupType != SystemTypeUnknown || ambiguous { + t.Fatalf("detectBackupTypeFromCategories() = (%s, %v); want (%s, false)", backupType, ambiguous, SystemTypeUnknown) + } + if decision == nil { + t.Fatalf("decision info is nil") + } + if decision.BackupType != SystemTypePBS { + t.Fatalf("BackupType=%s; want %s", decision.BackupType, SystemTypePBS) + } + if decision.Source != RestoreDecisionSourceInternalMetadata { + t.Fatalf("Source=%s; want %s", decision.Source, RestoreDecisionSourceInternalMetadata) + } + if decision.BackupHostname != "pbs-node" { + t.Fatalf("BackupHostname=%q; want %q", decision.BackupHostname, "pbs-node") + } + if decision.ClusterPayload { + t.Fatalf("ClusterPayload should stay false without pve_cluster payload") + } +} + +func TestAnalyzeRestoreArchive_ClusterPayloadUsesArchiveContents(t *testing.T) { + origRestoreFS := restoreFS + t.Cleanup(func() { restoreFS = origRestoreFS }) + restoreFS = osFS{} + + archivePath := filepath.Join(t.TempDir(), "backup.tar") + if err := writeTarFile(archivePath, map[string]string{ + "var/lib/pve-cluster/config.db": "db\n", + "var/lib/proxsave-info/backup_metadata.txt": "BACKUP_TYPE=pve\nPVE_CLUSTER_MODE=standalone\nHOSTNAME=node1\n", + }); err != nil { + t.Fatalf("writeTarFile: %v", err) + } + + logger := logging.New(logging.GetDefaultLogger().GetLevel(), false) + _, decision, err := AnalyzeRestoreArchive(archivePath, logger) + if err != nil { + t.Fatalf("AnalyzeRestoreArchive() error: %v", err) + } + if decision == nil { + t.Fatalf("decision info is nil") + } + if !decision.ClusterPayload { + t.Fatalf("ClusterPayload should be true when pve_cluster payload exists") + } + if decision.BackupType != SystemTypePVE { + t.Fatalf("BackupType=%s; want %s", decision.BackupType, SystemTypePVE) + } +} + +func TestCollectRestoreArchiveFacts_RejectsOversizedMetadata(t *testing.T) { + archivePath := filepath.Join(t.TempDir(), "backup.tar") + oversized := "BACKUP_TYPE=pbs\nHOSTNAME=pbs-node\n" + strings.Repeat("A", restoreDecisionMetadataMaxBytes) + if err := writeTarFile(archivePath, map[string]string{ + "var/lib/proxsave-info/backup_metadata.txt": oversized, + "var/lib/pve-cluster/config.db": "db\n", + }); err != nil { + t.Fatalf("writeTarFile: %v", err) + } + + file, err := os.Open(archivePath) + if err != nil { + t.Fatalf("os.Open: %v", err) + } + defer file.Close() + + archivePaths, metadata, metadataErr, err := collectRestoreArchiveFacts(tar.NewReader(file)) + if err != nil { + t.Fatalf("collectRestoreArchiveFacts() error: %v", err) + } + if metadata != nil { + t.Fatalf("metadata = %#v; want nil for oversized entry", metadata) + } + if metadataErr == nil { + t.Fatalf("metadataErr = nil; want oversize error") + } + if !strings.Contains(metadataErr.Error(), "too large") { + t.Fatalf("metadataErr = %v; want oversize error", metadataErr) + } + + foundMeta := false + foundCluster := false + for _, archivePath := range archivePaths { + if archivePath == restoreDecisionMetadataPath { + foundMeta = true + } + if archivePath == "var/lib/pve-cluster/config.db" { + foundCluster = true + } + } + if !foundMeta || !foundCluster { + t.Fatalf("archivePaths = %#v; want metadata and cluster entries present", archivePaths) + } +} + +func TestAnalyzeRestoreArchive_IgnoresOversizedInternalMetadata(t *testing.T) { + origRestoreFS := restoreFS + t.Cleanup(func() { restoreFS = origRestoreFS }) + restoreFS = osFS{} + + archivePath := filepath.Join(t.TempDir(), "backup.tar") + oversized := "BACKUP_TYPE=pbs\nHOSTNAME=pbs-node\n" + strings.Repeat("A", restoreDecisionMetadataMaxBytes) + if err := writeTarFile(archivePath, map[string]string{ + "etc/hosts": "127.0.0.1 localhost\n", + "var/lib/proxsave-info/backup_metadata.txt": oversized, + }); err != nil { + t.Fatalf("writeTarFile: %v", err) + } + + logger := logging.New(logging.GetDefaultLogger().GetLevel(), false) + _, decision, err := AnalyzeRestoreArchive(archivePath, logger) + if err != nil { + t.Fatalf("AnalyzeRestoreArchive() error: %v", err) + } + if decision == nil { + t.Fatalf("decision info is nil") + } + if decision.BackupType != SystemTypeUnknown { + t.Fatalf("BackupType=%s; want %s when metadata is oversized", decision.BackupType, SystemTypeUnknown) + } + if decision.Source != RestoreDecisionSourceUnknown { + t.Fatalf("Source=%s; want %s when metadata is oversized", decision.Source, RestoreDecisionSourceUnknown) + } + if decision.BackupHostname != "" { + t.Fatalf("BackupHostname=%q; want empty string when metadata is oversized", decision.BackupHostname) + } +} + +func TestAnalyzeRestoreArchive_PropagatesDecompressionCloseError(t *testing.T) { + origRestoreFS := restoreFS + origRestoreCmd := restoreCmd + t.Cleanup(func() { + restoreFS = origRestoreFS + restoreCmd = origRestoreCmd + }) + restoreFS = osFS{} + + closeErr := errors.New("decompressor wait failed") + restoreCmd = &streamCommandRunner{ + stream: &closeErrorReadCloser{ + Reader: bytes.NewReader(tarBytes(t, map[string]string{"etc/hosts": "127.0.0.1 localhost\n"})), + closeErr: closeErr, + }, + } + + archivePath := filepath.Join(t.TempDir(), "backup.tar.zst") + if err := os.WriteFile(archivePath, []byte("compressed payload"), 0o600); err != nil { + t.Fatalf("WriteFile: %v", err) + } + + logger := logging.New(logging.GetDefaultLogger().GetLevel(), false) + categories, decision, err := AnalyzeRestoreArchive(archivePath, logger) + if categories != nil { + t.Fatalf("categories = %#v; want nil on close error", categories) + } + if decision != nil { + t.Fatalf("decision = %#v; want nil on close error", decision) + } + if !errors.Is(err, closeErr) { + t.Fatalf("AnalyzeRestoreArchive() err = %v; want close error", err) + } +} + +func TestInspectRestoreArchiveContents_PrefersInspectErrorOverCloseError(t *testing.T) { + origRestoreFS := restoreFS + origRestoreCmd := restoreCmd + t.Cleanup(func() { + restoreFS = origRestoreFS + restoreCmd = origRestoreCmd + }) + restoreFS = osFS{} + + closeErr := errors.New("decompressor wait failed") + restoreCmd = &streamCommandRunner{ + stream: &closeErrorReadCloser{ + Reader: bytes.NewReader([]byte("not a tar archive")), + closeErr: closeErr, + }, + } + + archivePath := filepath.Join(t.TempDir(), "backup.tar.zst") + if err := os.WriteFile(archivePath, []byte("compressed payload"), 0o600); err != nil { + t.Fatalf("WriteFile: %v", err) + } + + logger := logging.New(logging.GetDefaultLogger().GetLevel(), false) + inspection, err := inspectRestoreArchiveContents(archivePath, logger) + if inspection != nil { + t.Fatalf("inspection = %#v; want nil on inspect error", inspection) + } + if err == nil { + t.Fatal("inspectRestoreArchiveContents() err = nil; want inspect error") + } + if errors.Is(err, closeErr) { + t.Fatalf("inspectRestoreArchiveContents() err = %v; want inspect error to take precedence over close error", err) + } + if !strings.Contains(err.Error(), "inspect archive") { + t.Fatalf("inspectRestoreArchiveContents() err = %v; want inspect archive context", err) + } +} diff --git a/internal/orchestrator/restore_errors_test.go b/internal/orchestrator/restore_errors_test.go index cad33a8..0faf024 100644 --- a/internal/orchestrator/restore_errors_test.go +++ b/internal/orchestrator/restore_errors_test.go @@ -798,10 +798,11 @@ type ErrorInjectingFS struct { linkErr error } -func (f *ErrorInjectingFS) Stat(path string) (os.FileInfo, error) { return f.base.Stat(path) } -func (f *ErrorInjectingFS) ReadFile(path string) ([]byte, error) { return f.base.ReadFile(path) } -func (f *ErrorInjectingFS) Open(path string) (*os.File, error) { return f.base.Open(path) } -func (f *ErrorInjectingFS) Create(name string) (*os.File, error) { return f.base.Create(name) } +func (f *ErrorInjectingFS) Stat(path string) (os.FileInfo, error) { return f.base.Stat(path) } +func (f *ErrorInjectingFS) Lstat(path string) (os.FileInfo, error) { return f.base.Lstat(path) } +func (f *ErrorInjectingFS) ReadFile(path string) ([]byte, error) { return f.base.ReadFile(path) } +func (f *ErrorInjectingFS) Open(path string) (*os.File, error) { return f.base.Open(path) } +func (f *ErrorInjectingFS) Create(name string) (*os.File, error) { return f.base.Create(name) } func (f *ErrorInjectingFS) WriteFile(path string, data []byte, perm os.FileMode) error { return f.base.WriteFile(path, data, perm) } diff --git a/internal/orchestrator/restore_firewall.go b/internal/orchestrator/restore_firewall.go index eeb5daa..1c2f5fb 100644 --- a/internal/orchestrator/restore_firewall.go +++ b/internal/orchestrator/restore_firewall.go @@ -673,12 +673,19 @@ func syncDirExact(srcDir, destDir string) ([]string, error) { if err != nil { return fmt.Errorf("readlink %s: %w", src, err) } + validatedTarget, err := validateOverlaySymlinkTargetWithinRoot(destDir, dest, target) + if err != nil { + return fmt.Errorf("unsafe symlink target %s -> %s: %w", dest, target, err) + } if err := ensureDirExistsWithInheritedMeta(filepath.Dir(dest)); err != nil { return fmt.Errorf("ensure %s: %w", filepath.Dir(dest), err) } _ = restoreFS.Remove(dest) - if err := restoreFS.Symlink(target, dest); err != nil { - return fmt.Errorf("symlink %s -> %s: %w", dest, target, err) + if err := restoreFS.Symlink(validatedTarget, dest); err != nil { + return fmt.Errorf("symlink %s -> %s: %w", dest, validatedTarget, err) + } + if err := verifyCreatedSymlinkWithinRootFS(restoreFS, destDir, dest); err != nil { + return err } applied = append(applied, dest) continue diff --git a/internal/orchestrator/restore_firewall_additional_test.go b/internal/orchestrator/restore_firewall_additional_test.go index 525c269..97e007b 100644 --- a/internal/orchestrator/restore_firewall_additional_test.go +++ b/internal/orchestrator/restore_firewall_additional_test.go @@ -372,14 +372,14 @@ func TestSyncDirExact_CopiesSymlinks(t *testing.T) { origFS := restoreFS t.Cleanup(func() { restoreFS = origFS }) - fakeFS := NewFakeFS() + fakeFS := newPreservingSymlinkFS() t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) restoreFS = fakeFS - if err := fakeFS.AddFile("/stage/target", []byte("x")); err != nil { - t.Fatalf("add target: %v", err) + if err := fakeFS.AddDir("/stage"); err != nil { + t.Fatalf("add stage dir: %v", err) } - if err := fakeFS.Symlink("/stage/target", "/stage/link"); err != nil { + if err := fakeFS.Symlink("/dest/target", "/stage/link"); err != nil { t.Fatalf("add symlink: %v", err) } @@ -392,8 +392,8 @@ func TestSyncDirExact_CopiesSymlinks(t *testing.T) { if err != nil { t.Fatalf("read dest symlink: %v", err) } - if strings.TrimSpace(destTarget) == "" { - t.Fatalf("expected non-empty symlink target") + if destTarget != "target" { + t.Fatalf("dest symlink target=%q want %q", destTarget, "target") } found := false for _, p := range applied { @@ -407,6 +407,70 @@ func TestSyncDirExact_CopiesSymlinks(t *testing.T) { } } +func TestSyncDirExact_RejectsEscapingSymlinkTargets(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := newPreservingSymlinkFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + if err := fakeFS.AddDir("/stage"); err != nil { + t.Fatalf("add stage dir: %v", err) + } + if err := fakeFS.Symlink("/stage/target", "/stage/link"); err != nil { + t.Fatalf("add symlink: %v", err) + } + + if _, err := syncDirExact("/stage", "/dest"); err == nil || !strings.Contains(err.Error(), "unsafe symlink target") { + t.Fatalf("expected unsafe symlink target error, got %v", err) + } +} + +func TestSyncDirExact_RejectsRelativeSymlinkTargetsThatEscapeDestinationRoot(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := newPreservingSymlinkFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + + if err := fakeFS.AddDir("/stage/sub"); err != nil { + t.Fatalf("add stage dir: %v", err) + } + if err := fakeFS.Symlink("../../outside", "/stage/sub/link"); err != nil { + t.Fatalf("add symlink: %v", err) + } + + if _, err := syncDirExact("/stage", "/dest"); err == nil || !strings.Contains(err.Error(), "unsafe symlink target") { + t.Fatalf("expected unsafe symlink target error, got %v", err) + } +} + +func TestSyncDirExact_CleansUpWhenCreatedSymlinkReadbackFails(t *testing.T) { + origFS := restoreFS + t.Cleanup(func() { restoreFS = origFS }) + + fakeFS := newPreservingSymlinkFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + + if err := fakeFS.AddDir("/stage"); err != nil { + t.Fatalf("add stage dir: %v", err) + } + if err := fakeFS.Symlink("/dest/target", "/stage/link"); err != nil { + t.Fatalf("add symlink: %v", err) + } + + restoreFS = readlinkFailFS{FS: fakeFS, failPath: "/dest/link", err: fmt.Errorf("boom")} + + if _, err := syncDirExact("/stage", "/dest"); err == nil || !strings.Contains(err.Error(), "read created symlink /dest/link") { + t.Fatalf("expected post-create readlink error, got %v", err) + } + if _, err := fakeFS.Lstat("/dest/link"); !os.IsNotExist(err) { + t.Fatalf("expected created symlink cleanup, lstat err = %v", err) + } +} + func TestSelectStageHostFirewall_ErrorsOnReadDirFailure(t *testing.T) { origFS := restoreFS t.Cleanup(func() { restoreFS = origFS }) @@ -1975,13 +2039,13 @@ func TestSyncDirExact_AdditionalEdgeCases(t *testing.T) { }) t.Run("symlink parent ensure error bubbles", func(t *testing.T) { - fakeFS := NewFakeFS() + fakeFS := newPreservingSymlinkFS() t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) if err := fakeFS.AddDir("/stage"); err != nil { t.Fatalf("add stage dir: %v", err) } - if err := fakeFS.Symlink("/stage/target", "/stage/link"); err != nil { + if err := fakeFS.Symlink("target", "/stage/link"); err != nil { t.Fatalf("add stage symlink: %v", err) } if err := fakeFS.AddDir("/dest"); err != nil { @@ -1995,13 +2059,13 @@ func TestSyncDirExact_AdditionalEdgeCases(t *testing.T) { }) t.Run("symlink creation error bubbles", func(t *testing.T) { - fakeFS := NewFakeFS() + fakeFS := newPreservingSymlinkFS() t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) if err := fakeFS.AddDir("/stage"); err != nil { t.Fatalf("add stage dir: %v", err) } - if err := fakeFS.Symlink("/stage/target", "/stage/link"); err != nil { + if err := fakeFS.Symlink("target", "/stage/link"); err != nil { t.Fatalf("add stage symlink: %v", err) } diff --git a/internal/orchestrator/restore_plan.go b/internal/orchestrator/restore_plan.go index 6d54716..6eba614 100644 --- a/internal/orchestrator/restore_plan.go +++ b/internal/orchestrator/restore_plan.go @@ -1,11 +1,5 @@ package orchestrator -import ( - "strings" - - "github.com/tis24dev/proxsave/internal/backup" -) - // RestorePlan contains a pure, side-effect-free description of a restore run. type RestorePlan struct { Mode RestoreMode @@ -22,7 +16,7 @@ type RestorePlan struct { // PlanRestore computes the restore plan without performing any I/O or prompts. func PlanRestore( - manifest *backup.Manifest, + clusterBackup bool, selectedCategories []Category, systemType SystemType, mode RestoreMode, @@ -35,7 +29,7 @@ func PlanRestore( NormalCategories: normal, StagedCategories: staged, ExportCategories: export, - ClusterBackup: manifest != nil && strings.EqualFold(strings.TrimSpace(manifest.ClusterMode), "cluster"), + ClusterBackup: clusterBackup, } plan.NeedsClusterRestore = systemType == SystemTypePVE && hasCategoryID(normal, "pve_cluster") diff --git a/internal/orchestrator/restore_plan_test.go b/internal/orchestrator/restore_plan_test.go index c38b562..f27ba99 100644 --- a/internal/orchestrator/restore_plan_test.go +++ b/internal/orchestrator/restore_plan_test.go @@ -2,16 +2,13 @@ package orchestrator import ( "testing" - - "github.com/tis24dev/proxsave/internal/backup" ) func TestPlanRestoreClusterSafeToggle(t *testing.T) { clusterCat := Category{ID: "pve_cluster", Type: CategoryTypePVE} storageCat := Category{ID: "storage_pve", Type: CategoryTypePVE} - manifest := &backup.Manifest{ClusterMode: "cluster"} - plan := PlanRestore(manifest, []Category{clusterCat, storageCat}, SystemTypePVE, RestoreModeCustom) + plan := PlanRestore(true, []Category{clusterCat, storageCat}, SystemTypePVE, RestoreModeCustom) if !plan.NeedsClusterRestore { t.Fatalf("expected NeedsClusterRestore true") @@ -50,7 +47,7 @@ func TestPlanRestorePBSCategories(t *testing.T) { pbsCat := Category{ID: "pbs_config", Type: CategoryTypePBS, ExportOnly: true} normalCat := Category{ID: "network", Type: CategoryTypeCommon} - plan := PlanRestore(nil, []Category{pbsCat, normalCat}, SystemTypePBS, RestoreModeCustom) + plan := PlanRestore(false, []Category{pbsCat, normalCat}, SystemTypePBS, RestoreModeCustom) if len(plan.ExportCategories) != 1 || !hasCategoryID(plan.ExportCategories, "pbs_config") { t.Fatalf("expected pbs_config to be exported, got %+v", plan.ExportCategories) } @@ -66,7 +63,7 @@ func TestPlanRestoreKeepsExportCategoriesFromFullSelection(t *testing.T) { exportCat := Category{ID: "pve_config_export", ExportOnly: true} normalCat := Category{ID: "network"} - plan := PlanRestore(nil, []Category{normalCat, exportCat}, SystemTypePVE, RestoreModeFull) + plan := PlanRestore(false, []Category{normalCat, exportCat}, SystemTypePVE, RestoreModeFull) if len(plan.StagedCategories) != 1 || plan.StagedCategories[0].ID != "network" { t.Fatalf("expected staged categories to keep network, got %+v", plan.StagedCategories) } diff --git a/internal/orchestrator/restore_test.go b/internal/orchestrator/restore_test.go index 80a1f10..d65e449 100644 --- a/internal/orchestrator/restore_test.go +++ b/internal/orchestrator/restore_test.go @@ -334,6 +334,155 @@ func TestExtractSymlink_SecurityValidation(t *testing.T) { } } +func TestExtractTarEntry_RejectsBrokenIntermediateSymlinkEscape(t *testing.T) { + logger := logging.New(types.LogLevelDebug, false) + orig := restoreFS + restoreFS = osFS{} + t.Cleanup(func() { restoreFS = orig }) + + destRoot := t.TempDir() + outside := t.TempDir() + if err := os.Symlink(outside, filepath.Join(destRoot, "escape-link")); err != nil { + t.Fatalf("create escape symlink: %v", err) + } + + header := &tar.Header{ + Name: "escape-link/missing", + Typeflag: tar.TypeDir, + Mode: 0o755, + } + + var tr *tar.Reader + err := extractTarEntry(tr, header, destRoot, logger) + if err == nil || !strings.Contains(err.Error(), "illegal path") { + t.Fatalf("expected illegal path error, got %v", err) + } + + if _, err := os.Stat(filepath.Join(outside, "missing")); !os.IsNotExist(err) { + t.Fatalf("outside path should not be created, got err=%v", err) + } +} + +func TestSanitizeRestoreEntryTargetWithFS_AllowsOperationalResolverErrorsWithinRoot(t *testing.T) { + fsys := NewFakeFS() + destRoot := filepath.Join(string(os.PathSeparator), "restore-root") + if err := fsys.AddDir(destRoot); err != nil { + t.Fatalf("add root dir: %v", err) + } + if err := fsys.AddDir(filepath.Join(destRoot, "subdir")); err != nil { + t.Fatalf("add subdir: %v", err) + } + fsys.StatErrors[filepath.Clean(filepath.Join(destRoot, "subdir", "file.txt"))] = os.ErrPermission + + target, cleanRoot, err := sanitizeRestoreEntryTargetWithFS(fsys, destRoot, filepath.Join("subdir", "file.txt")) + if err != nil { + t.Fatalf("sanitizeRestoreEntryTargetWithFS returned error: %v", err) + } + + wantTarget := filepath.Join(destRoot, "subdir", "file.txt") + if target != wantTarget { + t.Fatalf("target = %q, want %q", target, wantTarget) + } + if cleanRoot != filepath.Clean(destRoot) { + t.Fatalf("cleanRoot = %q, want %q", cleanRoot, filepath.Clean(destRoot)) + } +} + +func TestExtractSymlink_RejectsBrokenIntermediateSymlinkEscape(t *testing.T) { + logger := logging.New(types.LogLevelDebug, false) + orig := restoreFS + restoreFS = osFS{} + t.Cleanup(func() { restoreFS = orig }) + + destRoot := t.TempDir() + outside := t.TempDir() + if err := os.Symlink(outside, filepath.Join(destRoot, "escape-link")); err != nil { + t.Fatalf("create escape symlink: %v", err) + } + + header := &tar.Header{ + Name: "link", + Typeflag: tar.TypeSymlink, + Linkname: "escape-link/missing/file.txt", + } + target := filepath.Join(destRoot, header.Name) + + err := extractSymlink(target, header, destRoot, logger) + if err == nil || !strings.Contains(err.Error(), "escapes root") { + t.Fatalf("expected escapes root error, got %v", err) + } + + if _, err := os.Lstat(target); !os.IsNotExist(err) { + t.Fatalf("symlink should not be created, got err=%v", err) + } +} + +func TestExtractSymlink_RejectsEscapeWhenParentPathIsSymlink(t *testing.T) { + logger := logging.New(types.LogLevelDebug, false) + orig := restoreFS + restoreFS = osFS{} + t.Cleanup(func() { restoreFS = orig }) + + destRoot := t.TempDir() + if err := os.Symlink(".", filepath.Join(destRoot, "linkdir")); err != nil { + t.Fatalf("create parent symlink: %v", err) + } + + header := &tar.Header{ + Name: "linkdir/escape", + Typeflag: tar.TypeSymlink, + Linkname: "../outside", + } + target := filepath.Join(destRoot, header.Name) + + err := extractSymlink(target, header, destRoot, logger) + if err == nil || !strings.Contains(err.Error(), "escapes root") { + t.Fatalf("expected escapes root error, got %v", err) + } + + if _, err := os.Lstat(filepath.Join(destRoot, "escape")); !os.IsNotExist(err) { + t.Fatalf("escaping symlink should not be created, got err=%v", err) + } +} + +func TestExtractSymlink_AllowsSafeTargetWhenParentPathIsSymlink(t *testing.T) { + logger := logging.New(types.LogLevelDebug, false) + orig := restoreFS + restoreFS = osFS{} + t.Cleanup(func() { restoreFS = orig }) + + destRoot := t.TempDir() + if err := os.MkdirAll(filepath.Join(destRoot, "subdir"), 0o755); err != nil { + t.Fatalf("mkdir subdir: %v", err) + } + if err := os.Symlink("subdir", filepath.Join(destRoot, "linkdir")); err != nil { + t.Fatalf("create parent symlink: %v", err) + } + + header := &tar.Header{ + Name: "linkdir/ok", + Typeflag: tar.TypeSymlink, + Linkname: "file.txt", + } + target := filepath.Join(destRoot, header.Name) + + if err := extractSymlink(target, header, destRoot, logger); err != nil { + t.Fatalf("safe symlink should succeed: %v", err) + } + + linkTarget, err := os.Readlink(target) + if err != nil { + t.Fatalf("safe symlink should exist: %v", err) + } + if linkTarget != "file.txt" { + t.Fatalf("symlink target = %q, want %q", linkTarget, "file.txt") + } + + if _, err := os.Lstat(filepath.Join(destRoot, "subdir", "ok")); err != nil { + t.Fatalf("expected symlink at resolved parent path: %v", err) + } +} + func TestExtractTarEntry_DoesNotFollowExistingSymlinkTargetPath(t *testing.T) { logger := logging.New(types.LogLevelDebug, false) orig := restoreFS diff --git a/internal/orchestrator/restore_workflow_decision_test.go b/internal/orchestrator/restore_workflow_decision_test.go new file mode 100644 index 0000000..9f07c80 --- /dev/null +++ b/internal/orchestrator/restore_workflow_decision_test.go @@ -0,0 +1,315 @@ +package orchestrator + +import ( + "context" + "errors" + "os" + "path/filepath" + "testing" + "time" + + "github.com/tis24dev/proxsave/internal/backup" + "github.com/tis24dev/proxsave/internal/config" + "github.com/tis24dev/proxsave/internal/logging" + "github.com/tis24dev/proxsave/internal/types" +) + +func stubPreparedRestoreBundle(archivePath string, manifest *backup.Manifest) func(context.Context, *config.Config, *logging.Logger, string, RestoreWorkflowUI) (*decryptCandidate, *preparedBundle, error) { + return func(ctx context.Context, cfg *config.Config, logger *logging.Logger, version string, ui RestoreWorkflowUI) (*decryptCandidate, *preparedBundle, error) { + return &decryptCandidate{ + DisplayBase: "test", + Manifest: manifest, + }, &preparedBundle{ + ArchivePath: archivePath, + Manifest: backup.Manifest{ArchivePath: archivePath}, + cleanup: func() {}, + }, nil + } +} + +func TestRunRestoreWorkflow_ClusterPromptUsesArchivePayloadNotManifest(t *testing.T) { + origRestoreFS := restoreFS + origRestoreCmd := restoreCmd + origRestoreSystem := restoreSystem + origCompatFS := compatFS + origPrepare := prepareRestoreBundleFunc + origSafetyFS := safetyFS + t.Cleanup(func() { + restoreFS = origRestoreFS + restoreCmd = origRestoreCmd + restoreSystem = origRestoreSystem + compatFS = origCompatFS + prepareRestoreBundleFunc = origPrepare + safetyFS = origSafetyFS + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + compatFS = fakeFS + safetyFS = fakeFS + restoreCmd = runOnlyRunner{} + restoreSystem = fakeSystemDetector{systemType: SystemTypePVE} + + if err := fakeFS.AddFile("/usr/bin/qm", []byte("x")); err != nil { + t.Fatalf("fakeFS.AddFile: %v", err) + } + + tmpTar := filepath.Join(t.TempDir(), "bundle.tar") + if err := writeTarFile(tmpTar, map[string]string{ + "etc/hosts": "127.0.0.1 localhost\n", + "var/lib/pve-cluster/config.db": "db\n", + }); err != nil { + t.Fatalf("writeTarFile: %v", err) + } + tarBytes, err := os.ReadFile(tmpTar) + if err != nil { + t.Fatalf("ReadFile tar: %v", err) + } + if err := fakeFS.WriteFile("/bundle.tar", tarBytes, 0o640); err != nil { + t.Fatalf("fakeFS.WriteFile: %v", err) + } + + prepareRestoreBundleFunc = stubPreparedRestoreBundle("/bundle.tar", &backup.Manifest{ + CreatedAt: time.Unix(1700000000, 0), + ClusterMode: "standalone", + ProxmoxType: "pve", + ScriptVersion: "vtest", + }) + + logger := logging.New(types.LogLevelError, false) + cfg := &config.Config{BaseDir: "/base"} + ui := &fakeRestoreWorkflowUI{ + mode: RestoreModeCustom, + categories: []Category{ + mustCategoryByID(t, "pve_cluster"), + }, + confirmRestore: true, + clusterMode: ClusterRestoreSafe, + } + + if err := runRestoreWorkflowWithUI(context.Background(), cfg, logger, "vtest", ui); err != nil { + t.Fatalf("runRestoreWorkflowWithUI error: %v", err) + } + if ui.clusterRestoreModeCalls != 1 { + t.Fatalf("clusterRestoreModeCalls=%d; want 1", ui.clusterRestoreModeCalls) + } +} + +func TestRunRestoreWorkflow_CompatibilityUsesArchivePayloadNotManifest(t *testing.T) { + origRestoreFS := restoreFS + origRestoreCmd := restoreCmd + origRestoreSystem := restoreSystem + origCompatFS := compatFS + origPrepare := prepareRestoreBundleFunc + origSafetyFS := safetyFS + t.Cleanup(func() { + restoreFS = origRestoreFS + restoreCmd = origRestoreCmd + restoreSystem = origRestoreSystem + compatFS = origCompatFS + prepareRestoreBundleFunc = origPrepare + safetyFS = origSafetyFS + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + compatFS = fakeFS + safetyFS = fakeFS + restoreCmd = runOnlyRunner{} + restoreSystem = fakeSystemDetector{systemType: SystemTypePVE} + + if err := fakeFS.AddFile("/usr/bin/qm", []byte("x")); err != nil { + t.Fatalf("fakeFS.AddFile: %v", err) + } + + tmpTar := filepath.Join(t.TempDir(), "bundle.tar") + if err := writeTarFile(tmpTar, map[string]string{ + "etc/ssh/sshd_config": "Port 22\n", + "etc/pve/jobs.cfg": "jobs\n", + }); err != nil { + t.Fatalf("writeTarFile: %v", err) + } + tarBytes, err := os.ReadFile(tmpTar) + if err != nil { + t.Fatalf("ReadFile tar: %v", err) + } + if err := fakeFS.WriteFile("/bundle.tar", tarBytes, 0o640); err != nil { + t.Fatalf("fakeFS.WriteFile: %v", err) + } + + prepareRestoreBundleFunc = stubPreparedRestoreBundle("/bundle.tar", &backup.Manifest{ + CreatedAt: time.Unix(1700000000, 0), + ClusterMode: "standalone", + ProxmoxType: "pbs", + ScriptVersion: "vtest", + }) + + logger := logging.New(types.LogLevelError, false) + cfg := &config.Config{BaseDir: "/base"} + ui := &fakeRestoreWorkflowUI{ + mode: RestoreModeCustom, + categories: []Category{ + mustCategoryByID(t, "ssh"), + }, + confirmRestore: true, + confirmCompatible: false, + } + + if err := runRestoreWorkflowWithUI(context.Background(), cfg, logger, "vtest", ui); err != nil { + t.Fatalf("runRestoreWorkflowWithUI error: %v", err) + } + if ui.confirmCompatibilityCalls != 0 { + t.Fatalf("confirmCompatibilityCalls=%d; want 0", ui.confirmCompatibilityCalls) + } +} + +func TestRunRestoreWorkflow_CompatibilityWarnsOnArchiveMismatchDespiteManifest(t *testing.T) { + origRestoreFS := restoreFS + origRestoreCmd := restoreCmd + origRestoreSystem := restoreSystem + origCompatFS := compatFS + origPrepare := prepareRestoreBundleFunc + origAnalyze := analyzeRestoreArchiveFunc + origSafetyFS := safetyFS + t.Cleanup(func() { + restoreFS = origRestoreFS + restoreCmd = origRestoreCmd + restoreSystem = origRestoreSystem + compatFS = origCompatFS + prepareRestoreBundleFunc = origPrepare + analyzeRestoreArchiveFunc = origAnalyze + safetyFS = origSafetyFS + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + compatFS = fakeFS + safetyFS = fakeFS + restoreCmd = runOnlyRunner{} + restoreSystem = fakeSystemDetector{systemType: SystemTypePVE} + + if err := fakeFS.AddFile("/usr/bin/qm", []byte("x")); err != nil { + t.Fatalf("fakeFS.AddFile: %v", err) + } + + tmpTar := filepath.Join(t.TempDir(), "bundle.tar") + if err := writeTarFile(tmpTar, map[string]string{ + "etc/ssh/sshd_config": "Port 22\n", + "etc/proxmox-backup/sync.cfg": "sync\n", + }); err != nil { + t.Fatalf("writeTarFile: %v", err) + } + tarBytes, err := os.ReadFile(tmpTar) + if err != nil { + t.Fatalf("ReadFile tar: %v", err) + } + if err := fakeFS.WriteFile("/bundle.tar", tarBytes, 0o640); err != nil { + t.Fatalf("fakeFS.WriteFile: %v", err) + } + + prepareRestoreBundleFunc = stubPreparedRestoreBundle("/bundle.tar", &backup.Manifest{ + CreatedAt: time.Unix(1700000000, 0), + ClusterMode: "standalone", + ProxmoxType: "pve", + ScriptVersion: "vtest", + }) + + logger := logging.New(types.LogLevelError, false) + cfg := &config.Config{BaseDir: "/base"} + ui := &fakeRestoreWorkflowUI{ + mode: RestoreModeCustom, + categories: []Category{ + mustCategoryByID(t, "ssh"), + }, + confirmRestore: true, + confirmCompatible: true, + } + + if err := runRestoreWorkflowWithUI(context.Background(), cfg, logger, "vtest", ui); err != nil { + t.Fatalf("runRestoreWorkflowWithUI error: %v", err) + } + if ui.confirmCompatibilityCalls != 1 { + t.Fatalf("confirmCompatibilityCalls=%d; want 1", ui.confirmCompatibilityCalls) + } + if ui.lastCompatibilityWarning == nil { + t.Fatalf("expected compatibility warning to be passed to UI") + } +} + +func TestRunRestoreWorkflow_CompatibilityWarningStillRunsBeforeFullFallbackOnAnalysisError(t *testing.T) { + origRestoreFS := restoreFS + origRestoreCmd := restoreCmd + origRestoreSystem := restoreSystem + origCompatFS := compatFS + origPrepare := prepareRestoreBundleFunc + origAnalyze := analyzeRestoreArchiveFunc + origSafetyFS := safetyFS + t.Cleanup(func() { + restoreFS = origRestoreFS + restoreCmd = origRestoreCmd + restoreSystem = origRestoreSystem + compatFS = origCompatFS + prepareRestoreBundleFunc = origPrepare + analyzeRestoreArchiveFunc = origAnalyze + safetyFS = origSafetyFS + }) + + fakeFS := NewFakeFS() + t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) }) + restoreFS = fakeFS + compatFS = fakeFS + safetyFS = fakeFS + restoreCmd = runOnlyRunner{} + restoreSystem = fakeSystemDetector{systemType: SystemTypePVE} + + if err := fakeFS.AddFile("/usr/bin/qm", []byte("x")); err != nil { + t.Fatalf("fakeFS.AddFile: %v", err) + } + + tmpTar := filepath.Join(t.TempDir(), "bundle.tar") + if err := writeTarFile(tmpTar, map[string]string{ + "etc/hosts": "127.0.0.1 localhost\n", + }); err != nil { + t.Fatalf("writeTarFile: %v", err) + } + tarBytes, err := os.ReadFile(tmpTar) + if err != nil { + t.Fatalf("ReadFile tar: %v", err) + } + if err := fakeFS.WriteFile("/bundle.tar", tarBytes, 0o640); err != nil { + t.Fatalf("fakeFS.WriteFile: %v", err) + } + + prepareRestoreBundleFunc = stubPreparedRestoreBundle("/bundle.tar", &backup.Manifest{ + CreatedAt: time.Unix(1700000000, 0), + ClusterMode: "standalone", + ProxmoxType: "pbs", + ScriptVersion: "vtest", + }) + analyzeRestoreArchiveFunc = func(archivePath string, logger *logging.Logger) ([]Category, *RestoreDecisionInfo, error) { + return nil, nil, errors.New("boom") + } + + logger := logging.New(types.LogLevelError, false) + cfg := &config.Config{BaseDir: "/base"} + ui := &fakeRestoreWorkflowUI{ + confirmRestore: true, + confirmCompatible: true, + } + + if err := runRestoreWorkflowWithUI(context.Background(), cfg, logger, "vtest", ui); err != nil { + t.Fatalf("runRestoreWorkflowWithUI error: %v", err) + } + if ui.confirmCompatibilityCalls != 1 { + t.Fatalf("confirmCompatibilityCalls=%d; want 1", ui.confirmCompatibilityCalls) + } + if ui.lastCompatibilityWarning == nil { + t.Fatalf("expected compatibility warning before fallback") + } + if _, err := fakeFS.ReadFile("/etc/hosts"); err != nil { + t.Fatalf("expected full restore fallback to extract /etc/hosts: %v", err) + } +} diff --git a/internal/orchestrator/restore_workflow_more_test.go b/internal/orchestrator/restore_workflow_more_test.go index 41fa569..e60eb56 100644 --- a/internal/orchestrator/restore_workflow_more_test.go +++ b/internal/orchestrator/restore_workflow_more_test.go @@ -321,7 +321,8 @@ func TestRunRestoreWorkflow_IncompatibilityAndSafetyBackupFailureCanContinue(t * tmpTar := filepath.Join(t.TempDir(), "bundle.tar") if err := writeTarFile(tmpTar, map[string]string{ - "etc/hosts": "127.0.0.1 localhost\n", + "etc/hosts": "127.0.0.1 localhost\n", + "etc/proxmox-backup/sync.cfg": "sync\n", }); err != nil { t.Fatalf("writeTarFile: %v", err) } diff --git a/internal/orchestrator/restore_workflow_ui.go b/internal/orchestrator/restore_workflow_ui.go index 9ce2057..48e5ab5 100644 --- a/internal/orchestrator/restore_workflow_ui.go +++ b/internal/orchestrator/restore_workflow_ui.go @@ -11,12 +11,26 @@ import ( "strings" "time" + "github.com/tis24dev/proxsave/internal/backup" "github.com/tis24dev/proxsave/internal/config" "github.com/tis24dev/proxsave/internal/input" "github.com/tis24dev/proxsave/internal/logging" ) var prepareRestoreBundleFunc = prepareRestoreBundleWithUI +var analyzeRestoreArchiveFunc = AnalyzeRestoreArchive + +func fallbackRestoreDecisionInfoFromManifest(manifest *backup.Manifest) *RestoreDecisionInfo { + info := &RestoreDecisionInfo{Source: RestoreDecisionSourceUnknown} + if manifest == nil { + return info + } + + info.BackupType = DetectBackupType(manifest) + info.ClusterPayload = strings.EqualFold(strings.TrimSpace(manifest.ClusterMode), "cluster") + info.BackupHostname = strings.TrimSpace(manifest.Hostname) + return info +} func prepareRestoreBundleWithUI(ctx context.Context, cfg *config.Config, logger *logging.Logger, version string, ui RestoreWorkflowUI) (*decryptCandidate, *preparedBundle, error) { candidate, err := selectBackupCandidateWithUI(ctx, ui, cfg, logger, false) @@ -76,7 +90,19 @@ func runRestoreWorkflowWithUI(ctx context.Context, cfg *config.Config, logger *l systemType := restoreSystem.DetectCurrentSystem() logger.Info("Detected system type: %s", GetSystemTypeString(systemType)) - if warn := ValidateCompatibility(candidate.Manifest); warn != nil { + availableCategories, decisionInfo, err := analyzeRestoreArchiveFunc(prepared.ArchivePath, logger) + fallbackToFullRestore := false + if err != nil { + logger.Warning("Could not analyze categories: %v", err) + availableCategories = nil + decisionInfo = fallbackRestoreDecisionInfoFromManifest(candidate.Manifest) + fallbackToFullRestore = true + } + if decisionInfo == nil { + decisionInfo = &RestoreDecisionInfo{} + } + + if warn := ValidateCompatibility(systemType, decisionInfo.BackupType); warn != nil { logger.Warning("Compatibility check: %v", warn) proceed, perr := ui.ConfirmCompatibility(ctx, warn) if perr != nil { @@ -86,11 +112,7 @@ func runRestoreWorkflowWithUI(ctx context.Context, cfg *config.Config, logger *l return ErrRestoreAborted } } - - logger.Info("Analyzing backup contents...") - availableCategories, err := AnalyzeBackupCategories(prepared.ArchivePath, logger) - if err != nil { - logger.Warning("Could not analyze categories: %v", err) + if fallbackToFullRestore { logger.Info("Falling back to full restore mode") return runFullRestoreWithUI(ctx, ui, candidate, prepared, destRoot, logger, cfg.DryRun) } @@ -127,7 +149,7 @@ func runRestoreWorkflowWithUI(ctx context.Context, cfg *config.Config, logger *l } } - plan := PlanRestore(candidate.Manifest, selectedCategories, systemType, mode) + plan := PlanRestore(decisionInfo.ClusterPayload, selectedCategories, systemType, mode) if plan.SystemType == SystemTypePBS && (plan.HasCategoryID("pbs_host") || @@ -145,9 +167,8 @@ func runRestoreWorkflowWithUI(ctx context.Context, cfg *config.Config, logger *l logger.Info("PBS restore behavior: %s", behavior.DisplayName()) } - clusterBackup := strings.EqualFold(strings.TrimSpace(candidate.Manifest.ClusterMode), "cluster") - if plan.NeedsClusterRestore && clusterBackup { - logger.Info("Backup marked as cluster node; enabling guarded restore options for pve_cluster") + if plan.NeedsClusterRestore && plan.ClusterBackup { + logger.Info("Cluster payload detected in backup; enabling guarded restore options for pve_cluster") choice, promptErr := ui.SelectClusterRestoreMode(ctx) if promptErr != nil { return promptErr @@ -168,8 +189,8 @@ func runRestoreWorkflowWithUI(ctx context.Context, cfg *config.Config, logger *l if plan.HasCategoryID("pve_access_control") || plan.HasCategoryID("pbs_access_control") { currentHost, hostErr := os.Hostname() - if hostErr == nil && strings.TrimSpace(candidate.Manifest.Hostname) != "" && strings.TrimSpace(currentHost) != "" { - backupHost := strings.TrimSpace(candidate.Manifest.Hostname) + if hostErr == nil && strings.TrimSpace(decisionInfo.BackupHostname) != "" && strings.TrimSpace(currentHost) != "" { + backupHost := strings.TrimSpace(decisionInfo.BackupHostname) if !strings.EqualFold(strings.TrimSpace(currentHost), backupHost) { logger.Warning("Access control/TFA: backup hostname=%s current hostname=%s; WebAuthn users may require re-enrollment if the UI origin (FQDN/port) changes", backupHost, currentHost) } diff --git a/internal/orchestrator/restore_workflow_ui_helpers_test.go b/internal/orchestrator/restore_workflow_ui_helpers_test.go index f2a03724..ca1a8fe 100644 --- a/internal/orchestrator/restore_workflow_ui_helpers_test.go +++ b/internal/orchestrator/restore_workflow_ui_helpers_test.go @@ -35,6 +35,10 @@ type fakeRestoreWorkflowUI struct { confirmActionErr error repairNICNamesErr error networkCommitErr error + + confirmCompatibilityCalls int + clusterRestoreModeCalls int + lastCompatibilityWarning error } func (f *fakeRestoreWorkflowUI) RunTask(ctx context.Context, title, initialMessage string, run func(ctx context.Context, report ProgressReporter) error) error { @@ -84,10 +88,13 @@ func (f *fakeRestoreWorkflowUI) ConfirmRestore(ctx context.Context) (bool, error } func (f *fakeRestoreWorkflowUI) ConfirmCompatibility(ctx context.Context, warning error) (bool, error) { + f.confirmCompatibilityCalls++ + f.lastCompatibilityWarning = warning return f.confirmCompatible, f.confirmCompatibleErr } func (f *fakeRestoreWorkflowUI) SelectClusterRestoreMode(ctx context.Context) (ClusterRestoreMode, error) { + f.clusterRestoreModeCalls++ return f.clusterMode, f.clusterModeErr } diff --git a/internal/orchestrator/selective.go b/internal/orchestrator/selective.go index 4b05ea2..5b0b8af 100644 --- a/internal/orchestrator/selective.go +++ b/internal/orchestrator/selective.go @@ -27,40 +27,8 @@ type SelectiveRestoreConfig struct { // AnalyzeBackupCategories detects which categories are available in the backup func AnalyzeBackupCategories(archivePath string, logger *logging.Logger) (categories []Category, err error) { - done := logging.DebugStart(logger, "analyze backup categories", "archive=%s", archivePath) - defer func() { done(err) }() - logger.Info("Analyzing backup categories...") - - // Open the archive and read all entry names - file, err := restoreFS.Open(archivePath) - if err != nil { - return nil, fmt.Errorf("open archive: %w", err) - } - defer file.Close() - - // Create appropriate reader based on compression - reader, err := createDecompressionReader(context.Background(), file, archivePath) - if err != nil { - return nil, err - } - defer func() { - if closer, ok := reader.(interface{ Close() error }); ok { - closer.Close() - } - }() - - tarReader := tar.NewReader(reader) - - archivePaths := collectArchivePaths(tarReader) - logger.Debug("Found %d entries in archive", len(archivePaths)) - - availableCategories := AnalyzeArchivePaths(archivePaths, GetAllCategories()) - for _, cat := range availableCategories { - logger.Debug("Category available: %s (%s)", cat.ID, cat.Name) - } - - logger.Info("Detected %d available categories", len(availableCategories)) - return availableCategories, nil + availableCategories, _, err := AnalyzeRestoreArchive(archivePath, logger) + return availableCategories, err } // AnalyzeArchivePaths determines available categories from the provided archive entries. diff --git a/internal/pbs/namespaces.go b/internal/pbs/namespaces.go index 345ac56..d168bef 100644 --- a/internal/pbs/namespaces.go +++ b/internal/pbs/namespaces.go @@ -46,6 +46,12 @@ func ListNamespaces(ctx context.Context, datastoreName, datastorePath string, io return namespaces, true, nil } +// DiscoverNamespacesFromFilesystem skips the PBS CLI and infers namespaces +// directly from the datastore filesystem layout. +func DiscoverNamespacesFromFilesystem(ctx context.Context, datastorePath string, ioTimeout time.Duration) ([]Namespace, error) { + return discoverNamespacesFromFilesystem(ctx, datastorePath, ioTimeout) +} + func listNamespacesViaCLI(ctx context.Context, datastore string) ([]Namespace, error) { if err := ctx.Err(); err != nil { return nil, err diff --git a/internal/pbs/namespaces_test.go b/internal/pbs/namespaces_test.go index ac358b9..f151cae 100644 --- a/internal/pbs/namespaces_test.go +++ b/internal/pbs/namespaces_test.go @@ -130,6 +130,21 @@ func TestDiscoverNamespacesFromFilesystem_Errors(t *testing.T) { } } +func TestDiscoverNamespacesFromFilesystemExportedHelper(t *testing.T) { + tmpDir := t.TempDir() + mustMkdirAll(t, filepath.Join(tmpDir, "prod", "vm")) + + namespaces, err := DiscoverNamespacesFromFilesystem(context.Background(), tmpDir, 0) + if err != nil { + t.Fatalf("DiscoverNamespacesFromFilesystem failed: %v", err) + } + + got := namespacesToMap(namespaces) + if _, ok := got["prod"]; !ok { + t.Fatalf("expected exported helper to discover namespace, got %+v", namespaces) + } +} + func TestListNamespaces_CLISuccess(t *testing.T) { setExecCommandStub(t, "cli-success") diff --git a/internal/safefs/safefs.go b/internal/safefs/safefs.go index 36b001a..1ca2bed 100644 --- a/internal/safefs/safefs.go +++ b/internal/safefs/safefs.go @@ -14,6 +14,7 @@ var ( osStat = os.Stat osReadDir = os.ReadDir syscallStatfs = syscall.Statfs + fsOpLimiter = newOperationLimiter(32) ) // ErrTimeout is a sentinel error used to classify filesystem operations that did not @@ -56,100 +57,118 @@ func effectiveTimeout(ctx context.Context, timeout time.Duration) time.Duration return timeout } -func Stat(ctx context.Context, path string, timeout time.Duration) (fs.FileInfo, error) { - if err := ctx.Err(); err != nil { - return nil, err +func normalizeContextErr(ctx context.Context, deadlineErr error) error { + err := ctx.Err() + if err == nil { + return nil } - timeout = effectiveTimeout(ctx, timeout) - if timeout <= 0 { - return osStat(path) + if errors.Is(err, context.DeadlineExceeded) { + return deadlineErr } + return err +} - type result struct { - info fs.FileInfo - err error - } - ch := make(chan result, 1) - go func() { - info, err := osStat(path) - ch <- result{info: info, err: err} - }() +// operationLimiter bounds the number of in-flight filesystem goroutines whose +// callers may already have returned due to timeout/cancellation. +type operationLimiter struct { + slots chan struct{} +} - timer := time.NewTimer(timeout) - defer timer.Stop() +func newOperationLimiter(capacity int) *operationLimiter { + if capacity < 1 { + capacity = 1 + } + return &operationLimiter{ + slots: make(chan struct{}, capacity), + } +} +func (l *operationLimiter) acquire(ctx context.Context, timer <-chan time.Time) error { select { - case r := <-ch: - return r.info, r.err + case l.slots <- struct{}{}: + return nil case <-ctx.Done(): - return nil, ctx.Err() - case <-timer.C: - return nil, &TimeoutError{Op: "stat", Path: path, Timeout: timeout} + return normalizeContextErr(ctx, ErrTimeout) + case <-timer: + return ErrTimeout } } -func ReadDir(ctx context.Context, path string, timeout time.Duration) ([]os.DirEntry, error) { - if err := ctx.Err(); err != nil { - return nil, err +func (l *operationLimiter) release() { + select { + case <-l.slots: + default: + } +} + +func (l *operationLimiter) inflight() int { + return len(l.slots) +} + +func runLimited[T any](ctx context.Context, timeout time.Duration, timeoutErr *TimeoutError, run func() (T, error)) (T, error) { + var zero T + if err := normalizeContextErr(ctx, timeoutErr); err != nil { + return zero, err } timeout = effectiveTimeout(ctx, timeout) if timeout <= 0 { - return osReadDir(path) - } - - type result struct { - entries []os.DirEntry - err error + if err := normalizeContextErr(ctx, timeoutErr); err != nil { + return zero, err + } + return run() } - ch := make(chan result, 1) - go func() { - entries, err := osReadDir(path) - ch <- result{entries: entries, err: err} - }() timer := time.NewTimer(timeout) defer timer.Stop() - select { - case r := <-ch: - return r.entries, r.err - case <-ctx.Done(): - return nil, ctx.Err() - case <-timer.C: - return nil, &TimeoutError{Op: "readdir", Path: path, Timeout: timeout} - } -} - -func Statfs(ctx context.Context, path string, timeout time.Duration) (syscall.Statfs_t, error) { - if err := ctx.Err(); err != nil { - return syscall.Statfs_t{}, err - } - timeout = effectiveTimeout(ctx, timeout) - if timeout <= 0 { - var stat syscall.Statfs_t - return stat, syscallStatfs(path, &stat) + limiter := fsOpLimiter + if err := limiter.acquire(ctx, timer.C); err != nil { + if errors.Is(err, ErrTimeout) { + return zero, timeoutErr + } + return zero, err } type result struct { - stat syscall.Statfs_t - err error + value T + err error } ch := make(chan result, 1) go func() { - var stat syscall.Statfs_t - err := syscallStatfs(path, &stat) - ch <- result{stat: stat, err: err} + defer limiter.release() + value, err := run() + ch <- result{value: value, err: err} }() - timer := time.NewTimer(timeout) - defer timer.Stop() - select { case r := <-ch: - return r.stat, r.err + return r.value, r.err case <-ctx.Done(): - return syscall.Statfs_t{}, ctx.Err() + return zero, normalizeContextErr(ctx, timeoutErr) case <-timer.C: - return syscall.Statfs_t{}, &TimeoutError{Op: "statfs", Path: path, Timeout: timeout} + return zero, timeoutErr } } + +func Stat(ctx context.Context, path string, timeout time.Duration) (fs.FileInfo, error) { + stat := osStat + return runLimited(ctx, timeout, &TimeoutError{Op: "stat", Path: path, Timeout: effectiveTimeout(ctx, timeout)}, func() (fs.FileInfo, error) { + return stat(path) + }) +} + +func ReadDir(ctx context.Context, path string, timeout time.Duration) ([]os.DirEntry, error) { + readDir := osReadDir + return runLimited(ctx, timeout, &TimeoutError{Op: "readdir", Path: path, Timeout: effectiveTimeout(ctx, timeout)}, func() ([]os.DirEntry, error) { + return readDir(path) + }) +} + +func Statfs(ctx context.Context, path string, timeout time.Duration) (syscall.Statfs_t, error) { + statfs := syscallStatfs + return runLimited(ctx, timeout, &TimeoutError{Op: "statfs", Path: path, Timeout: effectiveTimeout(ctx, timeout)}, func() (syscall.Statfs_t, error) { + var stat syscall.Statfs_t + err := statfs(path, &stat) + return stat, err + }) +} diff --git a/internal/safefs/safefs_test.go b/internal/safefs/safefs_test.go index 30646ae..27d87b8 100644 --- a/internal/safefs/safefs_test.go +++ b/internal/safefs/safefs_test.go @@ -4,17 +4,95 @@ import ( "context" "errors" "os" + "sync/atomic" "syscall" "testing" "time" ) +type stagedDeadlineContext struct { + deadline time.Time + done <-chan struct{} + errCalls int +} + +func (c *stagedDeadlineContext) Deadline() (time.Time, bool) { + return c.deadline, true +} + +func (c *stagedDeadlineContext) Done() <-chan struct{} { + return c.done +} + +func (c *stagedDeadlineContext) Err() error { + c.errCalls++ + if c.errCalls == 1 { + return nil + } + return context.DeadlineExceeded +} + +func (c *stagedDeadlineContext) Value(any) any { + return nil +} + +type fixedErrContext struct { + done <-chan struct{} + err error +} + +func (c *fixedErrContext) Deadline() (time.Time, bool) { + return time.Time{}, false +} + +func (c *fixedErrContext) Done() <-chan struct{} { + return c.done +} + +func (c *fixedErrContext) Err() error { + select { + case <-c.done: + return c.err + default: + return nil + } +} + +func (c *fixedErrContext) Value(any) any { + return nil +} + +func waitForSignal(t *testing.T, ch <-chan struct{}, name string) { + t.Helper() + select { + case <-ch: + case <-time.After(500 * time.Millisecond): + t.Fatalf("timeout waiting for %s", name) + } +} + +func registerBlockedOpCleanup(t *testing.T, name string, unblock chan struct{}, finished <-chan struct{}, restore func()) { + t.Helper() + + t.Cleanup(restore) + t.Cleanup(func() { + close(unblock) + waitForSignal(t, finished, name) + }) +} + func TestStat_ReturnsTimeoutError(t *testing.T) { prev := osStat - defer func() { osStat = prev }() + unblock := make(chan struct{}) + finished := make(chan struct{}) + registerBlockedOpCleanup(t, "stat completion", unblock, finished, func() { + osStat = prev + }) osStat = func(string) (os.FileInfo, error) { - select {} + <-unblock + close(finished) + return nil, os.ErrNotExist } start := time.Now() @@ -29,10 +107,16 @@ func TestStat_ReturnsTimeoutError(t *testing.T) { func TestReadDir_ReturnsTimeoutError(t *testing.T) { prev := osReadDir - defer func() { osReadDir = prev }() + unblock := make(chan struct{}) + finished := make(chan struct{}) + registerBlockedOpCleanup(t, "readdir completion", unblock, finished, func() { + osReadDir = prev + }) osReadDir = func(string) ([]os.DirEntry, error) { - select {} + <-unblock + close(finished) + return nil, nil } start := time.Now() @@ -47,10 +131,16 @@ func TestReadDir_ReturnsTimeoutError(t *testing.T) { func TestStatfs_ReturnsTimeoutError(t *testing.T) { prev := syscallStatfs - defer func() { syscallStatfs = prev }() + unblock := make(chan struct{}) + finished := make(chan struct{}) + registerBlockedOpCleanup(t, "statfs completion", unblock, finished, func() { + syscallStatfs = prev + }) syscallStatfs = func(string, *syscall.Statfs_t) error { - select {} + <-unblock + close(finished) + return nil } start := time.Now() @@ -72,3 +162,165 @@ func TestStat_PropagatesContextCancellation(t *testing.T) { t.Fatalf("Stat err = %v; want context.Canceled", err) } } + +func TestRunLimited_ReturnsTimeoutErrorWhenDeadlineExpiresBeforeNoTimeoutPath(t *testing.T) { + done := make(chan struct{}) + close(done) + ctx := &stagedDeadlineContext{ + deadline: time.Now().Add(-time.Millisecond), + done: done, + } + + called := false + _, err := runLimited(ctx, 50*time.Millisecond, &TimeoutError{Op: "stat", Path: "/does/not/matter"}, func() (int, error) { + called = true + return 1, nil + }) + + if called { + t.Fatal("run called; want timeout before execution") + } + if err == nil || !errors.Is(err, ErrTimeout) { + t.Fatalf("runLimited err = %v; want timeout", err) + } +} + +func TestRunLimited_NormalizesExpiredDeadlineAtEntry(t *testing.T) { + done := make(chan struct{}) + close(done) + ctx := &fixedErrContext{ + done: done, + err: context.DeadlineExceeded, + } + + called := false + _, err := runLimited(ctx, 50*time.Millisecond, &TimeoutError{Op: "stat", Path: "/does/not/matter"}, func() (int, error) { + called = true + return 1, nil + }) + + if called { + t.Fatal("run called; want timeout before execution") + } + if err == nil || !errors.Is(err, ErrTimeout) { + t.Fatalf("runLimited err = %v; want timeout", err) + } +} + +func TestRunLimited_NormalizesDeadlineFromDoneBranch(t *testing.T) { + done := make(chan struct{}) + ctx := &fixedErrContext{ + done: done, + err: context.DeadlineExceeded, + } + + unblock := make(chan struct{}) + finished := make(chan struct{}) + t.Cleanup(func() { + close(unblock) + waitForSignal(t, finished, "runLimited completion") + }) + + go func() { + time.Sleep(10 * time.Millisecond) + close(done) + }() + + _, err := runLimited(ctx, time.Second, &TimeoutError{Op: "stat", Path: "/does/not/matter"}, func() (int, error) { + defer close(finished) + <-unblock + return 1, nil + }) + + if err == nil || !errors.Is(err, ErrTimeout) { + t.Fatalf("runLimited err = %v; want timeout", err) + } +} + +func TestOperationLimiterAcquire_NormalizesDeadlineExceeded(t *testing.T) { + limiter := newOperationLimiter(1) + limiter.slots <- struct{}{} + + done := make(chan struct{}) + ctx := &fixedErrContext{ + done: done, + err: context.DeadlineExceeded, + } + + timer := time.NewTimer(time.Second) + defer timer.Stop() + + go func() { + time.Sleep(10 * time.Millisecond) + close(done) + }() + + err := limiter.acquire(ctx, timer.C) + if err == nil || !errors.Is(err, ErrTimeout) { + t.Fatalf("acquire err = %v; want timeout", err) + } +} + +func TestOperationLimiterAcquire_PropagatesCancellation(t *testing.T) { + limiter := newOperationLimiter(1) + limiter.slots <- struct{}{} + + done := make(chan struct{}) + ctx := &fixedErrContext{ + done: done, + err: context.Canceled, + } + + timer := time.NewTimer(time.Second) + defer timer.Stop() + + go func() { + time.Sleep(10 * time.Millisecond) + close(done) + }() + + err := limiter.acquire(ctx, timer.C) + if !errors.Is(err, context.Canceled) { + t.Fatalf("acquire err = %v; want context.Canceled", err) + } +} + +func TestStat_DoesNotSpawnPastLimiterCapacity(t *testing.T) { + prevStat := osStat + prevLimiter := fsOpLimiter + unblock := make(chan struct{}) + finished := make(chan struct{}) + registerBlockedOpCleanup(t, "limited stat completion", unblock, finished, func() { + osStat = prevStat + fsOpLimiter = prevLimiter + }) + + fsOpLimiter = newOperationLimiter(1) + + var calls atomic.Int32 + osStat = func(string) (os.FileInfo, error) { + calls.Add(1) + <-unblock + close(finished) + return nil, os.ErrNotExist + } + + _, err := Stat(context.Background(), "/first", 25*time.Millisecond) + if err == nil || !errors.Is(err, ErrTimeout) { + t.Fatalf("first Stat err = %v; want timeout", err) + } + if got := calls.Load(); got != 1 { + t.Fatalf("calls after first timeout = %d; want 1", got) + } + if got := fsOpLimiter.inflight(); got != 1 { + t.Fatalf("inflight after first timeout = %d; want 1", got) + } + + _, err = Stat(context.Background(), "/second", 25*time.Millisecond) + if err == nil || !errors.Is(err, ErrTimeout) { + t.Fatalf("second Stat err = %v; want timeout", err) + } + if got := calls.Load(); got != 1 { + t.Fatalf("calls after limiter saturation = %d; want 1", got) + } +}