diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
index 73f72f2..d4c1577 100644
--- a/.github/workflows/release.yml
+++ b/.github/workflows/release.yml
@@ -70,7 +70,7 @@ jobs:
# GORELEASER
########################################
- name: Run GoReleaser
- uses: goreleaser/goreleaser-action@v6
+ uses: goreleaser/goreleaser-action@v7
with:
version: latest
workdir: ${{ github.workspace }}
@@ -82,6 +82,6 @@ jobs:
# ATTESTAZIONE PROVENIENZA BUILD
########################################
- name: Attest Build Provenance
- uses: actions/attest-build-provenance@v3
+ uses: actions/attest-build-provenance@v4
with:
subject-path: build/proxsave_*
diff --git a/Makefile b/Makefile
index e3041d4..6d17a9d 100644
--- a/Makefile
+++ b/Makefile
@@ -1,6 +1,8 @@
.PHONY: build test clean run build-release test-coverage lint fmt deps help coverage coverage-check
COVERAGE_THRESHOLD ?= 50.0
+TOOLCHAIN_FROM_MOD := $(shell awk '/^toolchain /{print $$2}' go.mod 2>/dev/null)
+COVER_GOTOOLCHAIN := $(if $(TOOLCHAIN_FROM_MOD),$(TOOLCHAIN_FROM_MOD)+auto,auto)
# Build del progetto
build:
@@ -60,20 +62,20 @@ test:
# Test con coverage
test-coverage:
- go test -coverprofile=coverage.out ./...
- go tool cover -html=coverage.out
+ GOTOOLCHAIN=$(COVER_GOTOOLCHAIN) go test -coverprofile=coverage.out ./...
+ GOTOOLCHAIN=$(COVER_GOTOOLCHAIN) go tool cover -html=coverage.out
# Full coverage report (all packages)
coverage:
@echo "Running coverage across all packages..."
- @go test -coverpkg=./... -coverprofile=coverage.out ./...
- @go tool cover -func=coverage.out | tail -n 1
+ @GOTOOLCHAIN=$(COVER_GOTOOLCHAIN) go test -coverpkg=./... -coverprofile=coverage.out ./...
+ @GOTOOLCHAIN=$(COVER_GOTOOLCHAIN) go tool cover -func=coverage.out | tail -n 1
# Enforce minimum coverage threshold
coverage-check:
@echo "Running coverage check (threshold $(COVERAGE_THRESHOLD)% )..."
- @go test -coverpkg=./... -coverprofile=coverage.out ./...
- @total=$$(go tool cover -func=coverage.out | grep total: | awk '{print $$3}' | sed 's/%//'); \
+ @GOTOOLCHAIN=$(COVER_GOTOOLCHAIN) go test -coverpkg=./... -coverprofile=coverage.out ./...
+ @total=$$(GOTOOLCHAIN=$(COVER_GOTOOLCHAIN) go tool cover -func=coverage.out | grep total: | awk '{print $$3}' | sed 's/%//'); \
echo "Total coverage: $$total%"; \
if awk -v total="$$total" -v threshold="$(COVERAGE_THRESHOLD)" 'BEGIN { exit !(total+0 >= threshold+0) }'; then \
echo "Coverage check passed."; \
diff --git a/docs/CONFIGURATION.md b/docs/CONFIGURATION.md
index a3ef784..e0f3d82 100644
--- a/docs/CONFIGURATION.md
+++ b/docs/CONFIGURATION.md
@@ -1053,6 +1053,8 @@ VZDUMP_CONFIG_PATH=/etc/vzdump.conf
# PBS datastore paths (comma/space separated)
PBS_DATASTORE_PATH= # e.g., "/mnt/pbs1,/mnt/pbs2"
+# Extra filesystem scan roots for datastore/PXAR discovery; these do not create
+# real PBS datastore definitions and may use path-derived output keys.
# System root override (testing/chroot)
SYSTEM_ROOT_PREFIX= # Optional alternate root for system collection. Empty or "/" = real root.
diff --git a/docs/RESTORE_GUIDE.md b/docs/RESTORE_GUIDE.md
index 4c04c4c..53312fb 100644
--- a/docs/RESTORE_GUIDE.md
+++ b/docs/RESTORE_GUIDE.md
@@ -117,7 +117,7 @@ API apply is automatic for supported PBS staged categories; ProxSave may fall ba
|----------|------|-------------|-------|
| `pbs_config` | PBS Config Export | **Export-only** copy of /etc/proxmox-backup (never written to system) | `./etc/proxmox-backup/` |
| `pbs_host` | PBS Host & Integrations | **Staged** node settings, ACME, proxy, metric servers and traffic control (API/file apply) | `./etc/proxmox-backup/node.cfg`
`./etc/proxmox-backup/proxy.cfg`
`./etc/proxmox-backup/acme/accounts.cfg`
`./etc/proxmox-backup/acme/plugins.cfg`
`./etc/proxmox-backup/metricserver.cfg`
`./etc/proxmox-backup/traffic-control.cfg`
`./var/lib/proxsave-info/commands/pbs/node_config.json`
`./var/lib/proxsave-info/commands/pbs/acme_accounts.json`
`./var/lib/proxsave-info/commands/pbs/acme_plugins.json`
`./var/lib/proxsave-info/commands/pbs/acme_account_*_info.json`
`./var/lib/proxsave-info/commands/pbs/acme_plugin_*_config.json`
`./var/lib/proxsave-info/commands/pbs/traffic_control.json` |
-| `datastore_pbs` | PBS Datastore Configuration | **Staged** datastore definitions (incl. S3 endpoints) (API/file apply) | `./etc/proxmox-backup/datastore.cfg`
`./etc/proxmox-backup/s3.cfg`
`./var/lib/proxsave-info/commands/pbs/datastore_list.json`
`./var/lib/proxsave-info/commands/pbs/datastore_*_status.json`
`./var/lib/proxsave-info/commands/pbs/s3_endpoints.json`
`./var/lib/proxsave-info/commands/pbs/s3_endpoint_*_buckets.json`
`./var/lib/proxsave-info/commands/pbs/pbs_datastore_inventory.json` |
+| `datastore_pbs` | PBS Datastore Configuration | **Staged** datastore definitions (incl. S3 endpoints) (API/file apply) | `./etc/proxmox-backup/datastore.cfg`
`./etc/proxmox-backup/s3.cfg`
`./var/lib/proxsave-info/commands/pbs/datastore_list.json`
`./var/lib/proxsave-info/commands/pbs/datastore_*_status.json`
`./var/lib/proxsave-info/commands/pbs/s3_endpoints.json`
`./var/lib/proxsave-info/commands/pbs/s3_endpoint_*_buckets.json`
`./var/lib/proxsave-info/commands/pbs/pbs_datastore_inventory.json`
Note: `PBS_DATASTORE_PATH` override scan roots are inventory context only and are not recreated as datastore definitions during restore. |
| `maintenance_pbs` | PBS Maintenance | Maintenance settings | `./etc/proxmox-backup/maintenance.cfg` |
| `pbs_jobs` | PBS Jobs | **Staged** sync/verify/prune jobs (API/file apply) | `./etc/proxmox-backup/sync.cfg`
`./etc/proxmox-backup/verification.cfg`
`./etc/proxmox-backup/prune.cfg`
`./var/lib/proxsave-info/commands/pbs/sync_jobs.json`
`./var/lib/proxsave-info/commands/pbs/verification_jobs.json`
`./var/lib/proxsave-info/commands/pbs/prune_jobs.json`
`./var/lib/proxsave-info/commands/pbs/gc_jobs.json` |
| `pbs_remotes` | PBS Remotes | **Staged** remotes for sync/verify (may include credentials) (API/file apply) | `./etc/proxmox-backup/remote.cfg`
`./var/lib/proxsave-info/commands/pbs/remote_list.json` |
@@ -2384,6 +2384,7 @@ systemctl restart proxmox-backup proxmox-backup-proxy
**Restore behavior**:
- ProxSave detects this condition during staged apply.
- If `var/lib/proxsave-info/commands/pbs/pbs_datastore_inventory.json` is available in the backup, ProxSave will use its embedded snapshot of the original `datastore.cfg` to recover a valid configuration.
+- Inventory entries that came only from `PBS_DATASTORE_PATH` scan roots are treated as diagnostic context and are excluded from regenerated `datastore.cfg`.
- If recovery is not possible, ProxSave will **leave the existing** `/etc/proxmox-backup/datastore.cfg` unchanged to avoid breaking PBS.
**Manual diagnosis**:
diff --git a/go.mod b/go.mod
index 4a36bb8..f99518e 100644
--- a/go.mod
+++ b/go.mod
@@ -2,7 +2,7 @@ module github.com/tis24dev/proxsave
go 1.25
-toolchain go1.25.7
+toolchain go1.25.8
require (
filippo.io/age v1.3.1
diff --git a/internal/backup/checksum.go b/internal/backup/checksum.go
index fc917ad..0e08b39 100644
--- a/internal/backup/checksum.go
+++ b/internal/backup/checksum.go
@@ -35,6 +35,30 @@ type Manifest struct {
ClusterMode string `json:"cluster_mode,omitempty"`
}
+// NormalizeChecksum validates and normalizes a SHA256 checksum string.
+func NormalizeChecksum(value string) (string, error) {
+ checksum := strings.ToLower(strings.TrimSpace(value))
+ if checksum == "" {
+ return "", fmt.Errorf("checksum is empty")
+ }
+ if len(checksum) != sha256.Size*2 {
+ return "", fmt.Errorf("checksum must be %d hex characters, got %d", sha256.Size*2, len(checksum))
+ }
+ if _, err := hex.DecodeString(checksum); err != nil {
+ return "", fmt.Errorf("checksum is not valid hex: %w", err)
+ }
+ return checksum, nil
+}
+
+// ParseChecksumData extracts a SHA256 checksum from checksum file contents.
+func ParseChecksumData(data []byte) (string, error) {
+ fields := strings.Fields(string(data))
+ if len(fields) == 0 {
+ return "", fmt.Errorf("checksum file is empty")
+ }
+ return NormalizeChecksum(fields[0])
+}
+
// GenerateChecksum calculates SHA256 checksum of a file
func GenerateChecksum(ctx context.Context, logger *logging.Logger, filePath string) (string, error) {
logger.Debug("Generating SHA256 checksum for: %s", filePath)
@@ -105,16 +129,21 @@ func CreateManifest(ctx context.Context, logger *logging.Logger, manifest *Manif
func VerifyChecksum(ctx context.Context, logger *logging.Logger, filePath, expectedChecksum string) (bool, error) {
logger.Debug("Verifying checksum for: %s", filePath)
+ normalizedExpected, err := NormalizeChecksum(expectedChecksum)
+ if err != nil {
+ return false, fmt.Errorf("invalid expected checksum: %w", err)
+ }
+
actualChecksum, err := GenerateChecksum(ctx, logger, filePath)
if err != nil {
return false, fmt.Errorf("failed to generate checksum: %w", err)
}
- matches := actualChecksum == expectedChecksum
+ matches := actualChecksum == normalizedExpected
if matches {
logger.Debug("Checksum verification passed")
} else {
- logger.Warning("Checksum mismatch! Expected: %s, Got: %s", expectedChecksum, actualChecksum)
+ logger.Warning("Checksum mismatch! Expected: %s, Got: %s", normalizedExpected, actualChecksum)
}
return matches, nil
@@ -205,9 +234,8 @@ func parseLegacyMetadata(scanner *bufio.Scanner, legacy *Manifest) {
func loadLegacyChecksum(archivePath string, legacy *Manifest) {
// Attempt to load checksum from legacy .sha256 file
if shaData, err := os.ReadFile(archivePath + ".sha256"); err == nil {
- fields := strings.Fields(string(shaData))
- if len(fields) > 0 {
- legacy.SHA256 = fields[0]
+ if checksum, parseErr := ParseChecksumData(shaData); parseErr == nil {
+ legacy.SHA256 = checksum
}
}
}
diff --git a/internal/backup/checksum_legacy_test.go b/internal/backup/checksum_legacy_test.go
index 1013ae8..c6738cf 100644
--- a/internal/backup/checksum_legacy_test.go
+++ b/internal/backup/checksum_legacy_test.go
@@ -28,7 +28,8 @@ func TestLoadLegacyManifestWithShaAndFallbackEncryption(t *testing.T) {
t.Fatalf("write metadata: %v", err)
}
- shaLine := "deadbeef " + filepath.Base(archive) + "\n"
+ expectedSHA := strings.Repeat("a", 64)
+ shaLine := expectedSHA + " " + filepath.Base(archive) + "\n"
if err := os.WriteFile(archive+".sha256", []byte(shaLine), 0o640); err != nil {
t.Fatalf("write sha256: %v", err)
}
@@ -50,8 +51,8 @@ func TestLoadLegacyManifestWithShaAndFallbackEncryption(t *testing.T) {
if m.EncryptionMode != "plain" {
t.Fatalf("expected fallback encryption mode plain, got %s", m.EncryptionMode)
}
- if m.SHA256 != "deadbeef" {
- t.Fatalf("expected sha256 deadbeef, got %s", m.SHA256)
+ if m.SHA256 != expectedSHA {
+ t.Fatalf("expected sha256 %s, got %s", expectedSHA, m.SHA256)
}
if time.Since(m.CreatedAt) > time.Minute {
t.Fatalf("unexpected CreatedAt too old: %v", m.CreatedAt)
diff --git a/internal/backup/checksum_test.go b/internal/backup/checksum_test.go
index bc12bd0..676506f 100644
--- a/internal/backup/checksum_test.go
+++ b/internal/backup/checksum_test.go
@@ -145,8 +145,9 @@ ENCRYPTION_MODE=age
if err := os.WriteFile(metadataPath, []byte(metadata), 0644); err != nil {
t.Fatalf("failed to write metadata: %v", err)
}
+ expectedSHA := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
shaPath := archivePath + ".sha256"
- if err := os.WriteFile(shaPath, []byte("deadbeef "+filepath.Base(archivePath)), 0644); err != nil {
+ if err := os.WriteFile(shaPath, []byte(expectedSHA+" "+filepath.Base(archivePath)), 0644); err != nil {
t.Fatalf("failed to write sha file: %v", err)
}
@@ -163,7 +164,7 @@ ENCRYPTION_MODE=age
if manifest.Hostname != "legacy-host" || manifest.ScriptVersion != "legacy-1.0" {
t.Fatalf("legacy metadata not parsed correctly: %+v", manifest)
}
- if manifest.SHA256 != "deadbeef" {
+ if manifest.SHA256 != expectedSHA {
t.Fatalf("expected SHA256 from sidecar, got %q", manifest.SHA256)
}
if manifest.EncryptionMode != "age" {
diff --git a/internal/backup/collector.go b/internal/backup/collector.go
index 466fed2..e3de49f 100644
--- a/internal/backup/collector.go
+++ b/internal/backup/collector.go
@@ -3,6 +3,8 @@ package backup
import (
"bytes"
"context"
+ "crypto/sha256"
+ "encoding/hex"
"errors"
"fmt"
"io"
@@ -1180,6 +1182,21 @@ func sanitizeFilename(name string) string {
return clean
}
+func collectorPathKey(name string) string {
+ trimmed := strings.TrimSpace(name)
+ if trimmed == "" {
+ return "entry"
+ }
+
+ safe := sanitizeFilename(trimmed)
+ if safe == trimmed {
+ return safe
+ }
+
+ sum := sha256.Sum256([]byte(trimmed))
+ return fmt.Sprintf("%s_%s", safe, hex.EncodeToString(sum[:4]))
+}
+
// GetStats returns current collection statistics
func (c *Collector) GetStats() *CollectionStats {
c.statsMu.Lock()
diff --git a/internal/backup/collector_helpers_extra_test.go b/internal/backup/collector_helpers_extra_test.go
index ae58f2a..a0aa38b 100644
--- a/internal/backup/collector_helpers_extra_test.go
+++ b/internal/backup/collector_helpers_extra_test.go
@@ -1,6 +1,9 @@
package backup
-import "testing"
+import (
+ "strings"
+ "testing"
+)
func TestSummarizeCommandOutputText(t *testing.T) {
if got := summarizeCommandOutputText(""); got != "(no stdout/stderr)" {
@@ -38,3 +41,25 @@ func TestSanitizeFilenameExtra(t *testing.T) {
}
}
}
+
+func TestCollectorPathKey(t *testing.T) {
+ if got := collectorPathKey("store1"); got != "store1" {
+ t.Fatalf("collectorPathKey(store1)=%q want %q", got, "store1")
+ }
+
+ unsafe := "../evil"
+ got := collectorPathKey(unsafe)
+ if got == unsafe {
+ t.Fatalf("collectorPathKey(%q) should not keep unsafe value", unsafe)
+ }
+ if got == sanitizeFilename(unsafe) {
+ t.Fatalf("collectorPathKey(%q) should add a disambiguating suffix", unsafe)
+ }
+ if !strings.HasPrefix(got, "__evil") {
+ t.Fatalf("collectorPathKey(%q)=%q should start with sanitized prefix", unsafe, got)
+ }
+
+ if a, b := collectorPathKey("a/b"), collectorPathKey("a_b"); a == b {
+ t.Fatalf("collectorPathKey should avoid collisions: %q == %q", a, b)
+ }
+}
diff --git a/internal/backup/collector_pbs.go b/internal/backup/collector_pbs.go
index 98212e9..8b23b7d 100644
--- a/internal/backup/collector_pbs.go
+++ b/internal/backup/collector_pbs.go
@@ -347,6 +347,11 @@ func (c *Collector) collectPBSDirectories(ctx context.Context, root string) erro
// collectPBSCommands collects output from PBS commands
func (c *Collector) collectPBSCommands(ctx context.Context, datastores []pbsDatastore) error {
+ if len(datastores) > 0 {
+ datastores = clonePBSDatastores(datastores)
+ assignUniquePBSDatastoreOutputKeys(datastores)
+ }
+
commandsDir := c.proxsaveCommandsDir("pbs")
if err := c.ensureDir(commandsDir); err != nil {
return fmt.Errorf("failed to create commands directory: %w", err)
@@ -383,9 +388,19 @@ func (c *Collector) collectPBSCommands(ctx context.Context, datastores []pbsData
// Datastore usage details
if c.config.BackupDatastoreConfigs && len(datastores) > 0 {
for _, ds := range datastores {
+ if ds.isOverride() {
+ c.logger.Debug("Skipping datastore status for %s (path=%s): no PBS datastore identity", ds.Name, ds.Path)
+ continue
+ }
+ cliName := ds.cliName()
+ if cliName == "" {
+ c.logger.Debug("Skipping datastore status for %s (path=%s): empty PBS datastore identity", ds.Name, ds.Path)
+ continue
+ }
+ dsKey := ds.pathKey()
c.safeCmdOutput(ctx,
- fmt.Sprintf("proxmox-backup-manager datastore show %s --output-format=json", ds.Name),
- filepath.Join(commandsDir, fmt.Sprintf("datastore_%s_status.json", ds.Name)),
+ fmt.Sprintf("proxmox-backup-manager datastore show %s --output-format=json", cliName),
+ filepath.Join(commandsDir, fmt.Sprintf("datastore_%s_status.json", dsKey)),
fmt.Sprintf("Datastore %s status", ds.Name),
false)
}
diff --git a/internal/backup/collector_pbs_commands_coverage_test.go b/internal/backup/collector_pbs_commands_coverage_test.go
index b96c44f..65c37a1 100644
--- a/internal/backup/collector_pbs_commands_coverage_test.go
+++ b/internal/backup/collector_pbs_commands_coverage_test.go
@@ -92,6 +92,42 @@ func TestCollectPBSCommandsWritesExpectedOutputs(t *testing.T) {
}
}
+func TestCollectPBSCommandsSkipsStatusForOverrideOnlyEntries(t *testing.T) {
+ pbsRoot := t.TempDir()
+ if err := os.WriteFile(filepath.Join(pbsRoot, "tape.cfg"), []byte("ok"), 0o640); err != nil {
+ t.Fatalf("write tape.cfg: %v", err)
+ }
+
+ cfg := GetDefaultCollectorConfig()
+ cfg.PBSConfigPath = pbsRoot
+
+ collector := NewCollectorWithDeps(newTestLogger(), cfg, t.TempDir(), types.ProxmoxBS, false, CollectorDeps{
+ LookPath: func(name string) (string, error) {
+ return "/bin/" + name, nil
+ },
+ RunCommand: func(ctx context.Context, name string, args ...string) ([]byte, error) {
+ return []byte(fmt.Sprintf("%s %s", name, strings.Join(args, " "))), nil
+ },
+ })
+
+ override := pbsDatastore{
+ Name: "backup",
+ Path: "/mnt/a/backup",
+ Source: pbsDatastoreSourceOverride,
+ NormalizedPath: normalizePBSDatastorePath("/mnt/a/backup"),
+ OutputKey: buildPBSOverrideOutputKey("/mnt/a/backup"),
+ }
+ if err := collector.collectPBSCommands(context.Background(), []pbsDatastore{override}); err != nil {
+ t.Fatalf("collectPBSCommands error: %v", err)
+ }
+
+ commandsDir := filepath.Join(collector.tempDir, "var/lib/proxsave-info", "commands", "pbs")
+ statusPath := filepath.Join(commandsDir, fmt.Sprintf("datastore_%s_status.json", override.pathKey()))
+ if _, err := os.Stat(statusPath); !os.IsNotExist(err) {
+ t.Fatalf("override datastore status file should not exist (%s), got err=%v", statusPath, err)
+ }
+}
+
func TestCollectPBSCommandsReturnsErrorWhenCriticalVersionFails(t *testing.T) {
cfg := GetDefaultCollectorConfig()
cfg.PBSConfigPath = t.TempDir()
@@ -189,6 +225,50 @@ func TestCollectPBSPxarMetadataProcessesMultipleDatastores(t *testing.T) {
}
}
+func TestCollectPBSPxarMetadataSeparatesOverrideBasenameCollisions(t *testing.T) {
+ tmp := t.TempDir()
+ cfg := GetDefaultCollectorConfig()
+ cfg.PxarDatastoreConcurrency = 2
+
+ collector := NewCollector(newTestLogger(), cfg, tmp, types.ProxmoxBS, false)
+
+ makeOverride := func(path string) pbsDatastore {
+ for _, sub := range []string{"vm", "ct"} {
+ if err := os.MkdirAll(filepath.Join(path, sub), 0o755); err != nil {
+ t.Fatalf("mkdir %s: %v", sub, err)
+ }
+ }
+ if err := os.WriteFile(filepath.Join(path, "vm", "backup.pxar"), []byte("data"), 0o640); err != nil {
+ t.Fatalf("write pxar: %v", err)
+ }
+ return pbsDatastore{
+ Name: "backup",
+ Path: path,
+ Comment: "configured via PBS_DATASTORE_PATH",
+ Source: pbsDatastoreSourceOverride,
+ NormalizedPath: normalizePBSDatastorePath(path),
+ OutputKey: buildPBSOverrideOutputKey(path),
+ }
+ }
+
+ ds1 := makeOverride(filepath.Join(tmp, "mnt", "a", "backup"))
+ ds2 := makeOverride(filepath.Join(tmp, "srv", "b", "backup"))
+ if ds1.pathKey() == ds2.pathKey() {
+ t.Fatalf("expected distinct path keys, got %q", ds1.pathKey())
+ }
+
+ if err := collector.collectPBSPxarMetadata(context.Background(), []pbsDatastore{ds1, ds2}); err != nil {
+ t.Fatalf("collectPBSPxarMetadata error: %v", err)
+ }
+
+ for _, ds := range []pbsDatastore{ds1, ds2} {
+ base := filepath.Join(tmp, "var/lib/proxsave-info", "pbs", "pxar", "metadata", ds.pathKey())
+ if _, err := os.Stat(filepath.Join(base, "metadata.json")); err != nil {
+ t.Fatalf("expected metadata for %s: %v", ds.pathKey(), err)
+ }
+ }
+}
+
func TestCollectPBSPxarMetadataReturnsErrorWhenTempVarIsFile(t *testing.T) {
tmp := t.TempDir()
if err := os.WriteFile(filepath.Join(tmp, "var"), []byte("not-a-dir"), 0o640); err != nil {
@@ -234,6 +314,42 @@ func TestCollectDatastoreConfigsCreatesConfigAndNamespaceFiles(t *testing.T) {
}
}
+func TestCollectDatastoreConfigsCreatesDistinctNamespaceFilesForOverrideCollisions(t *testing.T) {
+ cfg := GetDefaultCollectorConfig()
+ tmp := t.TempDir()
+ collector := NewCollectorWithDeps(newTestLogger(), cfg, tmp, types.ProxmoxBS, false, CollectorDeps{})
+
+ makeOverride := func(path string) pbsDatastore {
+ if err := os.MkdirAll(filepath.Join(path, "local", "vm"), 0o755); err != nil {
+ t.Fatalf("mkdir override namespace fixture: %v", err)
+ }
+ return pbsDatastore{
+ Name: "backup",
+ Path: path,
+ Comment: "configured via PBS_DATASTORE_PATH",
+ Source: pbsDatastoreSourceOverride,
+ NormalizedPath: normalizePBSDatastorePath(path),
+ OutputKey: buildPBSOverrideOutputKey(path),
+ }
+ }
+
+ ds1 := makeOverride(filepath.Join(tmp, "mnt", "a", "backup"))
+ ds2 := makeOverride(filepath.Join(tmp, "srv", "b", "backup"))
+ if err := collector.collectDatastoreConfigs(context.Background(), []pbsDatastore{ds1, ds2}); err != nil {
+ t.Fatalf("collectDatastoreConfigs error: %v", err)
+ }
+
+ datastoreDir := filepath.Join(tmp, "var/lib/proxsave-info", "pbs", "datastores")
+ for _, ds := range []pbsDatastore{ds1, ds2} {
+ if _, err := os.Stat(filepath.Join(datastoreDir, fmt.Sprintf("%s_namespaces.json", ds.pathKey()))); err != nil {
+ t.Fatalf("expected namespaces file for %s: %v", ds.pathKey(), err)
+ }
+ if _, err := os.Stat(filepath.Join(datastoreDir, fmt.Sprintf("%s_config.json", ds.pathKey()))); !os.IsNotExist(err) {
+ t.Fatalf("override config file should not exist for %s, got err=%v", ds.pathKey(), err)
+ }
+ }
+}
+
func TestCollectUserTokensSkipsInvalidUserListJSON(t *testing.T) {
tmp := t.TempDir()
collector := NewCollector(newTestLogger(), GetDefaultCollectorConfig(), tmp, types.ProxmoxBS, false)
diff --git a/internal/backup/collector_pbs_datastore.go b/internal/backup/collector_pbs_datastore.go
index 0b5c2ba..fa75ed5 100644
--- a/internal/backup/collector_pbs_datastore.go
+++ b/internal/backup/collector_pbs_datastore.go
@@ -2,12 +2,15 @@ package backup
import (
"context"
+ "crypto/sha256"
+ "encoding/hex"
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"
"regexp"
+ "sort"
"strings"
"sync"
"time"
@@ -16,13 +19,284 @@ import (
"github.com/tis24dev/proxsave/internal/safefs"
)
+const (
+ pbsDatastoreSourceCLI = "cli"
+ pbsDatastoreSourceOverride = "override"
+ pbsDatastoreSourceConfig = "config"
+ pbsDatastoreOriginMerged = "merged"
+)
+
+var (
+ pbsDatastoreNamePattern = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
+ listNamespacesFunc = pbs.ListNamespaces
+ discoverNamespacesFunc = pbs.DiscoverNamespacesFromFilesystem
+)
+
type pbsDatastore struct {
- Name string
- Path string
- Comment string
+ Name string
+ Path string
+ Comment string
+ Source string
+ CLIName string
+ NormalizedPath string
+ OutputKey string
+}
+
+func normalizePBSDatastorePath(path string) string {
+ trimmed := strings.TrimSpace(path)
+ if trimmed == "" {
+ return ""
+ }
+ return filepath.Clean(trimmed)
+}
+
+func buildPBSOverrideDisplayName(path string, idx int) string {
+ name := filepath.Base(normalizePBSDatastorePath(path))
+ if name == "" || name == "." || name == string(os.PathSeparator) || !pbsDatastoreNamePattern.MatchString(name) {
+ return fmt.Sprintf("datastore_%d", idx+1)
+ }
+ return name
+}
+
+func buildPBSOverrideOutputKey(path string) string {
+ normalized := normalizePBSDatastorePath(path)
+ if normalized == "" {
+ return "entry"
+ }
+
+ label := filepath.Base(normalized)
+ if label == "" || label == "." || label == string(os.PathSeparator) || !pbsDatastoreNamePattern.MatchString(label) {
+ label = "datastore"
+ }
+
+ sum := sha256.Sum256([]byte(normalized))
+ return fmt.Sprintf("%s_%s", sanitizeFilename(label), hex.EncodeToString(sum[:4]))
+}
+
+func pbsOutputKeyDigest(seed string) string {
+ sum := sha256.Sum256([]byte(seed))
+ return hex.EncodeToString(sum[:4])
+}
+
+func pbsDatastoreIdentityKey(ds pbsDatastore) string {
+ if ds.isOverride() {
+ if normalized := ds.normalizedPath(); normalized != "" {
+ return "override-path:" + normalized
+ }
+ if path := strings.TrimSpace(ds.Path); path != "" {
+ return "override-path:" + path
+ }
+ return ""
+ }
+
+ name := strings.TrimSpace(ds.Name)
+ if name == "" {
+ return ""
+ }
+ return "name:" + name
+}
+
+func pbsDatastoreDefinitionIdentityKey(def pbsDatastoreDefinition) string {
+ if strings.TrimSpace(def.Origin) == pbsDatastoreSourceOverride {
+ if normalized := normalizePBSDatastorePath(def.Path); normalized != "" {
+ return "override-path:" + normalized
+ }
+ if path := strings.TrimSpace(def.Path); path != "" {
+ return "override-path:" + path
+ }
+ return ""
+ }
+
+ name := strings.TrimSpace(def.Name)
+ if name == "" {
+ name = strings.TrimSpace(def.CLIName)
+ }
+ if name == "" {
+ return ""
+ }
+ return "name:" + name
+}
+
+func pbsDatastoreCandidateOutputKey(ds pbsDatastore) string {
+ if ds.isOverride() {
+ if normalized := ds.normalizedPath(); normalized != "" {
+ return buildPBSOverrideOutputKey(normalized)
+ }
+ return buildPBSOverrideOutputKey(ds.Path)
+ }
+ return collectorPathKey(ds.Name)
+}
+
+func pbsDatastoreDefinitionCandidateOutputKey(def pbsDatastoreDefinition) string {
+ if strings.TrimSpace(def.Origin) == pbsDatastoreSourceOverride {
+ if normalized := normalizePBSDatastorePath(def.Path); normalized != "" {
+ return buildPBSOverrideOutputKey(normalized)
+ }
+ return buildPBSOverrideOutputKey(def.Path)
+ }
+
+ name := strings.TrimSpace(def.Name)
+ if name == "" {
+ name = strings.TrimSpace(def.CLIName)
+ }
+ return collectorPathKey(name)
+}
+
+func pbsOutputKeyPriority(origin string) int {
+ if strings.TrimSpace(origin) == pbsDatastoreSourceOverride {
+ return 1
+ }
+ return 0
+}
+
+type pbsOutputKeyAssignment struct {
+ Index int
+ Identity string
+ BaseKey string
+ Priority int
+}
+
+func assignUniquePBSOutputKeys[T any](items []T, identityFn func(T) string, baseKeyFn func(T) string, priorityFn func(T) int, assignFn func(*T, string)) {
+ if len(items) == 0 {
+ return
+ }
+
+ grouped := make(map[string][]pbsOutputKeyAssignment, len(items))
+ baseKeys := make([]string, 0, len(items))
+ for idx, item := range items {
+ baseKey := strings.TrimSpace(baseKeyFn(item))
+ if baseKey == "" {
+ baseKey = "entry"
+ }
+
+ identity := strings.TrimSpace(identityFn(item))
+ if identity == "" {
+ identity = fmt.Sprintf("anonymous:%s:%d", baseKey, idx)
+ }
+
+ if _, ok := grouped[baseKey]; !ok {
+ baseKeys = append(baseKeys, baseKey)
+ }
+ grouped[baseKey] = append(grouped[baseKey], pbsOutputKeyAssignment{
+ Index: idx,
+ Identity: identity,
+ BaseKey: baseKey,
+ Priority: priorityFn(item),
+ })
+ }
+
+ sort.Strings(baseKeys)
+
+ usedKeys := make(map[string]string, len(items))
+ identityKeys := make(map[string]string, len(items))
+
+ for _, baseKey := range baseKeys {
+ assignments := grouped[baseKey]
+ sort.SliceStable(assignments, func(i, j int) bool {
+ if assignments[i].Priority != assignments[j].Priority {
+ return assignments[i].Priority < assignments[j].Priority
+ }
+ if assignments[i].Identity != assignments[j].Identity {
+ return assignments[i].Identity < assignments[j].Identity
+ }
+ return assignments[i].Index < assignments[j].Index
+ })
+
+ for pos, assignment := range assignments {
+ if existing := strings.TrimSpace(identityKeys[assignment.Identity]); existing != "" {
+ assignFn(&items[assignment.Index], existing)
+ continue
+ }
+
+ preferBase := pos == 0
+ for attempt := 0; ; attempt++ {
+ candidate := assignment.BaseKey
+ if !preferBase || attempt > 0 {
+ seed := assignment.Identity
+ if attempt > 0 {
+ seed = fmt.Sprintf("%s#%d", assignment.Identity, attempt)
+ }
+ candidate = fmt.Sprintf("%s_%s", assignment.BaseKey, pbsOutputKeyDigest(seed))
+ }
+
+ if owner, ok := usedKeys[candidate]; ok && owner != assignment.Identity {
+ continue
+ }
+
+ usedKeys[candidate] = assignment.Identity
+ identityKeys[assignment.Identity] = candidate
+ assignFn(&items[assignment.Index], candidate)
+ break
+ }
+ }
+ }
+}
+
+func assignUniquePBSDatastoreOutputKeys(datastores []pbsDatastore) {
+ assignUniquePBSOutputKeys(datastores,
+ pbsDatastoreIdentityKey,
+ pbsDatastoreCandidateOutputKey,
+ func(ds pbsDatastore) int {
+ return pbsOutputKeyPriority(ds.Source)
+ },
+ func(ds *pbsDatastore, key string) {
+ ds.OutputKey = key
+ })
+}
+
+func assignUniquePBSDatastoreDefinitionOutputKeys(defs []pbsDatastoreDefinition) {
+ assignUniquePBSOutputKeys(defs,
+ pbsDatastoreDefinitionIdentityKey,
+ pbsDatastoreDefinitionCandidateOutputKey,
+ func(def pbsDatastoreDefinition) int {
+ return pbsOutputKeyPriority(def.Origin)
+ },
+ func(def *pbsDatastoreDefinition, key string) {
+ def.OutputKey = key
+ })
+}
+
+func clonePBSDatastores(in []pbsDatastore) []pbsDatastore {
+ if len(in) == 0 {
+ return nil
+ }
+
+ out := make([]pbsDatastore, len(in))
+ copy(out, in)
+ return out
+}
+
+func (ds pbsDatastore) normalizedPath() string {
+ if path := strings.TrimSpace(ds.NormalizedPath); path != "" {
+ return path
+ }
+ return normalizePBSDatastorePath(ds.Path)
+}
+
+func (ds pbsDatastore) pathKey() string {
+ if key := strings.TrimSpace(ds.OutputKey); key != "" {
+ return key
+ }
+ return pbsDatastoreCandidateOutputKey(ds)
}
-var listNamespacesFunc = pbs.ListNamespaces
+func (ds pbsDatastore) cliName() string {
+ if name := strings.TrimSpace(ds.CLIName); name != "" {
+ return name
+ }
+ return strings.TrimSpace(ds.Name)
+}
+
+func (ds pbsDatastore) isOverride() bool {
+ return strings.TrimSpace(ds.Source) == pbsDatastoreSourceOverride
+}
+
+func (ds pbsDatastore) inventoryOrigin() string {
+ if origin := strings.TrimSpace(ds.Source); origin != "" {
+ return origin
+ }
+ return pbsDatastoreSourceCLI
+}
// collectDatastoreConfigs collects detailed datastore configurations
func (c *Collector) collectDatastoreConfigs(ctx context.Context, datastores []pbsDatastore) error {
@@ -30,6 +304,8 @@ func (c *Collector) collectDatastoreConfigs(ctx context.Context, datastores []pb
c.logger.Debug("No datastores found")
return nil
}
+ datastores = clonePBSDatastores(datastores)
+ assignUniquePBSDatastoreOutputKeys(datastores)
c.logger.Debug("Collecting datastore details for %d datastores", len(datastores))
datastoreDir := c.proxsaveInfoDir("pbs", "datastores")
@@ -38,12 +314,18 @@ func (c *Collector) collectDatastoreConfigs(ctx context.Context, datastores []pb
}
for _, ds := range datastores {
- // Get datastore configuration details
- c.safeCmdOutput(ctx,
- fmt.Sprintf("proxmox-backup-manager datastore show %s --output-format=json", ds.Name),
- filepath.Join(datastoreDir, fmt.Sprintf("%s_config.json", ds.Name)),
- fmt.Sprintf("Datastore %s configuration", ds.Name),
- false)
+ dsKey := ds.pathKey()
+
+ if cliName := ds.cliName(); cliName != "" && !ds.isOverride() {
+ // Get datastore configuration details for CLI-backed datastores only.
+ c.safeCmdOutput(ctx,
+ fmt.Sprintf("proxmox-backup-manager datastore show %s --output-format=json", cliName),
+ filepath.Join(datastoreDir, fmt.Sprintf("%s_config.json", dsKey)),
+ fmt.Sprintf("Datastore %s configuration", ds.Name),
+ false)
+ } else {
+ c.logger.Debug("Skipping datastore CLI config for %s (path=%s): no PBS datastore identity", ds.Name, ds.Path)
+ }
// Get namespace list using CLI/Filesystem fallback
if err := c.collectDatastoreNamespaces(ctx, ds, datastoreDir); err != nil {
@@ -60,7 +342,7 @@ func (c *Collector) collectDatastoreConfigs(ctx context.Context, datastores []pb
func (c *Collector) collectDatastoreNamespaces(ctx context.Context, ds pbsDatastore, datastoreDir string) error {
c.logger.Debug("Collecting namespaces for datastore %s (path: %s)", ds.Name, ds.Path)
// Write location is deterministic; if excluded, skip the whole operation.
- outputPath := filepath.Join(datastoreDir, fmt.Sprintf("%s_namespaces.json", ds.Name))
+ outputPath := filepath.Join(datastoreDir, fmt.Sprintf("%s_namespaces.json", ds.pathKey()))
if c.shouldExclude(outputPath) {
c.incFilesSkipped()
return nil
@@ -71,7 +353,17 @@ func (c *Collector) collectDatastoreNamespaces(ctx context.Context, ds pbsDatast
ioTimeout = time.Duration(c.config.FsIoTimeoutSeconds) * time.Second
}
- namespaces, fromFallback, err := listNamespacesFunc(ctx, ds.Name, ds.Path, ioTimeout)
+ var (
+ namespaces []pbs.Namespace
+ fromFallback bool
+ err error
+ )
+ if ds.isOverride() {
+ namespaces, err = discoverNamespacesFunc(ctx, ds.normalizedPath(), ioTimeout)
+ fromFallback = true
+ } else {
+ namespaces, fromFallback, err = listNamespacesFunc(ctx, ds.cliName(), ds.Path, ioTimeout)
+ }
if err != nil {
return err
}
@@ -102,6 +394,8 @@ func (c *Collector) collectPBSPxarMetadata(ctx context.Context, datastores []pbs
if len(datastores) == 0 {
return nil
}
+ datastores = clonePBSDatastores(datastores)
+ assignUniquePBSDatastoreOutputKeys(datastores)
c.logger.Debug("Collecting PXAR metadata for %d datastores", len(datastores))
dsWorkers := c.config.PxarDatastoreConcurrency
if dsWorkers <= 0 {
@@ -214,16 +508,17 @@ func (c *Collector) processPxarDatastore(ctx context.Context, ds pbsDatastore, m
start := time.Now()
c.logger.Debug("PXAR: scanning datastore %s at %s", ds.Name, ds.Path)
- dsDir := filepath.Join(metaRoot, ds.Name)
+ dsKey := ds.pathKey()
+ dsDir := filepath.Join(metaRoot, dsKey)
if err := c.ensureDir(dsDir); err != nil {
return fmt.Errorf("failed to create PXAR metadata directory for %s: %w", ds.Name, err)
}
for _, base := range []string{
- filepath.Join(selectedRoot, ds.Name, "vm"),
- filepath.Join(selectedRoot, ds.Name, "ct"),
- filepath.Join(smallRoot, ds.Name, "vm"),
- filepath.Join(smallRoot, ds.Name, "ct"),
+ filepath.Join(selectedRoot, dsKey, "vm"),
+ filepath.Join(selectedRoot, dsKey, "ct"),
+ filepath.Join(smallRoot, dsKey, "vm"),
+ filepath.Join(smallRoot, dsKey, "ct"),
} {
if err := c.ensureDir(base); err != nil {
c.logger.Debug("Failed to prepare PXAR directory %s: %v", base, err)
@@ -278,7 +573,7 @@ func (c *Collector) processPxarDatastore(ctx context.Context, ds pbsDatastore, m
return err
}
- if err := c.writePxarSubdirReport(ctx, filepath.Join(dsDir, fmt.Sprintf("%s_subdirs.txt", ds.Name)), ds, ioTimeout); err != nil {
+ if err := c.writePxarSubdirReport(ctx, filepath.Join(dsDir, fmt.Sprintf("%s_subdirs.txt", dsKey)), ds, ioTimeout); err != nil {
if errors.Is(err, safefs.ErrTimeout) {
c.logger.Warning("Skipping PXAR metadata for datastore %s (path=%s): subdir report timed out (%v)", ds.Name, ds.Path, err)
return nil
@@ -286,7 +581,7 @@ func (c *Collector) processPxarDatastore(ctx context.Context, ds pbsDatastore, m
return err
}
- if err := c.writePxarListReport(ctx, filepath.Join(dsDir, fmt.Sprintf("%s_vm_pxar_list.txt", ds.Name)), ds, "vm", ioTimeout); err != nil {
+ if err := c.writePxarListReport(ctx, filepath.Join(dsDir, fmt.Sprintf("%s_vm_pxar_list.txt", dsKey)), ds, "vm", ioTimeout); err != nil {
if errors.Is(err, safefs.ErrTimeout) {
c.logger.Warning("Skipping PXAR metadata for datastore %s (path=%s): VM list report timed out (%v)", ds.Name, ds.Path, err)
return nil
@@ -294,7 +589,7 @@ func (c *Collector) processPxarDatastore(ctx context.Context, ds pbsDatastore, m
return err
}
- if err := c.writePxarListReport(ctx, filepath.Join(dsDir, fmt.Sprintf("%s_ct_pxar_list.txt", ds.Name)), ds, "ct", ioTimeout); err != nil {
+ if err := c.writePxarListReport(ctx, filepath.Join(dsDir, fmt.Sprintf("%s_ct_pxar_list.txt", dsKey)), ds, "ct", ioTimeout); err != nil {
if errors.Is(err, safefs.ErrTimeout) {
c.logger.Warning("Skipping PXAR metadata for datastore %s (path=%s): CT list report timed out (%v)", ds.Name, ds.Path, err)
return nil
@@ -427,66 +722,82 @@ func (c *Collector) getDatastoreList(ctx context.Context) ([]pbsDatastore, error
}
c.logger.Debug("Enumerating PBS datastores via proxmox-backup-manager")
- if _, err := c.depLookPath("proxmox-backup-manager"); err != nil {
- return nil, nil
- }
-
- output, err := c.depRunCommand(ctx, "proxmox-backup-manager", "datastore", "list", "--output-format=json")
- if err != nil {
- return nil, fmt.Errorf("proxmox-backup-manager datastore list failed: %w", err)
- }
-
type datastoreEntry struct {
Name string `json:"name"`
Path string `json:"path"`
Comment string `json:"comment"`
}
- var entries []datastoreEntry
- if err := json.Unmarshal(output, &entries); err != nil {
- return nil, fmt.Errorf("failed to parse datastore list JSON: %w", err)
- }
-
- datastores := make([]pbsDatastore, 0, len(entries))
- for _, entry := range entries {
- name := strings.TrimSpace(entry.Name)
- if name != "" {
- datastores = append(datastores, pbsDatastore{
- Name: name,
- Path: strings.TrimSpace(entry.Path),
- Comment: strings.TrimSpace(entry.Comment),
- })
+ datastores := make([]pbsDatastore, 0, len(c.config.PBSDatastorePaths))
+ if _, err := c.depLookPath("proxmox-backup-manager"); err != nil {
+ c.logger.Debug("Skipping PBS datastore CLI enumeration: proxmox-backup-manager not available: %v", err)
+ } else {
+ output, err := c.depRunCommand(ctx, "proxmox-backup-manager", "datastore", "list", "--output-format=json")
+ if err != nil {
+ if ctxErr := ctx.Err(); ctxErr != nil {
+ return nil, ctxErr
+ }
+ c.logger.Debug("PBS datastore CLI enumeration failed: %v", err)
+ } else {
+ var entries []datastoreEntry
+ if err := json.Unmarshal(output, &entries); err != nil {
+ c.logger.Debug("Failed to parse PBS datastore list JSON: %v", err)
+ } else {
+ datastores = make([]pbsDatastore, 0, len(entries)+len(c.config.PBSDatastorePaths))
+ for _, entry := range entries {
+ name := strings.TrimSpace(entry.Name)
+ if name == "" {
+ continue
+ }
+ path := strings.TrimSpace(entry.Path)
+ datastores = append(datastores, pbsDatastore{
+ Name: name,
+ Path: path,
+ Comment: strings.TrimSpace(entry.Comment),
+ Source: pbsDatastoreSourceCLI,
+ CLIName: name,
+ NormalizedPath: normalizePBSDatastorePath(path),
+ OutputKey: collectorPathKey(name),
+ })
+ }
+ }
}
}
if len(c.config.PBSDatastorePaths) > 0 {
existing := make(map[string]struct{}, len(datastores))
for _, ds := range datastores {
- if ds.Path != "" {
- existing[ds.Path] = struct{}{}
+ if normalized := ds.normalizedPath(); normalized != "" {
+ existing[normalized] = struct{}{}
}
}
- validName := regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
for idx, override := range c.config.PBSDatastorePaths {
override = strings.TrimSpace(override)
if override == "" {
continue
}
- if _, ok := existing[override]; ok {
+ normalized := normalizePBSDatastorePath(override)
+ if normalized == "" {
continue
}
- name := filepath.Base(filepath.Clean(override))
- if name == "" || name == "." || name == string(os.PathSeparator) || !validName.MatchString(name) {
- name = fmt.Sprintf("datastore_%d", idx+1)
+ if _, ok := existing[normalized]; ok {
+ continue
}
+ existing[normalized] = struct{}{}
+ name := buildPBSOverrideDisplayName(normalized, idx)
datastores = append(datastores, pbsDatastore{
- Name: name,
- Path: override,
- Comment: "configured via PBS_DATASTORE_PATH",
+ Name: name,
+ Path: override,
+ Comment: "configured via PBS_DATASTORE_PATH",
+ Source: pbsDatastoreSourceOverride,
+ NormalizedPath: normalized,
+ OutputKey: buildPBSOverrideOutputKey(normalized),
})
}
}
+ assignUniquePBSDatastoreOutputKeys(datastores)
+
c.logger.Debug("Detected %d configured datastores", len(datastores))
return datastores, nil
}
diff --git a/internal/backup/collector_pbs_datastore_inventory.go b/internal/backup/collector_pbs_datastore_inventory.go
index 72bbb02..1b6ad85 100644
--- a/internal/backup/collector_pbs_datastore_inventory.go
+++ b/internal/backup/collector_pbs_datastore_inventory.go
@@ -66,6 +66,9 @@ type pbsDatastoreInventoryEntry struct {
Path string `json:"path,omitempty"`
Comment string `json:"comment,omitempty"`
Sources []string `json:"sources,omitempty"`
+ Origin string `json:"origin,omitempty"`
+ CLIName string `json:"cli_name,omitempty"`
+ OutputKey string `json:"output_key,omitempty"`
StatPath string `json:"stat_path,omitempty"`
PathOK bool `json:"path_ok,omitempty"`
PathIsDir bool `json:"path_is_dir,omitempty"`
@@ -216,10 +219,13 @@ func (c *Collector) collectPBSDatastoreInventory(ctx context.Context, cliDatasto
for _, def := range merged {
entry := pbsDatastoreInventoryEntry{
- Name: def.Name,
- Path: def.Path,
- Comment: def.Comment,
- Sources: append([]string(nil), def.Sources...),
+ Name: def.Name,
+ Path: def.Path,
+ Comment: def.Comment,
+ Sources: append([]string(nil), def.Sources...),
+ Origin: def.Origin,
+ CLIName: def.CLIName,
+ OutputKey: def.OutputKey,
}
statPath := def.Path
@@ -447,34 +453,71 @@ func (c *Collector) captureInventoryCommand(ctx context.Context, pretty string,
}
type pbsDatastoreDefinition struct {
- Name string
- Path string
- Comment string
- Sources []string
+ Name string
+ Path string
+ Comment string
+ Sources []string
+ Origin string
+ CLIName string
+ OutputKey string
}
func mergePBSDatastoreDefinitions(cli, config []pbsDatastore) []pbsDatastoreDefinition {
merged := make(map[string]*pbsDatastoreDefinition)
+ defKey := func(ds pbsDatastore) string {
+ return pbsDatastoreIdentityKey(ds)
+ }
+
add := func(ds pbsDatastore, source string) {
- name := strings.TrimSpace(ds.Name)
- if name == "" {
+ key := defKey(ds)
+ if key == "" {
return
}
- entry := merged[name]
+ name := strings.TrimSpace(ds.Name)
+ path := strings.TrimSpace(ds.Path)
+ comment := strings.TrimSpace(ds.Comment)
+ origin := ds.inventoryOrigin()
+ cliName := strings.TrimSpace(ds.cliName())
+ entry := merged[key]
if entry == nil {
- entry = &pbsDatastoreDefinition{Name: name}
- merged[name] = entry
+ entry = &pbsDatastoreDefinition{
+ Name: name,
+ Origin: origin,
+ CLIName: cliName,
+ OutputKey: strings.TrimSpace(ds.pathKey()),
+ }
+ merged[key] = entry
}
entry.Sources = append(entry.Sources, source)
- if entry.Path == "" && strings.TrimSpace(ds.Path) != "" {
- entry.Path = strings.TrimSpace(ds.Path)
+ if entry.Name == "" && name != "" {
+ entry.Name = name
+ }
+ if entry.Path == "" && path != "" {
+ entry.Path = path
}
- if entry.Comment == "" && strings.TrimSpace(ds.Comment) != "" {
- entry.Comment = strings.TrimSpace(ds.Comment)
+ if entry.Comment == "" && comment != "" {
+ entry.Comment = comment
+ }
+ if entry.CLIName == "" && !ds.isOverride() && cliName != "" {
+ entry.CLIName = cliName
+ }
+ if entry.OutputKey == "" {
+ entry.OutputKey = strings.TrimSpace(ds.pathKey())
+ }
+
+ switch {
+ case entry.Origin == "":
+ entry.Origin = origin
+ case entry.Origin == pbsDatastoreSourceOverride || origin == pbsDatastoreSourceOverride:
+ if entry.Origin == "" {
+ entry.Origin = pbsDatastoreSourceOverride
+ }
+ case entry.Origin != origin:
+ entry.Origin = pbsDatastoreOriginMerged
}
}
@@ -482,7 +525,11 @@ func mergePBSDatastoreDefinitions(cli, config []pbsDatastore) []pbsDatastoreDefi
add(ds, "datastore.cfg")
}
for _, ds := range cli {
- add(ds, "cli")
+ source := "cli"
+ if ds.isOverride() {
+ source = "override"
+ }
+ add(ds, source)
}
out := make([]pbsDatastoreDefinition, 0, len(merged))
@@ -493,13 +540,41 @@ func mergePBSDatastoreDefinitions(cli, config []pbsDatastore) []pbsDatastoreDefi
v.Sources = uniqueSortedStrings(v.Sources)
out = append(out, *v)
}
+ assignUniquePBSDatastoreDefinitionOutputKeys(out)
sort.Slice(out, func(i, j int) bool {
- return out[i].Name < out[j].Name
+ if out[i].Name != out[j].Name {
+ return out[i].Name < out[j].Name
+ }
+ if p1, p2 := pbsDatastoreOriginSortPriority(out[i].Origin), pbsDatastoreOriginSortPriority(out[j].Origin); p1 != p2 {
+ return p1 < p2
+ }
+ if out[i].Path != out[j].Path {
+ return out[i].Path < out[j].Path
+ }
+ if out[i].OutputKey != out[j].OutputKey {
+ return out[i].OutputKey < out[j].OutputKey
+ }
+ return out[i].CLIName < out[j].CLIName
})
return out
}
+func pbsDatastoreOriginSortPriority(origin string) int {
+ switch strings.TrimSpace(origin) {
+ case pbsDatastoreOriginMerged:
+ return 0
+ case pbsDatastoreSourceConfig:
+ return 1
+ case pbsDatastoreSourceCLI:
+ return 2
+ case pbsDatastoreSourceOverride:
+ return 3
+ default:
+ return 4
+ }
+}
+
func parsePBSDatastoreCfg(contents string) []pbsDatastore {
contents = strings.TrimSpace(contents)
if contents == "" {
@@ -536,7 +611,12 @@ func parsePBSDatastoreCfg(contents string) []pbsDatastore {
if name == "" {
continue
}
- current = &pbsDatastore{Name: name}
+ current = &pbsDatastore{
+ Name: name,
+ Source: pbsDatastoreSourceConfig,
+ CLIName: name,
+ OutputKey: collectorPathKey(name),
+ }
continue
}
@@ -560,6 +640,10 @@ func parsePBSDatastoreCfg(contents string) []pbsDatastore {
}
flush()
+ for i := range out {
+ out[i].NormalizedPath = normalizePBSDatastorePath(out[i].Path)
+ }
+
return out
}
diff --git a/internal/backup/collector_pbs_datastore_inventory_test.go b/internal/backup/collector_pbs_datastore_inventory_test.go
index ca11128..044c0b4 100644
--- a/internal/backup/collector_pbs_datastore_inventory_test.go
+++ b/internal/backup/collector_pbs_datastore_inventory_test.go
@@ -5,6 +5,7 @@ import (
"encoding/json"
"os"
"path/filepath"
+ "strings"
"testing"
"github.com/tis24dev/proxsave/internal/types"
@@ -295,3 +296,109 @@ func TestCollectPBSDatastoreInventoryCapturesHostCommands(t *testing.T) {
t.Fatalf("expected findmnt output to be captured")
}
}
+
+func TestMergePBSDatastoreDefinitionsKeepsOverridesSeparate(t *testing.T) {
+ config := []pbsDatastore{{
+ Name: "backup",
+ Path: "/real/backup",
+ Comment: "primary",
+ Source: pbsDatastoreSourceConfig,
+ CLIName: "backup",
+ NormalizedPath: normalizePBSDatastorePath("/real/backup"),
+ OutputKey: collectorPathKey("backup"),
+ }}
+ cli := []pbsDatastore{
+ {
+ Name: "backup",
+ Path: "/real/backup",
+ Comment: "runtime",
+ Source: pbsDatastoreSourceCLI,
+ CLIName: "backup",
+ NormalizedPath: normalizePBSDatastorePath("/real/backup"),
+ OutputKey: collectorPathKey("backup"),
+ },
+ {
+ Name: "backup",
+ Path: "/mnt/a/backup",
+ Comment: "configured via PBS_DATASTORE_PATH",
+ Source: pbsDatastoreSourceOverride,
+ NormalizedPath: normalizePBSDatastorePath("/mnt/a/backup"),
+ OutputKey: buildPBSOverrideOutputKey("/mnt/a/backup"),
+ },
+ {
+ Name: "backup",
+ Path: "/srv/b/backup",
+ Comment: "configured via PBS_DATASTORE_PATH",
+ Source: pbsDatastoreSourceOverride,
+ NormalizedPath: normalizePBSDatastorePath("/srv/b/backup"),
+ OutputKey: buildPBSOverrideOutputKey("/srv/b/backup"),
+ },
+ }
+
+ merged := mergePBSDatastoreDefinitions(cli, config)
+ if len(merged) != 3 {
+ t.Fatalf("expected 3 merged entries, got %d: %+v", len(merged), merged)
+ }
+
+ if merged[0].Origin != pbsDatastoreOriginMerged || merged[0].Path != "/real/backup" {
+ t.Fatalf("expected real datastore entry first, got %+v", merged[0])
+ }
+ if merged[1].Origin != pbsDatastoreSourceOverride || merged[2].Origin != pbsDatastoreSourceOverride {
+ t.Fatalf("expected override entries after real datastore, got %+v", merged)
+ }
+ if merged[1].OutputKey == merged[2].OutputKey {
+ t.Fatalf("override output keys should differ, got %+v", merged)
+ }
+}
+
+func TestMergePBSDatastoreDefinitionsDisambiguatesCLIAndOverrideOutputKeyCollisions(t *testing.T) {
+ overridePath := "/mnt/a/backup"
+ collidingKey := buildPBSOverrideOutputKey(overridePath)
+
+ cli := []pbsDatastore{
+ {
+ Name: collidingKey,
+ Path: "/real/runtime",
+ Comment: "runtime",
+ Source: pbsDatastoreSourceCLI,
+ CLIName: collidingKey,
+ NormalizedPath: normalizePBSDatastorePath("/real/runtime"),
+ OutputKey: collidingKey,
+ },
+ {
+ Name: "backup",
+ Path: overridePath,
+ Comment: "configured via PBS_DATASTORE_PATH",
+ Source: pbsDatastoreSourceOverride,
+ NormalizedPath: normalizePBSDatastorePath(overridePath),
+ OutputKey: buildPBSOverrideOutputKey(overridePath),
+ },
+ }
+
+ merged := mergePBSDatastoreDefinitions(cli, nil)
+ if len(merged) != 2 {
+ t.Fatalf("expected 2 merged entries, got %d: %+v", len(merged), merged)
+ }
+
+ var cliEntry, overrideEntry *pbsDatastoreDefinition
+ for i := range merged {
+ switch merged[i].Origin {
+ case pbsDatastoreSourceCLI:
+ cliEntry = &merged[i]
+ case pbsDatastoreSourceOverride:
+ overrideEntry = &merged[i]
+ }
+ }
+ if cliEntry == nil || overrideEntry == nil {
+ t.Fatalf("expected one CLI and one override entry, got %+v", merged)
+ }
+ if cliEntry.OutputKey != collidingKey {
+ t.Fatalf("CLI datastore should keep base key %q, got %+v", collidingKey, merged)
+ }
+ if overrideEntry.OutputKey == collidingKey {
+ t.Fatalf("override output key should be disambiguated, got %+v", merged)
+ }
+ if !strings.HasPrefix(overrideEntry.OutputKey, collidingKey+"_") {
+ t.Fatalf("override output key should extend colliding base key, got %+v", merged)
+ }
+}
diff --git a/internal/backup/collector_pbs_extra_test.go b/internal/backup/collector_pbs_extra_test.go
index bdb822f..da250ac 100644
--- a/internal/backup/collector_pbs_extra_test.go
+++ b/internal/backup/collector_pbs_extra_test.go
@@ -13,20 +13,29 @@ import (
"github.com/tis24dev/proxsave/internal/types"
)
-func TestGetDatastoreListNoBinary(t *testing.T) {
- collector := NewCollectorWithDeps(newTestLogger(), GetDefaultCollectorConfig(), t.TempDir(), types.ProxmoxBS, false, CollectorDeps{
+func TestGetDatastoreListNoBinaryStillIncludesOverrides(t *testing.T) {
+ cfg := GetDefaultCollectorConfig()
+ cfg.PBSDatastorePaths = []string{"/override/no-cli"}
+
+ collector := NewCollectorWithDeps(newTestLogger(), cfg, t.TempDir(), types.ProxmoxBS, false, CollectorDeps{
LookPath: func(string) (string, error) { return "", errors.New("not found") },
})
ds, err := collector.getDatastoreList(context.Background())
if err != nil {
t.Fatalf("expected nil error, got %v", err)
}
- if len(ds) != 0 {
- t.Fatalf("expected empty datastores when binary missing")
+ if len(ds) != 1 {
+ t.Fatalf("expected override datastore when binary missing, got %+v", ds)
+ }
+ if ds[0].Name != "no-cli" || ds[0].Path != "/override/no-cli" || ds[0].Source != pbsDatastoreSourceOverride {
+ t.Fatalf("unexpected override datastore when binary missing: %+v", ds[0])
}
}
-func TestGetDatastoreListCommandErrorAndParseError(t *testing.T) {
+func TestGetDatastoreListCommandErrorAndParseErrorStillIncludeOverrides(t *testing.T) {
+ cfg := GetDefaultCollectorConfig()
+ cfg.PBSDatastorePaths = []string{"/override/fallback"}
+
deps := CollectorDeps{
LookPath: func(string) (string, error) { return "/bin/true", nil },
RunCommand: func(context.Context, string, ...string) ([]byte, error) {
@@ -34,17 +43,25 @@ func TestGetDatastoreListCommandErrorAndParseError(t *testing.T) {
},
}
- c := NewCollectorWithDeps(newTestLogger(), GetDefaultCollectorConfig(), t.TempDir(), types.ProxmoxBS, false, deps)
- if _, err := c.getDatastoreList(context.Background()); err == nil {
- t.Fatalf("expected error when command fails")
+ c := NewCollectorWithDeps(newTestLogger(), cfg, t.TempDir(), types.ProxmoxBS, false, deps)
+ ds, err := c.getDatastoreList(context.Background())
+ if err != nil {
+ t.Fatalf("expected nil error on command failure, got %v", err)
+ }
+ if len(ds) != 1 || ds[0].Path != "/override/fallback" || ds[0].Source != pbsDatastoreSourceOverride {
+ t.Fatalf("expected override fallback on command failure, got %+v", ds)
}
// Now simulate parse error
c.deps.RunCommand = func(context.Context, string, ...string) ([]byte, error) {
return []byte("{invalid"), nil
}
- if _, err := c.getDatastoreList(context.Background()); err == nil {
- t.Fatalf("expected parse error for invalid JSON")
+ ds, err = c.getDatastoreList(context.Background())
+ if err != nil {
+ t.Fatalf("expected nil error on parse failure, got %v", err)
+ }
+ if len(ds) != 1 || ds[0].Path != "/override/fallback" || ds[0].Source != pbsDatastoreSourceOverride {
+ t.Fatalf("expected override fallback on parse failure, got %+v", ds)
}
}
diff --git a/internal/backup/collector_pbs_test.go b/internal/backup/collector_pbs_test.go
index e44fddf..317245f 100644
--- a/internal/backup/collector_pbs_test.go
+++ b/internal/backup/collector_pbs_test.go
@@ -12,6 +12,7 @@ import (
"time"
"github.com/tis24dev/proxsave/internal/pbs"
+ "github.com/tis24dev/proxsave/internal/types"
)
func TestGetDatastoreListSuccessWithOverrides(t *testing.T) {
@@ -62,6 +63,96 @@ func TestGetDatastoreListSuccessWithOverrides(t *testing.T) {
if datastores[2].Comment != "configured via PBS_DATASTORE_PATH" {
t.Fatalf("expected override comment, got %q", datastores[2].Comment)
}
+ if datastores[0].Source != pbsDatastoreSourceCLI || datastores[0].CLIName != "primary" || datastores[0].OutputKey != "primary" {
+ t.Fatalf("expected CLI datastore metadata, got %+v", datastores[0])
+ }
+ if datastores[1].Source != pbsDatastoreSourceOverride || datastores[1].CLIName != "" || datastores[1].OutputKey == "" {
+ t.Fatalf("expected override datastore metadata, got %+v", datastores[1])
+ }
+}
+
+func TestGetDatastoreListOverrideCollisionsUseDistinctOutputKeys(t *testing.T) {
+ collector := newTestCollectorWithDeps(t, CollectorDeps{
+ LookPath: func(cmd string) (string, error) {
+ return "/usr/bin/" + cmd, nil
+ },
+ RunCommand: func(ctx context.Context, name string, args ...string) ([]byte, error) {
+ return []byte(`[{"name":"primary","path":"/data/primary/","comment":"main"}]`), nil
+ },
+ })
+ collector.config.PBSDatastorePaths = []string{
+ "/mnt/a/backup",
+ "/srv/b/backup",
+ "/srv/b/backup/",
+ "/data/primary",
+ }
+
+ datastores, err := collector.getDatastoreList(context.Background())
+ if err != nil {
+ t.Fatalf("getDatastoreList failed: %v", err)
+ }
+ if len(datastores) != 3 {
+ t.Fatalf("expected 3 datastores after normalized dedupe, got %d: %+v", len(datastores), datastores)
+ }
+
+ if datastores[1].Name != "backup" || datastores[2].Name != "backup" {
+ t.Fatalf("expected colliding override display names, got %+v", datastores)
+ }
+ if datastores[1].OutputKey == datastores[2].OutputKey {
+ t.Fatalf("override output keys should differ, got %q", datastores[1].OutputKey)
+ }
+ if datastores[1].NormalizedPath == datastores[2].NormalizedPath {
+ t.Fatalf("override normalized paths should differ, got %+v", datastores)
+ }
+}
+
+func TestGetDatastoreListDisambiguatesCLIAndOverrideOutputKeyCollisions(t *testing.T) {
+ overridePath := "/mnt/a/backup"
+ collidingKey := buildPBSOverrideOutputKey(overridePath)
+
+ collector := newTestCollectorWithDeps(t, CollectorDeps{
+ LookPath: func(cmd string) (string, error) {
+ return "/usr/bin/" + cmd, nil
+ },
+ RunCommand: func(ctx context.Context, name string, args ...string) ([]byte, error) {
+ return []byte(fmt.Sprintf(`[{"name":%q,"path":"/data/runtime","comment":"main"}]`, collidingKey)), nil
+ },
+ })
+ collector.config.PBSDatastorePaths = []string{overridePath}
+
+ datastores, err := collector.getDatastoreList(context.Background())
+ if err != nil {
+ t.Fatalf("getDatastoreList failed: %v", err)
+ }
+ if len(datastores) != 2 {
+ t.Fatalf("expected 2 datastores, got %d: %+v", len(datastores), datastores)
+ }
+
+ cli := datastores[0]
+ override := datastores[1]
+ if cli.OutputKey != collidingKey {
+ t.Fatalf("CLI datastore should keep base key %q, got %+v", collidingKey, cli)
+ }
+ if override.OutputKey == collidingKey {
+ t.Fatalf("override key should be disambiguated away from %q, got %+v", collidingKey, override)
+ }
+ if override.OutputKey == cli.OutputKey {
+ t.Fatalf("datastore output keys should differ, got %+v", datastores)
+ }
+}
+
+func TestPBSDatastorePathKeyUsesOverridePathFallback(t *testing.T) {
+ dsPath := "/mnt/a/backup"
+ ds := pbsDatastore{
+ Name: "backup",
+ Path: dsPath,
+ Source: pbsDatastoreSourceOverride,
+ NormalizedPath: normalizePBSDatastorePath(dsPath),
+ }
+
+ if got, want := ds.pathKey(), buildPBSOverrideOutputKey(dsPath); got != want {
+ t.Fatalf("override pathKey()=%q want %q", got, want)
+ }
}
func TestGetDatastoreListContextCanceled(t *testing.T) {
@@ -98,10 +189,17 @@ func TestGetDatastoreListCommandError(t *testing.T) {
return nil, fmt.Errorf("command failed")
},
})
+ collector.config.PBSDatastorePaths = []string{"/override/from-error"}
- _, err := collector.getDatastoreList(context.Background())
- if err == nil || !strings.Contains(err.Error(), "datastore list failed") {
- t.Fatalf("expected datastore list error, got %v", err)
+ datastores, err := collector.getDatastoreList(context.Background())
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if len(datastores) != 1 {
+ t.Fatalf("expected only override datastore, got %+v", datastores)
+ }
+ if datastores[0].Name != "from-error" || datastores[0].Path != "/override/from-error" || datastores[0].Source != pbsDatastoreSourceOverride {
+ t.Fatalf("unexpected override datastore after command failure: %+v", datastores[0])
}
}
@@ -114,10 +212,17 @@ func TestGetDatastoreListBadJSON(t *testing.T) {
return []byte("not-json"), nil
},
})
+ collector.config.PBSDatastorePaths = []string{"/override/from-parse"}
- _, err := collector.getDatastoreList(context.Background())
- if err == nil || !strings.Contains(err.Error(), "failed to parse datastore list JSON") {
- t.Fatalf("expected parse error, got %v", err)
+ datastores, err := collector.getDatastoreList(context.Background())
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if len(datastores) != 1 {
+ t.Fatalf("expected only override datastore, got %+v", datastores)
+ }
+ if datastores[0].Name != "from-parse" || datastores[0].Path != "/override/from-parse" || datastores[0].Source != pbsDatastoreSourceOverride {
+ t.Fatalf("unexpected override datastore after parse failure: %+v", datastores[0])
}
}
@@ -305,6 +410,50 @@ func TestCollectDatastoreNamespacesError(t *testing.T) {
}
}
+func TestCollectDatastoreNamespacesOverrideUsesFilesystemOnly(t *testing.T) {
+ origList := listNamespacesFunc
+ t.Cleanup(func() { listNamespacesFunc = origList })
+ listNamespacesFunc = func(context.Context, string, string, time.Duration) ([]pbs.Namespace, bool, error) {
+ t.Fatal("CLI namespace discovery should not be used for override paths")
+ return nil, false, nil
+ }
+
+ collector := newTestCollectorWithDeps(t, CollectorDeps{})
+ dsDir := filepath.Join(collector.tempDir, "datastores")
+ if err := os.MkdirAll(dsDir, 0o755); err != nil {
+ t.Fatalf("failed to create datastore dir: %v", err)
+ }
+
+ dsPath := filepath.Join(collector.tempDir, "override")
+ if err := os.MkdirAll(filepath.Join(dsPath, "local", "vm"), 0o755); err != nil {
+ t.Fatalf("failed to create override namespace fixture: %v", err)
+ }
+
+ ds := pbsDatastore{
+ Name: "backup",
+ Path: dsPath,
+ Source: pbsDatastoreSourceOverride,
+ NormalizedPath: normalizePBSDatastorePath(dsPath),
+ OutputKey: buildPBSOverrideOutputKey(dsPath),
+ }
+ if err := collector.collectDatastoreNamespaces(context.Background(), ds, dsDir); err != nil {
+ t.Fatalf("collectDatastoreNamespaces failed: %v", err)
+ }
+
+ data, err := os.ReadFile(filepath.Join(dsDir, fmt.Sprintf("%s_namespaces.json", ds.pathKey())))
+ if err != nil {
+ t.Fatalf("namespaces file not created: %v", err)
+ }
+
+ var namespaces []pbs.Namespace
+ if err := json.Unmarshal(data, &namespaces); err != nil {
+ t.Fatalf("failed to decode namespaces: %v", err)
+ }
+ if len(namespaces) != 2 || namespaces[1].Ns != "local" {
+ t.Fatalf("unexpected override namespaces: %+v", namespaces)
+ }
+}
+
func TestCollectDatastoreConfigsDryRun(t *testing.T) {
stubListNamespaces(t, func(context.Context, string, string, time.Duration) ([]pbs.Namespace, bool, error) {
return []pbs.Namespace{{Ns: ""}}, false, nil
@@ -333,6 +482,348 @@ func TestCollectDatastoreConfigsDryRun(t *testing.T) {
}
}
+func TestCollectDatastoreConfigs_UsesPathSafeKeyForUnsafeDatastoreName(t *testing.T) {
+ unsafeName := "../escape"
+ expectedKey := collectorPathKey(unsafeName)
+
+ stubListNamespaces(t, func(_ context.Context, name, path string, _ time.Duration) ([]pbs.Namespace, bool, error) {
+ if name != unsafeName || path != "/fake" {
+ t.Fatalf("unexpected datastore args name=%q path=%q", name, path)
+ }
+ return []pbs.Namespace{{Ns: ""}}, false, nil
+ })
+
+ var seenArgs []string
+ collector := newTestCollectorWithDeps(t, CollectorDeps{
+ LookPath: func(cmd string) (string, error) {
+ return "/usr/bin/" + cmd, nil
+ },
+ RunCommand: func(_ context.Context, name string, args ...string) ([]byte, error) {
+ if name != "proxmox-backup-manager" {
+ return nil, fmt.Errorf("unexpected command %s", name)
+ }
+ seenArgs = append([]string(nil), args...)
+ return []byte(`{"ok":true}`), nil
+ },
+ })
+
+ ds := pbsDatastore{Name: unsafeName, Path: "/fake"}
+ if err := collector.collectDatastoreConfigs(context.Background(), []pbsDatastore{ds}); err != nil {
+ t.Fatalf("collectDatastoreConfigs failed: %v", err)
+ }
+
+ if len(seenArgs) < 3 || seenArgs[0] != "datastore" || seenArgs[1] != "show" || seenArgs[2] != unsafeName {
+ t.Fatalf("expected raw datastore name in command args, got %v", seenArgs)
+ }
+
+ datastoreDir := filepath.Join(collector.tempDir, "var", "lib", "proxsave-info", "pbs", "datastores")
+ safeConfig := filepath.Join(datastoreDir, fmt.Sprintf("%s_config.json", expectedKey))
+ safeNamespaces := filepath.Join(datastoreDir, fmt.Sprintf("%s_namespaces.json", expectedKey))
+ for _, path := range []string{safeConfig, safeNamespaces} {
+ if _, err := os.Stat(path); err != nil {
+ t.Fatalf("expected safe output %s: %v", path, err)
+ }
+ }
+
+ rawConfig := filepath.Join(datastoreDir, fmt.Sprintf("%s_config.json", unsafeName))
+ rawNamespaces := filepath.Join(datastoreDir, fmt.Sprintf("%s_namespaces.json", unsafeName))
+ for _, path := range []string{rawConfig, rawNamespaces} {
+ if path == safeConfig || path == safeNamespaces {
+ continue
+ }
+ if _, err := os.Stat(path); !os.IsNotExist(err) {
+ t.Fatalf("raw output path should not exist (%s), got err=%v", path, err)
+ }
+ }
+}
+
+func TestCollectDatastoreConfigsSkipsCLIConfigForOverridePaths(t *testing.T) {
+ origList := listNamespacesFunc
+ t.Cleanup(func() { listNamespacesFunc = origList })
+ listNamespacesFunc = func(context.Context, string, string, time.Duration) ([]pbs.Namespace, bool, error) {
+ t.Fatal("CLI namespace discovery should not be used for override paths")
+ return nil, false, nil
+ }
+
+ dsPath := t.TempDir()
+ if err := os.MkdirAll(filepath.Join(dsPath, "vm"), 0o755); err != nil {
+ t.Fatalf("mkdir vm: %v", err)
+ }
+
+ var runCalls int
+ collector := newTestCollectorWithDeps(t, CollectorDeps{
+ LookPath: func(cmd string) (string, error) {
+ return "/usr/bin/" + cmd, nil
+ },
+ RunCommand: func(_ context.Context, name string, args ...string) ([]byte, error) {
+ runCalls++
+ return []byte(`{"ok":true}`), nil
+ },
+ })
+
+ ds := pbsDatastore{
+ Name: "backup",
+ Path: dsPath,
+ Comment: "configured via PBS_DATASTORE_PATH",
+ Source: pbsDatastoreSourceOverride,
+ NormalizedPath: normalizePBSDatastorePath(dsPath),
+ OutputKey: buildPBSOverrideOutputKey(dsPath),
+ }
+ if err := collector.collectDatastoreConfigs(context.Background(), []pbsDatastore{ds}); err != nil {
+ t.Fatalf("collectDatastoreConfigs failed: %v", err)
+ }
+
+ datastoreDir := filepath.Join(collector.tempDir, "var", "lib", "proxsave-info", "pbs", "datastores")
+ if _, err := os.Stat(filepath.Join(datastoreDir, fmt.Sprintf("%s_namespaces.json", ds.pathKey()))); err != nil {
+ t.Fatalf("expected override namespaces file, got %v", err)
+ }
+ if _, err := os.Stat(filepath.Join(datastoreDir, fmt.Sprintf("%s_config.json", ds.pathKey()))); !os.IsNotExist(err) {
+ t.Fatalf("override config file should not exist, got err=%v", err)
+ }
+ if runCalls != 0 {
+ t.Fatalf("expected no CLI datastore show calls for override, got %d", runCalls)
+ }
+}
+
+func TestCollectDatastoreConfigsDisambiguatesManualCLIAndOverrideKeyCollisions(t *testing.T) {
+ overridePath := t.TempDir()
+ for _, sub := range []string{"vm", "ct"} {
+ if err := os.MkdirAll(filepath.Join(overridePath, sub), 0o755); err != nil {
+ t.Fatalf("mkdir %s: %v", sub, err)
+ }
+ }
+
+ collidingKey := buildPBSOverrideOutputKey(overridePath)
+ stubListNamespaces(t, func(_ context.Context, name, path string, _ time.Duration) ([]pbs.Namespace, bool, error) {
+ return []pbs.Namespace{{Ns: name, Path: path}}, false, nil
+ })
+
+ origDiscover := discoverNamespacesFunc
+ t.Cleanup(func() { discoverNamespacesFunc = origDiscover })
+ discoverNamespacesFunc = func(_ context.Context, path string, _ time.Duration) ([]pbs.Namespace, error) {
+ return []pbs.Namespace{{Ns: "override", Path: path}}, nil
+ }
+
+ collector := newTestCollectorWithDeps(t, CollectorDeps{
+ LookPath: func(cmd string) (string, error) {
+ return "/usr/bin/" + cmd, nil
+ },
+ RunCommand: func(_ context.Context, name string, args ...string) ([]byte, error) {
+ return []byte(`{"ok":true}`), nil
+ },
+ })
+
+ datastores := []pbsDatastore{
+ {Name: collidingKey, Path: "/data/runtime"},
+ {
+ Name: "backup",
+ Path: overridePath,
+ Source: pbsDatastoreSourceOverride,
+ NormalizedPath: normalizePBSDatastorePath(overridePath),
+ },
+ }
+ if err := collector.collectDatastoreConfigs(context.Background(), datastores); err != nil {
+ t.Fatalf("collectDatastoreConfigs failed: %v", err)
+ }
+
+ resolved := clonePBSDatastores(datastores)
+ assignUniquePBSDatastoreOutputKeys(resolved)
+ if resolved[0].OutputKey == resolved[1].OutputKey {
+ t.Fatalf("expected resolved keys to differ, got %+v", resolved)
+ }
+
+ datastoreDir := filepath.Join(collector.tempDir, "var", "lib", "proxsave-info", "pbs", "datastores")
+ for _, suffix := range []string{"config.json", "namespaces.json"} {
+ if _, err := os.Stat(filepath.Join(datastoreDir, fmt.Sprintf("%s_%s", resolved[0].OutputKey, suffix))); err != nil {
+ t.Fatalf("expected CLI output for %s: %v", suffix, err)
+ }
+ }
+ if _, err := os.Stat(filepath.Join(datastoreDir, fmt.Sprintf("%s_namespaces.json", resolved[1].OutputKey))); err != nil {
+ t.Fatalf("expected override namespaces file: %v", err)
+ }
+ if _, err := os.Stat(filepath.Join(datastoreDir, fmt.Sprintf("%s_config.json", resolved[1].OutputKey))); !os.IsNotExist(err) {
+ t.Fatalf("override config file should not exist, got err=%v", err)
+ }
+}
+
+func TestCollectPBSPxarMetadata_UsesPathSafeKeyForUnsafeDatastoreName(t *testing.T) {
+ tmp := t.TempDir()
+ cfg := GetDefaultCollectorConfig()
+ collector := NewCollector(newTestLogger(), cfg, tmp, types.ProxmoxBS, false)
+
+ dsPath := filepath.Join(tmp, "datastore")
+ for _, sub := range []string{"vm", "ct"} {
+ if err := os.MkdirAll(filepath.Join(dsPath, sub), 0o755); err != nil {
+ t.Fatalf("mkdir %s: %v", sub, err)
+ }
+ }
+ if err := os.WriteFile(filepath.Join(dsPath, "vm", "backup1.pxar"), []byte("data"), 0o640); err != nil {
+ t.Fatalf("write vm pxar: %v", err)
+ }
+ if err := os.WriteFile(filepath.Join(dsPath, "ct", "backup2.pxar"), []byte("data"), 0o640); err != nil {
+ t.Fatalf("write ct pxar: %v", err)
+ }
+
+ ds := pbsDatastore{Name: "../escape", Path: dsPath, Comment: "unsafe"}
+ if err := collector.collectPBSPxarMetadata(context.Background(), []pbsDatastore{ds}); err != nil {
+ t.Fatalf("collectPBSPxarMetadata failed: %v", err)
+ }
+
+ dsKey := collectorPathKey(ds.Name)
+ base := filepath.Join(tmp, "var/lib/proxsave-info", "pbs", "pxar", "metadata", dsKey)
+ for _, path := range []string{
+ filepath.Join(base, "metadata.json"),
+ filepath.Join(base, fmt.Sprintf("%s_subdirs.txt", dsKey)),
+ filepath.Join(base, fmt.Sprintf("%s_vm_pxar_list.txt", dsKey)),
+ filepath.Join(base, fmt.Sprintf("%s_ct_pxar_list.txt", dsKey)),
+ } {
+ if _, err := os.Stat(path); err != nil {
+ t.Fatalf("expected safe PXAR output %s: %v", path, err)
+ }
+ }
+
+ metaBytes, err := os.ReadFile(filepath.Join(base, "metadata.json"))
+ if err != nil {
+ t.Fatalf("read metadata.json: %v", err)
+ }
+ if !strings.Contains(string(metaBytes), ds.Name) {
+ t.Fatalf("metadata should keep raw datastore name, got %s", string(metaBytes))
+ }
+
+ selectedVM := filepath.Join(tmp, "var/lib/proxsave-info", "pbs", "pxar", "selected", dsKey, "vm")
+ smallVM := filepath.Join(tmp, "var/lib/proxsave-info", "pbs", "pxar", "small", dsKey, "vm")
+ for _, path := range []string{selectedVM, smallVM} {
+ if info, err := os.Stat(path); err != nil || !info.IsDir() {
+ t.Fatalf("expected safe PXAR directory %s, got err=%v", path, err)
+ }
+ }
+
+ rawBase := filepath.Join(tmp, "var/lib/proxsave-info", "pbs", "pxar", "metadata", ds.Name)
+ if rawBase != base {
+ if _, err := os.Stat(rawBase); !os.IsNotExist(err) {
+ t.Fatalf("raw PXAR directory should not exist (%s), got err=%v", rawBase, err)
+ }
+ }
+}
+
+func TestCollectPBSCommands_UsesPathSafeKeyForUnsafeDatastoreName(t *testing.T) {
+ pbsRoot := t.TempDir()
+ if err := os.WriteFile(filepath.Join(pbsRoot, "tape.cfg"), []byte("ok"), 0o640); err != nil {
+ t.Fatalf("write tape.cfg: %v", err)
+ }
+
+ cfg := GetDefaultCollectorConfig()
+ cfg.PBSConfigPath = pbsRoot
+
+ collector := NewCollectorWithDeps(newTestLogger(), cfg, t.TempDir(), types.ProxmoxBS, false, CollectorDeps{
+ LookPath: func(name string) (string, error) {
+ return "/bin/" + name, nil
+ },
+ RunCommand: func(_ context.Context, name string, args ...string) ([]byte, error) {
+ return []byte(fmt.Sprintf("%s %s", name, strings.Join(args, " "))), nil
+ },
+ })
+
+ ds := pbsDatastore{Name: "../escape", Path: "/data/escape"}
+ if err := collector.collectPBSCommands(context.Background(), []pbsDatastore{ds}); err != nil {
+ t.Fatalf("collectPBSCommands error: %v", err)
+ }
+
+ key := collectorPathKey(ds.Name)
+ commandsDir := filepath.Join(collector.tempDir, "var/lib/proxsave-info", "commands", "pbs")
+ safePath := filepath.Join(commandsDir, fmt.Sprintf("datastore_%s_status.json", key))
+ if _, err := os.Stat(safePath); err != nil {
+ t.Fatalf("expected safe datastore status file: %v", err)
+ }
+ data, err := os.ReadFile(safePath)
+ if err != nil {
+ t.Fatalf("read datastore status file: %v", err)
+ }
+ if !strings.Contains(string(data), ds.Name) {
+ t.Fatalf("status file should reflect raw datastore name in command output, got %s", string(data))
+ }
+
+ rawPath := filepath.Join(commandsDir, fmt.Sprintf("datastore_%s_status.json", ds.Name))
+ if rawPath != safePath {
+ if _, err := os.Stat(rawPath); !os.IsNotExist(err) {
+ t.Fatalf("raw datastore status path should not exist (%s), got err=%v", rawPath, err)
+ }
+ }
+}
+
+func TestCollectPBSCommands_DisambiguatesStatusFilesForCollidingDatastoreKeys(t *testing.T) {
+ pbsRoot := t.TempDir()
+ if err := os.WriteFile(filepath.Join(pbsRoot, "tape.cfg"), []byte("ok"), 0o640); err != nil {
+ t.Fatalf("write tape.cfg: %v", err)
+ }
+
+ cfg := GetDefaultCollectorConfig()
+ cfg.PBSConfigPath = pbsRoot
+
+ collector := NewCollectorWithDeps(newTestLogger(), cfg, t.TempDir(), types.ProxmoxBS, false, CollectorDeps{
+ LookPath: func(name string) (string, error) {
+ return "/bin/" + name, nil
+ },
+ RunCommand: func(_ context.Context, name string, args ...string) ([]byte, error) {
+ return []byte(fmt.Sprintf("%s %s", name, strings.Join(args, " "))), nil
+ },
+ })
+
+ unsafeName := "../escape"
+ baseKey := collectorPathKey(unsafeName)
+ datastores := []pbsDatastore{
+ {Name: unsafeName, Path: "/data/unsafe"},
+ {Name: baseKey, Path: "/data/colliding"},
+ }
+ if got := pbsDatastoreCandidateOutputKey(datastores[1]); got != baseKey {
+ t.Fatalf("expected second datastore to preserve colliding base key %q, got %q", baseKey, got)
+ }
+
+ resolved := clonePBSDatastores(datastores)
+ assignUniquePBSDatastoreOutputKeys(resolved)
+ if resolved[0].OutputKey == resolved[1].OutputKey {
+ t.Fatalf("expected distinct output keys for colliding datastores, got %+v", resolved)
+ }
+
+ baseCount := 0
+ suffixedCount := 0
+ for _, ds := range resolved {
+ switch {
+ case ds.OutputKey == baseKey:
+ baseCount++
+ case strings.HasPrefix(ds.OutputKey, baseKey+"_"):
+ suffixedCount++
+ }
+ }
+ if baseCount != 1 || suffixedCount != 1 {
+ t.Fatalf("expected one base key and one suffixed key from collision, got %+v", resolved)
+ }
+
+ if err := collector.collectPBSCommands(context.Background(), datastores); err != nil {
+ t.Fatalf("collectPBSCommands error: %v", err)
+ }
+
+ commandsDir := filepath.Join(collector.tempDir, "var/lib/proxsave-info", "commands", "pbs")
+ statusFiles, err := filepath.Glob(filepath.Join(commandsDir, "datastore_*_status.json"))
+ if err != nil {
+ t.Fatalf("glob status files: %v", err)
+ }
+ if len(statusFiles) != 2 {
+ t.Fatalf("expected 2 datastore status files, got %d: %v", len(statusFiles), statusFiles)
+ }
+
+ for _, ds := range resolved {
+ statusPath := filepath.Join(commandsDir, fmt.Sprintf("datastore_%s_status.json", ds.OutputKey))
+ data, err := os.ReadFile(statusPath)
+ if err != nil {
+ t.Fatalf("read datastore status file %s: %v", statusPath, err)
+ }
+ if !strings.Contains(string(data), ds.Name) {
+ t.Fatalf("status file %s should contain datastore name %q, got %s", statusPath, ds.Name, string(data))
+ }
+ }
+}
+
func TestCollectUserConfigsWithTokens(t *testing.T) {
collector := newTestCollectorWithDeps(t, CollectorDeps{
LookPath: func(cmd string) (string, error) {
diff --git a/internal/backup/collector_pve.go b/internal/backup/collector_pve.go
index ed0a8a6..8fe16eb 100644
--- a/internal/backup/collector_pve.go
+++ b/internal/backup/collector_pve.go
@@ -30,6 +30,10 @@ type pveStorageEntry struct {
Status string
}
+func (s pveStorageEntry) pathKey() string {
+ return collectorPathKey(s.Name)
+}
+
type pveRuntimeInfo struct {
Nodes []string
Storages []pveStorageEntry
@@ -1029,7 +1033,7 @@ func (c *Collector) collectPVEStorageMetadata(ctx context.Context, storages []pv
var summary strings.Builder
summary.WriteString("# PVE datastores detected on ")
summary.WriteString(time.Now().Format(time.RFC3339))
- summary.WriteString("\n# Format: NAME|PATH|TYPE|CONTENT\n\n")
+ summary.WriteString("\n# Format: TYPE|NAME|PATH|CONTENT\n\n")
ioTimeout := time.Duration(0)
if c.config != nil && c.config.FsIoTimeoutSeconds > 0 {
@@ -1073,6 +1077,45 @@ func (c *Collector) collectPVEStorageMetadata(ctx context.Context, storages []pv
return ""
}
+ describeOptionalBool := func(v *bool) string {
+ if v == nil {
+ return "unknown"
+ }
+ if *v {
+ return "true"
+ }
+ return "false"
+ }
+
+ logSkipDetails := func(storage pveStorageEntry, reason string, err error) {
+ if err != nil {
+ c.logger.Debug(
+ "PVE datastore skip details: name=%s path=%s type=%s content=%s active=%s enabled=%s status=%s reason=%s err=%v",
+ storage.Name,
+ storage.Path,
+ storage.Type,
+ storage.Content,
+ describeOptionalBool(storage.Active),
+ describeOptionalBool(storage.Enabled),
+ strings.TrimSpace(storage.Status),
+ reason,
+ err,
+ )
+ return
+ }
+ c.logger.Debug(
+ "PVE datastore skip details: name=%s path=%s type=%s content=%s active=%s enabled=%s status=%s reason=%s",
+ storage.Name,
+ storage.Path,
+ storage.Type,
+ storage.Content,
+ describeOptionalBool(storage.Active),
+ describeOptionalBool(storage.Enabled),
+ strings.TrimSpace(storage.Status),
+ reason,
+ )
+ }
+
processed := 0
for _, storage := range storages {
if storage.Path == "" {
@@ -1083,7 +1126,32 @@ func (c *Collector) collectPVEStorageMetadata(ctx context.Context, storages []pv
}
if reason := unavailableReason(storage); reason != "" {
- c.logger.Warning("Skipping datastore %s (path=%s)%s: not available (%s)", storage.Name, storage.Path, formatRuntime(storage), reason)
+ status := strings.TrimPrefix(reason, "status=")
+ switch {
+ case reason == "enabled=false" || status == "disabled":
+ c.logger.Skip(
+ "PVE datastore %s ignored: disabled in Proxmox. Not scanning %s for vzdump backup files (PVE config backup continues).",
+ storage.Name, storage.Path,
+ )
+ case reason == "active=false" || status == "inactive" || status == "unavailable":
+ c.logger.Warning(
+ "PVE datastore %s skipped: storage is offline. Not scanning %s for vzdump backup files. Mount/activate it, or disable it in Proxmox if unused.",
+ storage.Name, storage.Path,
+ )
+ default:
+ if status != reason {
+ c.logger.Warning(
+ "PVE datastore %s skipped: Proxmox reports storage status %q. Not scanning %s for vzdump backup files. Fix storage status, or disable it in Proxmox if unused.",
+ storage.Name, status, storage.Path,
+ )
+ } else {
+ c.logger.Warning(
+ "PVE datastore %s skipped: Proxmox reports the storage is not available. Not scanning %s for vzdump backup files. Fix storage status, or disable it in Proxmox if unused.",
+ storage.Name, storage.Path,
+ )
+ }
+ }
+ logSkipDetails(storage, reason, nil)
continue
}
@@ -1091,14 +1159,17 @@ func (c *Collector) collectPVEStorageMetadata(ctx context.Context, storages []pv
stat, err := safefs.Stat(ctx, storage.Path, ioTimeout)
if err != nil {
if errors.Is(err, safefs.ErrTimeout) {
- c.logger.Warning("Skipping datastore %s (path=%s)%s: filesystem probe timed out (%v)", storage.Name, storage.Path, formatRuntime(storage), err)
+ c.logger.Warning("PVE datastore %s skipped: filesystem probe timed out for %s (%v). Not scanning for vzdump backup files.", storage.Name, storage.Path, err)
+ logSkipDetails(storage, "filesystem_probe_timeout", err)
} else {
- c.logger.Debug("Skipping datastore %s (path not accessible: %s): %v", storage.Name, storage.Path, err)
+ c.logger.Warning("PVE datastore %s skipped: path %s not accessible (%v). Not scanning for vzdump backup files.", storage.Name, storage.Path, err)
+ logSkipDetails(storage, "path_not_accessible", err)
}
continue
}
if !stat.IsDir() {
- c.logger.Debug("Skipping datastore %s (path not a directory: %s)", storage.Name, storage.Path)
+ c.logger.Warning("PVE datastore %s skipped: path %s is not a directory. Not scanning for vzdump backup files.", storage.Name, storage.Path)
+ logSkipDetails(storage, "path_not_directory", nil)
continue
}
@@ -1109,7 +1180,7 @@ func (c *Collector) collectPVEStorageMetadata(ctx context.Context, storages []pv
storage.Path,
storage.Content))
- metaDir := filepath.Join(baseDir, storage.Name)
+ metaDir := filepath.Join(baseDir, storage.pathKey())
if err := c.ensureDir(metaDir); err != nil {
c.logger.Warning("Failed to create metadata directory for %s: %v", storage.Name, err)
continue
@@ -1251,9 +1322,11 @@ func (c *Collector) collectDetailedPVEBackups(ctx context.Context, storage pveSt
var totalFiles int64
var totalSize int64
+ storageKey := storage.pathKey()
+
var smallDir string
if c.config.BackupSmallPVEBackups && c.config.MaxPVEBackupSizeBytes > 0 {
- smallDir = filepath.Join(c.tempDir, "var/lib/pve-cluster/small_backups", storage.Name)
+ smallDir = filepath.Join(c.tempDir, "var/lib/pve-cluster/small_backups", storageKey)
if err := c.ensureDir(smallDir); err != nil {
c.logger.Warning("Cannot create small backups directory %s: %v", smallDir, err)
smallDir = ""
@@ -1263,7 +1336,7 @@ func (c *Collector) collectDetailedPVEBackups(ctx context.Context, storage pveSt
includePattern := strings.TrimSpace(c.config.PVEBackupIncludePattern)
var includeDir string
if includePattern != "" {
- includeDir = filepath.Join(c.tempDir, "var/lib/pve-cluster/selected_backups", storage.Name)
+ includeDir = filepath.Join(c.tempDir, "var/lib/pve-cluster/selected_backups", storageKey)
if err := c.ensureDir(includeDir); err != nil {
c.logger.Warning("Cannot create selected backups directory %s: %v", includeDir, err)
includeDir = ""
@@ -1400,7 +1473,7 @@ type patternWriter struct {
func newPatternWriter(storageName, storagePath, analysisDir, pattern string, dryRun bool) (*patternWriter, error) {
clean := cleanPatternName(pattern)
- filename := fmt.Sprintf("%s_%s_list.txt", storageName, clean)
+ filename := fmt.Sprintf("%s_%s_list.txt", collectorPathKey(storageName), clean)
filePath := filepath.Join(analysisDir, filename)
// In dry-run mode, create a writer without an actual file
@@ -1526,7 +1599,7 @@ func (c *Collector) writePatternSummary(storage pveStorageEntry, analysisDir str
return nil
}
- summaryPath := filepath.Join(analysisDir, fmt.Sprintf("%s_backup_summary.txt", storage.Name))
+ summaryPath := filepath.Join(analysisDir, fmt.Sprintf("%s_backup_summary.txt", storage.pathKey()))
file, err := os.OpenFile(summaryPath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0640)
if err != nil {
return err
diff --git a/internal/backup/collector_pve_parse_test.go b/internal/backup/collector_pve_parse_test.go
index aaea755..5f957e3 100644
--- a/internal/backup/collector_pve_parse_test.go
+++ b/internal/backup/collector_pve_parse_test.go
@@ -8,6 +8,11 @@ import (
// TestParseNodeStorageList tests parsing PVE storage entries from JSON
func TestParseNodeStorageList(t *testing.T) {
+ boolPtr := func(v bool) *bool {
+ b := v
+ return &b
+ }
+
tests := []struct {
name string
input string
@@ -26,6 +31,24 @@ func TestParseNodeStorageList(t *testing.T) {
{Name: "local-lvm", Path: "", Type: "lvmthin", Content: "images,rootdir"},
},
},
+ {
+ name: "storage with runtime flags",
+ input: `[
+ {"storage": "HDD1", "path": "/mnt/pve/HDD1", "type": "dir", "content": "backup", "active": 0, "enabled": 0, "status": "disabled"},
+ {"storage": "HDD2", "path": "/mnt/pve/HDD2", "type": "nfs", "content": "backup", "active": 0, "enabled": 1, "status": "unavailable"},
+ {"storage": "HDD3", "path": "/mnt/pve/HDD3", "type": "dir", "content": "backup", "active": true, "enabled": true, "status": "ok"},
+ {"storage": "HDD4", "path": "/mnt/pve/HDD4", "type": "dir", "content": "backup", "active": "0", "enabled": "1", "status": "unknown"},
+ {"storage": "HDD5", "path": "/mnt/pve/HDD5", "type": "dir", "content": "backup", "active": null, "enabled": null, "status": ""}
+ ]`,
+ expectError: false,
+ expected: []pveStorageEntry{
+ {Name: "HDD1", Path: "/mnt/pve/HDD1", Type: "dir", Content: "backup", Active: boolPtr(false), Enabled: boolPtr(false), Status: "disabled"},
+ {Name: "HDD2", Path: "/mnt/pve/HDD2", Type: "nfs", Content: "backup", Active: boolPtr(false), Enabled: boolPtr(true), Status: "unavailable"},
+ {Name: "HDD3", Path: "/mnt/pve/HDD3", Type: "dir", Content: "backup", Active: boolPtr(true), Enabled: boolPtr(true), Status: "ok"},
+ {Name: "HDD4", Path: "/mnt/pve/HDD4", Type: "dir", Content: "backup", Active: boolPtr(false), Enabled: boolPtr(true), Status: "unknown"},
+ {Name: "HDD5", Path: "/mnt/pve/HDD5", Type: "dir", Content: "backup", Active: nil, Enabled: nil, Status: ""},
+ },
+ },
{
name: "storage with name field instead of storage",
input: `[
@@ -145,6 +168,19 @@ func TestParseNodeStorageList(t *testing.T) {
if entry.Content != tt.expected[i].Content {
t.Errorf("entry[%d].Content = %q, want %q", i, entry.Content, tt.expected[i].Content)
}
+ if (entry.Active == nil) != (tt.expected[i].Active == nil) {
+ t.Errorf("entry[%d].Active nilness = %v, want %v", i, entry.Active == nil, tt.expected[i].Active == nil)
+ } else if entry.Active != nil && tt.expected[i].Active != nil && *entry.Active != *tt.expected[i].Active {
+ t.Errorf("entry[%d].Active = %v, want %v", i, *entry.Active, *tt.expected[i].Active)
+ }
+ if (entry.Enabled == nil) != (tt.expected[i].Enabled == nil) {
+ t.Errorf("entry[%d].Enabled nilness = %v, want %v", i, entry.Enabled == nil, tt.expected[i].Enabled == nil)
+ } else if entry.Enabled != nil && tt.expected[i].Enabled != nil && *entry.Enabled != *tt.expected[i].Enabled {
+ t.Errorf("entry[%d].Enabled = %v, want %v", i, *entry.Enabled, *tt.expected[i].Enabled)
+ }
+ if entry.Status != tt.expected[i].Status {
+ t.Errorf("entry[%d].Status = %q, want %q", i, entry.Status, tt.expected[i].Status)
+ }
}
})
}
diff --git a/internal/backup/collector_pve_test.go b/internal/backup/collector_pve_test.go
index 6d4a808..1a5e127 100644
--- a/internal/backup/collector_pve_test.go
+++ b/internal/backup/collector_pve_test.go
@@ -1,11 +1,13 @@
package backup
import (
+ "bytes"
"context"
"errors"
"fmt"
"os"
"path/filepath"
+ "strings"
"testing"
"github.com/tis24dev/proxsave/internal/logging"
@@ -814,15 +816,148 @@ func TestCollectPVEStorageMetadata(t *testing.T) {
collector := newPVECollector(t)
ctx := context.Background()
+ storageDir := t.TempDir()
storages := []pveStorageEntry{
- {Name: "local", Path: "/var/lib/vz", Type: "dir"},
+ {Name: "local", Path: storageDir, Type: "dir"},
}
- err := collector.collectPVEStorageMetadata(ctx, storages)
+ if err := collector.collectPVEStorageMetadata(ctx, storages); err != nil {
+ t.Fatalf("collectPVEStorageMetadata returned error: %v", err)
+ }
+
+ metaPath := filepath.Join(collector.tempDir, "var/lib/pve-cluster/info/datastores", "local", "metadata.json")
+ if _, err := os.Stat(metaPath); err != nil {
+ t.Fatalf("expected %s created, got %v", metaPath, err)
+ }
+}
+
+func TestCollectPVEStorageMetadata_UsesPathSafeKeyForUnsafeStorageName(t *testing.T) {
+ collector := newPVECollector(t)
+ collector.config.BackupSmallPVEBackups = true
+ collector.config.MaxPVEBackupSizeBytes = 1024 * 1024
+ collector.config.PVEBackupIncludePattern = "vm-100"
+
+ ctx := context.Background()
+ storageDir := t.TempDir()
+ if err := os.WriteFile(filepath.Join(storageDir, "vm-100-backup.vma"), []byte("content"), 0o644); err != nil {
+ t.Fatalf("write backup file: %v", err)
+ }
+
+ storage := pveStorageEntry{Name: "../escape", Path: storageDir, Type: "dir"}
+ if err := collector.collectPVEStorageMetadata(ctx, []pveStorageEntry{storage}); err != nil {
+ t.Fatalf("collectPVEStorageMetadata returned error: %v", err)
+ }
+
+ key := collectorPathKey(storage.Name)
+ baseDir := filepath.Join(collector.tempDir, "var/lib/pve-cluster/info/datastores", key)
+ for _, path := range []string{
+ filepath.Join(baseDir, "metadata.json"),
+ filepath.Join(baseDir, "backup_analysis", fmt.Sprintf("%s_backup_summary.txt", key)),
+ filepath.Join(baseDir, "backup_analysis", fmt.Sprintf("%s__vma_list.txt", key)),
+ filepath.Join(collector.tempDir, "var/lib/pve-cluster/small_backups", key, "vm-100-backup.vma"),
+ filepath.Join(collector.tempDir, "var/lib/pve-cluster/selected_backups", key, "vm-100-backup.vma"),
+ } {
+ if _, err := os.Stat(path); err != nil {
+ t.Fatalf("expected safe output %s: %v", path, err)
+ }
+ }
+
+ metaPath := filepath.Join(baseDir, "metadata.json")
+ metaBytes, err := os.ReadFile(metaPath)
if err != nil {
- t.Logf("collectPVEStorageMetadata returned error: %v", err)
+ t.Fatalf("read metadata.json: %v", err)
+ }
+ if !strings.Contains(string(metaBytes), storage.Name) {
+ t.Fatalf("metadata should keep raw storage name, got %s", string(metaBytes))
+ }
+
+ rawMetaPath := filepath.Join(collector.tempDir, "var/lib/pve-cluster/info/datastores", storage.Name, "metadata.json")
+ if rawMetaPath != metaPath {
+ if _, err := os.Stat(rawMetaPath); !os.IsNotExist(err) {
+ t.Fatalf("raw storage metadata path should not exist (%s), got err=%v", rawMetaPath, err)
+ }
}
}
+func TestCollectPVEStorageMetadata_SkipReasonsAreUserFriendly(t *testing.T) {
+ boolPtr := func(v bool) *bool {
+ b := v
+ return &b
+ }
+ ctx := context.Background()
+
+ t.Run("disabled storage uses SKIP with debug details", func(t *testing.T) {
+ collector := newPVECollector(t)
+ var out bytes.Buffer
+ collector.logger.SetOutput(&out)
+
+ storages := []pveStorageEntry{
+ {Name: "HDD1", Path: "/mnt/pve/HDD1", Active: boolPtr(false), Enabled: boolPtr(false)},
+ }
+ if err := collector.collectPVEStorageMetadata(ctx, storages); err != nil {
+ t.Fatalf("collectPVEStorageMetadata returned error: %v", err)
+ }
+
+ logText := out.String()
+ if collector.logger.WarningCount() != 0 {
+ t.Fatalf("expected 0 warnings, got %d", collector.logger.WarningCount())
+ }
+ if !strings.Contains(logText, "SKIP") || !strings.Contains(logText, "disabled in Proxmox") {
+ t.Fatalf("expected user-friendly SKIP message for disabled storage, got logs:\n%s", logText)
+ }
+ if !strings.Contains(logText, "PVE datastore skip details:") {
+ t.Fatalf("expected debug details for skipped storage, got logs:\n%s", logText)
+ }
+ })
+
+ t.Run("offline storage uses WARNING with debug details", func(t *testing.T) {
+ collector := newPVECollector(t)
+ var out bytes.Buffer
+ collector.logger.SetOutput(&out)
+
+ storages := []pveStorageEntry{
+ {Name: "HDD2", Path: "/mnt/pve/HDD2", Active: boolPtr(false), Enabled: boolPtr(true)},
+ }
+ if err := collector.collectPVEStorageMetadata(ctx, storages); err != nil {
+ t.Fatalf("collectPVEStorageMetadata returned error: %v", err)
+ }
+
+ logText := out.String()
+ if collector.logger.WarningCount() != 1 {
+ t.Fatalf("expected 1 warning, got %d", collector.logger.WarningCount())
+ }
+ if !strings.Contains(logText, "WARNING") || !strings.Contains(logText, "storage is offline") {
+ t.Fatalf("expected warning message for offline storage, got logs:\n%s", logText)
+ }
+ if !strings.Contains(logText, "PVE datastore skip details:") {
+ t.Fatalf("expected debug details for skipped storage, got logs:\n%s", logText)
+ }
+ })
+
+ t.Run("non-existent path uses WARNING with debug details", func(t *testing.T) {
+ collector := newPVECollector(t)
+ var out bytes.Buffer
+ collector.logger.SetOutput(&out)
+
+ storages := []pveStorageEntry{
+ {Name: "MISSING", Path: filepath.Join(t.TempDir(), "does-not-exist"), Type: "dir"},
+ }
+ if err := collector.collectPVEStorageMetadata(ctx, storages); err != nil {
+ t.Fatalf("collectPVEStorageMetadata returned error: %v", err)
+ }
+
+ logText := out.String()
+ if collector.logger.WarningCount() != 1 {
+ t.Fatalf("expected 1 warning, got %d", collector.logger.WarningCount())
+ }
+ if !strings.Contains(logText, "not accessible") {
+ t.Fatalf("expected warning for missing path, got logs:\n%s", logText)
+ }
+ if !strings.Contains(logText, "reason=path_not_accessible") {
+ t.Fatalf("expected debug details including reason, got logs:\n%s", logText)
+ }
+ })
+}
+
// Test collectPVECephInfo function
func TestCollectPVECephInfo(t *testing.T) {
t.Run("no ceph configured", func(t *testing.T) {
diff --git a/internal/config/templates/backup.env b/internal/config/templates/backup.env
index 2dbf810..ad5a9b0 100644
--- a/internal/config/templates/backup.env
+++ b/internal/config/templates/backup.env
@@ -318,7 +318,7 @@ PVE_CLUSTER_PATH=/var/lib/pve-cluster
COROSYNC_CONFIG_PATH=${PVE_CONFIG_PATH}/corosync.conf
VZDUMP_CONFIG_PATH=/etc/vzdump.conf
PBS_CONFIG_PATH=/etc/proxmox-backup
-PBS_DATASTORE_PATH= # Extra PBS paths separated by comma/space (e.g. /mnt/pbs1,/mnt/pbs2). Empty = auto-detect.
+PBS_DATASTORE_PATH= # Extra PBS filesystem scan roots separated by comma/space (e.g. /mnt/pbs1,/mnt/pbs2). Not real datastore definitions; output keys may be path-derived. Empty = auto-detect.
# System
BACKUP_NETWORK_CONFIGS=true
diff --git a/internal/input/input.go b/internal/input/input.go
index 15c87b1..3df8280 100644
--- a/internal/input/input.go
+++ b/internal/input/input.go
@@ -7,6 +7,7 @@ import (
"io"
"os"
"strings"
+ "sync"
)
// ErrInputAborted signals that interactive input was interrupted (typically via Ctrl+C
@@ -15,6 +16,41 @@ import (
// Callers should translate this into the appropriate workflow-level abort error.
var ErrInputAborted = errors.New("input aborted")
+type lineResult struct {
+ line string
+ err error
+}
+
+type lineState struct {
+ mu sync.Mutex
+ inflight *lineInflight
+}
+
+type lineInflight struct {
+ done chan lineResult
+ completed chan struct{}
+}
+
+type passwordResult struct {
+ b []byte
+ err error
+}
+
+type passwordState struct {
+ mu sync.Mutex
+ inflight *passwordInflight
+}
+
+type passwordInflight struct {
+ done chan passwordResult
+ completed chan struct{}
+}
+
+var (
+ lineStates sync.Map
+ passwordStates sync.Map
+)
+
// IsAborted reports whether an operation was aborted by the user (typically via Ctrl+C),
// by checking for ErrInputAborted and context cancellation.
func IsAborted(err error) bool {
@@ -41,28 +77,77 @@ func MapInputError(err error) error {
return err
}
+func mapContextInputError(ctx context.Context) error {
+ if ctx == nil || ctx.Err() == nil {
+ return nil
+ }
+ if errors.Is(ctx.Err(), context.DeadlineExceeded) {
+ return context.DeadlineExceeded
+ }
+ return ErrInputAborted
+}
+
+func getLineState(reader *bufio.Reader) *lineState {
+ if state, ok := lineStates.Load(reader); ok {
+ return state.(*lineState)
+ }
+ state := &lineState{}
+ actual, _ := lineStates.LoadOrStore(reader, state)
+ return actual.(*lineState)
+}
+
+func getPasswordState(fd int) *passwordState {
+ if state, ok := passwordStates.Load(fd); ok {
+ return state.(*passwordState)
+ }
+ state := &passwordState{}
+ actual, _ := passwordStates.LoadOrStore(fd, state)
+ return actual.(*passwordState)
+}
+
// ReadLineWithContext reads a single line and supports cancellation. On ctx cancellation
// or stdin closure it returns ErrInputAborted. On ctx deadline it returns context.DeadlineExceeded.
+// Cancellation stops waiting but does not interrupt an already-started reader.ReadString call;
+// at most one in-flight read is kept per reader to avoid goroutine buildup across retries.
+// A completed in-flight read remains attached to the reader until a later caller consumes it.
func ReadLineWithContext(ctx context.Context, reader *bufio.Reader) (string, error) {
if ctx == nil {
ctx = context.Background()
}
- type result struct {
- line string
- err error
+ if err := mapContextInputError(ctx); err != nil {
+ return "", err
+ }
+ if reader == nil {
+ return "", errors.New("reader is nil")
+ }
+ state := getLineState(reader)
+
+ state.mu.Lock()
+ defer state.mu.Unlock()
+ if state.inflight == nil {
+ inflight := &lineInflight{
+ done: make(chan lineResult, 1),
+ completed: make(chan struct{}),
+ }
+ state.inflight = inflight
+ go func() {
+ line, err := reader.ReadString('\n')
+ inflight.done <- lineResult{line: line, err: MapInputError(err)}
+ close(inflight.completed)
+ }()
}
- ch := make(chan result, 1)
- go func() {
- line, err := reader.ReadString('\n')
- ch <- result{line: line, err: MapInputError(err)}
- }()
+ inflight := state.inflight
+
select {
case <-ctx.Done():
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
return "", context.DeadlineExceeded
}
return "", ErrInputAborted
- case res := <-ch:
+ case res := <-inflight.done:
+ if state.inflight == inflight {
+ state.inflight = nil
+ }
return res.line, res.err
}
}
@@ -70,29 +155,47 @@ func ReadLineWithContext(ctx context.Context, reader *bufio.Reader) (string, err
// ReadPasswordWithContext reads a password (no echo) and supports cancellation. On ctx
// cancellation or stdin closure it returns ErrInputAborted. On ctx deadline it returns
// context.DeadlineExceeded.
+// Cancellation stops waiting but does not interrupt an already-started password read;
+// at most one in-flight password read is kept per file descriptor to avoid goroutine buildup.
+// A completed in-flight password read remains attached until a later caller consumes it.
func ReadPasswordWithContext(ctx context.Context, readPassword func(int) ([]byte, error), fd int) ([]byte, error) {
if ctx == nil {
ctx = context.Background()
}
+ if err := mapContextInputError(ctx); err != nil {
+ return nil, err
+ }
if readPassword == nil {
return nil, errors.New("readPassword function is nil")
}
- type result struct {
- b []byte
- err error
+ state := getPasswordState(fd)
+
+ state.mu.Lock()
+ defer state.mu.Unlock()
+ if state.inflight == nil {
+ inflight := &passwordInflight{
+ done: make(chan passwordResult, 1),
+ completed: make(chan struct{}),
+ }
+ state.inflight = inflight
+ go func() {
+ b, err := readPassword(fd)
+ inflight.done <- passwordResult{b: b, err: MapInputError(err)}
+ close(inflight.completed)
+ }()
}
- ch := make(chan result, 1)
- go func() {
- b, err := readPassword(fd)
- ch <- result{b: b, err: MapInputError(err)}
- }()
+ inflight := state.inflight
+
select {
case <-ctx.Done():
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
return nil, context.DeadlineExceeded
}
return nil, ErrInputAborted
- case res := <-ch:
+ case res := <-inflight.done:
+ if state.inflight == inflight {
+ state.inflight = nil
+ }
return res.b, res.err
}
}
diff --git a/internal/input/input_test.go b/internal/input/input_test.go
index 4113024..403d347 100644
--- a/internal/input/input_test.go
+++ b/internal/input/input_test.go
@@ -7,10 +7,143 @@ import (
"io"
"os"
"strings"
+ "sync/atomic"
"testing"
"time"
)
+type blockingLineReader struct {
+ release chan struct{}
+ finish chan struct{}
+ returned chan struct{}
+ payload string
+ calls atomic.Int32
+}
+
+func (r *blockingLineReader) Read(p []byte) (int, error) {
+ r.calls.Add(1)
+ <-r.release
+ if r.finish != nil {
+ <-r.finish
+ }
+ if r.payload == "" {
+ signalNonBlocking(r.returned)
+ return 0, io.EOF
+ }
+ n := copy(p, r.payload)
+ r.payload = r.payload[n:]
+ signalNonBlocking(r.returned)
+ return n, nil
+}
+
+type lineCallResult struct {
+ line string
+ err error
+}
+
+type passwordCallResult struct {
+ b []byte
+ err error
+}
+
+func signalNonBlocking(ch chan struct{}) {
+ if ch == nil {
+ return
+ }
+ select {
+ case ch <- struct{}{}:
+ default:
+ }
+}
+
+func waitForSignal(t *testing.T, ch <-chan struct{}, name string) {
+ t.Helper()
+ select {
+ case <-ch:
+ case <-time.After(500 * time.Millisecond):
+ t.Fatalf("timed out waiting for %s", name)
+ }
+}
+
+func waitForCondition(t *testing.T, name string, cond func() bool) {
+ t.Helper()
+ deadline := time.After(500 * time.Millisecond)
+ ticker := time.NewTicker(time.Millisecond)
+ defer ticker.Stop()
+ for {
+ if cond() {
+ return
+ }
+ select {
+ case <-deadline:
+ t.Fatalf("timed out waiting for %s", name)
+ case <-ticker.C:
+ }
+ }
+}
+
+func currentLineInflight(t *testing.T, reader *bufio.Reader) *lineInflight {
+ t.Helper()
+ state := getLineState(reader)
+ state.mu.Lock()
+ defer state.mu.Unlock()
+ if state.inflight == nil {
+ t.Fatalf("expected line inflight state")
+ }
+ return state.inflight
+}
+
+func currentPasswordInflight(t *testing.T, fd int) *passwordInflight {
+ t.Helper()
+ state := getPasswordState(fd)
+ state.mu.Lock()
+ defer state.mu.Unlock()
+ if state.inflight == nil {
+ t.Fatalf("expected password inflight state")
+ }
+ return state.inflight
+}
+
+func assertSameLineInflight(t *testing.T, reader *bufio.Reader, want *lineInflight) {
+ t.Helper()
+ state := getLineState(reader)
+ state.mu.Lock()
+ defer state.mu.Unlock()
+ if state.inflight != want {
+ t.Fatalf("line inflight=%p; want %p", state.inflight, want)
+ }
+}
+
+func assertSamePasswordInflight(t *testing.T, fd int, want *passwordInflight) {
+ t.Helper()
+ state := getPasswordState(fd)
+ state.mu.Lock()
+ defer state.mu.Unlock()
+ if state.inflight != want {
+ t.Fatalf("password inflight=%p; want %p", state.inflight, want)
+ }
+}
+
+func assertLineInflightCleared(t *testing.T, reader *bufio.Reader) {
+ t.Helper()
+ state := getLineState(reader)
+ state.mu.Lock()
+ defer state.mu.Unlock()
+ if state.inflight != nil {
+ t.Fatalf("line inflight=%p; want nil", state.inflight)
+ }
+}
+
+func assertPasswordInflightCleared(t *testing.T, fd int) {
+ t.Helper()
+ state := getPasswordState(fd)
+ state.mu.Lock()
+ defer state.mu.Unlock()
+ if state.inflight != nil {
+ t.Fatalf("password inflight=%p; want nil", state.inflight)
+ }
+}
+
func TestMapInputError(t *testing.T) {
if MapInputError(nil) != nil {
t.Fatalf("expected nil")
@@ -208,3 +341,271 @@ func TestReadPasswordWithContext_DeadlineReturnsDeadlineExceeded(t *testing.T) {
t.Fatalf("err=%v; want %v", err, context.DeadlineExceeded)
}
}
+
+func TestReadLineWithContext_ReusesInflightReadWhilePendingAfterTimeout(t *testing.T) {
+ src := &blockingLineReader{
+ release: make(chan struct{}),
+ finish: make(chan struct{}),
+ returned: make(chan struct{}, 1),
+ payload: "hello\n",
+ }
+ reader := bufio.NewReader(src)
+
+ ctx1, cancel1 := context.WithTimeout(context.Background(), 25*time.Millisecond)
+ defer cancel1()
+ _, err := ReadLineWithContext(ctx1, reader)
+ if !errors.Is(err, context.DeadlineExceeded) {
+ t.Fatalf("first err=%v; want %v", err, context.DeadlineExceeded)
+ }
+
+ ctx2, cancel2 := context.WithTimeout(context.Background(), 25*time.Millisecond)
+ defer cancel2()
+ _, err = ReadLineWithContext(ctx2, reader)
+ if !errors.Is(err, context.DeadlineExceeded) {
+ t.Fatalf("second err=%v; want %v", err, context.DeadlineExceeded)
+ }
+
+ if got := src.calls.Load(); got != 1 {
+ t.Fatalf("underlying Read calls=%d; want 1", got)
+ }
+
+ resultCh := make(chan lineCallResult, 1)
+ go func() {
+ line, err := ReadLineWithContext(context.Background(), reader)
+ resultCh <- lineCallResult{line: line, err: err}
+ }()
+
+ state := getLineState(reader)
+ waitForCondition(t, "line retry to block on inflight read", func() bool {
+ if state.mu.TryLock() {
+ state.mu.Unlock()
+ return false
+ }
+ return true
+ })
+
+ close(src.release)
+ close(src.finish)
+ waitForSignal(t, src.returned, "underlying line read completion")
+
+ res := <-resultCh
+ if res.err != nil {
+ t.Fatalf("retry ReadLineWithContext error: %v", res.err)
+ }
+ if res.line != "hello\n" {
+ t.Fatalf("line=%q; want %q", res.line, "hello\n")
+ }
+ if got := src.calls.Load(); got != 1 {
+ t.Fatalf("underlying Read calls after pending retry=%d; want 1", got)
+ }
+}
+
+func TestReadLineWithContext_PreservesCompletedReadForNextRetryAfterTimeout(t *testing.T) {
+ src := &blockingLineReader{
+ release: make(chan struct{}),
+ returned: make(chan struct{}, 1),
+ payload: "hello\n",
+ }
+ reader := bufio.NewReader(src)
+
+ ctx, cancel := context.WithTimeout(context.Background(), 25*time.Millisecond)
+ defer cancel()
+ _, err := ReadLineWithContext(ctx, reader)
+ if !errors.Is(err, context.DeadlineExceeded) {
+ t.Fatalf("first err=%v; want %v", err, context.DeadlineExceeded)
+ }
+
+ inflight := currentLineInflight(t, reader)
+ close(src.release)
+ waitForSignal(t, src.returned, "underlying line read return")
+ waitForSignal(t, inflight.completed, "line inflight completion")
+ assertSameLineInflight(t, reader, inflight)
+
+ line, err := ReadLineWithContext(context.Background(), reader)
+ if err != nil {
+ t.Fatalf("retry ReadLineWithContext error: %v", err)
+ }
+ if line != "hello\n" {
+ t.Fatalf("line=%q; want %q", line, "hello\n")
+ }
+ if got := src.calls.Load(); got != 1 {
+ t.Fatalf("underlying Read calls after completed retry=%d; want 1", got)
+ }
+}
+
+func TestReadPasswordWithContext_ReusesInflightReadWhilePendingAfterTimeout(t *testing.T) {
+ release := make(chan struct{})
+ finish := make(chan struct{})
+ returned := make(chan struct{}, 1)
+ var calls atomic.Int32
+ readPassword := func(fd int) ([]byte, error) {
+ calls.Add(1)
+ <-release
+ <-finish
+ signalNonBlocking(returned)
+ return []byte("secret"), nil
+ }
+
+ ctx1, cancel1 := context.WithTimeout(context.Background(), 25*time.Millisecond)
+ defer cancel1()
+ got, err := ReadPasswordWithContext(ctx1, readPassword, 42)
+ if got != nil {
+ t.Fatalf("expected nil bytes on first deadline")
+ }
+ if !errors.Is(err, context.DeadlineExceeded) {
+ t.Fatalf("first err=%v; want %v", err, context.DeadlineExceeded)
+ }
+
+ ctx2, cancel2 := context.WithTimeout(context.Background(), 25*time.Millisecond)
+ defer cancel2()
+ got, err = ReadPasswordWithContext(ctx2, readPassword, 42)
+ if got != nil {
+ t.Fatalf("expected nil bytes on second deadline")
+ }
+ if !errors.Is(err, context.DeadlineExceeded) {
+ t.Fatalf("second err=%v; want %v", err, context.DeadlineExceeded)
+ }
+
+ if gotCalls := calls.Load(); gotCalls != 1 {
+ t.Fatalf("readPassword calls=%d; want 1", gotCalls)
+ }
+
+ resultCh := make(chan passwordCallResult, 1)
+ go func() {
+ got, err := ReadPasswordWithContext(context.Background(), readPassword, 42)
+ resultCh <- passwordCallResult{b: got, err: err}
+ }()
+
+ state := getPasswordState(42)
+ waitForCondition(t, "password retry to block on inflight read", func() bool {
+ if state.mu.TryLock() {
+ state.mu.Unlock()
+ return false
+ }
+ return true
+ })
+
+ close(release)
+ close(finish)
+ waitForSignal(t, returned, "underlying password read completion")
+
+ res := <-resultCh
+ if res.err != nil {
+ t.Fatalf("retry ReadPasswordWithContext error: %v", res.err)
+ }
+ if string(res.b) != "secret" {
+ t.Fatalf("got=%q; want %q", string(res.b), "secret")
+ }
+ if gotCalls := calls.Load(); gotCalls != 1 {
+ t.Fatalf("readPassword calls after pending retry=%d; want 1", gotCalls)
+ }
+}
+
+func TestReadPasswordWithContext_PreservesCompletedReadForNextRetryAfterTimeout(t *testing.T) {
+ release := make(chan struct{})
+ returned := make(chan struct{}, 1)
+ var calls atomic.Int32
+ readPassword := func(fd int) ([]byte, error) {
+ calls.Add(1)
+ <-release
+ signalNonBlocking(returned)
+ return []byte("secret"), nil
+ }
+
+ ctx, cancel := context.WithTimeout(context.Background(), 25*time.Millisecond)
+ defer cancel()
+ got, err := ReadPasswordWithContext(ctx, readPassword, 42)
+ if got != nil {
+ t.Fatalf("expected nil bytes on first deadline")
+ }
+ if !errors.Is(err, context.DeadlineExceeded) {
+ t.Fatalf("first err=%v; want %v", err, context.DeadlineExceeded)
+ }
+
+ inflight := currentPasswordInflight(t, 42)
+ close(release)
+ waitForSignal(t, returned, "underlying password read return")
+ waitForSignal(t, inflight.completed, "password inflight completion")
+ assertSamePasswordInflight(t, 42, inflight)
+
+ got, err = ReadPasswordWithContext(context.Background(), readPassword, 42)
+ if err != nil {
+ t.Fatalf("retry ReadPasswordWithContext error: %v", err)
+ }
+ if string(got) != "secret" {
+ t.Fatalf("got=%q; want %q", string(got), "secret")
+ }
+ if gotCalls := calls.Load(); gotCalls != 1 {
+ t.Fatalf("readPassword calls after completed retry=%d; want 1", gotCalls)
+ }
+}
+
+func TestReadLineWithContext_ClearsInflightAfterCompletedRetryConsumesResult(t *testing.T) {
+ src := &blockingLineReader{
+ release: make(chan struct{}),
+ returned: make(chan struct{}, 1),
+ payload: "hello\n",
+ }
+ reader := bufio.NewReader(src)
+
+ ctx, cancel := context.WithTimeout(context.Background(), 25*time.Millisecond)
+ defer cancel()
+ _, err := ReadLineWithContext(ctx, reader)
+ if !errors.Is(err, context.DeadlineExceeded) {
+ t.Fatalf("first err=%v; want %v", err, context.DeadlineExceeded)
+ }
+
+ inflight := currentLineInflight(t, reader)
+ close(src.release)
+ waitForSignal(t, src.returned, "underlying line read return")
+ waitForSignal(t, inflight.completed, "line inflight completion")
+ assertSameLineInflight(t, reader, inflight)
+
+ line, err := ReadLineWithContext(context.Background(), reader)
+ if err != nil {
+ t.Fatalf("retry ReadLineWithContext error: %v", err)
+ }
+ if line != "hello\n" {
+ t.Fatalf("line=%q; want %q", line, "hello\n")
+ }
+ assertLineInflightCleared(t, reader)
+}
+
+func TestReadPasswordWithContext_ClearsInflightAfterCompletedRetryConsumesResult(t *testing.T) {
+ release := make(chan struct{})
+ returned := make(chan struct{}, 1)
+ var calls atomic.Int32
+ readPassword := func(fd int) ([]byte, error) {
+ calls.Add(1)
+ <-release
+ signalNonBlocking(returned)
+ return []byte("secret"), nil
+ }
+
+ const fd = 43
+
+ ctx, cancel := context.WithTimeout(context.Background(), 25*time.Millisecond)
+ defer cancel()
+ got, err := ReadPasswordWithContext(ctx, readPassword, fd)
+ if got != nil {
+ t.Fatalf("expected nil bytes on first deadline")
+ }
+ if !errors.Is(err, context.DeadlineExceeded) {
+ t.Fatalf("first err=%v; want %v", err, context.DeadlineExceeded)
+ }
+
+ inflight := currentPasswordInflight(t, fd)
+ close(release)
+ waitForSignal(t, returned, "underlying password read return")
+ waitForSignal(t, inflight.completed, "password inflight completion")
+ assertSamePasswordInflight(t, fd, inflight)
+
+ got, err = ReadPasswordWithContext(context.Background(), readPassword, fd)
+ if err != nil {
+ t.Fatalf("retry ReadPasswordWithContext error: %v", err)
+ }
+ if string(got) != "secret" {
+ t.Fatalf("got=%q; want %q", string(got), "secret")
+ }
+ assertPasswordInflightCleared(t, fd)
+}
diff --git a/internal/orchestrator/backup_safety.go b/internal/orchestrator/backup_safety.go
index 5402897..b992052 100644
--- a/internal/orchestrator/backup_safety.go
+++ b/internal/orchestrator/backup_safety.go
@@ -23,38 +23,8 @@ type safetyBackupSpec struct {
WriteLocationFile bool
}
-// resolveAndCheckPath cleans and resolves symlinks for candidate extraction paths
-// and verifies the resolved path is still within destRoot.
func resolveAndCheckPath(destRoot, candidate string) (string, error) {
- joined := candidate
- if !filepath.IsAbs(candidate) {
- joined = filepath.Join(destRoot, candidate)
- }
-
- resolved, err := filepath.EvalSymlinks(joined)
- if err != nil {
- // If the path doesn't exist yet, EvalSymlinks will fail; fallback to the cleaned path.
- resolved = filepath.Clean(joined)
- }
-
- absDestRoot, err := filepath.Abs(destRoot)
- if err != nil {
- return "", fmt.Errorf("cannot resolve destination root: %w", err)
- }
- absResolved, err := filepath.Abs(resolved)
- if err != nil {
- return "", fmt.Errorf("cannot resolve candidate path: %w", err)
- }
-
- rel, err := filepath.Rel(absDestRoot, absResolved)
- if err != nil {
- return "", fmt.Errorf("cannot compute relative path: %w", err)
- }
- if strings.HasPrefix(rel, ".."+string(os.PathSeparator)) || rel == ".." || filepath.IsAbs(rel) {
- return "", fmt.Errorf("resolved path escapes destination: %s", absResolved)
- }
-
- return absResolved, nil
+ return resolvePathWithinRootFS(safetyFS, destRoot, candidate)
}
// SafetyBackupResult contains information about the safety backup
@@ -391,7 +361,7 @@ func RestoreSafetyBackup(logger *logging.Logger, backupPath string, destRoot str
return fmt.Errorf("read tar entry: %w", err)
}
- target, _, err := sanitizeRestoreEntryTarget(absDestRoot, header.Name)
+ target, _, err := sanitizeRestoreEntryTargetWithFS(safetyFS, absDestRoot, header.Name)
if err != nil {
logger.Warning("Skipping archive entry %s: %v", header.Name, err)
continue
@@ -425,14 +395,7 @@ func RestoreSafetyBackup(logger *logging.Logger, backupPath string, destRoot str
if header.Typeflag == tar.TypeSymlink {
linkTarget := header.Linkname
- // Resolve intended target relative to the sanitized symlink directory inside the archive
- sanitizedDir := filepath.Dir(relTarget)
- resolvedLinkPath := linkTarget
- if !filepath.IsAbs(linkTarget) {
- resolvedLinkPath = filepath.Join(sanitizedDir, linkTarget)
- }
-
- if _, pathErr := resolveAndCheckPath(destRoot, resolvedLinkPath); pathErr != nil {
+ if _, pathErr := resolvePathRelativeToBaseWithinRootFS(safetyFS, absDestRoot, filepath.Dir(target), linkTarget); pathErr != nil {
logger.Warning("Skipping symlink %s -> %s: target escapes root: %v", target, linkTarget, pathErr)
continue
}
@@ -454,33 +417,9 @@ func RestoreSafetyBackup(logger *logging.Logger, backupPath string, destRoot str
continue
}
- // Resolve the symlink target relative to the symlink's directory
- symlinkTargetDir := filepath.Dir(target)
- resolvedTarget := actualTarget
- if !filepath.IsAbs(actualTarget) {
- resolvedTarget = filepath.Join(symlinkTargetDir, actualTarget)
- }
-
- // Validate the resolved target stays within destRoot
- absDestRoot, err := filepath.Abs(destRoot)
- if err != nil {
- logger.Warning("Cannot resolve destination root: %v", err)
- safetyFS.Remove(target)
- continue
- }
-
- absResolvedTarget, err := filepath.Abs(resolvedTarget)
- if err != nil {
- logger.Warning("Cannot resolve symlink target: %v", err)
- safetyFS.Remove(target)
- continue
- }
-
- // Check if resolved target is within destRoot
- rel, err := filepath.Rel(absDestRoot, absResolvedTarget)
- if err != nil || strings.HasPrefix(rel, ".."+string(os.PathSeparator)) || rel == ".." {
- logger.Warning("Removing symlink %s -> %s: target escapes root after creation (resolves to %s)",
- target, actualTarget, absResolvedTarget)
+ if _, err := resolvePathRelativeToBaseWithinRootFS(safetyFS, absDestRoot, filepath.Dir(target), actualTarget); err != nil {
+ logger.Warning("Removing symlink %s -> %s: target escapes root after creation: %v",
+ target, actualTarget, err)
safetyFS.Remove(target)
continue
}
diff --git a/internal/orchestrator/backup_safety_test.go b/internal/orchestrator/backup_safety_test.go
index 2529e2e..be0cc02 100644
--- a/internal/orchestrator/backup_safety_test.go
+++ b/internal/orchestrator/backup_safety_test.go
@@ -400,6 +400,152 @@ func TestRestoreSafetyBackup_DoesNotFollowExistingSymlinkTargetPath(t *testing.T
}
}
+func TestRestoreSafetyBackup_RejectsBrokenIntermediateSymlinkEscape(t *testing.T) {
+ logger := logging.New(types.LogLevelInfo, false)
+ var logBuf bytes.Buffer
+ logger.SetOutput(&logBuf)
+
+ tmpDir := t.TempDir()
+ backupPath := filepath.Join(tmpDir, "broken-escape.tar.gz")
+ restoreDir := filepath.Join(tmpDir, "restore")
+ outside := t.TempDir()
+
+ if err := os.MkdirAll(restoreDir, 0o755); err != nil {
+ t.Fatalf("mkdir restore dir: %v", err)
+ }
+ if err := os.Symlink(outside, filepath.Join(restoreDir, "escape-link")); err != nil {
+ t.Fatalf("create escape symlink: %v", err)
+ }
+
+ var buf bytes.Buffer
+ gzw := gzip.NewWriter(&buf)
+ tw := tar.NewWriter(gzw)
+ if err := tw.WriteHeader(&tar.Header{Name: "escape-link/missing/", Mode: 0o755, Typeflag: tar.TypeDir}); err != nil {
+ t.Fatalf("write dir header failed: %v", err)
+ }
+ if err := tw.Close(); err != nil {
+ t.Fatalf("tar close failed: %v", err)
+ }
+ if err := gzw.Close(); err != nil {
+ t.Fatalf("gzip close failed: %v", err)
+ }
+ if err := os.WriteFile(backupPath, buf.Bytes(), 0o644); err != nil {
+ t.Fatalf("write archive failed: %v", err)
+ }
+
+ if err := RestoreSafetyBackup(logger, backupPath, restoreDir); err != nil {
+ t.Fatalf("RestoreSafetyBackup failed: %v", err)
+ }
+
+ if _, err := os.Stat(filepath.Join(outside, "missing")); !os.IsNotExist(err) {
+ t.Fatalf("outside path should not be created, got err=%v", err)
+ }
+ if !strings.Contains(logBuf.String(), "Skipping archive entry") {
+ t.Fatalf("expected archive entry skip warning, logs=%q", logBuf.String())
+ }
+}
+
+func TestRestoreSafetyBackup_RejectsEscapeWhenParentPathIsSymlink(t *testing.T) {
+ logger := logging.New(types.LogLevelInfo, false)
+ var logBuf bytes.Buffer
+ logger.SetOutput(&logBuf)
+
+ tmpDir := t.TempDir()
+ backupPath := filepath.Join(tmpDir, "symlinked-parent-escape.tar.gz")
+ restoreDir := filepath.Join(tmpDir, "restore")
+
+ var buf bytes.Buffer
+ gzw := gzip.NewWriter(&buf)
+ tw := tar.NewWriter(gzw)
+
+ if err := tw.WriteHeader(&tar.Header{Name: "linkdir", Typeflag: tar.TypeSymlink, Linkname: "."}); err != nil {
+ t.Fatalf("write parent symlink header failed: %v", err)
+ }
+ if err := tw.WriteHeader(&tar.Header{Name: "linkdir/escape", Typeflag: tar.TypeSymlink, Linkname: "../outside"}); err != nil {
+ t.Fatalf("write escaping symlink header failed: %v", err)
+ }
+
+ if err := tw.Close(); err != nil {
+ t.Fatalf("tar close failed: %v", err)
+ }
+ if err := gzw.Close(); err != nil {
+ t.Fatalf("gzip close failed: %v", err)
+ }
+ if err := os.WriteFile(backupPath, buf.Bytes(), 0o644); err != nil {
+ t.Fatalf("write archive failed: %v", err)
+ }
+
+ if err := RestoreSafetyBackup(logger, backupPath, restoreDir); err != nil {
+ t.Fatalf("RestoreSafetyBackup failed: %v", err)
+ }
+
+ linkTarget, err := os.Readlink(filepath.Join(restoreDir, "linkdir"))
+ if err != nil {
+ t.Fatalf("parent symlink should exist: %v", err)
+ }
+ if linkTarget != "." {
+ t.Fatalf("parent symlink target = %q, want %q", linkTarget, ".")
+ }
+
+ if _, err := os.Lstat(filepath.Join(restoreDir, "escape")); !os.IsNotExist(err) {
+ t.Fatalf("escaping symlink should not be created, got err=%v", err)
+ }
+ if !strings.Contains(logBuf.String(), "Skipping symlink") {
+ t.Fatalf("expected symlink skip warning, logs=%q", logBuf.String())
+ }
+}
+
+func TestRestoreSafetyBackup_AllowsSafeTargetWhenParentPathIsSymlink(t *testing.T) {
+ logger := logging.New(types.LogLevelInfo, false)
+
+ tmpDir := t.TempDir()
+ backupPath := filepath.Join(tmpDir, "symlinked-parent-safe.tar.gz")
+ restoreDir := filepath.Join(tmpDir, "restore")
+
+ var buf bytes.Buffer
+ gzw := gzip.NewWriter(&buf)
+ tw := tar.NewWriter(gzw)
+
+ if err := tw.WriteHeader(&tar.Header{Name: "subdir/", Mode: 0o755, Typeflag: tar.TypeDir}); err != nil {
+ t.Fatalf("write subdir header failed: %v", err)
+ }
+ content := []byte("ok")
+ if err := tw.WriteHeader(&tar.Header{Name: "subdir/file.txt", Mode: 0o644, Size: int64(len(content))}); err != nil {
+ t.Fatalf("write file header failed: %v", err)
+ }
+ if _, err := tw.Write(content); err != nil {
+ t.Fatalf("write file content failed: %v", err)
+ }
+ if err := tw.WriteHeader(&tar.Header{Name: "linkdir", Typeflag: tar.TypeSymlink, Linkname: "subdir"}); err != nil {
+ t.Fatalf("write parent symlink header failed: %v", err)
+ }
+ if err := tw.WriteHeader(&tar.Header{Name: "linkdir/ok", Typeflag: tar.TypeSymlink, Linkname: "file.txt"}); err != nil {
+ t.Fatalf("write safe symlink header failed: %v", err)
+ }
+
+ if err := tw.Close(); err != nil {
+ t.Fatalf("tar close failed: %v", err)
+ }
+ if err := gzw.Close(); err != nil {
+ t.Fatalf("gzip close failed: %v", err)
+ }
+ if err := os.WriteFile(backupPath, buf.Bytes(), 0o644); err != nil {
+ t.Fatalf("write archive failed: %v", err)
+ }
+
+ if err := RestoreSafetyBackup(logger, backupPath, restoreDir); err != nil {
+ t.Fatalf("RestoreSafetyBackup failed: %v", err)
+ }
+
+ linkTarget, err := os.Readlink(filepath.Join(restoreDir, "subdir", "ok"))
+ if err != nil {
+ t.Fatalf("safe symlink should exist: %v", err)
+ }
+ if linkTarget != "file.txt" {
+ t.Fatalf("safe symlink target = %q, want %q", linkTarget, "file.txt")
+ }
+}
+
func TestCleanupOldSafetyBackups(t *testing.T) {
logger := logging.New(types.LogLevelInfo, false)
@@ -667,6 +813,18 @@ func TestResolveAndCheckPathRejectsSymlinkEscape(t *testing.T) {
}
}
+func TestResolveAndCheckPathRejectsBrokenIntermediateSymlinkEscape(t *testing.T) {
+ root := t.TempDir()
+ outside := t.TempDir()
+ if err := os.Symlink(outside, filepath.Join(root, "escape-link")); err != nil {
+ t.Fatalf("create symlink: %v", err)
+ }
+
+ if _, err := resolveAndCheckPath(root, filepath.Join("escape-link", "missing", "data.txt")); err == nil {
+ t.Fatalf("expected broken symlink escape to be rejected")
+ }
+}
+
// =====================================
// walkFS / walkFSRecursive tests
// =====================================
diff --git a/internal/orchestrator/backup_sources.go b/internal/orchestrator/backup_sources.go
index aac895e..b985ee4 100644
--- a/internal/orchestrator/backup_sources.go
+++ b/internal/orchestrator/backup_sources.go
@@ -130,6 +130,7 @@ func discoverRcloneBackups(ctx context.Context, cfg *config.Config, remotePath s
emptyEntries := 0
nonCandidateEntries := 0
manifestErrors := 0
+ integrityMissing := 0
logDebug(logger, "Cloud (rclone): scanned %d entries from rclone lsf output", totalEntries)
snapshot := make(map[string]struct{}, len(lines))
@@ -243,8 +244,8 @@ func discoverRcloneBackups(ctx context.Context, cfg *config.Config, remotePath s
case sourceRaw:
manifest, perr := inspectRcloneMetadataManifest(itemCtx, item.remoteMetadata, item.remoteArchive, logger)
- cancel()
if perr != nil {
+ cancel()
if errors.Is(perr, context.DeadlineExceeded) {
return nil, fmt.Errorf("timed out while inspecting %s (timeout=%s). Increase RCLONE_TIMEOUT_CONNECTION if needed: %w", item.filename, timeout, perr)
}
@@ -255,6 +256,36 @@ func discoverRcloneBackups(ctx context.Context, cfg *config.Config, remotePath s
logWarning(logger, "Skipping rclone metadata %s: %v", item.filename, perr)
continue
}
+ manifestChecksum, perr := normalizeCandidateManifestChecksum(manifest)
+ if perr != nil {
+ cancel()
+ integrityMissing++
+ logWarning(logger, "Skipping rclone backup %s: invalid manifest checksum: %v", item.filename, perr)
+ continue
+ }
+ checksumFromFile := ""
+ if item.remoteChecksum != "" {
+ checksumFromFile, perr = inspectRcloneChecksumFile(itemCtx, item.remoteChecksum, logger)
+ if perr != nil {
+ cancel()
+ if errors.Is(perr, context.DeadlineExceeded) {
+ return nil, fmt.Errorf("timed out while inspecting %s checksum (timeout=%s). Increase RCLONE_TIMEOUT_CONNECTION if needed: %w", item.filename, timeout, perr)
+ }
+ if errors.Is(perr, context.Canceled) {
+ return nil, perr
+ }
+ integrityMissing++
+ logWarning(logger, "Skipping rclone backup %s: invalid checksum file: %v", item.filename, perr)
+ continue
+ }
+ }
+ expectation, perr := resolveIntegrityExpectationValues(checksumFromFile, manifestChecksum)
+ cancel()
+ if perr != nil {
+ integrityMissing++
+ logWarning(logger, "Skipping rclone backup %s: %v", item.filename, perr)
+ continue
+ }
displayBase := filepath.Base(manifest.ArchivePath)
if strings.TrimSpace(displayBase) == "" {
displayBase = filepath.Base(baseNameFromRemoteRef(item.remoteArchive))
@@ -265,6 +296,7 @@ func discoverRcloneBackups(ctx context.Context, cfg *config.Config, remotePath s
RawArchivePath: item.remoteArchive,
RawMetadataPath: item.remoteMetadata,
RawChecksumPath: item.remoteChecksum,
+ Integrity: expectation,
DisplayBase: displayBase,
IsRclone: true,
})
@@ -292,11 +324,12 @@ func discoverRcloneBackups(ctx context.Context, cfg *config.Config, remotePath s
logging.DebugStep(
logger,
"discover rclone backups",
- "summary entries=%d empty=%d non_candidate=%d manifest_errors=%d accepted=%d elapsed=%s",
+ "summary entries=%d empty=%d non_candidate=%d manifest_errors=%d integrity_missing=%d accepted=%d elapsed=%s",
totalEntries,
emptyEntries,
nonCandidateEntries,
manifestErrors,
+ integrityMissing,
len(candidates),
time.Since(start),
)
@@ -336,6 +369,7 @@ func discoverBackupCandidates(logger *logging.Logger, root string) (candidates [
metadataMissingArchive := 0
metadataManifestErrors := 0
checksumMissing := 0
+ integrityUnavailable := 0
for _, entry := range entries {
if entry.IsDir() {
@@ -392,11 +426,26 @@ func discoverBackupCandidates(logger *logging.Logger, root string) (candidates [
logWarning(logger, "Skipping metadata %s: %v", name, err)
continue
}
- logging.DebugStep(logger, "discover backup candidates", "raw candidate accepted: %s created_at=%s", name, manifest.CreatedAt.Format(time.RFC3339))
-
- // If checksum is missing from both file and manifest, warn user
- if !hasChecksum && manifest.SHA256 == "" {
- logWarning(logger, "Backup %s has no checksum verification available", baseName)
+ manifestChecksum, err := normalizeCandidateManifestChecksum(manifest)
+ if err != nil {
+ integrityUnavailable++
+ logWarning(logger, "Skipping backup %s: invalid manifest checksum: %v", baseName, err)
+ continue
+ }
+ checksumFromFile := ""
+ if hasChecksum {
+ checksumFromFile, err = parseLocalChecksumFile(checksumPath)
+ if err != nil {
+ integrityUnavailable++
+ logWarning(logger, "Skipping backup %s: invalid checksum file: %v", baseName, err)
+ continue
+ }
+ }
+ expectation, err := resolveIntegrityExpectationValues(checksumFromFile, manifestChecksum)
+ if err != nil {
+ integrityUnavailable++
+ logWarning(logger, "Skipping backup %s: %v", baseName, err)
+ continue
}
rawBases[baseName] = struct{}{}
@@ -406,8 +455,10 @@ func discoverBackupCandidates(logger *logging.Logger, root string) (candidates [
RawArchivePath: archivePath,
RawMetadataPath: fullPath,
RawChecksumPath: checksumPath,
+ Integrity: expectation,
DisplayBase: filepath.Base(manifest.ArchivePath),
})
+ logging.DebugStep(logger, "discover backup candidates", "raw candidate accepted: %s created_at=%s", name, manifest.CreatedAt.Format(time.RFC3339))
}
}
@@ -418,7 +469,7 @@ func discoverBackupCandidates(logger *logging.Logger, root string) (candidates [
logging.DebugStep(
logger,
"discover backup candidates",
- "summary entries=%d files=%d dirs=%d bundles=%d bundle_manifest_errors=%d metadata=%d metadata_duplicate=%d metadata_missing_archive=%d metadata_manifest_errors=%d checksum_missing=%d candidates=%d",
+ "summary entries=%d files=%d dirs=%d bundles=%d bundle_manifest_errors=%d metadata=%d metadata_duplicate=%d metadata_missing_archive=%d metadata_manifest_errors=%d checksum_missing=%d integrity_unavailable=%d candidates=%d",
len(entries),
filesSeen,
dirsSkipped,
@@ -429,11 +480,53 @@ func discoverBackupCandidates(logger *logging.Logger, root string) (candidates [
metadataMissingArchive,
metadataManifestErrors,
checksumMissing,
+ integrityUnavailable,
len(candidates),
)
return candidates, nil
}
+func normalizeCandidateManifestChecksum(manifest *backup.Manifest) (string, error) {
+ if manifest == nil || strings.TrimSpace(manifest.SHA256) == "" {
+ return "", nil
+ }
+ normalized, err := backup.NormalizeChecksum(manifest.SHA256)
+ if err != nil {
+ return "", err
+ }
+ manifest.SHA256 = normalized
+ return normalized, nil
+}
+
+func parseLocalChecksumFile(checksumPath string) (string, error) {
+ data, err := restoreFS.ReadFile(checksumPath)
+ if err != nil {
+ return "", fmt.Errorf("read checksum file %s: %w", checksumPath, err)
+ }
+ checksum, err := backup.ParseChecksumData(data)
+ if err != nil {
+ return "", fmt.Errorf("parse checksum file %s: %w", checksumPath, err)
+ }
+ return checksum, nil
+}
+
+func inspectRcloneChecksumFile(ctx context.Context, remotePath string, logger *logging.Logger) (checksum string, err error) {
+ done := logging.DebugStart(logger, "inspect rclone checksum", "remote=%s", remotePath)
+ defer func() { done(err) }()
+ logging.DebugStep(logger, "inspect rclone checksum", "executing: rclone cat %s", remotePath)
+
+ cmd := exec.CommandContext(ctx, "rclone", "cat", remotePath)
+ output, err := cmd.CombinedOutput()
+ if err != nil {
+ return "", fmt.Errorf("rclone cat %s failed: %w (output: %s)", remotePath, err, strings.TrimSpace(string(output)))
+ }
+ checksum, err = backup.ParseChecksumData(output)
+ if err != nil {
+ return "", fmt.Errorf("parse checksum file %s: %w", remotePath, err)
+ }
+ return checksum, nil
+}
+
// isLocalFilesystemPath returns true if the given value represents an absolute
// local filesystem path (and not an rclone-style "remote:path" reference).
func isLocalFilesystemPath(path string) bool {
diff --git a/internal/orchestrator/backup_sources_test.go b/internal/orchestrator/backup_sources_test.go
index 30b9f0d..883ea9e 100644
--- a/internal/orchestrator/backup_sources_test.go
+++ b/internal/orchestrator/backup_sources_test.go
@@ -235,6 +235,7 @@ func TestDiscoverRcloneBackups_IncludesRawMetadata(t *testing.T) {
ProxmoxVersion: "8.1",
CreatedAt: time.Date(2025, 12, 5, 12, 0, 0, 0, time.UTC),
EncryptionMode: "none",
+ SHA256: checksumHexForBytes([]byte("node-backup-20251205")),
}
metaBytes, err := json.Marshal(&manifest)
if err != nil {
@@ -317,6 +318,7 @@ func TestDiscoverRcloneBackups_MixedCandidatesSortedByCreatedAt(t *testing.T) {
CreatedAt: time.Date(2025, 1, 3, 0, 0, 0, 0, time.UTC),
EncryptionMode: "none",
ProxmoxType: "pve",
+ SHA256: checksumHexForBytes([]byte("x")),
}
rawNewestData, _ := json.Marshal(&rawNewest)
if err := os.WriteFile(rawNewestMeta, rawNewestData, 0o600); err != nil {
@@ -354,6 +356,7 @@ func TestDiscoverRcloneBackups_MixedCandidatesSortedByCreatedAt(t *testing.T) {
CreatedAt: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC),
EncryptionMode: "none",
ProxmoxType: "pve",
+ SHA256: checksumHexForBytes([]byte("x")),
}
rawOldData, _ := json.Marshal(&rawOld)
if err := os.WriteFile(rawOldMeta, rawOldData, 0o600); err != nil {
@@ -629,7 +632,7 @@ esac
return manifest, cleanup
}
-func TestDiscoverBackupCandidates_NoLoggerStillCollectsRawArtifacts(t *testing.T) {
+func TestDiscoverBackupCandidates_NoLoggerSkipsRawArtifactsWithoutChecksumVerification(t *testing.T) {
tmpDir := t.TempDir()
archivePath := filepath.Join(tmpDir, "config.tar.xz")
if err := os.WriteFile(archivePath, []byte("dummy"), 0o600); err != nil {
@@ -652,6 +655,44 @@ func TestDiscoverBackupCandidates_NoLoggerStillCollectsRawArtifacts(t *testing.T
}
// Intentionally skip checksum to exercise warning path with nil logger.
+ candidates, err := discoverBackupCandidates(nil, tmpDir)
+ if err != nil {
+ t.Fatalf("discoverBackupCandidates() error = %v", err)
+ }
+ if len(candidates) != 0 {
+ t.Fatalf("discoverBackupCandidates() returned %d candidates; want 0", len(candidates))
+ }
+}
+
+func TestDiscoverBackupCandidates_NormalizesAndStoresIntegrityExpectation(t *testing.T) {
+ tmpDir := t.TempDir()
+ archiveData := []byte("archive")
+ archivePath := filepath.Join(tmpDir, "config.tar.xz")
+ if err := os.WriteFile(archivePath, archiveData, 0o600); err != nil {
+ t.Fatalf("write archive: %v", err)
+ }
+
+ checksum := checksumHexForBytes(archiveData)
+ manifest := backup.Manifest{
+ ArchivePath: "/etc/pve/config.tar.xz",
+ ProxmoxType: "pve",
+ CreatedAt: time.Now().UTC(),
+ EncryptionMode: "none",
+ SHA256: " " + strings.ToUpper(checksum) + " ",
+ }
+ metaPath := archivePath + ".metadata"
+ data, err := json.Marshal(&manifest)
+ if err != nil {
+ t.Fatalf("marshal manifest: %v", err)
+ }
+ if err := os.WriteFile(metaPath, data, 0o600); err != nil {
+ t.Fatalf("write metadata: %v", err)
+ }
+ checksumLine := strings.ToUpper(checksum) + " " + filepath.Base(archivePath) + "\n"
+ if err := os.WriteFile(archivePath+".sha256", []byte(checksumLine), 0o600); err != nil {
+ t.Fatalf("write checksum: %v", err)
+ }
+
candidates, err := discoverBackupCandidates(nil, tmpDir)
if err != nil {
t.Fatalf("discoverBackupCandidates() error = %v", err)
@@ -659,10 +700,271 @@ func TestDiscoverBackupCandidates_NoLoggerStillCollectsRawArtifacts(t *testing.T
if len(candidates) != 1 {
t.Fatalf("discoverBackupCandidates() returned %d candidates; want 1", len(candidates))
}
- if candidates[0].RawArchivePath != archivePath {
- t.Fatalf("RawArchivePath = %q; want %q", candidates[0].RawArchivePath, archivePath)
+ cand := candidates[0]
+ if cand.Manifest == nil {
+ t.Fatal("candidate Manifest is nil")
+ }
+ if cand.Manifest.SHA256 != checksum {
+ t.Fatalf("Manifest.SHA256 = %q; want %q", cand.Manifest.SHA256, checksum)
+ }
+ if cand.Integrity == nil {
+ t.Fatal("candidate Integrity is nil")
+ }
+ if cand.Integrity.Checksum != checksum {
+ t.Fatalf("Integrity.Checksum = %q; want %q", cand.Integrity.Checksum, checksum)
+ }
+ if cand.Integrity.Source != "checksum file and manifest" {
+ t.Fatalf("Integrity.Source = %q; want %q", cand.Integrity.Source, "checksum file and manifest")
+ }
+}
+
+func TestDiscoverBackupCandidates_RejectsMalformedOrConflictingChecksums(t *testing.T) {
+ tests := []struct {
+ name string
+ manifestSHA256 string
+ checksumData string
+ }{
+ {
+ name: "invalid manifest checksum",
+ manifestSHA256: "not-a-checksum",
+ checksumData: string(checksumLineForBytes("config.tar.xz", []byte("archive"))),
+ },
+ {
+ name: "invalid checksum file",
+ manifestSHA256: checksumHexForBytes([]byte("archive")),
+ checksumData: "not-a-checksum config.tar.xz\n",
+ },
+ {
+ name: "conflicting valid checksums",
+ manifestSHA256: checksumHexForBytes([]byte("archive")),
+ checksumData: string(checksumLineForBytes("config.tar.xz", []byte("different"))),
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ tmpDir := t.TempDir()
+ archivePath := filepath.Join(tmpDir, "config.tar.xz")
+ if err := os.WriteFile(archivePath, []byte("archive"), 0o600); err != nil {
+ t.Fatalf("write archive: %v", err)
+ }
+
+ manifest := backup.Manifest{
+ ArchivePath: "/etc/pve/config.tar.xz",
+ ProxmoxType: "pve",
+ CreatedAt: time.Now().UTC(),
+ EncryptionMode: "none",
+ SHA256: tt.manifestSHA256,
+ }
+ metaPath := archivePath + ".metadata"
+ data, err := json.Marshal(&manifest)
+ if err != nil {
+ t.Fatalf("marshal manifest: %v", err)
+ }
+ if err := os.WriteFile(metaPath, data, 0o600); err != nil {
+ t.Fatalf("write metadata: %v", err)
+ }
+ if err := os.WriteFile(archivePath+".sha256", []byte(tt.checksumData), 0o600); err != nil {
+ t.Fatalf("write checksum: %v", err)
+ }
+
+ candidates, err := discoverBackupCandidates(nil, tmpDir)
+ if err != nil {
+ t.Fatalf("discoverBackupCandidates() error = %v", err)
+ }
+ if len(candidates) != 0 {
+ t.Fatalf("discoverBackupCandidates() returned %d candidates; want 0", len(candidates))
+ }
+ })
+ }
+}
+
+func TestDiscoverRcloneBackups_NormalizesAndStoresIntegrityExpectation(t *testing.T) {
+ tmpDir := t.TempDir()
+
+ archiveData := []byte("archive")
+ checksum := checksumHexForBytes(archiveData)
+ manifest := backup.Manifest{
+ ArchivePath: "/var/backups/node-backup.tar.xz",
+ ProxmoxType: "pve",
+ CreatedAt: time.Date(2025, 12, 5, 12, 0, 0, 0, time.UTC),
+ EncryptionMode: "none",
+ SHA256: " " + strings.ToUpper(checksum) + " ",
+ }
+ metaBytes, err := json.Marshal(&manifest)
+ if err != nil {
+ t.Fatalf("marshal manifest: %v", err)
+ }
+ metadataPath := filepath.Join(tmpDir, "node-backup.tar.xz.metadata")
+ if err := os.WriteFile(metadataPath, metaBytes, 0o600); err != nil {
+ t.Fatalf("write metadata: %v", err)
+ }
+ checksumPath := filepath.Join(tmpDir, "node-backup.tar.xz.sha256")
+ checksumLine := strings.ToUpper(checksum) + " node-backup.tar.xz\n"
+ if err := os.WriteFile(checksumPath, []byte(checksumLine), 0o600); err != nil {
+ t.Fatalf("write checksum: %v", err)
+ }
+
+ scriptPath := filepath.Join(tmpDir, "rclone")
+ script := `#!/bin/sh
+subcmd="$1"
+target="$2"
+case "$subcmd" in
+ lsf)
+ printf 'node-backup.tar.xz\n'
+ printf 'node-backup.tar.xz.metadata\n'
+ printf 'node-backup.tar.xz.sha256\n'
+ ;;
+ cat)
+ case "$target" in
+ *node-backup.tar.xz.metadata) cat "$METADATA_PATH" ;;
+ *node-backup.tar.xz.sha256) cat "$CHECKSUM_PATH" ;;
+ *) echo "unexpected cat target: $target" >&2; exit 1 ;;
+ esac
+ ;;
+ *)
+ echo "unexpected subcommand: $subcmd" >&2
+ exit 1
+ ;;
+esac
+`
+ if err := os.WriteFile(scriptPath, []byte(script), 0o755); err != nil {
+ t.Fatalf("write fake rclone: %v", err)
+ }
+
+ oldPath := os.Getenv("PATH")
+ if err := os.Setenv("PATH", tmpDir+string(os.PathListSeparator)+oldPath); err != nil {
+ t.Fatalf("set PATH: %v", err)
+ }
+ defer os.Setenv("PATH", oldPath)
+ if err := os.Setenv("METADATA_PATH", metadataPath); err != nil {
+ t.Fatalf("set METADATA_PATH: %v", err)
+ }
+ defer os.Unsetenv("METADATA_PATH")
+ if err := os.Setenv("CHECKSUM_PATH", checksumPath); err != nil {
+ t.Fatalf("set CHECKSUM_PATH: %v", err)
+ }
+ defer os.Unsetenv("CHECKSUM_PATH")
+
+ candidates, err := discoverRcloneBackups(context.Background(), nil, "gdrive:pbs-backups/server1", nil, nil)
+ if err != nil {
+ t.Fatalf("discoverRcloneBackups() error = %v", err)
+ }
+ if len(candidates) != 1 {
+ t.Fatalf("discoverRcloneBackups() returned %d candidates; want 1", len(candidates))
+ }
+ cand := candidates[0]
+ if cand.Manifest == nil {
+ t.Fatal("candidate Manifest is nil")
+ }
+ if cand.Manifest.SHA256 != checksum {
+ t.Fatalf("Manifest.SHA256 = %q; want %q", cand.Manifest.SHA256, checksum)
+ }
+ if cand.Integrity == nil {
+ t.Fatal("candidate Integrity is nil")
+ }
+ if cand.Integrity.Checksum != checksum {
+ t.Fatalf("Integrity.Checksum = %q; want %q", cand.Integrity.Checksum, checksum)
+ }
+ if cand.Integrity.Source != "checksum file and manifest" {
+ t.Fatalf("Integrity.Source = %q; want %q", cand.Integrity.Source, "checksum file and manifest")
+ }
+}
+
+func TestDiscoverRcloneBackups_RejectsMalformedOrConflictingChecksums(t *testing.T) {
+ tests := []struct {
+ name string
+ manifestSHA256 string
+ checksumData string
+ }{
+ {
+ name: "invalid manifest checksum",
+ manifestSHA256: "not-a-checksum",
+ checksumData: string(checksumLineForBytes("node-backup.tar.xz", []byte("archive"))),
+ },
+ {
+ name: "invalid checksum file",
+ manifestSHA256: checksumHexForBytes([]byte("archive")),
+ checksumData: "not-a-checksum node-backup.tar.xz\n",
+ },
+ {
+ name: "conflicting valid checksums",
+ manifestSHA256: checksumHexForBytes([]byte("archive")),
+ checksumData: string(checksumLineForBytes("node-backup.tar.xz", []byte("different"))),
+ },
}
- if candidates[0].RawChecksumPath != "" {
- t.Fatalf("RawChecksumPath should be empty when checksum missing; got %q", candidates[0].RawChecksumPath)
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ tmpDir := t.TempDir()
+ manifest := backup.Manifest{
+ ArchivePath: "/var/backups/node-backup.tar.xz",
+ ProxmoxType: "pve",
+ CreatedAt: time.Date(2025, 12, 5, 12, 0, 0, 0, time.UTC),
+ EncryptionMode: "none",
+ SHA256: tt.manifestSHA256,
+ }
+ metaBytes, err := json.Marshal(&manifest)
+ if err != nil {
+ t.Fatalf("marshal manifest: %v", err)
+ }
+ metadataPath := filepath.Join(tmpDir, "node-backup.tar.xz.metadata")
+ if err := os.WriteFile(metadataPath, metaBytes, 0o600); err != nil {
+ t.Fatalf("write metadata: %v", err)
+ }
+ checksumPath := filepath.Join(tmpDir, "node-backup.tar.xz.sha256")
+ if err := os.WriteFile(checksumPath, []byte(tt.checksumData), 0o600); err != nil {
+ t.Fatalf("write checksum: %v", err)
+ }
+
+ scriptPath := filepath.Join(tmpDir, "rclone")
+ script := `#!/bin/sh
+subcmd="$1"
+target="$2"
+case "$subcmd" in
+ lsf)
+ printf 'node-backup.tar.xz\n'
+ printf 'node-backup.tar.xz.metadata\n'
+ printf 'node-backup.tar.xz.sha256\n'
+ ;;
+ cat)
+ case "$target" in
+ *node-backup.tar.xz.metadata) cat "$METADATA_PATH" ;;
+ *node-backup.tar.xz.sha256) cat "$CHECKSUM_PATH" ;;
+ *) echo "unexpected cat target: $target" >&2; exit 1 ;;
+ esac
+ ;;
+ *)
+ echo "unexpected subcommand: $subcmd" >&2
+ exit 1
+ ;;
+esac
+`
+ if err := os.WriteFile(scriptPath, []byte(script), 0o755); err != nil {
+ t.Fatalf("write fake rclone: %v", err)
+ }
+
+ oldPath := os.Getenv("PATH")
+ if err := os.Setenv("PATH", tmpDir+string(os.PathListSeparator)+oldPath); err != nil {
+ t.Fatalf("set PATH: %v", err)
+ }
+ defer os.Setenv("PATH", oldPath)
+ if err := os.Setenv("METADATA_PATH", metadataPath); err != nil {
+ t.Fatalf("set METADATA_PATH: %v", err)
+ }
+ defer os.Unsetenv("METADATA_PATH")
+ if err := os.Setenv("CHECKSUM_PATH", checksumPath); err != nil {
+ t.Fatalf("set CHECKSUM_PATH: %v", err)
+ }
+ defer os.Unsetenv("CHECKSUM_PATH")
+
+ candidates, err := discoverRcloneBackups(context.Background(), nil, "gdrive:pbs-backups/server1", nil, nil)
+ if err != nil {
+ t.Fatalf("discoverRcloneBackups() error = %v", err)
+ }
+ if len(candidates) != 0 {
+ t.Fatalf("discoverRcloneBackups() returned %d candidates; want 0", len(candidates))
+ }
+ })
}
}
diff --git a/internal/orchestrator/compatibility.go b/internal/orchestrator/compatibility.go
index cc2eb3f..bde7503 100644
--- a/internal/orchestrator/compatibility.go
+++ b/internal/orchestrator/compatibility.go
@@ -41,12 +41,8 @@ func DetectBackupType(manifest *backup.Manifest) SystemType {
// Check ProxmoxType field if present
if manifest.ProxmoxType != "" {
- proxmoxType := strings.ToLower(manifest.ProxmoxType)
- if strings.Contains(proxmoxType, "pve") || strings.Contains(proxmoxType, "proxmox-ve") {
- return SystemTypePVE
- }
- if strings.Contains(proxmoxType, "pbs") || strings.Contains(proxmoxType, "proxmox-backup") {
- return SystemTypePBS
+ if backupType := parseSystemTypeString(manifest.ProxmoxType); backupType != SystemTypeUnknown {
+ return backupType
}
}
@@ -65,12 +61,25 @@ func DetectBackupType(manifest *backup.Manifest) SystemType {
return SystemTypeUnknown
}
-// ValidateCompatibility checks if a backup is compatible with the current system
-func ValidateCompatibility(manifest *backup.Manifest) error {
- currentSystem := DetectCurrentSystem()
- backupType := DetectBackupType(manifest)
+func parseSystemTypeString(value string) SystemType {
+ normalized := strings.ToLower(strings.TrimSpace(value))
+ switch {
+ case strings.Contains(normalized, "pve"),
+ strings.Contains(normalized, "proxmox-ve"),
+ strings.Contains(normalized, "proxmox ve"):
+ return SystemTypePVE
+ case strings.Contains(normalized, "pbs"),
+ strings.Contains(normalized, "proxmox-backup"),
+ strings.Contains(normalized, "proxmox backup"),
+ strings.Contains(normalized, "proxmox backup server"):
+ return SystemTypePBS
+ default:
+ return SystemTypeUnknown
+ }
+}
- // If we can't detect either, issue a warning but allow
+// ValidateCompatibility checks if a backup is compatible with the current system.
+func ValidateCompatibility(currentSystem, backupType SystemType) error {
if currentSystem == SystemTypeUnknown {
return fmt.Errorf("warning: cannot detect current system type - restoration may fail")
}
diff --git a/internal/orchestrator/compatibility_test.go b/internal/orchestrator/compatibility_test.go
index 5c7ef98..70bafb1 100644
--- a/internal/orchestrator/compatibility_test.go
+++ b/internal/orchestrator/compatibility_test.go
@@ -3,8 +3,6 @@ package orchestrator
import (
"os"
"testing"
-
- "github.com/tis24dev/proxsave/internal/backup"
)
func TestValidateCompatibility_Mismatch(t *testing.T) {
@@ -18,8 +16,7 @@ func TestValidateCompatibility_Mismatch(t *testing.T) {
t.Fatalf("mkdir: %v", err)
}
- manifest := &backup.Manifest{ProxmoxType: "pbs"}
- if err := ValidateCompatibility(manifest); err == nil {
+ if err := ValidateCompatibility(SystemTypePVE, SystemTypePBS); err == nil {
t.Fatalf("expected incompatibility error")
}
}
@@ -36,6 +33,38 @@ func TestDetectCurrentSystem_Unknown(t *testing.T) {
}
}
+func TestParseSystemTypeString_AcceptsFullNames(t *testing.T) {
+ tests := []struct {
+ name string
+ input string
+ want SystemType
+ }{
+ {
+ name: "pve full name with space",
+ input: "Proxmox VE",
+ want: SystemTypePVE,
+ },
+ {
+ name: "pbs generic full name",
+ input: "Proxmox Backup",
+ want: SystemTypePBS,
+ },
+ {
+ name: "pbs full server name",
+ input: "Proxmox Backup Server",
+ want: SystemTypePBS,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if got := parseSystemTypeString(tt.input); got != tt.want {
+ t.Fatalf("parseSystemTypeString(%q) = %v; want %v", tt.input, got, tt.want)
+ }
+ })
+ }
+}
+
func TestGetSystemInfoDetectsPVE(t *testing.T) {
orig := compatFS
defer func() { compatFS = orig }()
diff --git a/internal/orchestrator/decrypt.go b/internal/orchestrator/decrypt.go
index c995530..83b44f7 100644
--- a/internal/orchestrator/decrypt.go
+++ b/internal/orchestrator/decrypt.go
@@ -43,6 +43,7 @@ type decryptCandidate struct {
RawArchivePath string
RawMetadataPath string
RawChecksumPath string
+ Integrity *stagedIntegrityExpectation
DisplayBase string
IsRclone bool
}
@@ -54,10 +55,11 @@ type stagedFiles struct {
}
type preparedBundle struct {
- ArchivePath string
- Manifest backup.Manifest
- Checksum string
- cleanup func()
+ ArchivePath string
+ Manifest backup.Manifest
+ Checksum string
+ SourceChecksum string
+ cleanup func()
}
func (p *preparedBundle) Cleanup() {
diff --git a/internal/orchestrator/decrypt_integrity.go b/internal/orchestrator/decrypt_integrity.go
new file mode 100644
index 0000000..aba5784
--- /dev/null
+++ b/internal/orchestrator/decrypt_integrity.go
@@ -0,0 +1,120 @@
+package orchestrator
+
+import (
+ "context"
+ "fmt"
+ "strings"
+
+ "github.com/tis24dev/proxsave/internal/backup"
+ "github.com/tis24dev/proxsave/internal/logging"
+)
+
+type stagedIntegrityExpectation struct {
+ Checksum string
+ Source string
+}
+
+func resolveIntegrityExpectationValues(checksumFromFile, checksumFromManifest string) (*stagedIntegrityExpectation, error) {
+ switch {
+ case checksumFromFile != "" && checksumFromManifest != "" && checksumFromFile != checksumFromManifest:
+ return nil, fmt.Errorf("checksum mismatch between checksum file and manifest")
+ case checksumFromFile != "" && checksumFromManifest != "":
+ return &stagedIntegrityExpectation{Checksum: checksumFromFile, Source: "checksum file and manifest"}, nil
+ case checksumFromFile != "":
+ return &stagedIntegrityExpectation{Checksum: checksumFromFile, Source: "checksum file"}, nil
+ case checksumFromManifest != "":
+ return &stagedIntegrityExpectation{Checksum: checksumFromManifest, Source: "manifest"}, nil
+ default:
+ return nil, fmt.Errorf("backup has no checksum verification available")
+ }
+}
+
+func resolveStagedIntegrityExpectation(staged stagedFiles, manifest *backup.Manifest) (*stagedIntegrityExpectation, error) {
+ var (
+ checksumFromFile string
+ checksumFromManifest string
+ )
+
+ if strings.TrimSpace(staged.ChecksumPath) != "" {
+ data, err := restoreFS.ReadFile(staged.ChecksumPath)
+ if err != nil {
+ return nil, fmt.Errorf("read checksum file: %w", err)
+ }
+ checksumFromFile, err = backup.ParseChecksumData(data)
+ if err != nil {
+ return nil, fmt.Errorf("parse checksum file %s: %w", staged.ChecksumPath, err)
+ }
+ }
+
+ if manifest != nil && strings.TrimSpace(manifest.SHA256) != "" {
+ normalized, err := backup.NormalizeChecksum(manifest.SHA256)
+ if err != nil {
+ return nil, fmt.Errorf("parse manifest checksum: %w", err)
+ }
+ checksumFromManifest = normalized
+ }
+
+ return resolveIntegrityExpectationValues(checksumFromFile, checksumFromManifest)
+}
+
+func resolveCandidateIntegrityExpectation(staged stagedFiles, cand *decryptCandidate) (*stagedIntegrityExpectation, error) {
+ if cand != nil && cand.Integrity != nil && strings.TrimSpace(cand.Integrity.Checksum) != "" {
+ normalized, err := backup.NormalizeChecksum(cand.Integrity.Checksum)
+ if err != nil {
+ return nil, fmt.Errorf("parse candidate checksum: %w", err)
+ }
+ expectation := &stagedIntegrityExpectation{
+ Checksum: normalized,
+ Source: strings.TrimSpace(cand.Integrity.Source),
+ }
+ if expectation.Source == "" {
+ expectation.Source = "candidate"
+ }
+
+ if strings.TrimSpace(staged.ChecksumPath) != "" {
+ data, err := restoreFS.ReadFile(staged.ChecksumPath)
+ if err != nil {
+ return nil, fmt.Errorf("read checksum file: %w", err)
+ }
+ checksumFromFile, err := backup.ParseChecksumData(data)
+ if err != nil {
+ return nil, fmt.Errorf("parse checksum file %s: %w", staged.ChecksumPath, err)
+ }
+ if checksumFromFile != expectation.Checksum {
+ return nil, fmt.Errorf("checksum mismatch between checksum file and selected candidate")
+ }
+ }
+
+ return expectation, nil
+ }
+
+ var manifest *backup.Manifest
+ if cand != nil {
+ manifest = cand.Manifest
+ }
+ return resolveStagedIntegrityExpectation(staged, manifest)
+}
+
+func verifyStagedArchiveIntegrity(ctx context.Context, logger *logging.Logger, staged stagedFiles, cand *decryptCandidate) (string, error) {
+ if staged.ArchivePath == "" {
+ return "", fmt.Errorf("staged archive path is empty")
+ }
+ if logger == nil {
+ logger = logging.GetDefaultLogger()
+ }
+
+ expectation, err := resolveCandidateIntegrityExpectation(staged, cand)
+ if err != nil {
+ return "", err
+ }
+
+ logger.Info("Verifying staged archive integrity using %s", expectation.Source)
+ ok, err := backup.VerifyChecksum(ctx, logger, staged.ArchivePath, expectation.Checksum)
+ if err != nil {
+ return "", fmt.Errorf("verify staged archive: %w", err)
+ }
+ if !ok {
+ return "", fmt.Errorf("staged archive checksum mismatch")
+ }
+ return expectation.Checksum, nil
+}
diff --git a/internal/orchestrator/decrypt_integrity_test.go b/internal/orchestrator/decrypt_integrity_test.go
new file mode 100644
index 0000000..a77b685
--- /dev/null
+++ b/internal/orchestrator/decrypt_integrity_test.go
@@ -0,0 +1,160 @@
+package orchestrator
+
+import (
+ "bufio"
+ "context"
+ "encoding/json"
+ "os"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/tis24dev/proxsave/internal/backup"
+ "github.com/tis24dev/proxsave/internal/logging"
+ "github.com/tis24dev/proxsave/internal/types"
+)
+
+func TestResolveStagedIntegrityExpectation_RejectsConflictingSources(t *testing.T) {
+ origFS := restoreFS
+ restoreFS = osFS{}
+ t.Cleanup(func() { restoreFS = origFS })
+
+ dir := t.TempDir()
+ checksumPath := dir + "/backup.tar.sha256"
+ if err := os.WriteFile(checksumPath, checksumLineForBytes("backup.tar", []byte("archive")), 0o640); err != nil {
+ t.Fatalf("write checksum: %v", err)
+ }
+
+ _, err := resolveStagedIntegrityExpectation(stagedFiles{ChecksumPath: checksumPath}, &backup.Manifest{
+ SHA256: checksumHexForBytes([]byte("different")),
+ })
+ if err == nil {
+ t.Fatalf("expected conflict error")
+ }
+ if !strings.Contains(err.Error(), "checksum mismatch") {
+ t.Fatalf("expected checksum mismatch error, got %v", err)
+ }
+}
+
+func TestPreparePlainBundle_RejectsMissingChecksumVerification(t *testing.T) {
+ origFS := restoreFS
+ restoreFS = osFS{}
+ t.Cleanup(func() { restoreFS = origFS })
+
+ dir := t.TempDir()
+ archivePath := dir + "/backup.tar"
+ if err := os.WriteFile(archivePath, []byte("archive"), 0o640); err != nil {
+ t.Fatalf("write archive: %v", err)
+ }
+ metaPath := archivePath + ".metadata"
+ manifest := &backup.Manifest{
+ ArchivePath: archivePath,
+ CreatedAt: time.Now(),
+ EncryptionMode: "none",
+ }
+ data, err := json.Marshal(manifest)
+ if err != nil {
+ t.Fatalf("marshal metadata: %v", err)
+ }
+ if err := os.WriteFile(metaPath, data, 0o640); err != nil {
+ t.Fatalf("write metadata: %v", err)
+ }
+
+ cand := &decryptCandidate{
+ Manifest: manifest,
+ Source: sourceRaw,
+ RawArchivePath: archivePath,
+ RawMetadataPath: metaPath,
+ DisplayBase: "backup.tar",
+ }
+
+ reader := bufio.NewReader(strings.NewReader(""))
+ logger := logging.New(types.LogLevelError, false)
+
+ _, err = preparePlainBundle(context.Background(), reader, cand, "", logger)
+ if err == nil {
+ t.Fatalf("expected missing checksum verification error")
+ }
+ if !strings.Contains(err.Error(), "no checksum verification available") {
+ t.Fatalf("expected checksum availability error, got %v", err)
+ }
+}
+
+func TestPreparePlainBundle_RejectsChecksumMismatch(t *testing.T) {
+ origFS := restoreFS
+ restoreFS = osFS{}
+ t.Cleanup(func() { restoreFS = origFS })
+
+ dir := t.TempDir()
+ archiveData := []byte("archive")
+ archivePath := dir + "/backup.tar"
+ if err := os.WriteFile(archivePath, archiveData, 0o640); err != nil {
+ t.Fatalf("write archive: %v", err)
+ }
+ metaPath := archivePath + ".metadata"
+ manifest := &backup.Manifest{
+ ArchivePath: archivePath,
+ CreatedAt: time.Now(),
+ EncryptionMode: "none",
+ }
+ data, err := json.Marshal(manifest)
+ if err != nil {
+ t.Fatalf("marshal metadata: %v", err)
+ }
+ if err := os.WriteFile(metaPath, data, 0o640); err != nil {
+ t.Fatalf("write metadata: %v", err)
+ }
+ checksumPath := archivePath + ".sha256"
+ if err := os.WriteFile(checksumPath, checksumLineForBytes("backup.tar", []byte("tampered")), 0o640); err != nil {
+ t.Fatalf("write checksum: %v", err)
+ }
+
+ cand := &decryptCandidate{
+ Manifest: manifest,
+ Source: sourceRaw,
+ RawArchivePath: archivePath,
+ RawMetadataPath: metaPath,
+ RawChecksumPath: checksumPath,
+ DisplayBase: "backup.tar",
+ }
+
+ reader := bufio.NewReader(strings.NewReader(""))
+ logger := logging.New(types.LogLevelError, false)
+
+ _, err = preparePlainBundle(context.Background(), reader, cand, "", logger)
+ if err == nil {
+ t.Fatalf("expected checksum mismatch error")
+ }
+ if !strings.Contains(err.Error(), "checksum mismatch") {
+ t.Fatalf("expected checksum mismatch error, got %v", err)
+ }
+}
+
+func TestVerifyStagedArchiveIntegrity_UsesCandidateIntegrityExpectation(t *testing.T) {
+ origFS := restoreFS
+ restoreFS = osFS{}
+ t.Cleanup(func() { restoreFS = origFS })
+
+ dir := t.TempDir()
+ archiveData := []byte("archive")
+ archivePath := dir + "/backup.tar"
+ if err := os.WriteFile(archivePath, archiveData, 0o640); err != nil {
+ t.Fatalf("write archive: %v", err)
+ }
+
+ got, err := verifyStagedArchiveIntegrity(context.Background(), logging.New(types.LogLevelError, false), stagedFiles{
+ ArchivePath: archivePath,
+ }, &decryptCandidate{
+ Integrity: &stagedIntegrityExpectation{
+ Checksum: strings.ToUpper(checksumHexForBytes(archiveData)),
+ Source: "checksum file",
+ },
+ })
+ if err != nil {
+ t.Fatalf("verifyStagedArchiveIntegrity() error = %v", err)
+ }
+ want := checksumHexForBytes(archiveData)
+ if got != want {
+ t.Fatalf("verifyStagedArchiveIntegrity() = %q; want %q", got, want)
+ }
+}
diff --git a/internal/orchestrator/decrypt_integrity_test_helpers_test.go b/internal/orchestrator/decrypt_integrity_test_helpers_test.go
new file mode 100644
index 0000000..f8f022a
--- /dev/null
+++ b/internal/orchestrator/decrypt_integrity_test_helpers_test.go
@@ -0,0 +1,15 @@
+package orchestrator
+
+import (
+ "crypto/sha256"
+ "fmt"
+)
+
+func checksumHexForBytes(data []byte) string {
+ sum := sha256.Sum256(data)
+ return fmt.Sprintf("%x", sum)
+}
+
+func checksumLineForBytes(filename string, data []byte) []byte {
+ return []byte(fmt.Sprintf("%s %s", checksumHexForBytes(data), filename))
+}
diff --git a/internal/orchestrator/decrypt_prepare_common.go b/internal/orchestrator/decrypt_prepare_common.go
new file mode 100644
index 0000000..857832f
--- /dev/null
+++ b/internal/orchestrator/decrypt_prepare_common.go
@@ -0,0 +1,133 @@
+package orchestrator
+
+import (
+ "context"
+ "fmt"
+ "path/filepath"
+ "strings"
+
+ "github.com/tis24dev/proxsave/internal/backup"
+ "github.com/tis24dev/proxsave/internal/logging"
+)
+
+type archiveDecryptFunc func(ctx context.Context, encryptedPath, outputPath, displayName string) error
+
+func preparePlainBundleCommon(ctx context.Context, cand *decryptCandidate, version string, logger *logging.Logger, decryptArchive archiveDecryptFunc) (bundle *preparedBundle, err error) {
+ if cand == nil || cand.Manifest == nil {
+ return nil, fmt.Errorf("invalid backup candidate")
+ }
+ if logger == nil {
+ logger = logging.GetDefaultLogger()
+ }
+
+ var rcloneCleanup func()
+ if cand.IsRclone && cand.Source == sourceBundle {
+ logger.Debug("Detected rclone backup, downloading...")
+ localPath, cleanup, err := downloadRcloneBackup(ctx, cand.BundlePath, logger)
+ if err != nil {
+ return nil, fmt.Errorf("failed to download rclone backup: %w", err)
+ }
+ rcloneCleanup = cleanup
+ cand.BundlePath = localPath
+ }
+
+ tempRoot := filepath.Join("/tmp", "proxsave")
+ if err := restoreFS.MkdirAll(tempRoot, 0o755); err != nil {
+ if rcloneCleanup != nil {
+ rcloneCleanup()
+ }
+ return nil, fmt.Errorf("create temp root: %w", err)
+ }
+
+ workDir, err := restoreFS.MkdirTemp(tempRoot, "proxmox-decrypt-*")
+ if err != nil {
+ if rcloneCleanup != nil {
+ rcloneCleanup()
+ }
+ return nil, fmt.Errorf("create temp dir: %w", err)
+ }
+
+ cleanup := func() {
+ _ = restoreFS.RemoveAll(workDir)
+ if rcloneCleanup != nil {
+ rcloneCleanup()
+ }
+ }
+
+ var staged stagedFiles
+ switch cand.Source {
+ case sourceBundle:
+ logger.Info("Extracting bundle %s", filepath.Base(cand.BundlePath))
+ staged, err = extractBundleToWorkdirWithLogger(cand.BundlePath, workDir, logger)
+ case sourceRaw:
+ logger.Info("Staging raw artifacts for %s", filepath.Base(cand.RawArchivePath))
+ staged, err = copyRawArtifactsToWorkdirWithLogger(ctx, cand, workDir, logger)
+ default:
+ err = fmt.Errorf("unsupported candidate source")
+ }
+ if err != nil {
+ cleanup()
+ return nil, err
+ }
+
+ sourceChecksum, err := verifyStagedArchiveIntegrity(ctx, logger, staged, cand)
+ if err != nil {
+ cleanup()
+ return nil, err
+ }
+
+ manifestCopy := *cand.Manifest
+ currentEncryption := strings.ToLower(manifestCopy.EncryptionMode)
+ logger.Info("Preparing archive %s for decryption (mode: %s)", manifestCopy.ArchivePath, statusFromManifest(&manifestCopy))
+
+ plainArchiveName := strings.TrimSuffix(filepath.Base(staged.ArchivePath), ".age")
+ plainArchivePath := filepath.Join(workDir, plainArchiveName)
+
+ if currentEncryption == "age" {
+ if decryptArchive == nil {
+ cleanup()
+ return nil, fmt.Errorf("decrypt function not available")
+ }
+ displayName := cand.DisplayBase
+ if strings.TrimSpace(displayName) == "" {
+ displayName = filepath.Base(manifestCopy.ArchivePath)
+ }
+ if err := decryptArchive(ctx, staged.ArchivePath, plainArchivePath, displayName); err != nil {
+ cleanup()
+ return nil, err
+ }
+ } else if staged.ArchivePath != plainArchivePath {
+ if err := copyFile(restoreFS, staged.ArchivePath, plainArchivePath); err != nil {
+ cleanup()
+ return nil, fmt.Errorf("copy archive: %w", err)
+ }
+ }
+
+ archiveInfo, err := restoreFS.Stat(plainArchivePath)
+ if err != nil {
+ cleanup()
+ return nil, fmt.Errorf("stat decrypted archive: %w", err)
+ }
+
+ plainChecksum, err := backup.GenerateChecksum(ctx, logger, plainArchivePath)
+ if err != nil {
+ cleanup()
+ return nil, fmt.Errorf("generate checksum: %w", err)
+ }
+
+ manifestCopy.ArchivePath = plainArchivePath
+ manifestCopy.ArchiveSize = archiveInfo.Size()
+ manifestCopy.SHA256 = plainChecksum
+ manifestCopy.EncryptionMode = "none"
+ if version != "" {
+ manifestCopy.ScriptVersion = version
+ }
+
+ return &preparedBundle{
+ ArchivePath: plainArchivePath,
+ Manifest: manifestCopy,
+ Checksum: plainChecksum,
+ SourceChecksum: sourceChecksum,
+ cleanup: cleanup,
+ }, nil
+}
diff --git a/internal/orchestrator/decrypt_test.go b/internal/orchestrator/decrypt_test.go
index 84fcee1..dabad3e 100644
--- a/internal/orchestrator/decrypt_test.go
+++ b/internal/orchestrator/decrypt_test.go
@@ -1575,14 +1575,15 @@ func TestPreparePlainBundle_SourceBundleSuccess(t *testing.T) {
dir := t.TempDir()
// Create bundle with required files
+ archiveData := []byte("archive data")
manifestData, _ := json.Marshal(&backup.Manifest{
ArchivePath: filepath.Join(dir, "archive.tar.xz"),
EncryptionMode: "none",
})
bundlePath := createTestBundle(t, []bundleEntry{
- {name: "archive.tar.xz", data: []byte("archive data")},
+ {name: "archive.tar.xz", data: archiveData},
{name: "backup.metadata", data: manifestData},
- {name: "backup.sha256", data: []byte("abc123 archive.tar.xz")},
+ {name: "backup.sha256", data: checksumLineForBytes("archive.tar.xz", archiveData)},
})
cand := &decryptCandidate{
@@ -2819,7 +2820,7 @@ func TestPreparePlainBundle_CopyFileSamePath(t *testing.T) {
t.Fatalf("write metadata: %v", err)
}
checksumPath := archivePath + ".sha256"
- if err := os.WriteFile(checksumPath, []byte("abc123 backup.tar.xz"), 0o644); err != nil {
+ if err := os.WriteFile(checksumPath, checksumLineForBytes("backup.tar.xz", []byte("archive content")), 0o644); err != nil {
t.Fatalf("write checksum: %v", err)
}
@@ -2890,7 +2891,7 @@ func TestPreparePlainBundle_AgeDecryptionWithRclone(t *testing.T) {
tw.Write(manifestData)
// Add checksum
- checksumData := []byte("abc123 backup.tar.xz.age")
+ checksumData := checksumLineForBytes("backup.tar.xz.age", archiveContent)
tw.WriteHeader(&tar.Header{Name: "backup.sha256", Size: int64(len(checksumData)), Mode: 0o600})
tw.Write(checksumData)
@@ -3017,7 +3018,7 @@ func TestPreparePlainBundle_SourceBundleAdditional(t *testing.T) {
tw.WriteHeader(&tar.Header{Name: "backup.metadata", Size: int64(len(manifestData)), Mode: 0o600})
tw.Write(manifestData)
- checksumData := []byte("abc123 backup.tar.xz")
+ checksumData := checksumLineForBytes("backup.tar.xz", archiveData)
tw.WriteHeader(&tar.Header{Name: "backup.sha256", Size: int64(len(checksumData)), Mode: 0o600})
tw.Write(checksumData)
@@ -3393,7 +3394,7 @@ func TestExtractBundleToWorkdir_OpenFileErrorOnExtract(t *testing.T) {
}
// Add checksum
- checksum := []byte("checksum backup.tar.xz\n")
+ checksum := checksumLineForBytes("backup.tar.xz", archiveData)
if err := tw.WriteHeader(&tar.Header{Name: "backup.sha256", Size: int64(len(checksum)), Mode: 0o640}); err != nil {
t.Fatalf("write checksum header: %v", err)
}
@@ -3608,7 +3609,7 @@ func TestSelectDecryptCandidate_RequireEncryptedAllPlain(t *testing.T) {
tw.Write(metaJSON)
// Add checksum
- checksum := []byte("abc123 backup.tar.xz\n")
+ checksum := checksumLineForBytes("backup.tar.xz", archiveData)
tw.WriteHeader(&tar.Header{Name: "backup.sha256", Size: int64(len(checksum)), Mode: 0o640})
tw.Write(checksum)
tw.Close()
@@ -3720,7 +3721,7 @@ exit 1
tw.WriteHeader(&tar.Header{Name: "backup.metadata", Size: int64(len(metaJSON)), Mode: 0o640})
tw.Write(metaJSON)
- checksum := []byte("abc123 backup.tar.xz\n")
+ checksum := checksumLineForBytes("backup.tar.xz", archiveData)
tw.WriteHeader(&tar.Header{Name: "backup.sha256", Size: int64(len(checksum)), Mode: 0o640})
tw.Write(checksum)
tw.Close()
@@ -3771,7 +3772,7 @@ func TestPreparePlainBundle_StatErrorAfterExtract(t *testing.T) {
tw.WriteHeader(&tar.Header{Name: "backup.metadata", Size: int64(len(metaJSON)), Mode: 0o640})
tw.Write(metaJSON)
- checksum := []byte("abc123 backup.tar.xz\n")
+ checksum := checksumLineForBytes("backup.tar.xz", archiveData)
tw.WriteHeader(&tar.Header{Name: "backup.sha256", Size: int64(len(checksum)), Mode: 0o640})
tw.Write(checksum)
tw.Close()
@@ -3869,8 +3870,9 @@ func TestPreparePlainBundle_MkdirTempErrorWithRcloneCleanup(t *testing.T) {
metaJSON, _ := json.Marshal(backup.Manifest{EncryptionMode: "none", ArchivePath: "backup.tar.xz"})
tw.WriteHeader(&tar.Header{Name: "backup.metadata", Size: int64(len(metaJSON)), Mode: 0o640})
tw.Write(metaJSON)
- tw.WriteHeader(&tar.Header{Name: "backup.sha256", Size: 5, Mode: 0o640})
- tw.Write([]byte("hash\n"))
+ checksum := checksumLineForBytes("backup.tar.xz", archiveData)
+ tw.WriteHeader(&tar.Header{Name: "backup.sha256", Size: int64(len(checksum)), Mode: 0o640})
+ tw.Write(checksum)
tw.Close()
bundleFile.Close()
@@ -4066,7 +4068,7 @@ func TestPreparePlainBundle_CopyFileError(t *testing.T) {
tw.WriteHeader(&tar.Header{Name: "backup.metadata", Size: int64(len(metaJSON)), Mode: 0o640})
tw.Write(metaJSON)
- checksum := []byte("abc123 backup.tar.xz\n")
+ checksum := checksumLineForBytes("backup.tar.xz", archiveData)
tw.WriteHeader(&tar.Header{Name: "backup.sha256", Size: int64(len(checksum)), Mode: 0o640})
tw.Write(checksum)
tw.Close()
@@ -4175,7 +4177,7 @@ func TestPreparePlainBundle_StatErrorOnPlainArchive(t *testing.T) {
tw.WriteHeader(&tar.Header{Name: "backup.metadata", Size: int64(len(metaJSON)), Mode: 0o640})
tw.Write(metaJSON)
- checksum := []byte("abc123 backup.tar.xz\n")
+ checksum := checksumLineForBytes("backup.tar.xz", archiveData)
tw.WriteHeader(&tar.Header{Name: "backup.sha256", Size: int64(len(checksum)), Mode: 0o640})
tw.Write(checksum)
tw.Close()
@@ -4311,7 +4313,7 @@ func TestPreparePlainBundle_GenerateChecksumErrorPath(t *testing.T) {
tw.WriteHeader(&tar.Header{Name: "backup.metadata", Size: int64(len(metaJSON)), Mode: 0o640})
tw.Write(metaJSON)
- checksum := []byte("abc123 backup.tar.xz\n")
+ checksum := checksumLineForBytes("backup.tar.xz", archiveData)
tw.WriteHeader(&tar.Header{Name: "backup.sha256", Size: int64(len(checksum)), Mode: 0o640})
tw.Write(checksum)
tw.Close()
diff --git a/internal/orchestrator/decrypt_tui.go b/internal/orchestrator/decrypt_tui.go
index 655f7f5..7602d8a 100644
--- a/internal/orchestrator/decrypt_tui.go
+++ b/internal/orchestrator/decrypt_tui.go
@@ -233,109 +233,9 @@ func promptNewPathInput(defaultPath, configPath, buildSig string) (string, error
}
func preparePlainBundleTUI(ctx context.Context, cand *decryptCandidate, version string, logger *logging.Logger, configPath, buildSig string) (*preparedBundle, error) {
- if cand == nil || cand.Manifest == nil {
- return nil, fmt.Errorf("invalid backup candidate")
- }
-
- // If this is an rclone-backed bundle, download it first into the local temp area.
- var rcloneCleanup func()
- if cand.IsRclone && cand.Source == sourceBundle {
- logger.Debug("Detected rclone backup, downloading for TUI workflow...")
- localPath, cleanupFn, err := downloadRcloneBackup(ctx, cand.BundlePath, logger)
- if err != nil {
- return nil, fmt.Errorf("failed to download rclone backup: %w", err)
- }
- rcloneCleanup = cleanupFn
- cand.BundlePath = localPath
- }
-
- tempRoot := filepath.Join("/tmp", "proxsave")
- if err := restoreFS.MkdirAll(tempRoot, 0o755); err != nil {
- if rcloneCleanup != nil {
- rcloneCleanup()
- }
- return nil, fmt.Errorf("create temp root: %w", err)
- }
- workDir, err := restoreFS.MkdirTemp(tempRoot, "proxmox-decrypt-*")
- if err != nil {
- if rcloneCleanup != nil {
- rcloneCleanup()
- }
- return nil, fmt.Errorf("create temp dir: %w", err)
- }
- cleanup := func() {
- _ = restoreFS.RemoveAll(workDir)
- if rcloneCleanup != nil {
- rcloneCleanup()
- }
- }
-
- var staged stagedFiles
- switch cand.Source {
- case sourceBundle:
- logger.Debug("Extracting bundle %s", filepath.Base(cand.BundlePath))
- staged, err = extractBundleToWorkdirWithLogger(cand.BundlePath, workDir, logger)
- case sourceRaw:
- logger.Debug("Staging raw artifacts for %s", filepath.Base(cand.RawArchivePath))
- staged, err = copyRawArtifactsToWorkdirWithLogger(ctx, cand, workDir, logger)
- default:
- err = fmt.Errorf("unsupported candidate source")
- }
- if err != nil {
- cleanup()
- return nil, err
- }
-
- manifestCopy := *cand.Manifest
- currentEncryption := strings.ToLower(manifestCopy.EncryptionMode)
-
- logger.Debug("Preparing archive %s for decryption (mode: %s)", filepath.Base(manifestCopy.ArchivePath), statusFromManifest(&manifestCopy))
-
- plainArchiveName := strings.TrimSuffix(filepath.Base(staged.ArchivePath), ".age")
- plainArchivePath := filepath.Join(workDir, plainArchiveName)
-
- if currentEncryption == "age" {
- displayName := cand.DisplayBase
- if displayName == "" {
- displayName = filepath.Base(manifestCopy.ArchivePath)
- }
- if err := decryptArchiveWithTUIPrompts(ctx, staged.ArchivePath, plainArchivePath, displayName, configPath, buildSig, logger); err != nil {
- cleanup()
- return nil, err
- }
- } else if staged.ArchivePath != plainArchivePath {
- if err := copyFile(restoreFS, staged.ArchivePath, plainArchivePath); err != nil {
- cleanup()
- return nil, fmt.Errorf("copy archive: %w", err)
- }
- }
-
- archiveInfo, err := restoreFS.Stat(plainArchivePath)
- if err != nil {
- cleanup()
- return nil, fmt.Errorf("stat decrypted archive: %w", err)
- }
-
- checksum, err := backup.GenerateChecksum(ctx, logger, plainArchivePath)
- if err != nil {
- cleanup()
- return nil, fmt.Errorf("generate checksum: %w", err)
- }
-
- manifestCopy.ArchivePath = plainArchivePath
- manifestCopy.ArchiveSize = archiveInfo.Size()
- manifestCopy.SHA256 = checksum
- manifestCopy.EncryptionMode = "none"
- if version != "" {
- manifestCopy.ScriptVersion = version
- }
-
- return &preparedBundle{
- ArchivePath: plainArchivePath,
- Manifest: manifestCopy,
- Checksum: checksum,
- cleanup: cleanup,
- }, nil
+ return preparePlainBundleCommon(ctx, cand, version, logger, func(ctx context.Context, encryptedPath, outputPath, displayName string) error {
+ return decryptArchiveWithTUIPrompts(ctx, encryptedPath, outputPath, displayName, configPath, buildSig, logger)
+ })
}
func decryptArchiveWithTUIPrompts(ctx context.Context, encryptedPath, outputPath, displayName, configPath, buildSig string, logger *logging.Logger) error {
diff --git a/internal/orchestrator/decrypt_tui_test.go b/internal/orchestrator/decrypt_tui_test.go
index 6404b76..f08f9a0 100644
--- a/internal/orchestrator/decrypt_tui_test.go
+++ b/internal/orchestrator/decrypt_tui_test.go
@@ -204,7 +204,7 @@ func TestPreparePlainBundleTUICopiesRawArtifacts(t *testing.T) {
if err := os.WriteFile(rawMetadata, []byte(`{"manifest":true}`), 0o640); err != nil {
t.Fatalf("write metadata: %v", err)
}
- if err := os.WriteFile(rawChecksum, []byte("checksum backup.tar\n"), 0o640); err != nil {
+ if err := os.WriteFile(rawChecksum, checksumLineForBytes("backup.tar", []byte("payload-data")), 0o640); err != nil {
t.Fatalf("write checksum: %v", err)
}
diff --git a/internal/orchestrator/decrypt_workflow_test.go b/internal/orchestrator/decrypt_workflow_test.go
index 3f68fa4..e145b7d 100644
--- a/internal/orchestrator/decrypt_workflow_test.go
+++ b/internal/orchestrator/decrypt_workflow_test.go
@@ -18,7 +18,8 @@ import (
func writeRawBackup(t *testing.T, dir, name string) *backup.Manifest {
t.Helper()
archive := filepath.Join(dir, name)
- if err := os.WriteFile(archive, []byte("data"), 0o640); err != nil {
+ archiveData := []byte("data")
+ if err := os.WriteFile(archive, archiveData, 0o640); err != nil {
t.Fatalf("write archive: %v", err)
}
manifest := &backup.Manifest{
@@ -31,7 +32,7 @@ func writeRawBackup(t *testing.T, dir, name string) *backup.Manifest {
if err := os.WriteFile(manifestPath, data, 0o640); err != nil {
t.Fatalf("write metadata: %v", err)
}
- if err := os.WriteFile(archive+".sha256", []byte("checksum file"), 0o640); err != nil {
+ if err := os.WriteFile(archive+".sha256", checksumLineForBytes(filepath.Base(archive), archiveData), 0o640); err != nil {
t.Fatalf("write checksum: %v", err)
}
return manifest
@@ -70,21 +71,23 @@ func TestRunDecryptWorkflow_BundleNotFound(t *testing.T) {
func TestPreparePlainBundle_AllowsMissingRawChecksumSidecar(t *testing.T) {
dir := t.TempDir()
archive := filepath.Join(dir, "bad.bundle.tar")
- if err := os.WriteFile(archive, []byte("data"), 0o640); err != nil {
+ archiveData := []byte("data")
+ if err := os.WriteFile(archive, archiveData, 0o640); err != nil {
t.Fatalf("write archive: %v", err)
}
manifest := &backup.Manifest{
ArchivePath: archive,
CreatedAt: time.Now(),
Hostname: "host",
+ SHA256: checksumHexForBytes(archiveData),
}
metaPath := archive + ".metadata"
data, _ := json.Marshal(manifest)
if err := os.WriteFile(metaPath, data, 0o640); err != nil {
t.Fatalf("write metadata: %v", err)
}
- // No checksum file: ProxSave should still allow restore/decrypt to proceed
- // (it re-computes checksums on the staged/plain archive anyway).
+ // No checksum sidecar: restore/decrypt should still proceed when the manifest
+ // already carries the expected archive checksum.
cand := &decryptCandidate{
Manifest: manifest,
@@ -100,7 +103,7 @@ func TestPreparePlainBundle_AllowsMissingRawChecksumSidecar(t *testing.T) {
t.Cleanup(func() { restoreFS = osFS{} })
if _, err := preparePlainBundle(context.Background(), reader, cand, "", logging.New(types.LogLevelInfo, false)); err != nil {
- t.Fatalf("expected missing checksum to be tolerated, got error: %v", err)
+ t.Fatalf("expected manifest checksum to cover missing sidecar, got error: %v", err)
}
}
diff --git a/internal/orchestrator/decrypt_workflow_ui.go b/internal/orchestrator/decrypt_workflow_ui.go
index 2ae37d7..7de3f69 100644
--- a/internal/orchestrator/decrypt_workflow_ui.go
+++ b/internal/orchestrator/decrypt_workflow_ui.go
@@ -179,113 +179,9 @@ func preparePlainBundleWithUI(ctx context.Context, cand *decryptCandidate, versi
}) (bundle *preparedBundle, err error) {
done := logging.DebugStart(logger, "prepare plain bundle (ui)", "source=%v rclone=%v", cand.Source, cand.IsRclone)
defer func() { done(err) }()
-
- if cand == nil || cand.Manifest == nil {
- return nil, fmt.Errorf("invalid backup candidate")
- }
-
- var rcloneCleanup func()
- if cand.IsRclone && cand.Source == sourceBundle {
- logger.Debug("Detected rclone backup, downloading...")
- localPath, cleanup, err := downloadRcloneBackup(ctx, cand.BundlePath, logger)
- if err != nil {
- return nil, fmt.Errorf("failed to download rclone backup: %w", err)
- }
- rcloneCleanup = cleanup
- cand.BundlePath = localPath
- }
-
- tempRoot := filepath.Join("/tmp", "proxsave")
- if err := restoreFS.MkdirAll(tempRoot, 0o755); err != nil {
- if rcloneCleanup != nil {
- rcloneCleanup()
- }
- return nil, fmt.Errorf("create temp root: %w", err)
- }
-
- workDir, err := restoreFS.MkdirTemp(tempRoot, "proxmox-decrypt-*")
- if err != nil {
- if rcloneCleanup != nil {
- rcloneCleanup()
- }
- return nil, fmt.Errorf("create temp dir: %w", err)
- }
-
- cleanup := func() {
- _ = restoreFS.RemoveAll(workDir)
- if rcloneCleanup != nil {
- rcloneCleanup()
- }
- }
-
- var staged stagedFiles
- switch cand.Source {
- case sourceBundle:
- logger.Info("Extracting bundle %s", filepath.Base(cand.BundlePath))
- staged, err = extractBundleToWorkdirWithLogger(cand.BundlePath, workDir, logger)
- case sourceRaw:
- logger.Info("Staging raw artifacts for %s", filepath.Base(cand.RawArchivePath))
- staged, err = copyRawArtifactsToWorkdirWithLogger(ctx, cand, workDir, logger)
- default:
- err = fmt.Errorf("unsupported candidate source")
- }
- if err != nil {
- cleanup()
- return nil, err
- }
-
- manifestCopy := *cand.Manifest
- currentEncryption := strings.ToLower(manifestCopy.EncryptionMode)
- logger.Info("Preparing archive %s for decryption (mode: %s)", manifestCopy.ArchivePath, statusFromManifest(&manifestCopy))
-
- plainArchiveName := strings.TrimSuffix(filepath.Base(staged.ArchivePath), ".age")
- plainArchivePath := filepath.Join(workDir, plainArchiveName)
-
- if currentEncryption == "age" {
- displayName := cand.DisplayBase
- if strings.TrimSpace(displayName) == "" {
- displayName = filepath.Base(manifestCopy.ArchivePath)
- }
- if err := decryptArchiveWithSecretPrompt(ctx, staged.ArchivePath, plainArchivePath, displayName, ui.PromptDecryptSecret); err != nil {
- cleanup()
- return nil, err
- }
- } else {
- if staged.ArchivePath != plainArchivePath {
- if err := copyFile(restoreFS, staged.ArchivePath, plainArchivePath); err != nil {
- cleanup()
- return nil, fmt.Errorf("copy archive: %w", err)
- }
- }
- }
-
- archiveInfo, err := restoreFS.Stat(plainArchivePath)
- if err != nil {
- cleanup()
- return nil, fmt.Errorf("stat decrypted archive: %w", err)
- }
-
- checksum, err := backup.GenerateChecksum(ctx, logger, plainArchivePath)
- if err != nil {
- cleanup()
- return nil, fmt.Errorf("generate checksum: %w", err)
- }
-
- manifestCopy.ArchivePath = plainArchivePath
- manifestCopy.ArchiveSize = archiveInfo.Size()
- manifestCopy.SHA256 = checksum
- manifestCopy.EncryptionMode = "none"
- if version != "" {
- manifestCopy.ScriptVersion = version
- }
-
- bundle = &preparedBundle{
- ArchivePath: plainArchivePath,
- Manifest: manifestCopy,
- Checksum: checksum,
- cleanup: cleanup,
- }
- return bundle, nil
+ return preparePlainBundleCommon(ctx, cand, version, logger, func(ctx context.Context, encryptedPath, outputPath, displayName string) error {
+ return decryptArchiveWithSecretPrompt(ctx, encryptedPath, outputPath, displayName, ui.PromptDecryptSecret)
+ })
}
func runDecryptWorkflowWithUI(ctx context.Context, cfg *config.Config, logger *logging.Logger, version string, ui DecryptWorkflowUI) (err error) {
diff --git a/internal/orchestrator/deps.go b/internal/orchestrator/deps.go
index 648e20b..6530c30 100644
--- a/internal/orchestrator/deps.go
+++ b/internal/orchestrator/deps.go
@@ -16,6 +16,7 @@ import (
// FS abstracts filesystem operations to simplify testing.
type FS interface {
Stat(path string) (os.FileInfo, error)
+ Lstat(path string) (os.FileInfo, error)
ReadFile(path string) ([]byte, error)
Open(path string) (*os.File, error)
OpenFile(path string, flag int, perm fs.FileMode) (*os.File, error)
@@ -69,9 +70,10 @@ type Deps struct {
type osFS struct{}
-func (osFS) Stat(path string) (os.FileInfo, error) { return os.Stat(path) }
-func (osFS) ReadFile(path string) ([]byte, error) { return os.ReadFile(path) }
-func (osFS) Open(path string) (*os.File, error) { return os.Open(path) }
+func (osFS) Stat(path string) (os.FileInfo, error) { return os.Stat(path) }
+func (osFS) Lstat(path string) (os.FileInfo, error) { return os.Lstat(path) }
+func (osFS) ReadFile(path string) ([]byte, error) { return os.ReadFile(path) }
+func (osFS) Open(path string) (*os.File, error) { return os.Open(path) }
func (osFS) OpenFile(path string, flag int, perm fs.FileMode) (*os.File, error) {
return os.OpenFile(path, flag, perm)
}
diff --git a/internal/orchestrator/deps_test.go b/internal/orchestrator/deps_test.go
index ba4d404..986051f 100644
--- a/internal/orchestrator/deps_test.go
+++ b/internal/orchestrator/deps_test.go
@@ -60,6 +60,16 @@ func (f *FakeFS) Stat(path string) (os.FileInfo, error) {
return os.Stat(f.onDisk(path))
}
+func (f *FakeFS) Lstat(path string) (os.FileInfo, error) {
+ if err, ok := f.StatErr[filepath.Clean(path)]; ok {
+ return nil, err
+ }
+ if err, ok := f.StatErrors[filepath.Clean(path)]; ok {
+ return nil, err
+ }
+ return os.Lstat(f.onDisk(path))
+}
+
func (f *FakeFS) ReadFile(path string) ([]byte, error) {
return os.ReadFile(f.onDisk(path))
}
diff --git a/internal/orchestrator/network_apply_additional_test.go b/internal/orchestrator/network_apply_additional_test.go
new file mode 100644
index 0000000..b3e2780
--- /dev/null
+++ b/internal/orchestrator/network_apply_additional_test.go
@@ -0,0 +1,1791 @@
+package orchestrator
+
+import (
+ "archive/tar"
+ "bufio"
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "io/fs"
+ "os"
+ "path/filepath"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/tis24dev/proxsave/internal/logging"
+ "github.com/tis24dev/proxsave/internal/types"
+)
+
+type writeFailFS struct {
+ FS
+ failPath string
+ err error
+}
+
+func (f writeFailFS) WriteFile(path string, data []byte, perm fs.FileMode) error {
+ if filepath.Clean(path) == filepath.Clean(f.failPath) {
+ return f.err
+ }
+ return f.FS.WriteFile(path, data, perm)
+}
+
+func newDiscardLogger() *logging.Logger {
+ logger := logging.New(types.LogLevelDebug, false)
+ logger.SetOutput(io.Discard)
+ return logger
+}
+
+func TestNetworkApplyNotCommittedError_ErrorAndUnwrap(t *testing.T) {
+ var err *NetworkApplyNotCommittedError
+ if got := err.Error(); got != ErrNetworkApplyNotCommitted.Error() {
+ t.Fatalf("Error()=%q want %q", got, ErrNetworkApplyNotCommitted.Error())
+ }
+ if got := err.Unwrap(); !errors.Is(got, ErrNetworkApplyNotCommitted) {
+ t.Fatalf("Unwrap()=%v want %v", got, ErrNetworkApplyNotCommitted)
+ }
+
+ err = &NetworkApplyNotCommittedError{RollbackLog: "/tmp/log"}
+ if got := err.Error(); got != ErrNetworkApplyNotCommitted.Error() {
+ t.Fatalf("Error()=%q want %q", got, ErrNetworkApplyNotCommitted.Error())
+ }
+ if got := err.Unwrap(); got != ErrNetworkApplyNotCommitted {
+ t.Fatalf("Unwrap()=%v want %v", got, ErrNetworkApplyNotCommitted)
+ }
+}
+
+func TestNetworkRollbackHandleRemaining(t *testing.T) {
+ var h *networkRollbackHandle
+ if got := h.remaining(time.Now()); got != 0 {
+ t.Fatalf("nil remaining=%s want 0", got)
+ }
+
+ h = &networkRollbackHandle{
+ armedAt: time.Date(2026, 2, 1, 1, 2, 3, 0, time.UTC),
+ timeout: 10 * time.Second,
+ }
+ if got := h.remaining(h.armedAt); got != 10*time.Second {
+ t.Fatalf("remaining=%s want %s", got, 10*time.Second)
+ }
+ if got := h.remaining(h.armedAt.Add(4 * time.Second)); got != 6*time.Second {
+ t.Fatalf("remaining=%s want %s", got, 6*time.Second)
+ }
+ if got := h.remaining(h.armedAt.Add(20 * time.Second)); got != 0 {
+ t.Fatalf("remaining=%s want 0", got)
+ }
+}
+
+func TestShouldAttemptNetworkApply(t *testing.T) {
+ if shouldAttemptNetworkApply(nil) {
+ t.Fatalf("expected false for nil plan")
+ }
+ if shouldAttemptNetworkApply(&RestorePlan{NormalCategories: []Category{{ID: "storage_pve"}}}) {
+ t.Fatalf("expected false when network category not present")
+ }
+ if !shouldAttemptNetworkApply(&RestorePlan{NormalCategories: []Category{{ID: "network"}}}) {
+ t.Fatalf("expected true when network category present")
+ }
+}
+
+func TestExtractIPFromSnapshot_EmptyArgsReturnUnknown(t *testing.T) {
+ if got := extractIPFromSnapshot("", "vmbr0"); got != "unknown" {
+ t.Fatalf("got %q want unknown", got)
+ }
+ if got := extractIPFromSnapshot("/snap.txt", ""); got != "unknown" {
+ t.Fatalf("got %q want unknown", got)
+ }
+}
+
+func TestExtractIPFromSnapshot_IgnoresLinesOutsideAddrSection(t *testing.T) {
+ origFS := restoreFS
+ t.Cleanup(func() { restoreFS = origFS })
+
+ fakeFS := NewFakeFS()
+ t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) })
+ restoreFS = fakeFS
+
+ snapshot := strings.Join([]string{
+ "$ ip -br addr",
+ "lo UNKNOWN 127.0.0.1/8",
+ "$ ip route show",
+ "vmbr0 UP 192.0.2.10/24",
+ "",
+ }, "\n")
+ if err := fakeFS.WriteFile("/snap.txt", []byte(snapshot), 0o600); err != nil {
+ t.Fatalf("write snapshot: %v", err)
+ }
+ if got := extractIPFromSnapshot("/snap.txt", "vmbr0"); got != "unknown" {
+ t.Fatalf("got %q want unknown", got)
+ }
+}
+
+func TestExtractIPFromSnapshot_SkipsInvalidTokensButReturnsFirstValid(t *testing.T) {
+ origFS := restoreFS
+ t.Cleanup(func() { restoreFS = origFS })
+
+ fakeFS := NewFakeFS()
+ t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) })
+ restoreFS = fakeFS
+
+ snapshot := strings.Join([]string{
+ "$ ip -br addr",
+ "vmbr0 UP not-an-ip 2a01:db8::2/64",
+ "",
+ }, "\n")
+ if err := fakeFS.WriteFile("/snap.txt", []byte(snapshot), 0o600); err != nil {
+ t.Fatalf("write snapshot: %v", err)
+ }
+ if got := extractIPFromSnapshot("/snap.txt", "vmbr0"); got != "2a01:db8::2/64" {
+ t.Fatalf("got %q want %q", got, "2a01:db8::2/64")
+ }
+}
+
+func TestExtractIPFromSnapshot_SkipsErrorLinesAndParsesNext(t *testing.T) {
+ origFS := restoreFS
+ t.Cleanup(func() { restoreFS = origFS })
+
+ fakeFS := NewFakeFS()
+ t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) })
+ restoreFS = fakeFS
+
+ snapshot := strings.Join([]string{
+ "$ ip -br addr",
+ "ERROR: ip failed",
+ "vmbr0 UP 192.0.2.55/24",
+ "",
+ }, "\n")
+ if err := fakeFS.WriteFile("/snap.txt", []byte(snapshot), 0o600); err != nil {
+ t.Fatalf("write snapshot: %v", err)
+ }
+ if got := extractIPFromSnapshot("/snap.txt", "vmbr0"); got != "192.0.2.55/24" {
+ t.Fatalf("got %q want %q", got, "192.0.2.55/24")
+ }
+}
+
+func TestExtractIPFromSnapshot_ReturnsUnknownWhenNoValidAddressTokens(t *testing.T) {
+ origFS := restoreFS
+ t.Cleanup(func() { restoreFS = origFS })
+
+ fakeFS := NewFakeFS()
+ t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) })
+ restoreFS = fakeFS
+
+ snapshot := strings.Join([]string{
+ "$ ip -br addr",
+ "vmbr0 UP not-an-ip also-bad",
+ "",
+ }, "\n")
+ if err := fakeFS.WriteFile("/snap.txt", []byte(snapshot), 0o600); err != nil {
+ t.Fatalf("write snapshot: %v", err)
+ }
+ if got := extractIPFromSnapshot("/snap.txt", "vmbr0"); got != "unknown" {
+ t.Fatalf("got %q want unknown", got)
+ }
+}
+
+func TestExtractIPFromSnapshot_IgnoresCommandsBeforeAddrSection(t *testing.T) {
+ origFS := restoreFS
+ t.Cleanup(func() { restoreFS = origFS })
+
+ fakeFS := NewFakeFS()
+ t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) })
+ restoreFS = fakeFS
+
+ snapshot := strings.Join([]string{
+ "$ ip route show",
+ "default via 192.0.2.1 dev vmbr0",
+ "$ ip -br addr",
+ "vmbr0 UP 192.0.2.99/24",
+ "",
+ }, "\n")
+ if err := fakeFS.WriteFile("/snap.txt", []byte(snapshot), 0o600); err != nil {
+ t.Fatalf("write snapshot: %v", err)
+ }
+ if got := extractIPFromSnapshot("/snap.txt", "vmbr0"); got != "192.0.2.99/24" {
+ t.Fatalf("got %q want %q", got, "192.0.2.99/24")
+ }
+}
+
+func TestBuildNetworkApplyNotCommittedError_HandleNilIfaceEmpty(t *testing.T) {
+ origFS := restoreFS
+ origCmd := restoreCmd
+ t.Cleanup(func() {
+ restoreFS = origFS
+ restoreCmd = origCmd
+ })
+
+ restoreFS = NewFakeFS()
+ restoreCmd = &FakeCommandRunner{}
+
+ logger := newDiscardLogger()
+
+ got := buildNetworkApplyNotCommittedError(context.Background(), logger, "", nil)
+ if got.RollbackArmed {
+ t.Fatalf("RollbackArmed=true want false")
+ }
+ if got.RollbackLog != "" || got.RollbackMarker != "" {
+ t.Fatalf("expected empty rollback paths, got log=%q marker=%q", got.RollbackLog, got.RollbackMarker)
+ }
+ if got.RestoredIP != "unknown" || got.OriginalIP != "unknown" {
+ t.Fatalf("restored=%q original=%q want unknown/unknown", got.RestoredIP, got.OriginalIP)
+ }
+ if !got.RollbackDeadline.IsZero() {
+ t.Fatalf("RollbackDeadline=%s want zero", got.RollbackDeadline)
+ }
+}
+
+func TestBuildNetworkApplyNotCommittedError_ArmedWithMarkerAndSnapshots(t *testing.T) {
+ origFS := restoreFS
+ origCmd := restoreCmd
+ t.Cleanup(func() {
+ restoreFS = origFS
+ restoreCmd = origCmd
+ })
+
+ fakeFS := NewFakeFS()
+ t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) })
+ restoreFS = fakeFS
+
+ fakeCmd := &FakeCommandRunner{
+ Outputs: map[string][]byte{
+ "ip -o addr show dev vmbr0 scope global": []byte(strings.Join([]string{
+ "2: vmbr0 inet 192.0.2.10/24 brd 192.0.2.255 scope global vmbr0",
+ "2: vmbr0 inet6 2001:db8::1/64 scope global",
+ }, "\n")),
+ "ip route show default": []byte("default via 192.0.2.1 dev vmbr0\n"),
+ },
+ }
+ restoreCmd = fakeCmd
+
+ logger := newDiscardLogger()
+
+ handle := &networkRollbackHandle{
+ workDir: "/work",
+ markerPath: "/work/marker",
+ logPath: "/work/log",
+ armedAt: time.Date(2026, 2, 1, 12, 0, 0, 0, time.UTC),
+ timeout: 3 * time.Minute,
+ }
+
+ if err := fakeFS.WriteFile(handle.markerPath, []byte("pending\n"), 0o600); err != nil {
+ t.Fatalf("write marker: %v", err)
+ }
+ beforeSnapshot := strings.Join([]string{
+ "$ ip -br addr",
+ "vmbr0 UP 10.0.0.2/24",
+ "",
+ }, "\n")
+ if err := fakeFS.WriteFile("/work/before.txt", []byte(beforeSnapshot), 0o600); err != nil {
+ t.Fatalf("write before snapshot: %v", err)
+ }
+
+ got := buildNetworkApplyNotCommittedError(context.Background(), logger, "vmbr0", handle)
+ if !got.RollbackArmed {
+ t.Fatalf("RollbackArmed=false want true")
+ }
+ if got.RollbackLog != "/work/log" || got.RollbackMarker != "/work/marker" {
+ t.Fatalf("paths log=%q marker=%q want /work/log /work/marker", got.RollbackLog, got.RollbackMarker)
+ }
+ if got.RestoredIP != "192.0.2.10/24, 2001:db8::1/64" {
+ t.Fatalf("RestoredIP=%q", got.RestoredIP)
+ }
+ if got.OriginalIP != "10.0.0.2/24" {
+ t.Fatalf("OriginalIP=%q want %q", got.OriginalIP, "10.0.0.2/24")
+ }
+ wantDeadline := handle.armedAt.Add(handle.timeout)
+ if !got.RollbackDeadline.Equal(wantDeadline) {
+ t.Fatalf("RollbackDeadline=%s want %s", got.RollbackDeadline, wantDeadline)
+ }
+}
+
+func TestBuildNetworkApplyNotCommittedError_MarkerMissingAndIPQueryFails(t *testing.T) {
+ origFS := restoreFS
+ origCmd := restoreCmd
+ t.Cleanup(func() {
+ restoreFS = origFS
+ restoreCmd = origCmd
+ })
+
+ fakeFS := NewFakeFS()
+ t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) })
+ restoreFS = fakeFS
+
+ restoreCmd = &FakeCommandRunner{
+ Errors: map[string]error{
+ "ip -o addr show dev vmbr0 scope global": errors.New("boom"),
+ },
+ }
+
+ logger := newDiscardLogger()
+ handle := &networkRollbackHandle{
+ workDir: "/work",
+ markerPath: "/work/missing",
+ armedAt: time.Date(2026, 2, 1, 12, 0, 0, 0, time.UTC),
+ timeout: 1 * time.Minute,
+ }
+
+ got := buildNetworkApplyNotCommittedError(context.Background(), logger, "vmbr0", handle)
+ if got.RollbackArmed {
+ t.Fatalf("RollbackArmed=true want false when marker missing")
+ }
+ if got.RestoredIP != "unknown" {
+ t.Fatalf("RestoredIP=%q want unknown", got.RestoredIP)
+ }
+ if got.OriginalIP != "unknown" {
+ t.Fatalf("OriginalIP=%q want unknown", got.OriginalIP)
+ }
+}
+
+func TestRollbackAlreadyRunning(t *testing.T) {
+ origCmd := restoreCmd
+ t.Cleanup(func() { restoreCmd = origCmd })
+
+ logger := newDiscardLogger()
+ pathDir := t.TempDir()
+
+ t.Run("skip when handle nil", func(t *testing.T) {
+ t.Setenv("PATH", pathDir)
+ restoreCmd = &FakeCommandRunner{}
+ if rollbackAlreadyRunning(context.Background(), logger, nil) {
+ t.Fatalf("expected false for nil handle")
+ }
+ })
+
+ t.Run("skip when unit name empty", func(t *testing.T) {
+ t.Setenv("PATH", pathDir)
+ restoreCmd = &FakeCommandRunner{}
+ if rollbackAlreadyRunning(context.Background(), logger, &networkRollbackHandle{unitName: ""}) {
+ t.Fatalf("expected false when unitName empty")
+ }
+ })
+
+ t.Run("skip when systemctl missing", func(t *testing.T) {
+ t.Setenv("PATH", pathDir)
+ restoreCmd = &FakeCommandRunner{}
+ if rollbackAlreadyRunning(context.Background(), logger, &networkRollbackHandle{unitName: "x"}) {
+ t.Fatalf("expected false when systemctl not available")
+ }
+ })
+
+ t.Run("running when active or activating", func(t *testing.T) {
+ writeExecutable(t, pathDir, "systemctl")
+ t.Setenv("PATH", pathDir)
+
+ for _, tc := range []struct {
+ state string
+ want bool
+ }{
+ {state: "active\n", want: true},
+ {state: "activating\n", want: true},
+ {state: "inactive\n", want: false},
+ } {
+ restoreCmd = &FakeCommandRunner{
+ Outputs: map[string][]byte{
+ "systemctl is-active unit.service": []byte(tc.state),
+ },
+ }
+ if got := rollbackAlreadyRunning(context.Background(), logger, &networkRollbackHandle{unitName: "unit"}); got != tc.want {
+ t.Fatalf("state=%q got=%v want=%v", tc.state, got, tc.want)
+ }
+ }
+ })
+
+ t.Run("not running when systemctl errors", func(t *testing.T) {
+ writeExecutable(t, pathDir, "systemctl")
+ t.Setenv("PATH", pathDir)
+ restoreCmd = &FakeCommandRunner{
+ Errors: map[string]error{
+ "systemctl is-active unit.service": errors.New("boom"),
+ },
+ }
+ if rollbackAlreadyRunning(context.Background(), logger, &networkRollbackHandle{unitName: "unit"}) {
+ t.Fatalf("expected false when systemctl is-active errors")
+ }
+ })
+}
+
+func TestArmNetworkRollback_ValidationErrors(t *testing.T) {
+ if _, err := armNetworkRollback(context.Background(), newDiscardLogger(), "", 10*time.Second, ""); err == nil {
+ t.Fatalf("expected error for empty backup path")
+ }
+ if _, err := armNetworkRollback(context.Background(), newDiscardLogger(), "/backup.tar", 0, ""); err == nil {
+ t.Fatalf("expected error for invalid timeout")
+ }
+}
+
+func TestArmNetworkRollback_CreateRollbackDirFailureReturnsError(t *testing.T) {
+ origFS := restoreFS
+ origTime := restoreTime
+ t.Cleanup(func() {
+ restoreFS = origFS
+ restoreTime = origTime
+ })
+
+ fakeFS := NewFakeFS()
+ t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) })
+ fakeFS.MkdirAllErr = errors.New("boom")
+ restoreFS = fakeFS
+ restoreTime = &FakeTime{Current: time.Date(2026, 2, 1, 12, 34, 56, 0, time.UTC)}
+
+ if _, err := armNetworkRollback(context.Background(), newDiscardLogger(), "/backup.tar", 30*time.Second, "/secure"); err == nil || !strings.Contains(err.Error(), "create rollback directory") {
+ t.Fatalf("err=%v want create rollback directory error", err)
+ }
+}
+
+func TestArmNetworkRollback_MarkerWriteFailureReturnsError(t *testing.T) {
+ origFS := restoreFS
+ origTime := restoreTime
+ t.Cleanup(func() {
+ restoreFS = origFS
+ restoreTime = origTime
+ })
+
+ fakeFS := NewFakeFS()
+ t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) })
+ fakeFS.WriteErr = errors.New("boom")
+ restoreFS = fakeFS
+ restoreTime = &FakeTime{Current: time.Date(2026, 2, 1, 12, 34, 56, 0, time.UTC)}
+
+ if _, err := armNetworkRollback(context.Background(), newDiscardLogger(), "/backup.tar", 30*time.Second, "/secure"); err == nil || !strings.Contains(err.Error(), "write rollback marker") {
+ t.Fatalf("err=%v want write rollback marker error", err)
+ }
+}
+
+func TestArmNetworkRollback_ScriptWriteFailureReturnsError(t *testing.T) {
+ origFS := restoreFS
+ origTime := restoreTime
+ t.Cleanup(func() {
+ restoreFS = origFS
+ restoreTime = origTime
+ })
+
+ fakeFS := NewFakeFS()
+ t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) })
+ restoreTime = &FakeTime{Current: time.Date(2026, 2, 1, 12, 34, 56, 0, time.UTC)}
+
+ scriptPath := "/secure/network_rollback_20260201_123456.sh"
+ restoreFS = writeFailFS{FS: fakeFS, failPath: scriptPath, err: errors.New("boom")}
+
+ if _, err := armNetworkRollback(context.Background(), newDiscardLogger(), "/backup.tar", 30*time.Second, "/secure"); err == nil || !strings.Contains(err.Error(), "write rollback script") {
+ t.Fatalf("err=%v want write rollback script error", err)
+ }
+}
+
+func TestArmNetworkRollback_UsesSystemdRunWhenAvailable(t *testing.T) {
+ origFS := restoreFS
+ origCmd := restoreCmd
+ origTime := restoreTime
+ t.Cleanup(func() {
+ restoreFS = origFS
+ restoreCmd = origCmd
+ restoreTime = origTime
+ })
+
+ fakeFS := NewFakeFS()
+ t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) })
+ restoreFS = fakeFS
+ restoreTime = &FakeTime{Current: time.Date(2026, 2, 1, 12, 34, 56, 0, time.UTC)}
+
+ pathDir := t.TempDir()
+ writeExecutable(t, pathDir, "systemd-run")
+ t.Setenv("PATH", pathDir)
+
+ fakeCmd := &FakeCommandRunner{
+ Outputs: map[string][]byte{},
+ }
+ restoreCmd = fakeCmd
+
+ logger := newDiscardLogger()
+ handle, err := armNetworkRollback(context.Background(), logger, "/backup.tar", 30*time.Second, "")
+ if err != nil {
+ t.Fatalf("armNetworkRollback error: %v", err)
+ }
+ if handle == nil || handle.unitName == "" {
+ t.Fatalf("expected handle with unitName set")
+ }
+ if !strings.Contains(handle.unitName, "20260201_123456") {
+ t.Fatalf("unitName=%q want timestamp", handle.unitName)
+ }
+
+ data, err := fakeFS.ReadFile(handle.scriptPath)
+ if err != nil {
+ t.Fatalf("read script: %v", err)
+ }
+ if !strings.Contains(string(data), "Restart networking after rollback") {
+ t.Fatalf("expected restartNetworking block in script")
+ }
+
+ foundSystemdRun := false
+ for _, call := range fakeCmd.CallsList() {
+ if strings.HasPrefix(call, "systemd-run --unit=") {
+ foundSystemdRun = true
+ }
+ if strings.HasPrefix(call, "sh -c ") {
+ t.Fatalf("unexpected fallback sh -c call: %s", call)
+ }
+ }
+ if !foundSystemdRun {
+ t.Fatalf("expected systemd-run to be called; calls=%v", fakeCmd.CallsList())
+ }
+}
+
+func TestArmNetworkRollback_CustomWorkDirAndSystemdRunOutput(t *testing.T) {
+ origFS := restoreFS
+ origCmd := restoreCmd
+ origTime := restoreTime
+ t.Cleanup(func() {
+ restoreFS = origFS
+ restoreCmd = origCmd
+ restoreTime = origTime
+ })
+
+ fakeFS := NewFakeFS()
+ t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) })
+ restoreFS = fakeFS
+ restoreTime = &FakeTime{Current: time.Date(2026, 2, 1, 12, 34, 56, 0, time.UTC)}
+
+ pathDir := t.TempDir()
+ writeExecutable(t, pathDir, "systemd-run")
+ t.Setenv("PATH", pathDir)
+
+ expectedSystemdRun := "systemd-run --unit=proxsave-network-rollback-20260201_123456 --on-active=2s /bin/sh /secure/network_rollback_20260201_123456.sh"
+ restoreCmd = &FakeCommandRunner{
+ Outputs: map[string][]byte{
+ expectedSystemdRun: []byte("unit started\n"),
+ },
+ }
+
+ handle, err := armNetworkRollback(context.Background(), newDiscardLogger(), "/backup.tar", 2*time.Second, "/secure")
+ if err != nil {
+ t.Fatalf("armNetworkRollback error: %v", err)
+ }
+ if handle == nil || handle.workDir != "/secure" {
+ t.Fatalf("handle=%#v want workDir=/secure", handle)
+ }
+}
+
+func TestArmNetworkRollback_SystemdRunFailureFallsBackToNohup(t *testing.T) {
+ origFS := restoreFS
+ origCmd := restoreCmd
+ origTime := restoreTime
+ t.Cleanup(func() {
+ restoreFS = origFS
+ restoreCmd = origCmd
+ restoreTime = origTime
+ })
+
+ fakeFS := NewFakeFS()
+ t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) })
+ restoreFS = fakeFS
+ restoreTime = &FakeTime{Current: time.Date(2026, 2, 1, 12, 34, 56, 0, time.UTC)}
+
+ pathDir := t.TempDir()
+ writeExecutable(t, pathDir, "systemd-run")
+ t.Setenv("PATH", pathDir)
+
+ fakeCmd := &FakeCommandRunner{
+ Errors: map[string]error{},
+ }
+ restoreCmd = fakeCmd
+
+ logger := newDiscardLogger()
+ expectedSystemdRun := "systemd-run --unit=proxsave-network-rollback-20260201_123456 --on-active=30s /bin/sh /tmp/proxsave/network_rollback_20260201_123456.sh"
+ fakeCmd.Errors[expectedSystemdRun] = errors.New("boom")
+
+ handle, err := armNetworkRollback(context.Background(), logger, "/backup.tar", 30*time.Second, "")
+ if err != nil {
+ t.Fatalf("armNetworkRollback error: %v", err)
+ }
+ if handle == nil {
+ t.Fatalf("expected handle")
+ }
+ if handle.unitName != "" {
+ t.Fatalf("unitName=%q want empty after systemd-run failure", handle.unitName)
+ }
+
+ foundSystemdRun := false
+ foundFallback := false
+ for _, call := range fakeCmd.CallsList() {
+ if strings.HasPrefix(call, "systemd-run ") {
+ foundSystemdRun = true
+ }
+ if strings.HasPrefix(call, "sh -c nohup sh -c 'sleep ") {
+ foundFallback = true
+ }
+ }
+ if !foundSystemdRun || !foundFallback {
+ t.Fatalf("expected both systemd-run and fallback; calls=%v", fakeCmd.CallsList())
+ }
+}
+
+func TestArmNetworkRollback_WithoutSystemdRunUsesNohup(t *testing.T) {
+ origFS := restoreFS
+ origCmd := restoreCmd
+ origTime := restoreTime
+ t.Cleanup(func() {
+ restoreFS = origFS
+ restoreCmd = origCmd
+ restoreTime = origTime
+ })
+
+ fakeFS := NewFakeFS()
+ t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) })
+ restoreFS = fakeFS
+ restoreTime = &FakeTime{Current: time.Date(2026, 2, 1, 12, 34, 56, 0, time.UTC)}
+
+ pathDir := t.TempDir()
+ t.Setenv("PATH", pathDir)
+
+ fakeCmd := &FakeCommandRunner{}
+ restoreCmd = fakeCmd
+
+ logger := newDiscardLogger()
+ handle, err := armNetworkRollback(context.Background(), logger, "/backup.tar", 1*time.Second, "")
+ if err != nil {
+ t.Fatalf("armNetworkRollback error: %v", err)
+ }
+ if handle == nil {
+ t.Fatalf("expected handle")
+ }
+ if handle.unitName != "" {
+ t.Fatalf("unitName=%q want empty in nohup mode", handle.unitName)
+ }
+
+ foundFallback := false
+ for _, call := range fakeCmd.CallsList() {
+ if strings.HasPrefix(call, "sh -c nohup sh -c 'sleep ") {
+ foundFallback = true
+ }
+ }
+ if !foundFallback {
+ t.Fatalf("expected fallback sh -c call; calls=%v", fakeCmd.CallsList())
+ }
+}
+
+func TestArmNetworkRollback_SubSecondTimeoutArmsAtLeastOneSecond(t *testing.T) {
+ origFS := restoreFS
+ origCmd := restoreCmd
+ origTime := restoreTime
+ t.Cleanup(func() {
+ restoreFS = origFS
+ restoreCmd = origCmd
+ restoreTime = origTime
+ })
+
+ fakeFS := NewFakeFS()
+ t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) })
+ restoreFS = fakeFS
+ restoreTime = &FakeTime{Current: time.Date(2026, 2, 1, 12, 34, 56, 0, time.UTC)}
+
+ pathDir := t.TempDir()
+ t.Setenv("PATH", pathDir)
+
+ fakeCmd := &FakeCommandRunner{}
+ restoreCmd = fakeCmd
+
+ handle, err := armNetworkRollback(context.Background(), newDiscardLogger(), "/backup.tar", 500*time.Millisecond, "")
+ if err != nil {
+ t.Fatalf("armNetworkRollback error: %v", err)
+ }
+ if handle == nil {
+ t.Fatalf("expected handle")
+ }
+
+ foundSleep1 := false
+ for _, call := range fakeCmd.CallsList() {
+ if strings.Contains(call, "sleep 1;") {
+ foundSleep1 = true
+ }
+ }
+ if !foundSleep1 {
+ t.Fatalf("expected sleep 1 in nohup command; calls=%v", fakeCmd.CallsList())
+ }
+}
+
+func TestArmNetworkRollback_FallbackCommandFailureReturnsError(t *testing.T) {
+ origFS := restoreFS
+ origCmd := restoreCmd
+ origTime := restoreTime
+ t.Cleanup(func() {
+ restoreFS = origFS
+ restoreCmd = origCmd
+ restoreTime = origTime
+ })
+
+ fakeFS := NewFakeFS()
+ t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) })
+ restoreFS = fakeFS
+ restoreTime = &FakeTime{Current: time.Date(2026, 2, 1, 12, 34, 56, 0, time.UTC)}
+
+ pathDir := t.TempDir()
+ t.Setenv("PATH", pathDir)
+
+ restoreCmd = &FakeCommandRunner{
+ Errors: map[string]error{
+ "sh -c nohup sh -c 'sleep 1; /bin/sh /tmp/proxsave/network_rollback_20260201_123456.sh' >/dev/null 2>&1 &": errors.New("boom"),
+ },
+ }
+
+ _, err := armNetworkRollback(context.Background(), newDiscardLogger(), "/backup.tar", 1*time.Second, "")
+ if err == nil || !strings.Contains(err.Error(), "failed to arm rollback timer") {
+ t.Fatalf("err=%v want failed to arm rollback timer", err)
+ }
+}
+
+func TestDisarmNetworkRollback_RemovesMarkerAndStopsTimer(t *testing.T) {
+ origFS := restoreFS
+ origCmd := restoreCmd
+ t.Cleanup(func() {
+ restoreFS = origFS
+ restoreCmd = origCmd
+ })
+
+ fakeFS := NewFakeFS()
+ t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) })
+ restoreFS = fakeFS
+
+ pathDir := t.TempDir()
+ writeExecutable(t, pathDir, "systemctl")
+ t.Setenv("PATH", pathDir)
+
+ fakeCmd := &FakeCommandRunner{}
+ restoreCmd = fakeCmd
+
+ handle := &networkRollbackHandle{
+ markerPath: "/tmp/marker",
+ unitName: "proxsave-network-rollback-test",
+ }
+ if err := fakeFS.WriteFile(handle.markerPath, []byte("pending\n"), 0o600); err != nil {
+ t.Fatalf("write marker: %v", err)
+ }
+
+ disarmNetworkRollback(context.Background(), newDiscardLogger(), handle)
+ if _, err := fakeFS.Stat(handle.markerPath); !errors.Is(err, os.ErrNotExist) {
+ t.Fatalf("expected marker to be removed, stat err=%v", err)
+ }
+
+ calls := strings.Join(fakeCmd.CallsList(), "\n")
+ if !strings.Contains(calls, "systemctl stop "+handle.unitName+".timer") {
+ t.Fatalf("expected systemctl stop call; calls=%v", fakeCmd.CallsList())
+ }
+ if !strings.Contains(calls, "systemctl reset-failed "+handle.unitName+".service "+handle.unitName+".timer") {
+ t.Fatalf("expected systemctl reset-failed call; calls=%v", fakeCmd.CallsList())
+ }
+}
+
+func TestDisarmNetworkRollback_NilHandleNoop(t *testing.T) {
+ disarmNetworkRollback(context.Background(), newDiscardLogger(), nil)
+}
+
+func TestDisarmNetworkRollback_MarkerRemoveFailureAndStopError(t *testing.T) {
+ origFS := restoreFS
+ origCmd := restoreCmd
+ t.Cleanup(func() {
+ restoreFS = origFS
+ restoreCmd = origCmd
+ })
+
+ fakeFS := NewFakeFS()
+ t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) })
+
+ pathDir := t.TempDir()
+ writeExecutable(t, pathDir, "systemctl")
+ t.Setenv("PATH", pathDir)
+
+ fakeCmd := &FakeCommandRunner{
+ Errors: map[string]error{
+ "systemctl stop unit.timer": errors.New("boom"),
+ },
+ }
+ restoreCmd = fakeCmd
+
+ handle := &networkRollbackHandle{
+ markerPath: "/tmp/marker",
+ unitName: "unit",
+ }
+ if err := fakeFS.WriteFile(handle.markerPath, []byte("pending\n"), 0o600); err != nil {
+ t.Fatalf("write marker: %v", err)
+ }
+
+ restoreFS = removeFailFS{FS: fakeFS, failPath: handle.markerPath, err: errors.New("remove boom")}
+ disarmNetworkRollback(context.Background(), newDiscardLogger(), handle)
+
+ calls := strings.Join(fakeCmd.CallsList(), "\n")
+ if !strings.Contains(calls, "systemctl stop unit.timer") {
+ t.Fatalf("expected systemctl stop call; calls=%v", fakeCmd.CallsList())
+ }
+ if !strings.Contains(calls, "systemctl reset-failed unit.service unit.timer") {
+ t.Fatalf("expected systemctl reset-failed call; calls=%v", fakeCmd.CallsList())
+ }
+}
+
+func TestMaybeRepairNICNamesCLI_SkippedWhenArchiveMissing(t *testing.T) {
+ origTime := restoreTime
+ t.Cleanup(func() { restoreTime = origTime })
+ restoreTime = &FakeTime{Current: time.Date(2026, 2, 2, 3, 4, 5, 0, time.UTC)}
+
+ reader := bufio.NewReader(strings.NewReader(""))
+ result := maybeRepairNICNamesCLI(context.Background(), reader, newDiscardLogger(), "")
+ if result == nil {
+ t.Fatalf("expected result")
+ }
+ if !strings.Contains(result.SkippedReason, "backup archive not available") {
+ t.Fatalf("SkippedReason=%q", result.SkippedReason)
+ }
+ if !result.AppliedAt.Equal(nowRestore()) {
+ t.Fatalf("AppliedAt=%s want %s", result.AppliedAt, nowRestore())
+ }
+}
+
+func TestMaybeRepairNICNamesCLI_ReturnsNilOnPlanError(t *testing.T) {
+ origFS := restoreFS
+ t.Cleanup(func() { restoreFS = origFS })
+
+ fakeFS := NewFakeFS()
+ t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) })
+ restoreFS = fakeFS
+
+ if err := fakeFS.WriteFile("/backup.zip", []byte("not a tar"), 0o600); err != nil {
+ t.Fatalf("write archive: %v", err)
+ }
+
+ reader := bufio.NewReader(strings.NewReader(""))
+ if got := maybeRepairNICNamesCLI(context.Background(), reader, newDiscardLogger(), "/backup.zip"); got != nil {
+ t.Fatalf("expected nil on plan error, got %#v", got)
+ }
+}
+
+func TestMaybeRepairNICNamesCLI_AppliesMappingWithoutConflicts(t *testing.T) {
+ origFS := restoreFS
+ origTime := restoreTime
+ origSysNet := sysClassNetPath
+ t.Cleanup(func() {
+ restoreFS = origFS
+ restoreTime = origTime
+ sysClassNetPath = origSysNet
+ })
+
+ fakeFS := NewFakeFS()
+ t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) })
+ restoreFS = fakeFS
+ restoreTime = &FakeTime{Current: time.Date(2026, 2, 2, 3, 4, 5, 0, time.UTC)}
+
+ pathDir := t.TempDir()
+ t.Setenv("PATH", pathDir)
+
+ sysDir := t.TempDir()
+ sysClassNetPath = sysDir
+ if err := os.MkdirAll(filepath.Join(sysDir, "enp3s0"), 0o755); err != nil {
+ t.Fatalf("mkdir enp3s0: %v", err)
+ }
+ if err := os.WriteFile(filepath.Join(sysDir, "enp3s0", "address"), []byte("aa:bb:cc:dd:ee:ff\n"), 0o644); err != nil {
+ t.Fatalf("write address: %v", err)
+ }
+
+ invJSON := []byte(`{"interfaces":[{"name":"eno1","mac":"aa:bb:cc:dd:ee:ff"}]}`)
+ writeTarToFakeFS(t, fakeFS, "/backup.tar", []tarEntry{
+ {
+ Name: "./var/lib/proxsave-info/commands/system/network_inventory.json",
+ Typeflag: tar.TypeReg,
+ Mode: 0o644,
+ Data: invJSON,
+ },
+ })
+
+ if err := fakeFS.WriteFile("/etc/network/interfaces", []byte("auto eno1\niface eno1 inet manual\n"), 0o644); err != nil {
+ t.Fatalf("write interfaces: %v", err)
+ }
+
+ reader := bufio.NewReader(strings.NewReader(""))
+ result := maybeRepairNICNamesCLI(context.Background(), reader, newDiscardLogger(), "/backup.tar")
+ if result == nil || !result.Applied() {
+ t.Fatalf("expected applied result, got %#v", result)
+ }
+
+ updated, err := fakeFS.ReadFile("/etc/network/interfaces")
+ if err != nil {
+ t.Fatalf("read updated interfaces: %v", err)
+ }
+ if string(updated) != "auto enp3s0\niface enp3s0 inet manual\n" {
+ t.Fatalf("updated=%q", string(updated))
+ }
+}
+
+func TestMaybeRepairNICNamesCLI_SkipsWhenNamingOverridesDetectedAndUserConfirms(t *testing.T) {
+ origFS := restoreFS
+ origTime := restoreTime
+ origSysNet := sysClassNetPath
+ t.Cleanup(func() {
+ restoreFS = origFS
+ restoreTime = origTime
+ sysClassNetPath = origSysNet
+ })
+
+ fakeFS := NewFakeFS()
+ t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) })
+ restoreFS = fakeFS
+ restoreTime = &FakeTime{Current: time.Date(2026, 2, 2, 3, 4, 5, 0, time.UTC)}
+
+ pathDir := t.TempDir()
+ t.Setenv("PATH", pathDir)
+
+ sysDir := t.TempDir()
+ sysClassNetPath = sysDir
+ if err := os.MkdirAll(filepath.Join(sysDir, "enp3s0"), 0o755); err != nil {
+ t.Fatalf("mkdir enp3s0: %v", err)
+ }
+ if err := os.WriteFile(filepath.Join(sysDir, "enp3s0", "address"), []byte("aa:bb:cc:dd:ee:ff\n"), 0o644); err != nil {
+ t.Fatalf("write address: %v", err)
+ }
+
+ invJSON := []byte(`{"interfaces":[{"name":"eno1","mac":"aa:bb:cc:dd:ee:ff"}]}`)
+ writeTarToFakeFS(t, fakeFS, "/backup.tar", []tarEntry{
+ {
+ Name: "./var/lib/proxsave-info/commands/system/network_inventory.json",
+ Typeflag: tar.TypeReg,
+ Mode: 0o644,
+ Data: invJSON,
+ },
+ })
+
+ if err := fakeFS.MkdirAll("/etc/udev/rules.d", 0o755); err != nil {
+ t.Fatalf("mkdir udev: %v", err)
+ }
+ rule := `SUBSYSTEM=="net", ATTR{address}=="aa:bb:cc:dd:ee:ff", NAME="eth0"`
+ if err := fakeFS.WriteFile("/etc/udev/rules.d/70-persistent-net.rules", []byte(rule+"\n"), 0o644); err != nil {
+ t.Fatalf("write rule: %v", err)
+ }
+
+ reader := bufio.NewReader(strings.NewReader("y\n"))
+ result := maybeRepairNICNamesCLI(context.Background(), reader, newDiscardLogger(), "/backup.tar")
+ if result == nil {
+ t.Fatalf("expected result")
+ }
+ if !strings.Contains(result.SkippedReason, "persistent NIC naming rules") {
+ t.Fatalf("SkippedReason=%q", result.SkippedReason)
+ }
+}
+
+func TestMaybeRepairNICNamesCLI_OverridesDetectedUserChoosesProceed(t *testing.T) {
+ origFS := restoreFS
+ origTime := restoreTime
+ origSysNet := sysClassNetPath
+ t.Cleanup(func() {
+ restoreFS = origFS
+ restoreTime = origTime
+ sysClassNetPath = origSysNet
+ })
+
+ fakeFS := NewFakeFS()
+ t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) })
+ restoreFS = fakeFS
+ restoreTime = &FakeTime{Current: time.Date(2026, 2, 2, 3, 4, 5, 0, time.UTC)}
+
+ pathDir := t.TempDir()
+ t.Setenv("PATH", pathDir)
+
+ sysDir := t.TempDir()
+ sysClassNetPath = sysDir
+ if err := os.MkdirAll(filepath.Join(sysDir, "enp3s0"), 0o755); err != nil {
+ t.Fatalf("mkdir enp3s0: %v", err)
+ }
+ if err := os.WriteFile(filepath.Join(sysDir, "enp3s0", "address"), []byte("aa:bb:cc:dd:ee:ff\n"), 0o644); err != nil {
+ t.Fatalf("write address: %v", err)
+ }
+
+ invJSON := []byte(`{"interfaces":[{"name":"eno1","mac":"aa:bb:cc:dd:ee:ff"}]}`)
+ writeTarToFakeFS(t, fakeFS, "/backup.tar", []tarEntry{
+ {
+ Name: "./var/lib/proxsave-info/commands/system/network_inventory.json",
+ Typeflag: tar.TypeReg,
+ Mode: 0o644,
+ Data: invJSON,
+ },
+ })
+
+ if err := fakeFS.MkdirAll("/etc/udev/rules.d", 0o755); err != nil {
+ t.Fatalf("mkdir udev: %v", err)
+ }
+ rule := `SUBSYSTEM=="net", ATTR{address}=="aa:bb:cc:dd:ee:ff", NAME="eth0"`
+ if err := fakeFS.WriteFile("/etc/udev/rules.d/70-persistent-net.rules", []byte(rule+"\n"), 0o644); err != nil {
+ t.Fatalf("write rule: %v", err)
+ }
+
+ if err := fakeFS.WriteFile("/etc/network/interfaces", []byte("auto eno1\niface eno1 inet manual\n"), 0o644); err != nil {
+ t.Fatalf("write interfaces: %v", err)
+ }
+
+ reader := bufio.NewReader(strings.NewReader("n\n"))
+ result := maybeRepairNICNamesCLI(context.Background(), reader, newDiscardLogger(), "/backup.tar")
+ if result == nil || !result.Applied() {
+ t.Fatalf("expected applied result, got %#v", result)
+ }
+}
+
+func TestMaybeRepairNICNamesCLI_OverridesDetectionErrorStillApplies(t *testing.T) {
+ origFS := restoreFS
+ origTime := restoreTime
+ origSysNet := sysClassNetPath
+ t.Cleanup(func() {
+ restoreFS = origFS
+ restoreTime = origTime
+ sysClassNetPath = origSysNet
+ })
+
+ fakeFS := NewFakeFS()
+ t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) })
+ restoreFS = fakeFS
+ restoreTime = &FakeTime{Current: time.Date(2026, 2, 2, 3, 4, 5, 0, time.UTC)}
+
+ pathDir := t.TempDir()
+ t.Setenv("PATH", pathDir)
+
+ sysDir := t.TempDir()
+ sysClassNetPath = sysDir
+ if err := os.MkdirAll(filepath.Join(sysDir, "enp3s0"), 0o755); err != nil {
+ t.Fatalf("mkdir enp3s0: %v", err)
+ }
+ if err := os.WriteFile(filepath.Join(sysDir, "enp3s0", "address"), []byte("aa:bb:cc:dd:ee:ff\n"), 0o644); err != nil {
+ t.Fatalf("write address: %v", err)
+ }
+
+ invJSON := []byte(`{"interfaces":[{"name":"eno1","mac":"aa:bb:cc:dd:ee:ff"}]}`)
+ writeTarToFakeFS(t, fakeFS, "/backup.tar", []tarEntry{
+ {
+ Name: "./var/lib/proxsave-info/commands/system/network_inventory.json",
+ Typeflag: tar.TypeReg,
+ Mode: 0o644,
+ Data: invJSON,
+ },
+ })
+
+ if err := fakeFS.WriteFile("/etc/network/interfaces", []byte("auto eno1\niface eno1 inet manual\n"), 0o644); err != nil {
+ t.Fatalf("write interfaces: %v", err)
+ }
+
+ restoreFS = readDirFailFS{FS: fakeFS, failPath: "/etc/udev/rules.d", err: errors.New("boom")}
+
+ reader := bufio.NewReader(strings.NewReader(""))
+ result := maybeRepairNICNamesCLI(context.Background(), reader, newDiscardLogger(), "/backup.tar")
+ if result == nil || !result.Applied() {
+ t.Fatalf("expected applied result, got %#v", result)
+ }
+}
+
+func TestMaybeRepairNICNamesCLI_ConflictsPromptAndSkip(t *testing.T) {
+ origFS := restoreFS
+ origTime := restoreTime
+ origSysNet := sysClassNetPath
+ t.Cleanup(func() {
+ restoreFS = origFS
+ restoreTime = origTime
+ sysClassNetPath = origSysNet
+ })
+
+ fakeFS := NewFakeFS()
+ t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) })
+ restoreFS = fakeFS
+ restoreTime = &FakeTime{Current: time.Date(2026, 2, 2, 3, 4, 5, 0, time.UTC)}
+
+ pathDir := t.TempDir()
+ t.Setenv("PATH", pathDir)
+
+ sysDir := t.TempDir()
+ sysClassNetPath = sysDir
+ if err := os.MkdirAll(filepath.Join(sysDir, "enp3s0"), 0o755); err != nil {
+ t.Fatalf("mkdir enp3s0: %v", err)
+ }
+ if err := os.MkdirAll(filepath.Join(sysDir, "eno1"), 0o755); err != nil {
+ t.Fatalf("mkdir eno1: %v", err)
+ }
+ if err := os.WriteFile(filepath.Join(sysDir, "enp3s0", "address"), []byte("aa:bb:cc:dd:ee:ff\n"), 0o644); err != nil {
+ t.Fatalf("write enp3s0 address: %v", err)
+ }
+ if err := os.WriteFile(filepath.Join(sysDir, "eno1", "address"), []byte("11:22:33:44:55:66\n"), 0o644); err != nil {
+ t.Fatalf("write eno1 address: %v", err)
+ }
+
+ invJSON := []byte(`{"interfaces":[{"name":"eno1","mac":"aa:bb:cc:dd:ee:ff"}]}`)
+ writeTarToFakeFS(t, fakeFS, "/backup.tar", []tarEntry{
+ {
+ Name: "./var/lib/proxsave-info/commands/system/network_inventory.json",
+ Typeflag: tar.TypeReg,
+ Mode: 0o644,
+ Data: invJSON,
+ },
+ })
+
+ reader := bufio.NewReader(strings.NewReader("n\n"))
+ result := maybeRepairNICNamesCLI(context.Background(), reader, newDiscardLogger(), "/backup.tar")
+ if result == nil {
+ t.Fatalf("expected result")
+ }
+ if result.Applied() {
+ t.Fatalf("expected no changes when conflicts skipped, got %#v", result)
+ }
+ if !strings.Contains(result.SkippedReason, "conflicting NIC mappings") {
+ t.Fatalf("SkippedReason=%q", result.SkippedReason)
+ }
+}
+
+func TestMaybeRepairNICNamesCLI_ConflictsPromptAndApply(t *testing.T) {
+ origFS := restoreFS
+ origTime := restoreTime
+ origSysNet := sysClassNetPath
+ t.Cleanup(func() {
+ restoreFS = origFS
+ restoreTime = origTime
+ sysClassNetPath = origSysNet
+ })
+
+ fakeFS := NewFakeFS()
+ t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) })
+ restoreFS = fakeFS
+ restoreTime = &FakeTime{Current: time.Date(2026, 2, 2, 3, 4, 5, 0, time.UTC)}
+
+ pathDir := t.TempDir()
+ t.Setenv("PATH", pathDir)
+
+ sysDir := t.TempDir()
+ sysClassNetPath = sysDir
+ if err := os.MkdirAll(filepath.Join(sysDir, "enp3s0"), 0o755); err != nil {
+ t.Fatalf("mkdir enp3s0: %v", err)
+ }
+ if err := os.MkdirAll(filepath.Join(sysDir, "eno1"), 0o755); err != nil {
+ t.Fatalf("mkdir eno1: %v", err)
+ }
+ if err := os.WriteFile(filepath.Join(sysDir, "enp3s0", "address"), []byte("aa:bb:cc:dd:ee:ff\n"), 0o644); err != nil {
+ t.Fatalf("write enp3s0 address: %v", err)
+ }
+ if err := os.WriteFile(filepath.Join(sysDir, "eno1", "address"), []byte("11:22:33:44:55:66\n"), 0o644); err != nil {
+ t.Fatalf("write eno1 address: %v", err)
+ }
+
+ invJSON := []byte(`{"interfaces":[{"name":"eno1","mac":"aa:bb:cc:dd:ee:ff"}]}`)
+ writeTarToFakeFS(t, fakeFS, "/backup.tar", []tarEntry{
+ {
+ Name: "./var/lib/proxsave-info/commands/system/network_inventory.json",
+ Typeflag: tar.TypeReg,
+ Mode: 0o644,
+ Data: invJSON,
+ },
+ })
+
+ if err := fakeFS.WriteFile("/etc/network/interfaces", []byte("auto eno1\niface eno1 inet manual\n"), 0o644); err != nil {
+ t.Fatalf("write interfaces: %v", err)
+ }
+
+ reader := bufio.NewReader(strings.NewReader("y\n"))
+ result := maybeRepairNICNamesCLI(context.Background(), reader, newDiscardLogger(), "/backup.tar")
+ if result == nil || !result.Applied() {
+ t.Fatalf("expected applied result, got %#v", result)
+ }
+}
+
+func TestMaybeRepairNICNamesCLI_OverridesPromptErrorContinues(t *testing.T) {
+ origFS := restoreFS
+ origTime := restoreTime
+ origSysNet := sysClassNetPath
+ t.Cleanup(func() {
+ restoreFS = origFS
+ restoreTime = origTime
+ sysClassNetPath = origSysNet
+ })
+
+ fakeFS := NewFakeFS()
+ t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) })
+ restoreFS = fakeFS
+ restoreTime = &FakeTime{Current: time.Date(2026, 2, 2, 3, 4, 5, 0, time.UTC)}
+
+ pathDir := t.TempDir()
+ t.Setenv("PATH", pathDir)
+
+ sysDir := t.TempDir()
+ sysClassNetPath = sysDir
+ if err := os.MkdirAll(filepath.Join(sysDir, "enp3s0"), 0o755); err != nil {
+ t.Fatalf("mkdir enp3s0: %v", err)
+ }
+ if err := os.WriteFile(filepath.Join(sysDir, "enp3s0", "address"), []byte("aa:bb:cc:dd:ee:ff\n"), 0o644); err != nil {
+ t.Fatalf("write address: %v", err)
+ }
+
+ invJSON := []byte(`{"interfaces":[{"name":"eno1","mac":"aa:bb:cc:dd:ee:ff"}]}`)
+ writeTarToFakeFS(t, fakeFS, "/backup.tar", []tarEntry{
+ {
+ Name: "./var/lib/proxsave-info/commands/system/network_inventory.json",
+ Typeflag: tar.TypeReg,
+ Mode: 0o644,
+ Data: invJSON,
+ },
+ })
+
+ if err := fakeFS.MkdirAll("/etc/udev/rules.d", 0o755); err != nil {
+ t.Fatalf("mkdir udev: %v", err)
+ }
+ rule := `SUBSYSTEM=="net", ATTR{address}=="aa:bb:cc:dd:ee:ff", NAME="eth0"`
+ if err := fakeFS.WriteFile("/etc/udev/rules.d/70-persistent-net.rules", []byte(rule+"\n"), 0o644); err != nil {
+ t.Fatalf("write rule: %v", err)
+ }
+
+ if err := fakeFS.WriteFile("/etc/network/interfaces", []byte("auto eno1\niface eno1 inet manual\n"), 0o644); err != nil {
+ t.Fatalf("write interfaces: %v", err)
+ }
+
+ reader := bufio.NewReader(strings.NewReader(""))
+ result := maybeRepairNICNamesCLI(context.Background(), reader, newDiscardLogger(), "/backup.tar")
+ if result == nil || !result.Applied() {
+ t.Fatalf("expected applied result despite prompt error, got %#v", result)
+ }
+}
+
+func TestMaybeRepairNICNamesCLI_ConflictPromptErrorLeavesConflictsExcluded(t *testing.T) {
+ origFS := restoreFS
+ origTime := restoreTime
+ origSysNet := sysClassNetPath
+ t.Cleanup(func() {
+ restoreFS = origFS
+ restoreTime = origTime
+ sysClassNetPath = origSysNet
+ })
+
+ fakeFS := NewFakeFS()
+ t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) })
+ restoreFS = fakeFS
+ restoreTime = &FakeTime{Current: time.Date(2026, 2, 2, 3, 4, 5, 0, time.UTC)}
+
+ pathDir := t.TempDir()
+ t.Setenv("PATH", pathDir)
+
+ sysDir := t.TempDir()
+ sysClassNetPath = sysDir
+ if err := os.MkdirAll(filepath.Join(sysDir, "enp3s0"), 0o755); err != nil {
+ t.Fatalf("mkdir enp3s0: %v", err)
+ }
+ if err := os.MkdirAll(filepath.Join(sysDir, "eno1"), 0o755); err != nil {
+ t.Fatalf("mkdir eno1: %v", err)
+ }
+ if err := os.WriteFile(filepath.Join(sysDir, "enp3s0", "address"), []byte("aa:bb:cc:dd:ee:ff\n"), 0o644); err != nil {
+ t.Fatalf("write enp3s0 address: %v", err)
+ }
+ if err := os.WriteFile(filepath.Join(sysDir, "eno1", "address"), []byte("11:22:33:44:55:66\n"), 0o644); err != nil {
+ t.Fatalf("write eno1 address: %v", err)
+ }
+
+ invJSON := []byte(`{"interfaces":[{"name":"eno1","mac":"aa:bb:cc:dd:ee:ff"}]}`)
+ writeTarToFakeFS(t, fakeFS, "/backup.tar", []tarEntry{
+ {
+ Name: "./var/lib/proxsave-info/commands/system/network_inventory.json",
+ Typeflag: tar.TypeReg,
+ Mode: 0o644,
+ Data: invJSON,
+ },
+ })
+
+ reader := bufio.NewReader(strings.NewReader(""))
+ result := maybeRepairNICNamesCLI(context.Background(), reader, newDiscardLogger(), "/backup.tar")
+ if result == nil {
+ t.Fatalf("expected result")
+ }
+ if result.Applied() {
+ t.Fatalf("expected conflicts excluded due to prompt error, got %#v", result)
+ }
+ if !strings.Contains(result.SkippedReason, "conflicting NIC mappings") {
+ t.Fatalf("SkippedReason=%q", result.SkippedReason)
+ }
+}
+
+func TestMaybeRepairNICNamesCLI_MoreThan32ConflictsTriggersTruncationBranch(t *testing.T) {
+ origFS := restoreFS
+ origTime := restoreTime
+ origSysNet := sysClassNetPath
+ t.Cleanup(func() {
+ restoreFS = origFS
+ restoreTime = origTime
+ sysClassNetPath = origSysNet
+ })
+
+ fakeFS := NewFakeFS()
+ t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) })
+ restoreFS = fakeFS
+ restoreTime = &FakeTime{Current: time.Date(2026, 2, 2, 3, 4, 5, 0, time.UTC)}
+
+ pathDir := t.TempDir()
+ t.Setenv("PATH", pathDir)
+
+ sysDir := t.TempDir()
+ sysClassNetPath = sysDir
+
+ var ifaceJSON []string
+ for i := 0; i < 33; i++ {
+ oldName := fmt.Sprintf("eno%d", i)
+ newName := fmt.Sprintf("enp%d", i)
+ mac := fmt.Sprintf("aa:bb:cc:dd:ee:%02x", i)
+ ifaceJSON = append(ifaceJSON, fmt.Sprintf(`{"name":"%s","mac":"%s"}`, oldName, mac))
+
+ if err := os.MkdirAll(filepath.Join(sysDir, newName), 0o755); err != nil {
+ t.Fatalf("mkdir %s: %v", newName, err)
+ }
+ if err := os.WriteFile(filepath.Join(sysDir, newName, "address"), []byte(mac+"\n"), 0o644); err != nil {
+ t.Fatalf("write %s address: %v", newName, err)
+ }
+
+ if err := os.MkdirAll(filepath.Join(sysDir, oldName), 0o755); err != nil {
+ t.Fatalf("mkdir %s: %v", oldName, err)
+ }
+ conflictMAC := fmt.Sprintf("11:22:33:44:55:%02x", i)
+ if err := os.WriteFile(filepath.Join(sysDir, oldName, "address"), []byte(conflictMAC+"\n"), 0o644); err != nil {
+ t.Fatalf("write %s address: %v", oldName, err)
+ }
+ }
+
+ invJSON := []byte(`{"interfaces":[` + strings.Join(ifaceJSON, ",") + `]}`)
+ writeTarToFakeFS(t, fakeFS, "/backup.tar", []tarEntry{
+ {
+ Name: "./var/lib/proxsave-info/commands/system/network_inventory.json",
+ Typeflag: tar.TypeReg,
+ Mode: 0o644,
+ Data: invJSON,
+ },
+ })
+
+ reader := bufio.NewReader(strings.NewReader("n\n"))
+ result := maybeRepairNICNamesCLI(context.Background(), reader, newDiscardLogger(), "/backup.tar")
+ if result == nil {
+ t.Fatalf("expected result")
+ }
+ if result.Applied() {
+ t.Fatalf("expected no changes when conflicts skipped, got %#v", result)
+ }
+}
+
+func TestMaybeRepairNICNamesCLI_ReturnsNilOnApplyError(t *testing.T) {
+ origFS := restoreFS
+ origTime := restoreTime
+ origSysNet := sysClassNetPath
+ t.Cleanup(func() {
+ restoreFS = origFS
+ restoreTime = origTime
+ sysClassNetPath = origSysNet
+ })
+
+ fakeFS := NewFakeFS()
+ t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) })
+ restoreFS = fakeFS
+ restoreTime = &FakeTime{Current: time.Date(2026, 2, 2, 3, 4, 5, 0, time.UTC)}
+
+ pathDir := t.TempDir()
+ t.Setenv("PATH", pathDir)
+
+ sysDir := t.TempDir()
+ sysClassNetPath = sysDir
+ if err := os.MkdirAll(filepath.Join(sysDir, "enp3s0"), 0o755); err != nil {
+ t.Fatalf("mkdir enp3s0: %v", err)
+ }
+ if err := os.WriteFile(filepath.Join(sysDir, "enp3s0", "address"), []byte("aa:bb:cc:dd:ee:ff\n"), 0o644); err != nil {
+ t.Fatalf("write address: %v", err)
+ }
+
+ invJSON := []byte(`{"interfaces":[{"name":"eno1","mac":"aa:bb:cc:dd:ee:ff"}]}`)
+ writeTarToFakeFS(t, fakeFS, "/backup.tar", []tarEntry{
+ {
+ Name: "./var/lib/proxsave-info/commands/system/network_inventory.json",
+ Typeflag: tar.TypeReg,
+ Mode: 0o644,
+ Data: invJSON,
+ },
+ })
+
+ if err := fakeFS.WriteFile("/etc/network/interfaces", []byte("auto eno1\niface eno1 inet manual\n"), 0o644); err != nil {
+ t.Fatalf("write interfaces: %v", err)
+ }
+
+ restoreFS = mkdirAllFailFS{FS: fakeFS, failPath: "/tmp/proxsave", err: errors.New("boom")}
+
+ reader := bufio.NewReader(strings.NewReader(""))
+ if got := maybeRepairNICNamesCLI(context.Background(), reader, newDiscardLogger(), "/backup.tar"); got != nil {
+ t.Fatalf("expected nil on apply error, got %#v", got)
+ }
+}
+
+func TestApplyNetworkConfig_SelectsAvailableCommand(t *testing.T) {
+ origCmd := restoreCmd
+ t.Cleanup(func() { restoreCmd = origCmd })
+
+ logger := newDiscardLogger()
+
+ t.Run("ifreload", func(t *testing.T) {
+ pathDir := t.TempDir()
+ writeExecutable(t, pathDir, "ifreload")
+ writeExecutable(t, pathDir, "systemctl")
+ writeExecutable(t, pathDir, "ifup")
+ t.Setenv("PATH", pathDir)
+
+ fakeCmd := &FakeCommandRunner{}
+ restoreCmd = fakeCmd
+
+ if err := applyNetworkConfig(context.Background(), logger); err != nil {
+ t.Fatalf("applyNetworkConfig error: %v", err)
+ }
+ if calls := fakeCmd.CallsList(); len(calls) != 1 || calls[0] != "ifreload -a" {
+ t.Fatalf("calls=%v want [ifreload -a]", calls)
+ }
+ })
+
+ t.Run("systemctl", func(t *testing.T) {
+ pathDir := t.TempDir()
+ writeExecutable(t, pathDir, "systemctl")
+ t.Setenv("PATH", pathDir)
+
+ fakeCmd := &FakeCommandRunner{}
+ restoreCmd = fakeCmd
+
+ if err := applyNetworkConfig(context.Background(), logger); err != nil {
+ t.Fatalf("applyNetworkConfig error: %v", err)
+ }
+ if calls := fakeCmd.CallsList(); len(calls) != 1 || calls[0] != "systemctl restart networking" {
+ t.Fatalf("calls=%v want [systemctl restart networking]", calls)
+ }
+ })
+
+ t.Run("ifup", func(t *testing.T) {
+ pathDir := t.TempDir()
+ writeExecutable(t, pathDir, "ifup")
+ t.Setenv("PATH", pathDir)
+
+ fakeCmd := &FakeCommandRunner{}
+ restoreCmd = fakeCmd
+
+ if err := applyNetworkConfig(context.Background(), logger); err != nil {
+ t.Fatalf("applyNetworkConfig error: %v", err)
+ }
+ if calls := fakeCmd.CallsList(); len(calls) != 1 || calls[0] != "ifup -a" {
+ t.Fatalf("calls=%v want [ifup -a]", calls)
+ }
+ })
+
+ t.Run("none", func(t *testing.T) {
+ pathDir := t.TempDir()
+ t.Setenv("PATH", pathDir)
+
+ restoreCmd = &FakeCommandRunner{}
+ if err := applyNetworkConfig(context.Background(), logger); err == nil || !strings.Contains(err.Error(), "no supported network reload command") {
+ t.Fatalf("err=%v want supported network reload command error", err)
+ }
+ })
+}
+
+func TestParseSSHClientIP(t *testing.T) {
+ t.Run("SSH_CONNECTION", func(t *testing.T) {
+ t.Setenv("SSH_CONNECTION", "203.0.113.9 1234 10.0.0.1 22")
+ t.Setenv("SSH_CLIENT", "")
+ if got := parseSSHClientIP(); got != "203.0.113.9" {
+ t.Fatalf("got %q want %q", got, "203.0.113.9")
+ }
+ })
+
+ t.Run("SSH_CLIENT", func(t *testing.T) {
+ t.Setenv("SSH_CONNECTION", "")
+ t.Setenv("SSH_CLIENT", "203.0.113.8 2222 22")
+ if got := parseSSHClientIP(); got != "203.0.113.8" {
+ t.Fatalf("got %q want %q", got, "203.0.113.8")
+ }
+ })
+
+ t.Run("none", func(t *testing.T) {
+ t.Setenv("SSH_CONNECTION", "")
+ t.Setenv("SSH_CLIENT", "")
+ if got := parseSSHClientIP(); got != "" {
+ t.Fatalf("got %q want empty", got)
+ }
+ })
+}
+
+func TestDetectManagementInterface(t *testing.T) {
+ origCmd := restoreCmd
+ t.Cleanup(func() { restoreCmd = origCmd })
+
+ logger := newDiscardLogger()
+
+ t.Run("ssh route", func(t *testing.T) {
+ t.Setenv("SSH_CONNECTION", "203.0.113.9 1234 10.0.0.1 22")
+ t.Setenv("SSH_CLIENT", "")
+ restoreCmd = &FakeCommandRunner{
+ Outputs: map[string][]byte{
+ "ip route get 203.0.113.9": []byte("203.0.113.9 dev vmbr0 src 192.0.2.2\n"),
+ },
+ }
+ iface, src := detectManagementInterface(context.Background(), logger)
+ if iface != "vmbr0" || src != "ssh" {
+ t.Fatalf("iface=%q src=%q want vmbr0 ssh", iface, src)
+ }
+ })
+
+ t.Run("default route fallback", func(t *testing.T) {
+ t.Setenv("SSH_CONNECTION", "203.0.113.9 1234 10.0.0.1 22")
+ t.Setenv("SSH_CLIENT", "")
+ restoreCmd = &FakeCommandRunner{
+ Errors: map[string]error{
+ "ip route get 203.0.113.9": errors.New("boom"),
+ },
+ Outputs: map[string][]byte{
+ "ip route show default": []byte("default via 192.0.2.1 dev nic1\n"),
+ },
+ }
+ iface, src := detectManagementInterface(context.Background(), logger)
+ if iface != "nic1" || src != "default-route" {
+ t.Fatalf("iface=%q src=%q want nic1 default-route", iface, src)
+ }
+ })
+
+ t.Run("none", func(t *testing.T) {
+ t.Setenv("SSH_CONNECTION", "")
+ t.Setenv("SSH_CLIENT", "")
+ restoreCmd = &FakeCommandRunner{
+ Errors: map[string]error{
+ "ip route show default": errors.New("boom"),
+ },
+ }
+ iface, src := detectManagementInterface(context.Background(), logger)
+ if iface != "" || src != "" {
+ t.Fatalf("iface=%q src=%q want empty", iface, src)
+ }
+ })
+}
+
+func TestRouteInterfaceForIPAndDefaultRouteInterface_ErrorCases(t *testing.T) {
+ origCmd := restoreCmd
+ t.Cleanup(func() { restoreCmd = origCmd })
+
+ restoreCmd = &FakeCommandRunner{
+ Errors: map[string]error{
+ "ip route get 203.0.113.9": errors.New("boom"),
+ "ip route show default": errors.New("boom"),
+ "ip route show default -x": errors.New("boom"),
+ "ip route show default --y": errors.New("boom"),
+ },
+ }
+ if got := routeInterfaceForIP(context.Background(), "203.0.113.9"); got != "" {
+ t.Fatalf("routeInterfaceForIP=%q want empty on error", got)
+ }
+ if got := defaultRouteInterface(context.Background()); got != "" {
+ t.Fatalf("defaultRouteInterface=%q want empty on error", got)
+ }
+}
+
+func TestDefaultRouteInterface_EmptyOutputReturnsEmpty(t *testing.T) {
+ origCmd := restoreCmd
+ t.Cleanup(func() { restoreCmd = origCmd })
+
+ restoreCmd = &FakeCommandRunner{
+ Outputs: map[string][]byte{
+ "ip route show default": []byte(""),
+ },
+ }
+ if got := defaultRouteInterface(context.Background()); got != "" {
+ t.Fatalf("defaultRouteInterface=%q want empty", got)
+ }
+}
+
+func TestDefaultNetworkPortChecks(t *testing.T) {
+ if got := defaultNetworkPortChecks(SystemTypePVE); len(got) != 1 || got[0].Port != 8006 {
+ t.Fatalf("PVE checks=%v", got)
+ }
+ if got := defaultNetworkPortChecks(SystemTypePBS); len(got) != 1 || got[0].Port != 8007 {
+ t.Fatalf("PBS checks=%v", got)
+ }
+ if got := defaultNetworkPortChecks(SystemTypeUnknown); got != nil {
+ t.Fatalf("unknown checks=%v want nil", got)
+ }
+}
+
+func TestPromptNetworkCommitWithCountdown_InputAborted(t *testing.T) {
+ reader := bufio.NewReader(strings.NewReader(""))
+ logger := newDiscardLogger()
+
+ committed, err := promptNetworkCommitWithCountdown(context.Background(), reader, logger, 2*time.Second)
+ if committed {
+ t.Fatalf("expected committed=false")
+ }
+ if err == nil {
+ t.Fatalf("expected error")
+ }
+}
+
+func TestRollbackNetworkFilesNow_ErrorCasesAndScriptFailure(t *testing.T) {
+ origFS := restoreFS
+ origCmd := restoreCmd
+ origTime := restoreTime
+ t.Cleanup(func() {
+ restoreFS = origFS
+ restoreCmd = origCmd
+ restoreTime = origTime
+ })
+
+ logger := newDiscardLogger()
+
+ if _, err := rollbackNetworkFilesNow(context.Background(), logger, "", ""); err == nil {
+ t.Fatalf("expected error for empty backup path")
+ }
+
+ fakeFS := NewFakeFS()
+ t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) })
+ restoreTime = &FakeTime{Current: time.Date(2026, 2, 1, 12, 34, 56, 0, time.UTC)}
+
+ t.Run("mkdir error", func(t *testing.T) {
+ restoreFS = &FakeFS{Root: fakeFS.Root, MkdirAllErr: errors.New("boom"), StatErr: make(map[string]error), StatErrors: make(map[string]error), OpenFileErr: make(map[string]error)}
+ restoreCmd = &FakeCommandRunner{}
+ if _, err := rollbackNetworkFilesNow(context.Background(), logger, "/backup.tar", "/work"); err == nil {
+ t.Fatalf("expected error for mkdir failure")
+ }
+ })
+
+ t.Run("marker write error", func(t *testing.T) {
+ failFS := NewFakeFS()
+ t.Cleanup(func() { _ = os.RemoveAll(failFS.Root) })
+ failFS.WriteErr = errors.New("boom")
+ restoreFS = failFS
+ restoreCmd = &FakeCommandRunner{}
+
+ if _, err := rollbackNetworkFilesNow(context.Background(), logger, "/backup.tar", "/work"); err == nil || !strings.Contains(err.Error(), "write rollback marker") {
+ t.Fatalf("err=%v want write rollback marker error", err)
+ }
+ })
+
+ t.Run("script write error removes marker", func(t *testing.T) {
+ restoreFS = fakeFS
+ restoreCmd = &FakeCommandRunner{}
+
+ scriptPath := "/work/network_rollback_now_20260201_123456.sh"
+ restoreFS = writeFailFS{FS: fakeFS, failPath: scriptPath, err: errors.New("boom")}
+
+ _, err := rollbackNetworkFilesNow(context.Background(), logger, "/backup.tar", "/work")
+ if err == nil || !strings.Contains(err.Error(), "write rollback script") {
+ t.Fatalf("err=%v want write rollback script error", err)
+ }
+
+ markerPath := "/work/network_rollback_now_pending_20260201_123456"
+ if _, statErr := fakeFS.Stat(markerPath); !errors.Is(statErr, os.ErrNotExist) {
+ t.Fatalf("expected marker to be removed on script write failure, stat err=%v", statErr)
+ }
+ })
+
+ t.Run("marker remove error is non-fatal", func(t *testing.T) {
+ restoreCmd = &FakeCommandRunner{
+ Outputs: map[string][]byte{
+ "sh /work/network_rollback_now_20260201_123456.sh": []byte("ok\n"),
+ },
+ }
+
+ markerPath := "/work/network_rollback_now_pending_20260201_123456"
+ restoreFS = removeFailFS{FS: fakeFS, failPath: markerPath, err: errors.New("remove boom")}
+
+ logPath, err := rollbackNetworkFilesNow(context.Background(), logger, "/backup.tar", "/work")
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if logPath != "/work/network_rollback_now_20260201_123456.log" {
+ t.Fatalf("logPath=%q", logPath)
+ }
+ })
+
+ t.Run("script run error returns log path", func(t *testing.T) {
+ restoreFS = fakeFS
+
+ restoreCmd = &FakeCommandRunner{
+ Errors: map[string]error{
+ "sh /work/network_rollback_now_20260201_123456.sh": errors.New("boom"),
+ },
+ }
+
+ logPath, err := rollbackNetworkFilesNow(context.Background(), logger, "/backup.tar", "/work")
+ if err == nil || !strings.Contains(err.Error(), "rollback script failed") {
+ t.Fatalf("err=%v want rollback script failed", err)
+ }
+ if logPath != "/work/network_rollback_now_20260201_123456.log" {
+ t.Fatalf("logPath=%q", logPath)
+ }
+ })
+}
+
+func TestRollbackNetworkFilesNow_DefaultWorkDirUsesTmpProxsave(t *testing.T) {
+ origFS := restoreFS
+ origCmd := restoreCmd
+ origTime := restoreTime
+ t.Cleanup(func() {
+ restoreFS = origFS
+ restoreCmd = origCmd
+ restoreTime = origTime
+ })
+
+ fakeFS := NewFakeFS()
+ t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) })
+ restoreFS = fakeFS
+ restoreTime = &FakeTime{Current: time.Date(2026, 2, 1, 12, 34, 56, 0, time.UTC)}
+
+ restoreCmd = &FakeCommandRunner{
+ Outputs: map[string][]byte{
+ "sh /tmp/proxsave/network_rollback_now_20260201_123456.sh": []byte("ok\n"),
+ },
+ }
+
+ logPath, err := rollbackNetworkFilesNow(context.Background(), newDiscardLogger(), "/backup.tar", "")
+ if err != nil {
+ t.Fatalf("rollbackNetworkFilesNow error: %v", err)
+ }
+ if logPath != "/tmp/proxsave/network_rollback_now_20260201_123456.log" {
+ t.Fatalf("logPath=%q", logPath)
+ }
+}
+
+func TestBuildRollbackScript_IncludesRestartBlockWhenEnabled(t *testing.T) {
+ script := buildRollbackScript("/marker", "/backup with spaces.tar", "/tmp/log file.log", true)
+ if !strings.Contains(script, "Restart networking after rollback") {
+ t.Fatalf("expected restartNetworking block")
+ }
+ if !strings.Contains(script, "LOG='/tmp/log file.log'") {
+ t.Fatalf("expected quoted LOG, got:\n%s", script)
+ }
+ if !strings.Contains(script, "BACKUP='/backup with spaces.tar'") {
+ t.Fatalf("expected quoted BACKUP, got:\n%s", script)
+ }
+}
+
+func TestShellQuote(t *testing.T) {
+ if got := shellQuote(""); got != "''" {
+ t.Fatalf("shellQuote empty=%q", got)
+ }
+ if got := shellQuote("simple"); got != "simple" {
+ t.Fatalf("shellQuote simple=%q", got)
+ }
+ want := `'a b '\''c'\'''`
+ if got := shellQuote("a b 'c'"); got != want {
+ t.Fatalf("shellQuote=%q want %q", got, want)
+ }
+}
+
+func TestRunCommandLogged_SuccessAndFailure(t *testing.T) {
+ origCmd := restoreCmd
+ t.Cleanup(func() { restoreCmd = origCmd })
+
+ logger := newDiscardLogger()
+
+ restoreCmd = &FakeCommandRunner{
+ Outputs: map[string][]byte{
+ "echo hi": []byte("hi\n"),
+ },
+ }
+ if err := runCommandLogged(context.Background(), logger, "echo", "hi"); err != nil {
+ t.Fatalf("runCommandLogged error: %v", err)
+ }
+
+ restoreCmd = &FakeCommandRunner{
+ Errors: map[string]error{
+ "false x": errors.New("boom"),
+ },
+ }
+ err := runCommandLogged(context.Background(), logger, "false", "x")
+ if err == nil || !strings.Contains(err.Error(), "false") || !strings.Contains(err.Error(), "failed") {
+ t.Fatalf("err=%v want wrapped failure", err)
+ }
+}
diff --git a/internal/orchestrator/network_staged_apply.go b/internal/orchestrator/network_staged_apply.go
index 3412d15..fa5a912 100644
--- a/internal/orchestrator/network_staged_apply.go
+++ b/internal/orchestrator/network_staged_apply.go
@@ -59,13 +59,20 @@ func applyNetworkFilesFromStage(logger *logging.Logger, stageRoot string) (appli
}
func copyDirOverlay(srcDir, destDir string) ([]string, error) {
- info, err := restoreFS.Stat(srcDir)
+ return copyDirOverlayWithinRoot(srcDir, destDir, destDir)
+}
+
+func copyDirOverlayWithinRoot(srcDir, destDir, destRoot string) ([]string, error) {
+ info, err := restoreFS.Lstat(srcDir)
if err != nil {
if os.IsNotExist(err) {
return nil, nil
}
return nil, fmt.Errorf("stat %s: %w", srcDir, err)
}
+ if info.Mode()&os.ModeSymlink != 0 {
+ return nil, fmt.Errorf("staged directory must not be a symlink: %s", srcDir)
+ }
if !info.IsDir() {
return nil, nil
}
@@ -91,8 +98,27 @@ func copyDirOverlay(srcDir, destDir string) ([]string, error) {
src := filepath.Join(srcDir, name)
dest := filepath.Join(destDir, name)
- if entry.IsDir() {
- paths, err := copyDirOverlay(src, dest)
+ info, err := restoreFS.Lstat(src)
+ if err != nil {
+ if os.IsNotExist(err) {
+ continue
+ }
+ return applied, fmt.Errorf("stat %s: %w", src, err)
+ }
+
+ if info.Mode()&os.ModeSymlink != 0 {
+ ok, err := copySymlinkOverlayWithinRoot(src, dest, destRoot)
+ if err != nil {
+ return applied, err
+ }
+ if ok {
+ applied = append(applied, dest)
+ }
+ continue
+ }
+
+ if info.IsDir() {
+ paths, err := copyDirOverlayWithinRoot(src, dest, destRoot)
if err != nil {
return applied, err
}
@@ -100,7 +126,7 @@ func copyDirOverlay(srcDir, destDir string) ([]string, error) {
continue
}
- ok, err := copyFileOverlay(src, dest)
+ ok, err := copyFileOverlayWithinRoot(src, dest, destRoot)
if err != nil {
return applied, err
}
@@ -113,16 +139,26 @@ func copyDirOverlay(srcDir, destDir string) ([]string, error) {
}
func copyFileOverlay(src, dest string) (bool, error) {
- info, err := restoreFS.Stat(src)
+ return copyFileOverlayWithinRoot(src, dest, filepath.Dir(dest))
+}
+
+func copyFileOverlayWithinRoot(src, dest, destRoot string) (bool, error) {
+ info, err := restoreFS.Lstat(src)
if err != nil {
if os.IsNotExist(err) {
return false, nil
}
return false, fmt.Errorf("stat %s: %w", src, err)
}
+ if info.Mode()&os.ModeSymlink != 0 {
+ return copySymlinkOverlayWithinRoot(src, dest, destRoot)
+ }
if info.IsDir() {
return false, nil
}
+ if !info.Mode().IsRegular() {
+ return false, fmt.Errorf("unsupported staged file type %s (mode=%s)", src, info.Mode())
+ }
data, err := restoreFS.ReadFile(src)
if err != nil {
@@ -141,3 +177,73 @@ func copyFileOverlay(src, dest string) (bool, error) {
}
return true, nil
}
+
+func copySymlinkOverlay(src, dest string) (bool, error) {
+ return copySymlinkOverlayWithinRoot(src, dest, filepath.Dir(dest))
+}
+
+func copySymlinkOverlayWithinRoot(src, dest, destRoot string) (bool, error) {
+ info, err := restoreFS.Lstat(src)
+ if err != nil {
+ if os.IsNotExist(err) {
+ return false, nil
+ }
+ return false, fmt.Errorf("stat %s: %w", src, err)
+ }
+ if info.Mode()&os.ModeSymlink == 0 {
+ return false, fmt.Errorf("source is not a symlink: %s", src)
+ }
+
+ target, err := restoreFS.Readlink(src)
+ if err != nil {
+ if os.IsNotExist(err) {
+ return false, nil
+ }
+ return false, fmt.Errorf("readlink %s: %w", src, err)
+ }
+
+ validatedTarget, err := validateOverlaySymlinkTargetWithinRoot(destRoot, dest, target)
+ if err != nil {
+ return false, fmt.Errorf("unsafe symlink target %s -> %s: %w", dest, target, err)
+ }
+
+ if err := ensureDirExistsWithInheritedMeta(filepath.Dir(dest)); err != nil {
+ return false, fmt.Errorf("ensure %s: %w", filepath.Dir(dest), err)
+ }
+
+ if existing, err := restoreFS.Lstat(dest); err == nil {
+ if existing.IsDir() {
+ return false, fmt.Errorf("destination exists as directory: %s", dest)
+ }
+ if err := restoreFS.Remove(dest); err != nil && !os.IsNotExist(err) {
+ return false, fmt.Errorf("remove %s: %w", dest, err)
+ }
+ } else if !os.IsNotExist(err) {
+ return false, fmt.Errorf("stat %s: %w", dest, err)
+ }
+
+ if err := restoreFS.Symlink(validatedTarget, dest); err != nil {
+ return false, fmt.Errorf("symlink %s -> %s: %w", dest, validatedTarget, err)
+ }
+ return true, nil
+}
+
+func validateOverlaySymlinkTargetWithinRoot(destRoot, dest, target string) (string, error) {
+ destRoot = filepath.Clean(strings.TrimSpace(destRoot))
+ dest = filepath.Clean(strings.TrimSpace(dest))
+
+ resolved, err := resolvePathRelativeToBaseWithinRootFS(restoreFS, destRoot, filepath.Dir(dest), target)
+ if err != nil {
+ return "", err
+ }
+
+ if !filepath.IsAbs(target) {
+ return target, nil
+ }
+
+ rewrittenTarget, err := filepath.Rel(filepath.Dir(dest), resolved)
+ if err != nil {
+ return "", fmt.Errorf("rewrite symlink target %s -> %s: %w", dest, resolved, err)
+ }
+ return filepath.Clean(rewrittenTarget), nil
+}
diff --git a/internal/orchestrator/network_staged_apply_test.go b/internal/orchestrator/network_staged_apply_test.go
new file mode 100644
index 0000000..3683af0
--- /dev/null
+++ b/internal/orchestrator/network_staged_apply_test.go
@@ -0,0 +1,134 @@
+package orchestrator
+
+import (
+ "os"
+ "path/filepath"
+ "strings"
+ "testing"
+)
+
+type preservingSymlinkFS struct {
+ *FakeFS
+}
+
+func newPreservingSymlinkFS() *preservingSymlinkFS {
+ return &preservingSymlinkFS{FakeFS: NewFakeFS()}
+}
+
+func (f *preservingSymlinkFS) Symlink(oldname, newname string) error {
+ if err := os.MkdirAll(filepath.Dir(f.onDisk(newname)), 0o755); err != nil {
+ return err
+ }
+ return os.Symlink(oldname, f.onDisk(newname))
+}
+
+func TestApplyNetworkFilesFromStage_RewritesSafeAbsoluteSymlinkTargetsWithinDestinationRoot(t *testing.T) {
+ origFS := restoreFS
+ t.Cleanup(func() { restoreFS = origFS })
+
+ fakeFS := newPreservingSymlinkFS()
+ t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) })
+ restoreFS = fakeFS
+
+ if err := fakeFS.MkdirAll("/stage/etc/network", 0o755); err != nil {
+ t.Fatalf("create stage dir: %v", err)
+ }
+ if err := fakeFS.Symlink("/etc/network/interfaces.real", "/stage/etc/network/interfaces"); err != nil {
+ t.Fatalf("create staged symlink: %v", err)
+ }
+
+ applied, err := applyNetworkFilesFromStage(newTestLogger(), "/stage")
+ if err != nil {
+ t.Fatalf("applyNetworkFilesFromStage: %v", err)
+ }
+
+ info, err := fakeFS.Lstat("/etc/network/interfaces")
+ if err != nil {
+ t.Fatalf("lstat dest symlink: %v", err)
+ }
+ if info.Mode()&os.ModeSymlink == 0 {
+ t.Fatalf("expected symlink mode, got %s", info.Mode())
+ }
+
+ gotTarget, err := fakeFS.Readlink("/etc/network/interfaces")
+ if err != nil {
+ t.Fatalf("readlink dest symlink: %v", err)
+ }
+ if gotTarget != "interfaces.real" {
+ t.Fatalf("symlink target=%q, want %q", gotTarget, "interfaces.real")
+ }
+
+ found := false
+ for _, path := range applied {
+ if path == "/etc/network/interfaces" {
+ found = true
+ break
+ }
+ }
+ if !found {
+ t.Fatalf("expected applied paths to include /etc/network/interfaces, got %#v", applied)
+ }
+}
+
+func TestApplyNetworkFilesFromStage_RejectsEscapingSymlinkTargets(t *testing.T) {
+ origFS := restoreFS
+ t.Cleanup(func() { restoreFS = origFS })
+
+ fakeFS := newPreservingSymlinkFS()
+ t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) })
+ restoreFS = fakeFS
+
+ if err := fakeFS.MkdirAll("/stage/etc/network", 0o755); err != nil {
+ t.Fatalf("create stage dir: %v", err)
+ }
+ if err := fakeFS.Symlink("/stage/etc/network/interfaces.real", "/stage/etc/network/interfaces"); err != nil {
+ t.Fatalf("create staged symlink: %v", err)
+ }
+
+ _, err := applyNetworkFilesFromStage(newTestLogger(), "/stage")
+ if err == nil || !strings.Contains(err.Error(), "unsafe symlink target") {
+ t.Fatalf("expected unsafe symlink target error, got %v", err)
+ }
+}
+
+func TestApplyNetworkFilesFromStage_RejectsRelativeSymlinkTargetsThatEscapeDestinationRoot(t *testing.T) {
+ origFS := restoreFS
+ t.Cleanup(func() { restoreFS = origFS })
+
+ fakeFS := newPreservingSymlinkFS()
+ t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) })
+ restoreFS = fakeFS
+
+ if err := fakeFS.MkdirAll("/stage/etc/network/interfaces.d", 0o755); err != nil {
+ t.Fatalf("create stage dir: %v", err)
+ }
+ if err := fakeFS.Symlink("../../shadow", "/stage/etc/network/interfaces.d/uplink"); err != nil {
+ t.Fatalf("create staged symlink: %v", err)
+ }
+
+ _, err := applyNetworkFilesFromStage(newTestLogger(), "/stage")
+ if err == nil || !strings.Contains(err.Error(), "unsafe symlink target") {
+ t.Fatalf("expected unsafe symlink target error, got %v", err)
+ }
+}
+
+func TestApplyNetworkFilesFromStage_RejectsSymlinkStageDirectory(t *testing.T) {
+ origFS := restoreFS
+ t.Cleanup(func() { restoreFS = origFS })
+
+ fakeFS := newPreservingSymlinkFS()
+ t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) })
+ restoreFS = fakeFS
+
+ if err := fakeFS.MkdirAll("/stage/etc", 0o755); err != nil {
+ t.Fatalf("create stage parent dir: %v", err)
+ }
+ if err := fakeFS.Symlink("/outside/network", "/stage/etc/network"); err != nil {
+ t.Fatalf("create staged dir symlink: %v", err)
+ }
+
+ _, err := applyNetworkFilesFromStage(newTestLogger(), "/stage")
+ if err == nil || !strings.Contains(err.Error(), "must not be a symlink") {
+ t.Fatalf("expected staged directory symlink error, got %v", err)
+ }
+}
diff --git a/internal/orchestrator/path_security.go b/internal/orchestrator/path_security.go
new file mode 100644
index 0000000..7b2a425
--- /dev/null
+++ b/internal/orchestrator/path_security.go
@@ -0,0 +1,286 @@
+package orchestrator
+
+import (
+ "errors"
+ "fmt"
+ "os"
+ "path/filepath"
+ "strings"
+)
+
+const maxPathSecuritySymlinkHops = 40
+
+type pathResolutionErrorKind string
+
+const (
+ pathResolutionErrorSecurity pathResolutionErrorKind = "security"
+ pathResolutionErrorOperational pathResolutionErrorKind = "operational"
+)
+
+type pathResolutionError struct {
+ kind pathResolutionErrorKind
+ msg string
+ cause error
+}
+
+func (e *pathResolutionError) Error() string {
+ if e == nil {
+ return ""
+ }
+ if e.cause != nil {
+ return fmt.Sprintf("%s: %v", e.msg, e.cause)
+ }
+ return e.msg
+}
+
+func (e *pathResolutionError) Unwrap() error {
+ if e == nil {
+ return nil
+ }
+ return e.cause
+}
+
+func newPathResolutionError(kind pathResolutionErrorKind, cause error, format string, args ...any) error {
+ return &pathResolutionError{
+ kind: kind,
+ msg: fmt.Sprintf(format, args...),
+ cause: cause,
+ }
+}
+
+func newPathSecurityError(format string, args ...any) error {
+ return newPathResolutionError(pathResolutionErrorSecurity, nil, format, args...)
+}
+
+func wrapPathOperationalError(cause error, format string, args ...any) error {
+ return newPathResolutionError(pathResolutionErrorOperational, cause, format, args...)
+}
+
+func isPathSecurityError(err error) bool {
+ var target *pathResolutionError
+ return errors.As(err, &target) && target.kind == pathResolutionErrorSecurity
+}
+
+func isPathOperationalError(err error) bool {
+ var target *pathResolutionError
+ return errors.As(err, &target) && target.kind == pathResolutionErrorOperational
+}
+
+func resolvePathWithinRootFS(fsys FS, destRoot, candidate string) (string, error) {
+ lexicalRoot, canonicalRoot, err := prepareRootPathsFS(fsys, destRoot)
+ if err != nil {
+ return "", err
+ }
+
+ candidateAbs, err := normalizeCandidateWithinRoot(lexicalRoot, canonicalRoot, candidate)
+ if err != nil {
+ return "", err
+ }
+
+ return resolvePathWithinPreparedRootFS(fsys, lexicalRoot, canonicalRoot, candidateAbs, true, maxPathSecuritySymlinkHops)
+}
+
+func resolvePathRelativeToBaseWithinRootFS(fsys FS, destRoot, baseDir, candidate string) (string, error) {
+ lexicalRoot, canonicalRoot, err := prepareRootPathsFS(fsys, destRoot)
+ if err != nil {
+ return "", err
+ }
+
+ baseAbs, err := filepath.Abs(filepath.Clean(baseDir))
+ if err != nil {
+ return "", fmt.Errorf("resolve base directory: %w", err)
+ }
+ normalizedBaseDir, err := normalizeAbsolutePathWithinRoot(lexicalRoot, canonicalRoot, baseAbs)
+ if err != nil {
+ return "", err
+ }
+
+ // Resolve symlinks already present in the base directory before validating
+ // a relative target against it. The kernel creates the new symlink under the
+ // resolved parent directory, not the lexical parent path.
+ canonicalBaseDir, err := resolvePathWithinPreparedRootFS(
+ fsys,
+ lexicalRoot,
+ canonicalRoot,
+ normalizedBaseDir,
+ true,
+ maxPathSecuritySymlinkHops,
+ )
+ if err != nil {
+ return "", err
+ }
+
+ var candidateAbs string
+ if filepath.IsAbs(candidate) {
+ candidateAbs, err = normalizeAbsolutePathWithinRoot(lexicalRoot, canonicalRoot, filepath.Clean(candidate))
+ if err != nil {
+ return "", err
+ }
+ } else {
+ candidateAbs = filepath.Clean(filepath.Join(canonicalBaseDir, candidate))
+ }
+
+ return resolvePathWithinPreparedRootFS(fsys, lexicalRoot, canonicalRoot, candidateAbs, true, maxPathSecuritySymlinkHops)
+}
+
+func prepareRootPathsFS(fsys FS, destRoot string) (string, string, error) {
+ fsys = pathSecurityFS(fsys)
+
+ cleanRoot := filepath.Clean(destRoot)
+ absRoot, err := filepath.Abs(cleanRoot)
+ if err != nil {
+ return "", "", fmt.Errorf("resolve destination root: %w", err)
+ }
+
+ canonicalRoot, err := resolvePathFromFilesystemRootFS(fsys, absRoot, true, maxPathSecuritySymlinkHops)
+ if err != nil {
+ return "", "", fmt.Errorf("resolve destination root: %w", err)
+ }
+
+ return absRoot, canonicalRoot, nil
+}
+
+func normalizeCandidateWithinRoot(lexicalRoot, canonicalRoot, candidate string) (string, error) {
+ if filepath.IsAbs(candidate) {
+ return normalizeAbsolutePathWithinRoot(lexicalRoot, canonicalRoot, filepath.Clean(candidate))
+ }
+
+ return filepath.Clean(filepath.Join(canonicalRoot, candidate)), nil
+}
+
+func normalizeAbsolutePathWithinRoot(lexicalRoot, canonicalRoot, candidateAbs string) (string, error) {
+ rel, ok, err := relativePathWithinRoot(lexicalRoot, candidateAbs)
+ if err != nil {
+ return "", fmt.Errorf("cannot compute relative path: %w", err)
+ }
+ if ok {
+ return filepath.Clean(filepath.Join(canonicalRoot, rel)), nil
+ }
+
+ rel, ok, err = relativePathWithinRoot(canonicalRoot, candidateAbs)
+ if err != nil {
+ return "", fmt.Errorf("cannot compute relative path: %w", err)
+ }
+ if ok {
+ return filepath.Clean(filepath.Join(canonicalRoot, rel)), nil
+ }
+
+ return "", newPathSecurityError("resolved path escapes destination: %s", candidateAbs)
+}
+
+func resolvePathFromFilesystemRootFS(fsys FS, candidateAbs string, allowMissingTail bool, hopsRemaining int) (string, error) {
+ root := string(os.PathSeparator)
+ return resolvePathWithinPreparedRootFS(fsys, root, root, candidateAbs, allowMissingTail, hopsRemaining)
+}
+
+func resolvePathWithinPreparedRootFS(fsys FS, lexicalRoot, canonicalRoot, candidateAbs string, allowMissingTail bool, hopsRemaining int) (string, error) {
+ fsys = pathSecurityFS(fsys)
+
+ lexicalRoot = filepath.Clean(lexicalRoot)
+ canonicalRoot = filepath.Clean(canonicalRoot)
+ candidateAbs = filepath.Clean(candidateAbs)
+
+ if !filepath.IsAbs(lexicalRoot) {
+ return "", fmt.Errorf("destination root must be absolute: %s", lexicalRoot)
+ }
+ if !filepath.IsAbs(canonicalRoot) {
+ return "", fmt.Errorf("destination root must be absolute: %s", canonicalRoot)
+ }
+ if !filepath.IsAbs(candidateAbs) {
+ return "", fmt.Errorf("candidate path must be absolute: %s", candidateAbs)
+ }
+
+ rel, ok, err := relativePathWithinRoot(canonicalRoot, candidateAbs)
+ if err != nil {
+ return "", fmt.Errorf("cannot compute relative path: %w", err)
+ }
+ if !ok {
+ return "", newPathSecurityError("resolved path escapes destination: %s", candidateAbs)
+ }
+ if rel == "." {
+ return canonicalRoot, nil
+ }
+
+ current := canonicalRoot
+ parts := splitRelativePath(rel)
+ for idx, part := range parts {
+ next := filepath.Join(current, part)
+ info, err := fsys.Lstat(next)
+ if err != nil {
+ if allowMissingTail && os.IsNotExist(err) {
+ return filepath.Clean(filepath.Join(current, filepath.Join(parts[idx:]...))), nil
+ }
+ return "", wrapPathOperationalError(err, "lstat %s", next)
+ }
+
+ if info.Mode()&os.ModeSymlink != 0 {
+ if hopsRemaining <= 0 {
+ return "", newPathSecurityError("too many symlink resolutions for %s", candidateAbs)
+ }
+ target, err := fsys.Readlink(next)
+ if err != nil {
+ return "", wrapPathOperationalError(err, "readlink %s", next)
+ }
+
+ var resolvedLink string
+ if filepath.IsAbs(target) {
+ resolvedLink, err = normalizeAbsolutePathWithinRoot(lexicalRoot, canonicalRoot, filepath.Clean(target))
+ if err != nil {
+ return "", err
+ }
+ } else {
+ resolvedLink = filepath.Join(current, target)
+ }
+ resolvedLink = filepath.Clean(resolvedLink)
+
+ if _, ok, err := relativePathWithinRoot(canonicalRoot, resolvedLink); err != nil {
+ return "", fmt.Errorf("cannot compute relative path: %w", err)
+ } else if !ok {
+ return "", newPathSecurityError("resolved path escapes destination: %s", resolvedLink)
+ }
+
+ remainder := filepath.Join(parts[idx+1:]...)
+ if remainder != "" {
+ resolvedLink = filepath.Join(resolvedLink, remainder)
+ }
+
+ return resolvePathWithinPreparedRootFS(fsys, lexicalRoot, canonicalRoot, resolvedLink, allowMissingTail, hopsRemaining-1)
+ }
+
+ if !info.IsDir() && idx < len(parts)-1 {
+ return "", newPathResolutionError(pathResolutionErrorOperational, nil, "path component is not a directory: %s", next)
+ }
+
+ current = next
+ }
+
+ return current, nil
+}
+
+func relativePathWithinRoot(root, candidate string) (string, bool, error) {
+ rel, err := filepath.Rel(root, candidate)
+ if err != nil {
+ return "", false, err
+ }
+ if rel == "." {
+ return rel, true, nil
+ }
+ if rel == ".." || strings.HasPrefix(rel, ".."+string(os.PathSeparator)) || filepath.IsAbs(rel) {
+ return rel, false, nil
+ }
+ return rel, true, nil
+}
+
+func splitRelativePath(rel string) []string {
+ if rel == "" || rel == "." {
+ return nil
+ }
+ return strings.Split(rel, string(os.PathSeparator))
+}
+
+func pathSecurityFS(fsys FS) FS {
+ if fsys == nil {
+ return osFS{}
+ }
+ return fsys
+}
diff --git a/internal/orchestrator/path_security_test.go b/internal/orchestrator/path_security_test.go
new file mode 100644
index 0000000..b3685b4
--- /dev/null
+++ b/internal/orchestrator/path_security_test.go
@@ -0,0 +1,184 @@
+package orchestrator
+
+import (
+ "os"
+ "path/filepath"
+ "strings"
+ "testing"
+)
+
+func TestResolvePathWithinRootFS_AllowsMissingTailAfterSafeSymlink(t *testing.T) {
+ root := t.TempDir()
+ if err := os.MkdirAll(filepath.Join(root, "safe-target"), 0o755); err != nil {
+ t.Fatalf("mkdir safe-target: %v", err)
+ }
+ if err := os.Symlink("safe-target", filepath.Join(root, "safe-link")); err != nil {
+ t.Fatalf("create safe symlink: %v", err)
+ }
+
+ resolved, err := resolvePathWithinRootFS(osFS{}, root, filepath.Join("safe-link", "missing", "file.txt"))
+ if err != nil {
+ t.Fatalf("resolvePathWithinRootFS returned error: %v", err)
+ }
+
+ want := filepath.Join(root, "safe-target", "missing", "file.txt")
+ if resolved != want {
+ t.Fatalf("resolved path = %q, want %q", resolved, want)
+ }
+}
+
+func TestResolvePathWithinRootFS_RejectsBrokenIntermediateSymlinkEscape(t *testing.T) {
+ root := t.TempDir()
+ outside := t.TempDir()
+ if err := os.Symlink(outside, filepath.Join(root, "escape-link")); err != nil {
+ t.Fatalf("create escape symlink: %v", err)
+ }
+
+ _, err := resolvePathWithinRootFS(osFS{}, root, filepath.Join("escape-link", "missing", "file.txt"))
+ if err == nil || !strings.Contains(err.Error(), "escapes destination") {
+ t.Fatalf("expected escape rejection, got %v", err)
+ }
+}
+
+func TestResolvePathWithinRootFS_RejectsSymlinkLoop(t *testing.T) {
+ root := t.TempDir()
+ if err := os.Symlink("loop", filepath.Join(root, "loop")); err != nil {
+ t.Fatalf("create loop symlink: %v", err)
+ }
+
+ _, err := resolvePathWithinRootFS(osFS{}, root, filepath.Join("loop", "file.txt"))
+ if err == nil || !strings.Contains(err.Error(), "too many symlink resolutions") {
+ t.Fatalf("expected symlink loop rejection, got %v", err)
+ }
+}
+
+func TestResolvePathWithinRootFS_ClassifiesPathComponentNotDirectoryAsOperational(t *testing.T) {
+ root := t.TempDir()
+ blocker := filepath.Join(root, "deep")
+ if err := os.WriteFile(blocker, []byte("block"), 0o644); err != nil {
+ t.Fatalf("write blocker: %v", err)
+ }
+
+ _, err := resolvePathWithinRootFS(osFS{}, root, filepath.Join("deep", "nested", "file.txt"))
+ if err == nil {
+ t.Fatal("expected error for non-directory path component")
+ }
+ if !isPathOperationalError(err) {
+ t.Fatalf("expected operational error classification, got %v", err)
+ }
+ if isPathSecurityError(err) {
+ t.Fatalf("expected non-security error classification, got %v", err)
+ }
+}
+
+func TestResolvePathWithinRootFS_ClassifiesPermissionDeniedAsOperational(t *testing.T) {
+ fsys := NewFakeFS()
+ root := filepath.Join(string(os.PathSeparator), "restore-root")
+ if err := fsys.AddDir(root); err != nil {
+ t.Fatalf("add root dir: %v", err)
+ }
+ if err := fsys.AddDir(filepath.Join(root, "subdir")); err != nil {
+ t.Fatalf("add subdir: %v", err)
+ }
+ fsys.StatErrors[filepath.Clean(filepath.Join(root, "subdir", "file.txt"))] = os.ErrPermission
+
+ _, err := resolvePathWithinRootFS(fsys, root, filepath.Join("subdir", "file.txt"))
+ if err == nil {
+ t.Fatal("expected permission error")
+ }
+ if !isPathOperationalError(err) {
+ t.Fatalf("expected operational error classification, got %v", err)
+ }
+ if isPathSecurityError(err) {
+ t.Fatalf("expected non-security error classification, got %v", err)
+ }
+}
+
+func TestResolvePathWithinRootFS_AllowsAbsoluteSymlinkTargetViaLexicalRoot(t *testing.T) {
+ parent := t.TempDir()
+ realRoot := filepath.Join(parent, "real-root")
+ linkRoot := filepath.Join(parent, "link-root")
+ if err := os.MkdirAll(filepath.Join(realRoot, "etc-real"), 0o755); err != nil {
+ t.Fatalf("mkdir real root: %v", err)
+ }
+ if err := os.Symlink(realRoot, linkRoot); err != nil {
+ t.Fatalf("create root symlink: %v", err)
+ }
+ if err := os.Symlink(filepath.Join(linkRoot, "etc-real"), filepath.Join(realRoot, "etc")); err != nil {
+ t.Fatalf("create nested symlink: %v", err)
+ }
+
+ resolved, err := resolvePathWithinRootFS(osFS{}, linkRoot, filepath.Join("etc", "missing", "file.txt"))
+ if err != nil {
+ t.Fatalf("resolvePathWithinRootFS returned error: %v", err)
+ }
+
+ want := filepath.Join(realRoot, "etc-real", "missing", "file.txt")
+ if resolved != want {
+ t.Fatalf("resolved path = %q, want %q", resolved, want)
+ }
+}
+
+func TestResolvePathRelativeToBaseWithinRootFS_RejectsEscapeViaSymlinkedBaseDir(t *testing.T) {
+ root := t.TempDir()
+ if err := os.Symlink(".", filepath.Join(root, "linkdir")); err != nil {
+ t.Fatalf("create base symlink: %v", err)
+ }
+
+ _, err := resolvePathRelativeToBaseWithinRootFS(osFS{}, root, filepath.Join(root, "linkdir"), "../outside")
+ if err == nil || !strings.Contains(err.Error(), "escapes destination") {
+ t.Fatalf("expected escape rejection, got %v", err)
+ }
+}
+
+func TestResolvePathRelativeToBaseWithinRootFS_AllowsRelativeTargetAfterSafeBaseResolution(t *testing.T) {
+ root := t.TempDir()
+ if err := os.MkdirAll(filepath.Join(root, "subdir"), 0o755); err != nil {
+ t.Fatalf("mkdir subdir: %v", err)
+ }
+ if err := os.Symlink("subdir", filepath.Join(root, "linkdir")); err != nil {
+ t.Fatalf("create base symlink: %v", err)
+ }
+
+ resolved, err := resolvePathRelativeToBaseWithinRootFS(osFS{}, root, filepath.Join(root, "linkdir"), "file.txt")
+ if err != nil {
+ t.Fatalf("resolvePathRelativeToBaseWithinRootFS returned error: %v", err)
+ }
+
+ want := filepath.Join(root, "subdir", "file.txt")
+ if resolved != want {
+ t.Fatalf("resolved path = %q, want %q", resolved, want)
+ }
+}
+
+func TestResolvePathRelativeToBaseWithinRootFS_RejectsBaseDirSymlinkOutsideRoot(t *testing.T) {
+ root := t.TempDir()
+ outside := t.TempDir()
+ if err := os.Symlink(outside, filepath.Join(root, "linkdir")); err != nil {
+ t.Fatalf("create base symlink: %v", err)
+ }
+
+ _, err := resolvePathRelativeToBaseWithinRootFS(osFS{}, root, filepath.Join(root, "linkdir"), "file.txt")
+ if err == nil || !strings.Contains(err.Error(), "escapes destination") {
+ t.Fatalf("expected escape rejection, got %v", err)
+ }
+}
+
+func TestResolvePathRelativeToBaseWithinRootFS_PreservesAbsoluteCandidateBehavior(t *testing.T) {
+ root := t.TempDir()
+ if err := os.MkdirAll(filepath.Join(root, "subdir"), 0o755); err != nil {
+ t.Fatalf("mkdir subdir: %v", err)
+ }
+ if err := os.Symlink("subdir", filepath.Join(root, "linkdir")); err != nil {
+ t.Fatalf("create base symlink: %v", err)
+ }
+
+ candidate := filepath.Join(root, "subdir", "file.txt")
+ resolved, err := resolvePathRelativeToBaseWithinRootFS(osFS{}, root, filepath.Join(root, "linkdir"), candidate)
+ if err != nil {
+ t.Fatalf("resolvePathRelativeToBaseWithinRootFS returned error: %v", err)
+ }
+ if resolved != candidate {
+ t.Fatalf("resolved path = %q, want %q", resolved, candidate)
+ }
+}
diff --git a/internal/orchestrator/pbs_staged_apply.go b/internal/orchestrator/pbs_staged_apply.go
index 265910c..2bd91a1 100644
--- a/internal/orchestrator/pbs_staged_apply.go
+++ b/internal/orchestrator/pbs_staged_apply.go
@@ -359,6 +359,8 @@ type pbsDatastoreInventoryRestoreLite struct {
Name string `json:"name"`
Path string `json:"path"`
Comment string `json:"comment"`
+ Origin string `json:"origin"`
+ CLIName string `json:"cli_name"`
} `json:"datastores"`
}
@@ -387,7 +389,14 @@ func loadPBSDatastoreCfgFromInventory(stageRoot string) (string, string, error)
// Fallback: generate a minimal datastore.cfg from the inventory's datastore list.
var out strings.Builder
for _, ds := range report.Datastores {
+ if strings.TrimSpace(ds.Origin) == "override" {
+ continue
+ }
+
name := strings.TrimSpace(ds.Name)
+ if name == "" {
+ name = strings.TrimSpace(ds.CLIName)
+ }
path := strings.TrimSpace(ds.Path)
if name == "" || path == "" {
continue
diff --git a/internal/orchestrator/pbs_staged_apply_additional_test.go b/internal/orchestrator/pbs_staged_apply_additional_test.go
index 80cb253..5c0d1d2 100644
--- a/internal/orchestrator/pbs_staged_apply_additional_test.go
+++ b/internal/orchestrator/pbs_staged_apply_additional_test.go
@@ -280,6 +280,40 @@ func TestLoadPBSDatastoreCfgFromInventory_FallsBackToDatastoreList(t *testing.T)
}
}
+func TestLoadPBSDatastoreCfgFromInventory_IgnoresOverrideEntries(t *testing.T) {
+ origFS := restoreFS
+ t.Cleanup(func() { restoreFS = origFS })
+
+ fakeFS := NewFakeFS()
+ t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) })
+ restoreFS = fakeFS
+
+ stageRoot := "/stage"
+ inventory := `{"datastores":[{"name":"DS1","path":"/mnt/ds1","comment":"primary","origin":"merged"},{"name":"DS1","path":"/mnt/scan-root","comment":"configured via PBS_DATASTORE_PATH","origin":"override"}]}`
+ if err := fakeFS.WriteFile(stageRoot+"/var/lib/proxsave-info/commands/pbs/pbs_datastore_inventory.json", []byte(inventory), 0o640); err != nil {
+ t.Fatalf("write inventory: %v", err)
+ }
+
+ content, src, err := loadPBSDatastoreCfgFromInventory(stageRoot)
+ if err != nil {
+ t.Fatalf("loadPBSDatastoreCfgFromInventory: %v", err)
+ }
+ if src != "pbs_datastore_inventory.json.datastores" {
+ t.Fatalf("src=%q", src)
+ }
+ if strings.Contains(content, "/mnt/scan-root") {
+ t.Fatalf("override entry should not be present in generated datastore.cfg: %s", content)
+ }
+
+ blocks, err := parsePBSDatastoreCfgBlocks(content)
+ if err != nil {
+ t.Fatalf("parsePBSDatastoreCfgBlocks: %v", err)
+ }
+ if len(blocks) != 1 || blocks[0].Name != "DS1" || blocks[0].Path != "/mnt/ds1" {
+ t.Fatalf("unexpected blocks: %+v", blocks)
+ }
+}
+
func TestLoadPBSDatastoreCfgFromInventory_PropagatesErrors(t *testing.T) {
origFS := restoreFS
t.Cleanup(func() { restoreFS = origFS })
diff --git a/internal/orchestrator/restore.go b/internal/orchestrator/restore.go
index 2f14d20..a390360 100644
--- a/internal/orchestrator/restore.go
+++ b/internal/orchestrator/restore.go
@@ -1448,6 +1448,10 @@ func runRestoreCommandStream(ctx context.Context, name string, stdin io.Reader,
}
func sanitizeRestoreEntryTarget(destRoot, entryName string) (string, string, error) {
+ return sanitizeRestoreEntryTargetWithFS(restoreFS, destRoot, entryName)
+}
+
+func sanitizeRestoreEntryTargetWithFS(fsys FS, destRoot, entryName string) (string, string, error) {
cleanDestRoot := filepath.Clean(destRoot)
if cleanDestRoot == "" {
cleanDestRoot = string(os.PathSeparator)
@@ -1490,23 +1494,13 @@ func sanitizeRestoreEntryTarget(destRoot, entryName string) (string, string, err
return "", "", fmt.Errorf("illegal path: %s", entryName)
}
- parentDir := filepath.Dir(absTarget)
- resolvedParentDir, err := filepath.EvalSymlinks(parentDir)
- if err != nil {
- // If the path doesn't exist yet, EvalSymlinks will fail; fallback to the cleaned path.
- resolvedParentDir = filepath.Clean(parentDir)
- }
- absResolvedParentDir, err := filepath.Abs(resolvedParentDir)
- if err != nil {
- return "", "", fmt.Errorf("resolve extraction parent directory: %w", err)
- }
-
- relParent, err := filepath.Rel(absDestRoot, absResolvedParentDir)
- if err != nil {
- return "", "", fmt.Errorf("illegal path: %s: %w", entryName, err)
- }
- if strings.HasPrefix(relParent, ".."+string(os.PathSeparator)) || relParent == ".." || filepath.IsAbs(relParent) {
- return "", "", fmt.Errorf("illegal path: %s", entryName)
+ if _, err := resolvePathWithinRootFS(fsys, absDestRoot, absTarget); err != nil {
+ if isPathSecurityError(err) {
+ return "", "", fmt.Errorf("illegal path: %s: %w", entryName, err)
+ }
+ if !isPathOperationalError(err) {
+ return "", "", fmt.Errorf("resolve extraction target: %w", err)
+ }
}
return absTarget, absDestRoot, nil
@@ -1537,7 +1531,7 @@ func shouldSkipProxmoxSystemRestore(relTarget string) (bool, string) {
// extractTarEntry extracts a single TAR entry, preserving all attributes including atime/ctime
func extractTarEntry(tarReader *tar.Reader, header *tar.Header, destRoot string, logger *logging.Logger) error {
- target, cleanDestRoot, err := sanitizeRestoreEntryTarget(destRoot, header.Name)
+ target, cleanDestRoot, err := sanitizeRestoreEntryTargetWithFS(restoreFS, destRoot, header.Name)
if err != nil {
return err
}
@@ -1647,26 +1641,8 @@ func extractRegularFile(tarReader *tar.Reader, target string, header *tar.Header
func extractSymlink(target string, header *tar.Header, destRoot string, logger *logging.Logger) error {
linkTarget := header.Linkname
- // Pre-validation: ensure the resolved target would stay within destRoot before creating the symlink
- relativeTarget, err := filepath.Rel(destRoot, target)
- if err != nil {
- return fmt.Errorf("determine relative path for symlink %s: %w", target, err)
- }
- if strings.HasPrefix(relativeTarget, ".."+string(os.PathSeparator)) || relativeTarget == ".." {
- return fmt.Errorf("sanitized symlink path escapes root: %s", target)
- }
-
- symlinkArchivePath := path.Clean(filepath.ToSlash(relativeTarget))
- symlinkArchiveDir := path.Dir(symlinkArchivePath)
- if symlinkArchiveDir == "." {
- symlinkArchiveDir = ""
- }
- potentialTarget := linkTarget
- if !filepath.IsAbs(linkTarget) {
- potentialTarget = filepath.FromSlash(path.Join(symlinkArchiveDir, linkTarget))
- }
-
- if _, err := resolveAndCheckPath(destRoot, potentialTarget); err != nil {
+ // Pre-validation: ensure the symlink target resolves within destRoot before creation.
+ if _, err := resolvePathRelativeToBaseWithinRootFS(restoreFS, destRoot, filepath.Dir(target), linkTarget); err != nil {
return fmt.Errorf("symlink target escapes root before creation: %s -> %s: %w", header.Name, linkTarget, err)
}
@@ -1685,32 +1661,9 @@ func extractSymlink(target string, header *tar.Header, destRoot string, logger *
return fmt.Errorf("read created symlink %s: %w", target, err)
}
- // Resolve the symlink target relative to the symlink's directory
- symlinkDir := filepath.Dir(target)
- resolvedTarget := actualTarget
- if !filepath.IsAbs(actualTarget) {
- resolvedTarget = filepath.Join(symlinkDir, actualTarget)
- }
-
- // Validate the resolved target stays within destRoot using absolute paths
- absDestRoot, err := filepath.Abs(destRoot)
- if err != nil {
- restoreFS.Remove(target)
- return fmt.Errorf("resolve destination root: %w", err)
- }
-
- absResolvedTarget, err := filepath.Abs(resolvedTarget)
- if err != nil {
- restoreFS.Remove(target)
- return fmt.Errorf("resolve symlink target: %w", err)
- }
-
- // Check if resolved target is within destRoot
- rel, err := filepath.Rel(absDestRoot, absResolvedTarget)
- if err != nil || strings.HasPrefix(rel, ".."+string(os.PathSeparator)) || rel == ".." {
+ if _, err := resolvePathRelativeToBaseWithinRootFS(restoreFS, destRoot, filepath.Dir(target), actualTarget); err != nil {
restoreFS.Remove(target)
- return fmt.Errorf("symlink target escapes root after creation: %s -> %s (resolves to %s)",
- header.Name, linkTarget, absResolvedTarget)
+ return fmt.Errorf("symlink target escapes root after creation: %s -> %s: %w", header.Name, actualTarget, err)
}
// Set ownership (on the symlink itself, not the target)
@@ -1733,7 +1686,7 @@ func extractHardlink(target string, header *tar.Header, destRoot string) error {
}
// Validate the hard link target stays within extraction root
- if _, err := resolveAndCheckPath(destRoot, linkName); err != nil {
+ if _, err := resolvePathWithinRootFS(restoreFS, destRoot, linkName); err != nil {
return fmt.Errorf("hardlink target escapes root: %s -> %s: %w", header.Name, linkName, err)
}
diff --git a/internal/orchestrator/restore_decision.go b/internal/orchestrator/restore_decision.go
new file mode 100644
index 0000000..d18d0ec
--- /dev/null
+++ b/internal/orchestrator/restore_decision.go
@@ -0,0 +1,301 @@
+package orchestrator
+
+import (
+ "archive/tar"
+ "bufio"
+ "bytes"
+ "context"
+ "fmt"
+ "io"
+ "path"
+ "strings"
+
+ "github.com/tis24dev/proxsave/internal/logging"
+)
+
+// RestoreDecisionSource describes where trusted restore facts were derived from.
+type RestoreDecisionSource string
+
+const (
+ RestoreDecisionSourceUnknown RestoreDecisionSource = "unknown"
+ RestoreDecisionSourceInternalMetadata RestoreDecisionSource = "internal_metadata"
+ RestoreDecisionSourceCategories RestoreDecisionSource = "categories"
+ RestoreDecisionSourceAmbiguous RestoreDecisionSource = "ambiguous"
+)
+
+// RestoreDecisionInfo contains the archive-derived facts used for restore decisions.
+type RestoreDecisionInfo struct {
+ BackupType SystemType
+ ClusterPayload bool
+ BackupHostname string
+ Source RestoreDecisionSource
+}
+
+type restoreDecisionMetadata struct {
+ BackupType SystemType
+ ClusterMode string
+ Hostname string
+}
+
+type restoreArchiveInspection struct {
+ AvailableCategories []Category
+ Decision *RestoreDecisionInfo
+}
+
+const (
+ restoreDecisionMetadataPath = "var/lib/proxsave-info/backup_metadata.txt"
+ restoreDecisionMetadataMaxBytes = 8 * 1024
+)
+
+// AnalyzeRestoreArchive inspects the archive once and derives trusted restore facts
+// from archive contents plus internal backup metadata when present.
+func AnalyzeRestoreArchive(archivePath string, logger *logging.Logger) (categories []Category, decision *RestoreDecisionInfo, err error) {
+ if logger == nil {
+ logger = logging.GetDefaultLogger()
+ }
+
+ done := logging.DebugStart(logger, "analyze restore archive", "archive=%s", archivePath)
+ defer func() { done(err) }()
+ logger.Info("Analyzing backup contents...")
+
+ inspection, err := inspectRestoreArchiveContents(archivePath, logger)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ for _, cat := range inspection.AvailableCategories {
+ logger.Debug("Category available: %s (%s)", cat.ID, cat.Name)
+ }
+ logger.Info("Detected %d available categories", len(inspection.AvailableCategories))
+ if inspection.Decision != nil {
+ logger.Debug(
+ "Restore decision facts: backup_type=%s cluster_payload=%v hostname=%q source=%s",
+ inspection.Decision.BackupType,
+ inspection.Decision.ClusterPayload,
+ inspection.Decision.BackupHostname,
+ inspection.Decision.Source,
+ )
+ }
+
+ return inspection.AvailableCategories, inspection.Decision, nil
+}
+
+func inspectRestoreArchiveContents(archivePath string, logger *logging.Logger) (*restoreArchiveInspection, error) {
+ file, err := restoreFS.Open(archivePath)
+ if err != nil {
+ return nil, fmt.Errorf("open archive: %w", err)
+ }
+ defer file.Close()
+
+ reader, err := createDecompressionReader(context.Background(), file, archivePath)
+ if err != nil {
+ return nil, err
+ }
+ defer func() {
+ if closer, ok := reader.(interface{ Close() error }); ok {
+ closer.Close()
+ }
+ }()
+
+ tarReader := tar.NewReader(reader)
+ archivePaths, metadata, metadataErr, err := collectRestoreArchiveFacts(tarReader)
+ if err != nil {
+ return nil, fmt.Errorf("inspect archive: %w", err)
+ }
+ if metadataErr != nil {
+ logger.Warning("Could not parse internal backup metadata: %v", metadataErr)
+ }
+
+ logger.Debug("Found %d entries in archive", len(archivePaths))
+ availableCategories := AnalyzeArchivePaths(archivePaths, GetAllCategories())
+
+ decision := buildRestoreDecisionInfo(metadata, availableCategories, logger)
+ return &restoreArchiveInspection{
+ AvailableCategories: availableCategories,
+ Decision: decision,
+ }, nil
+}
+
+func collectRestoreArchiveFacts(tarReader *tar.Reader) ([]string, *restoreDecisionMetadata, error, error) {
+ var (
+ archivePaths []string
+ metadata *restoreDecisionMetadata
+ metadataErr error
+ )
+
+ for {
+ header, err := tarReader.Next()
+ if err == io.EOF {
+ break
+ }
+ if err != nil {
+ return nil, nil, nil, err
+ }
+
+ archivePaths = append(archivePaths, header.Name)
+ if metadata != nil || header.FileInfo().IsDir() {
+ continue
+ }
+ if !isRestoreDecisionMetadataEntry(header.Name) {
+ continue
+ }
+
+ data, readErr := readRestoreDecisionMetadata(tarReader, header)
+ if readErr != nil {
+ metadataErr = readErr
+ continue
+ }
+ parsed, parseErr := parseRestoreDecisionMetadata(data)
+ if parseErr != nil {
+ metadataErr = parseErr
+ continue
+ }
+ metadata = parsed
+ }
+
+ return archivePaths, metadata, metadataErr, nil
+}
+
+func readRestoreDecisionMetadata(tarReader *tar.Reader, header *tar.Header) ([]byte, error) {
+ if header == nil {
+ return nil, fmt.Errorf("restore metadata entry is missing a tar header")
+ }
+ if header.Typeflag != tar.TypeReg && header.Typeflag != tar.TypeRegA {
+ return nil, fmt.Errorf("archive entry %s is not a regular file", header.Name)
+ }
+
+ limited := io.LimitReader(tarReader, restoreDecisionMetadataMaxBytes+1)
+ data, err := io.ReadAll(limited)
+ if err != nil {
+ return nil, err
+ }
+ if int64(len(data)) > restoreDecisionMetadataMaxBytes {
+ size := header.Size
+ if size <= 0 {
+ size = int64(len(data))
+ }
+ return nil, fmt.Errorf("archive entry %s too large (%d bytes)", header.Name, size)
+ }
+ return data, nil
+}
+
+func isRestoreDecisionMetadataEntry(entryName string) bool {
+ return normalizeRestoreEntryPath(entryName) == restoreDecisionMetadataPath
+}
+
+func normalizeRestoreEntryPath(entryName string) string {
+ clean := strings.TrimSpace(strings.ReplaceAll(entryName, "\\", "/"))
+ if clean == "" {
+ return ""
+ }
+
+ clean = path.Clean(clean)
+ clean = strings.TrimPrefix(clean, "./")
+ clean = strings.TrimPrefix(clean, "/")
+ if clean == "." {
+ return ""
+ }
+ return clean
+}
+
+func parseRestoreDecisionMetadata(data []byte) (*restoreDecisionMetadata, error) {
+ meta := &restoreDecisionMetadata{}
+ scanner := bufio.NewScanner(bytes.NewReader(data))
+ for scanner.Scan() {
+ line := strings.TrimSpace(scanner.Text())
+ if line == "" || strings.HasPrefix(line, "#") {
+ continue
+ }
+
+ parts := strings.SplitN(line, "=", 2)
+ if len(parts) != 2 {
+ continue
+ }
+
+ key := strings.TrimSpace(parts[0])
+ value := strings.TrimSpace(parts[1])
+ switch key {
+ case "BACKUP_TYPE":
+ meta.BackupType = parseSystemTypeString(value)
+ case "PVE_CLUSTER_MODE", "CLUSTER_MODE":
+ meta.ClusterMode = value
+ case "HOSTNAME":
+ meta.Hostname = value
+ }
+ }
+ if err := scanner.Err(); err != nil {
+ return nil, err
+ }
+ return meta, nil
+}
+
+func buildRestoreDecisionInfo(metadata *restoreDecisionMetadata, categories []Category, logger *logging.Logger) *RestoreDecisionInfo {
+ if logger == nil {
+ logger = logging.GetDefaultLogger()
+ }
+
+ info := &RestoreDecisionInfo{
+ BackupType: SystemTypeUnknown,
+ ClusterPayload: hasCategoryID(categories, "pve_cluster"),
+ Source: RestoreDecisionSourceUnknown,
+ }
+
+ if metadata != nil {
+ info.BackupHostname = strings.TrimSpace(metadata.Hostname)
+ }
+
+ categoryType, ambiguousType := detectBackupTypeFromCategories(categories)
+ switch {
+ case ambiguousType:
+ logger.Warning("Archive contains both PVE and PBS-specific payloads; treating backup type as unknown for compatibility checks")
+ info.Source = RestoreDecisionSourceAmbiguous
+ case metadata != nil && metadata.BackupType != SystemTypeUnknown && categoryType == SystemTypeUnknown:
+ info.BackupType = metadata.BackupType
+ info.Source = RestoreDecisionSourceInternalMetadata
+ case metadata != nil && metadata.BackupType != SystemTypeUnknown && categoryType != SystemTypeUnknown && metadata.BackupType != categoryType:
+ logger.Warning("Internal backup metadata and archive payload disagree on backup type; using archive-derived type %s", strings.ToUpper(string(categoryType)))
+ info.BackupType = categoryType
+ info.Source = RestoreDecisionSourceCategories
+ case categoryType != SystemTypeUnknown:
+ info.BackupType = categoryType
+ info.Source = RestoreDecisionSourceCategories
+ case metadata != nil && metadata.BackupType != SystemTypeUnknown:
+ info.BackupType = metadata.BackupType
+ info.Source = RestoreDecisionSourceInternalMetadata
+ }
+
+ if metadata != nil {
+ metadataCluster := strings.EqualFold(strings.TrimSpace(metadata.ClusterMode), "cluster")
+ switch {
+ case metadataCluster && !info.ClusterPayload:
+ logger.Warning("Internal backup metadata reports cluster mode, but no pve_cluster payload was found; guarded cluster restore remains disabled")
+ case !metadataCluster && info.ClusterPayload:
+ logger.Warning("Cluster payload detected in archive despite metadata reporting non-cluster backup; guarded cluster restore remains enabled")
+ }
+ }
+
+ return info
+}
+
+func detectBackupTypeFromCategories(categories []Category) (SystemType, bool) {
+ var hasPVE, hasPBS bool
+ for _, cat := range categories {
+ switch cat.Type {
+ case CategoryTypePVE:
+ hasPVE = true
+ case CategoryTypePBS:
+ hasPBS = true
+ }
+ }
+
+ switch {
+ case hasPVE && hasPBS:
+ return SystemTypeUnknown, true
+ case hasPVE:
+ return SystemTypePVE, false
+ case hasPBS:
+ return SystemTypePBS, false
+ default:
+ return SystemTypeUnknown, false
+ }
+}
diff --git a/internal/orchestrator/restore_decision_test.go b/internal/orchestrator/restore_decision_test.go
new file mode 100644
index 0000000..77a4f29
--- /dev/null
+++ b/internal/orchestrator/restore_decision_test.go
@@ -0,0 +1,156 @@
+package orchestrator
+
+import (
+ "archive/tar"
+ "os"
+ "path/filepath"
+ "strings"
+ "testing"
+
+ "github.com/tis24dev/proxsave/internal/logging"
+)
+
+func TestAnalyzeRestoreArchive_UsesInternalMetadataWhenCategoriesAreCommonOnly(t *testing.T) {
+ origRestoreFS := restoreFS
+ t.Cleanup(func() { restoreFS = origRestoreFS })
+ restoreFS = osFS{}
+
+ archivePath := filepath.Join(t.TempDir(), "backup.tar")
+ if err := writeTarFile(archivePath, map[string]string{
+ "etc/hosts": "127.0.0.1 localhost\n",
+ "var/lib/proxsave-info/backup_metadata.txt": "# ProxSave Metadata\nBACKUP_TYPE=pbs\nHOSTNAME=pbs-node\nPVE_CLUSTER_MODE=cluster\n",
+ }); err != nil {
+ t.Fatalf("writeTarFile: %v", err)
+ }
+
+ logger := logging.New(logging.GetDefaultLogger().GetLevel(), false)
+ categories, decision, err := AnalyzeRestoreArchive(archivePath, logger)
+ if err != nil {
+ t.Fatalf("AnalyzeRestoreArchive() error: %v", err)
+ }
+ if backupType, ambiguous := detectBackupTypeFromCategories(categories); backupType != SystemTypeUnknown || ambiguous {
+ t.Fatalf("detectBackupTypeFromCategories() = (%s, %v); want (%s, false)", backupType, ambiguous, SystemTypeUnknown)
+ }
+ if decision == nil {
+ t.Fatalf("decision info is nil")
+ }
+ if decision.BackupType != SystemTypePBS {
+ t.Fatalf("BackupType=%s; want %s", decision.BackupType, SystemTypePBS)
+ }
+ if decision.Source != RestoreDecisionSourceInternalMetadata {
+ t.Fatalf("Source=%s; want %s", decision.Source, RestoreDecisionSourceInternalMetadata)
+ }
+ if decision.BackupHostname != "pbs-node" {
+ t.Fatalf("BackupHostname=%q; want %q", decision.BackupHostname, "pbs-node")
+ }
+ if decision.ClusterPayload {
+ t.Fatalf("ClusterPayload should stay false without pve_cluster payload")
+ }
+}
+
+func TestAnalyzeRestoreArchive_ClusterPayloadUsesArchiveContents(t *testing.T) {
+ origRestoreFS := restoreFS
+ t.Cleanup(func() { restoreFS = origRestoreFS })
+ restoreFS = osFS{}
+
+ archivePath := filepath.Join(t.TempDir(), "backup.tar")
+ if err := writeTarFile(archivePath, map[string]string{
+ "var/lib/pve-cluster/config.db": "db\n",
+ "var/lib/proxsave-info/backup_metadata.txt": "BACKUP_TYPE=pve\nPVE_CLUSTER_MODE=standalone\nHOSTNAME=node1\n",
+ }); err != nil {
+ t.Fatalf("writeTarFile: %v", err)
+ }
+
+ logger := logging.New(logging.GetDefaultLogger().GetLevel(), false)
+ _, decision, err := AnalyzeRestoreArchive(archivePath, logger)
+ if err != nil {
+ t.Fatalf("AnalyzeRestoreArchive() error: %v", err)
+ }
+ if decision == nil {
+ t.Fatalf("decision info is nil")
+ }
+ if !decision.ClusterPayload {
+ t.Fatalf("ClusterPayload should be true when pve_cluster payload exists")
+ }
+ if decision.BackupType != SystemTypePVE {
+ t.Fatalf("BackupType=%s; want %s", decision.BackupType, SystemTypePVE)
+ }
+}
+
+func TestCollectRestoreArchiveFacts_RejectsOversizedMetadata(t *testing.T) {
+ archivePath := filepath.Join(t.TempDir(), "backup.tar")
+ oversized := "BACKUP_TYPE=pbs\nHOSTNAME=pbs-node\n" + strings.Repeat("A", restoreDecisionMetadataMaxBytes)
+ if err := writeTarFile(archivePath, map[string]string{
+ "var/lib/proxsave-info/backup_metadata.txt": oversized,
+ "var/lib/pve-cluster/config.db": "db\n",
+ }); err != nil {
+ t.Fatalf("writeTarFile: %v", err)
+ }
+
+ file, err := os.Open(archivePath)
+ if err != nil {
+ t.Fatalf("os.Open: %v", err)
+ }
+ defer file.Close()
+
+ archivePaths, metadata, metadataErr, err := collectRestoreArchiveFacts(tar.NewReader(file))
+ if err != nil {
+ t.Fatalf("collectRestoreArchiveFacts() error: %v", err)
+ }
+ if metadata != nil {
+ t.Fatalf("metadata = %#v; want nil for oversized entry", metadata)
+ }
+ if metadataErr == nil {
+ t.Fatalf("metadataErr = nil; want oversize error")
+ }
+ if !strings.Contains(metadataErr.Error(), "too large") {
+ t.Fatalf("metadataErr = %v; want oversize error", metadataErr)
+ }
+
+ foundMeta := false
+ foundCluster := false
+ for _, archivePath := range archivePaths {
+ if archivePath == restoreDecisionMetadataPath {
+ foundMeta = true
+ }
+ if archivePath == "var/lib/pve-cluster/config.db" {
+ foundCluster = true
+ }
+ }
+ if !foundMeta || !foundCluster {
+ t.Fatalf("archivePaths = %#v; want metadata and cluster entries present", archivePaths)
+ }
+}
+
+func TestAnalyzeRestoreArchive_IgnoresOversizedInternalMetadata(t *testing.T) {
+ origRestoreFS := restoreFS
+ t.Cleanup(func() { restoreFS = origRestoreFS })
+ restoreFS = osFS{}
+
+ archivePath := filepath.Join(t.TempDir(), "backup.tar")
+ oversized := "BACKUP_TYPE=pbs\nHOSTNAME=pbs-node\n" + strings.Repeat("A", restoreDecisionMetadataMaxBytes)
+ if err := writeTarFile(archivePath, map[string]string{
+ "etc/hosts": "127.0.0.1 localhost\n",
+ "var/lib/proxsave-info/backup_metadata.txt": oversized,
+ }); err != nil {
+ t.Fatalf("writeTarFile: %v", err)
+ }
+
+ logger := logging.New(logging.GetDefaultLogger().GetLevel(), false)
+ _, decision, err := AnalyzeRestoreArchive(archivePath, logger)
+ if err != nil {
+ t.Fatalf("AnalyzeRestoreArchive() error: %v", err)
+ }
+ if decision == nil {
+ t.Fatalf("decision info is nil")
+ }
+ if decision.BackupType != SystemTypeUnknown {
+ t.Fatalf("BackupType=%s; want %s when metadata is oversized", decision.BackupType, SystemTypeUnknown)
+ }
+ if decision.Source != RestoreDecisionSourceUnknown {
+ t.Fatalf("Source=%s; want %s when metadata is oversized", decision.Source, RestoreDecisionSourceUnknown)
+ }
+ if decision.BackupHostname != "" {
+ t.Fatalf("BackupHostname=%q; want empty string when metadata is oversized", decision.BackupHostname)
+ }
+}
diff --git a/internal/orchestrator/restore_errors_test.go b/internal/orchestrator/restore_errors_test.go
index cad33a8..0faf024 100644
--- a/internal/orchestrator/restore_errors_test.go
+++ b/internal/orchestrator/restore_errors_test.go
@@ -798,10 +798,11 @@ type ErrorInjectingFS struct {
linkErr error
}
-func (f *ErrorInjectingFS) Stat(path string) (os.FileInfo, error) { return f.base.Stat(path) }
-func (f *ErrorInjectingFS) ReadFile(path string) ([]byte, error) { return f.base.ReadFile(path) }
-func (f *ErrorInjectingFS) Open(path string) (*os.File, error) { return f.base.Open(path) }
-func (f *ErrorInjectingFS) Create(name string) (*os.File, error) { return f.base.Create(name) }
+func (f *ErrorInjectingFS) Stat(path string) (os.FileInfo, error) { return f.base.Stat(path) }
+func (f *ErrorInjectingFS) Lstat(path string) (os.FileInfo, error) { return f.base.Lstat(path) }
+func (f *ErrorInjectingFS) ReadFile(path string) ([]byte, error) { return f.base.ReadFile(path) }
+func (f *ErrorInjectingFS) Open(path string) (*os.File, error) { return f.base.Open(path) }
+func (f *ErrorInjectingFS) Create(name string) (*os.File, error) { return f.base.Create(name) }
func (f *ErrorInjectingFS) WriteFile(path string, data []byte, perm os.FileMode) error {
return f.base.WriteFile(path, data, perm)
}
diff --git a/internal/orchestrator/restore_firewall.go b/internal/orchestrator/restore_firewall.go
index eeb5daa..e14a599 100644
--- a/internal/orchestrator/restore_firewall.go
+++ b/internal/orchestrator/restore_firewall.go
@@ -673,12 +673,16 @@ func syncDirExact(srcDir, destDir string) ([]string, error) {
if err != nil {
return fmt.Errorf("readlink %s: %w", src, err)
}
+ validatedTarget, err := validateOverlaySymlinkTargetWithinRoot(destDir, dest, target)
+ if err != nil {
+ return fmt.Errorf("unsafe symlink target %s -> %s: %w", dest, target, err)
+ }
if err := ensureDirExistsWithInheritedMeta(filepath.Dir(dest)); err != nil {
return fmt.Errorf("ensure %s: %w", filepath.Dir(dest), err)
}
_ = restoreFS.Remove(dest)
- if err := restoreFS.Symlink(target, dest); err != nil {
- return fmt.Errorf("symlink %s -> %s: %w", dest, target, err)
+ if err := restoreFS.Symlink(validatedTarget, dest); err != nil {
+ return fmt.Errorf("symlink %s -> %s: %w", dest, validatedTarget, err)
}
applied = append(applied, dest)
continue
diff --git a/internal/orchestrator/restore_firewall_additional_test.go b/internal/orchestrator/restore_firewall_additional_test.go
index 525c269..cce281e 100644
--- a/internal/orchestrator/restore_firewall_additional_test.go
+++ b/internal/orchestrator/restore_firewall_additional_test.go
@@ -372,14 +372,14 @@ func TestSyncDirExact_CopiesSymlinks(t *testing.T) {
origFS := restoreFS
t.Cleanup(func() { restoreFS = origFS })
- fakeFS := NewFakeFS()
+ fakeFS := newPreservingSymlinkFS()
t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) })
restoreFS = fakeFS
- if err := fakeFS.AddFile("/stage/target", []byte("x")); err != nil {
- t.Fatalf("add target: %v", err)
+ if err := fakeFS.AddDir("/stage"); err != nil {
+ t.Fatalf("add stage dir: %v", err)
}
- if err := fakeFS.Symlink("/stage/target", "/stage/link"); err != nil {
+ if err := fakeFS.Symlink("/dest/target", "/stage/link"); err != nil {
t.Fatalf("add symlink: %v", err)
}
@@ -392,8 +392,8 @@ func TestSyncDirExact_CopiesSymlinks(t *testing.T) {
if err != nil {
t.Fatalf("read dest symlink: %v", err)
}
- if strings.TrimSpace(destTarget) == "" {
- t.Fatalf("expected non-empty symlink target")
+ if destTarget != "target" {
+ t.Fatalf("dest symlink target=%q want %q", destTarget, "target")
}
found := false
for _, p := range applied {
@@ -407,6 +407,46 @@ func TestSyncDirExact_CopiesSymlinks(t *testing.T) {
}
}
+func TestSyncDirExact_RejectsEscapingSymlinkTargets(t *testing.T) {
+ origFS := restoreFS
+ t.Cleanup(func() { restoreFS = origFS })
+
+ fakeFS := newPreservingSymlinkFS()
+ t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) })
+ restoreFS = fakeFS
+
+ if err := fakeFS.AddDir("/stage"); err != nil {
+ t.Fatalf("add stage dir: %v", err)
+ }
+ if err := fakeFS.Symlink("/stage/target", "/stage/link"); err != nil {
+ t.Fatalf("add symlink: %v", err)
+ }
+
+ if _, err := syncDirExact("/stage", "/dest"); err == nil || !strings.Contains(err.Error(), "unsafe symlink target") {
+ t.Fatalf("expected unsafe symlink target error, got %v", err)
+ }
+}
+
+func TestSyncDirExact_RejectsRelativeSymlinkTargetsThatEscapeDestinationRoot(t *testing.T) {
+ origFS := restoreFS
+ t.Cleanup(func() { restoreFS = origFS })
+
+ fakeFS := newPreservingSymlinkFS()
+ t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) })
+ restoreFS = fakeFS
+
+ if err := fakeFS.AddDir("/stage/sub"); err != nil {
+ t.Fatalf("add stage dir: %v", err)
+ }
+ if err := fakeFS.Symlink("../../outside", "/stage/sub/link"); err != nil {
+ t.Fatalf("add symlink: %v", err)
+ }
+
+ if _, err := syncDirExact("/stage", "/dest"); err == nil || !strings.Contains(err.Error(), "unsafe symlink target") {
+ t.Fatalf("expected unsafe symlink target error, got %v", err)
+ }
+}
+
func TestSelectStageHostFirewall_ErrorsOnReadDirFailure(t *testing.T) {
origFS := restoreFS
t.Cleanup(func() { restoreFS = origFS })
@@ -1975,13 +2015,13 @@ func TestSyncDirExact_AdditionalEdgeCases(t *testing.T) {
})
t.Run("symlink parent ensure error bubbles", func(t *testing.T) {
- fakeFS := NewFakeFS()
+ fakeFS := newPreservingSymlinkFS()
t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) })
if err := fakeFS.AddDir("/stage"); err != nil {
t.Fatalf("add stage dir: %v", err)
}
- if err := fakeFS.Symlink("/stage/target", "/stage/link"); err != nil {
+ if err := fakeFS.Symlink("target", "/stage/link"); err != nil {
t.Fatalf("add stage symlink: %v", err)
}
if err := fakeFS.AddDir("/dest"); err != nil {
@@ -1995,13 +2035,13 @@ func TestSyncDirExact_AdditionalEdgeCases(t *testing.T) {
})
t.Run("symlink creation error bubbles", func(t *testing.T) {
- fakeFS := NewFakeFS()
+ fakeFS := newPreservingSymlinkFS()
t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) })
if err := fakeFS.AddDir("/stage"); err != nil {
t.Fatalf("add stage dir: %v", err)
}
- if err := fakeFS.Symlink("/stage/target", "/stage/link"); err != nil {
+ if err := fakeFS.Symlink("target", "/stage/link"); err != nil {
t.Fatalf("add stage symlink: %v", err)
}
diff --git a/internal/orchestrator/restore_plan.go b/internal/orchestrator/restore_plan.go
index 6d54716..6eba614 100644
--- a/internal/orchestrator/restore_plan.go
+++ b/internal/orchestrator/restore_plan.go
@@ -1,11 +1,5 @@
package orchestrator
-import (
- "strings"
-
- "github.com/tis24dev/proxsave/internal/backup"
-)
-
// RestorePlan contains a pure, side-effect-free description of a restore run.
type RestorePlan struct {
Mode RestoreMode
@@ -22,7 +16,7 @@ type RestorePlan struct {
// PlanRestore computes the restore plan without performing any I/O or prompts.
func PlanRestore(
- manifest *backup.Manifest,
+ clusterBackup bool,
selectedCategories []Category,
systemType SystemType,
mode RestoreMode,
@@ -35,7 +29,7 @@ func PlanRestore(
NormalCategories: normal,
StagedCategories: staged,
ExportCategories: export,
- ClusterBackup: manifest != nil && strings.EqualFold(strings.TrimSpace(manifest.ClusterMode), "cluster"),
+ ClusterBackup: clusterBackup,
}
plan.NeedsClusterRestore = systemType == SystemTypePVE && hasCategoryID(normal, "pve_cluster")
diff --git a/internal/orchestrator/restore_plan_test.go b/internal/orchestrator/restore_plan_test.go
index c38b562..f27ba99 100644
--- a/internal/orchestrator/restore_plan_test.go
+++ b/internal/orchestrator/restore_plan_test.go
@@ -2,16 +2,13 @@ package orchestrator
import (
"testing"
-
- "github.com/tis24dev/proxsave/internal/backup"
)
func TestPlanRestoreClusterSafeToggle(t *testing.T) {
clusterCat := Category{ID: "pve_cluster", Type: CategoryTypePVE}
storageCat := Category{ID: "storage_pve", Type: CategoryTypePVE}
- manifest := &backup.Manifest{ClusterMode: "cluster"}
- plan := PlanRestore(manifest, []Category{clusterCat, storageCat}, SystemTypePVE, RestoreModeCustom)
+ plan := PlanRestore(true, []Category{clusterCat, storageCat}, SystemTypePVE, RestoreModeCustom)
if !plan.NeedsClusterRestore {
t.Fatalf("expected NeedsClusterRestore true")
@@ -50,7 +47,7 @@ func TestPlanRestorePBSCategories(t *testing.T) {
pbsCat := Category{ID: "pbs_config", Type: CategoryTypePBS, ExportOnly: true}
normalCat := Category{ID: "network", Type: CategoryTypeCommon}
- plan := PlanRestore(nil, []Category{pbsCat, normalCat}, SystemTypePBS, RestoreModeCustom)
+ plan := PlanRestore(false, []Category{pbsCat, normalCat}, SystemTypePBS, RestoreModeCustom)
if len(plan.ExportCategories) != 1 || !hasCategoryID(plan.ExportCategories, "pbs_config") {
t.Fatalf("expected pbs_config to be exported, got %+v", plan.ExportCategories)
}
@@ -66,7 +63,7 @@ func TestPlanRestoreKeepsExportCategoriesFromFullSelection(t *testing.T) {
exportCat := Category{ID: "pve_config_export", ExportOnly: true}
normalCat := Category{ID: "network"}
- plan := PlanRestore(nil, []Category{normalCat, exportCat}, SystemTypePVE, RestoreModeFull)
+ plan := PlanRestore(false, []Category{normalCat, exportCat}, SystemTypePVE, RestoreModeFull)
if len(plan.StagedCategories) != 1 || plan.StagedCategories[0].ID != "network" {
t.Fatalf("expected staged categories to keep network, got %+v", plan.StagedCategories)
}
diff --git a/internal/orchestrator/restore_test.go b/internal/orchestrator/restore_test.go
index 80a1f10..d65e449 100644
--- a/internal/orchestrator/restore_test.go
+++ b/internal/orchestrator/restore_test.go
@@ -334,6 +334,155 @@ func TestExtractSymlink_SecurityValidation(t *testing.T) {
}
}
+func TestExtractTarEntry_RejectsBrokenIntermediateSymlinkEscape(t *testing.T) {
+ logger := logging.New(types.LogLevelDebug, false)
+ orig := restoreFS
+ restoreFS = osFS{}
+ t.Cleanup(func() { restoreFS = orig })
+
+ destRoot := t.TempDir()
+ outside := t.TempDir()
+ if err := os.Symlink(outside, filepath.Join(destRoot, "escape-link")); err != nil {
+ t.Fatalf("create escape symlink: %v", err)
+ }
+
+ header := &tar.Header{
+ Name: "escape-link/missing",
+ Typeflag: tar.TypeDir,
+ Mode: 0o755,
+ }
+
+ var tr *tar.Reader
+ err := extractTarEntry(tr, header, destRoot, logger)
+ if err == nil || !strings.Contains(err.Error(), "illegal path") {
+ t.Fatalf("expected illegal path error, got %v", err)
+ }
+
+ if _, err := os.Stat(filepath.Join(outside, "missing")); !os.IsNotExist(err) {
+ t.Fatalf("outside path should not be created, got err=%v", err)
+ }
+}
+
+func TestSanitizeRestoreEntryTargetWithFS_AllowsOperationalResolverErrorsWithinRoot(t *testing.T) {
+ fsys := NewFakeFS()
+ destRoot := filepath.Join(string(os.PathSeparator), "restore-root")
+ if err := fsys.AddDir(destRoot); err != nil {
+ t.Fatalf("add root dir: %v", err)
+ }
+ if err := fsys.AddDir(filepath.Join(destRoot, "subdir")); err != nil {
+ t.Fatalf("add subdir: %v", err)
+ }
+ fsys.StatErrors[filepath.Clean(filepath.Join(destRoot, "subdir", "file.txt"))] = os.ErrPermission
+
+ target, cleanRoot, err := sanitizeRestoreEntryTargetWithFS(fsys, destRoot, filepath.Join("subdir", "file.txt"))
+ if err != nil {
+ t.Fatalf("sanitizeRestoreEntryTargetWithFS returned error: %v", err)
+ }
+
+ wantTarget := filepath.Join(destRoot, "subdir", "file.txt")
+ if target != wantTarget {
+ t.Fatalf("target = %q, want %q", target, wantTarget)
+ }
+ if cleanRoot != filepath.Clean(destRoot) {
+ t.Fatalf("cleanRoot = %q, want %q", cleanRoot, filepath.Clean(destRoot))
+ }
+}
+
+func TestExtractSymlink_RejectsBrokenIntermediateSymlinkEscape(t *testing.T) {
+ logger := logging.New(types.LogLevelDebug, false)
+ orig := restoreFS
+ restoreFS = osFS{}
+ t.Cleanup(func() { restoreFS = orig })
+
+ destRoot := t.TempDir()
+ outside := t.TempDir()
+ if err := os.Symlink(outside, filepath.Join(destRoot, "escape-link")); err != nil {
+ t.Fatalf("create escape symlink: %v", err)
+ }
+
+ header := &tar.Header{
+ Name: "link",
+ Typeflag: tar.TypeSymlink,
+ Linkname: "escape-link/missing/file.txt",
+ }
+ target := filepath.Join(destRoot, header.Name)
+
+ err := extractSymlink(target, header, destRoot, logger)
+ if err == nil || !strings.Contains(err.Error(), "escapes root") {
+ t.Fatalf("expected escapes root error, got %v", err)
+ }
+
+ if _, err := os.Lstat(target); !os.IsNotExist(err) {
+ t.Fatalf("symlink should not be created, got err=%v", err)
+ }
+}
+
+func TestExtractSymlink_RejectsEscapeWhenParentPathIsSymlink(t *testing.T) {
+ logger := logging.New(types.LogLevelDebug, false)
+ orig := restoreFS
+ restoreFS = osFS{}
+ t.Cleanup(func() { restoreFS = orig })
+
+ destRoot := t.TempDir()
+ if err := os.Symlink(".", filepath.Join(destRoot, "linkdir")); err != nil {
+ t.Fatalf("create parent symlink: %v", err)
+ }
+
+ header := &tar.Header{
+ Name: "linkdir/escape",
+ Typeflag: tar.TypeSymlink,
+ Linkname: "../outside",
+ }
+ target := filepath.Join(destRoot, header.Name)
+
+ err := extractSymlink(target, header, destRoot, logger)
+ if err == nil || !strings.Contains(err.Error(), "escapes root") {
+ t.Fatalf("expected escapes root error, got %v", err)
+ }
+
+ if _, err := os.Lstat(filepath.Join(destRoot, "escape")); !os.IsNotExist(err) {
+ t.Fatalf("escaping symlink should not be created, got err=%v", err)
+ }
+}
+
+func TestExtractSymlink_AllowsSafeTargetWhenParentPathIsSymlink(t *testing.T) {
+ logger := logging.New(types.LogLevelDebug, false)
+ orig := restoreFS
+ restoreFS = osFS{}
+ t.Cleanup(func() { restoreFS = orig })
+
+ destRoot := t.TempDir()
+ if err := os.MkdirAll(filepath.Join(destRoot, "subdir"), 0o755); err != nil {
+ t.Fatalf("mkdir subdir: %v", err)
+ }
+ if err := os.Symlink("subdir", filepath.Join(destRoot, "linkdir")); err != nil {
+ t.Fatalf("create parent symlink: %v", err)
+ }
+
+ header := &tar.Header{
+ Name: "linkdir/ok",
+ Typeflag: tar.TypeSymlink,
+ Linkname: "file.txt",
+ }
+ target := filepath.Join(destRoot, header.Name)
+
+ if err := extractSymlink(target, header, destRoot, logger); err != nil {
+ t.Fatalf("safe symlink should succeed: %v", err)
+ }
+
+ linkTarget, err := os.Readlink(target)
+ if err != nil {
+ t.Fatalf("safe symlink should exist: %v", err)
+ }
+ if linkTarget != "file.txt" {
+ t.Fatalf("symlink target = %q, want %q", linkTarget, "file.txt")
+ }
+
+ if _, err := os.Lstat(filepath.Join(destRoot, "subdir", "ok")); err != nil {
+ t.Fatalf("expected symlink at resolved parent path: %v", err)
+ }
+}
+
func TestExtractTarEntry_DoesNotFollowExistingSymlinkTargetPath(t *testing.T) {
logger := logging.New(types.LogLevelDebug, false)
orig := restoreFS
diff --git a/internal/orchestrator/restore_workflow_decision_test.go b/internal/orchestrator/restore_workflow_decision_test.go
new file mode 100644
index 0000000..9f07c80
--- /dev/null
+++ b/internal/orchestrator/restore_workflow_decision_test.go
@@ -0,0 +1,315 @@
+package orchestrator
+
+import (
+ "context"
+ "errors"
+ "os"
+ "path/filepath"
+ "testing"
+ "time"
+
+ "github.com/tis24dev/proxsave/internal/backup"
+ "github.com/tis24dev/proxsave/internal/config"
+ "github.com/tis24dev/proxsave/internal/logging"
+ "github.com/tis24dev/proxsave/internal/types"
+)
+
+func stubPreparedRestoreBundle(archivePath string, manifest *backup.Manifest) func(context.Context, *config.Config, *logging.Logger, string, RestoreWorkflowUI) (*decryptCandidate, *preparedBundle, error) {
+ return func(ctx context.Context, cfg *config.Config, logger *logging.Logger, version string, ui RestoreWorkflowUI) (*decryptCandidate, *preparedBundle, error) {
+ return &decryptCandidate{
+ DisplayBase: "test",
+ Manifest: manifest,
+ }, &preparedBundle{
+ ArchivePath: archivePath,
+ Manifest: backup.Manifest{ArchivePath: archivePath},
+ cleanup: func() {},
+ }, nil
+ }
+}
+
+func TestRunRestoreWorkflow_ClusterPromptUsesArchivePayloadNotManifest(t *testing.T) {
+ origRestoreFS := restoreFS
+ origRestoreCmd := restoreCmd
+ origRestoreSystem := restoreSystem
+ origCompatFS := compatFS
+ origPrepare := prepareRestoreBundleFunc
+ origSafetyFS := safetyFS
+ t.Cleanup(func() {
+ restoreFS = origRestoreFS
+ restoreCmd = origRestoreCmd
+ restoreSystem = origRestoreSystem
+ compatFS = origCompatFS
+ prepareRestoreBundleFunc = origPrepare
+ safetyFS = origSafetyFS
+ })
+
+ fakeFS := NewFakeFS()
+ t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) })
+ restoreFS = fakeFS
+ compatFS = fakeFS
+ safetyFS = fakeFS
+ restoreCmd = runOnlyRunner{}
+ restoreSystem = fakeSystemDetector{systemType: SystemTypePVE}
+
+ if err := fakeFS.AddFile("/usr/bin/qm", []byte("x")); err != nil {
+ t.Fatalf("fakeFS.AddFile: %v", err)
+ }
+
+ tmpTar := filepath.Join(t.TempDir(), "bundle.tar")
+ if err := writeTarFile(tmpTar, map[string]string{
+ "etc/hosts": "127.0.0.1 localhost\n",
+ "var/lib/pve-cluster/config.db": "db\n",
+ }); err != nil {
+ t.Fatalf("writeTarFile: %v", err)
+ }
+ tarBytes, err := os.ReadFile(tmpTar)
+ if err != nil {
+ t.Fatalf("ReadFile tar: %v", err)
+ }
+ if err := fakeFS.WriteFile("/bundle.tar", tarBytes, 0o640); err != nil {
+ t.Fatalf("fakeFS.WriteFile: %v", err)
+ }
+
+ prepareRestoreBundleFunc = stubPreparedRestoreBundle("/bundle.tar", &backup.Manifest{
+ CreatedAt: time.Unix(1700000000, 0),
+ ClusterMode: "standalone",
+ ProxmoxType: "pve",
+ ScriptVersion: "vtest",
+ })
+
+ logger := logging.New(types.LogLevelError, false)
+ cfg := &config.Config{BaseDir: "/base"}
+ ui := &fakeRestoreWorkflowUI{
+ mode: RestoreModeCustom,
+ categories: []Category{
+ mustCategoryByID(t, "pve_cluster"),
+ },
+ confirmRestore: true,
+ clusterMode: ClusterRestoreSafe,
+ }
+
+ if err := runRestoreWorkflowWithUI(context.Background(), cfg, logger, "vtest", ui); err != nil {
+ t.Fatalf("runRestoreWorkflowWithUI error: %v", err)
+ }
+ if ui.clusterRestoreModeCalls != 1 {
+ t.Fatalf("clusterRestoreModeCalls=%d; want 1", ui.clusterRestoreModeCalls)
+ }
+}
+
+func TestRunRestoreWorkflow_CompatibilityUsesArchivePayloadNotManifest(t *testing.T) {
+ origRestoreFS := restoreFS
+ origRestoreCmd := restoreCmd
+ origRestoreSystem := restoreSystem
+ origCompatFS := compatFS
+ origPrepare := prepareRestoreBundleFunc
+ origSafetyFS := safetyFS
+ t.Cleanup(func() {
+ restoreFS = origRestoreFS
+ restoreCmd = origRestoreCmd
+ restoreSystem = origRestoreSystem
+ compatFS = origCompatFS
+ prepareRestoreBundleFunc = origPrepare
+ safetyFS = origSafetyFS
+ })
+
+ fakeFS := NewFakeFS()
+ t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) })
+ restoreFS = fakeFS
+ compatFS = fakeFS
+ safetyFS = fakeFS
+ restoreCmd = runOnlyRunner{}
+ restoreSystem = fakeSystemDetector{systemType: SystemTypePVE}
+
+ if err := fakeFS.AddFile("/usr/bin/qm", []byte("x")); err != nil {
+ t.Fatalf("fakeFS.AddFile: %v", err)
+ }
+
+ tmpTar := filepath.Join(t.TempDir(), "bundle.tar")
+ if err := writeTarFile(tmpTar, map[string]string{
+ "etc/ssh/sshd_config": "Port 22\n",
+ "etc/pve/jobs.cfg": "jobs\n",
+ }); err != nil {
+ t.Fatalf("writeTarFile: %v", err)
+ }
+ tarBytes, err := os.ReadFile(tmpTar)
+ if err != nil {
+ t.Fatalf("ReadFile tar: %v", err)
+ }
+ if err := fakeFS.WriteFile("/bundle.tar", tarBytes, 0o640); err != nil {
+ t.Fatalf("fakeFS.WriteFile: %v", err)
+ }
+
+ prepareRestoreBundleFunc = stubPreparedRestoreBundle("/bundle.tar", &backup.Manifest{
+ CreatedAt: time.Unix(1700000000, 0),
+ ClusterMode: "standalone",
+ ProxmoxType: "pbs",
+ ScriptVersion: "vtest",
+ })
+
+ logger := logging.New(types.LogLevelError, false)
+ cfg := &config.Config{BaseDir: "/base"}
+ ui := &fakeRestoreWorkflowUI{
+ mode: RestoreModeCustom,
+ categories: []Category{
+ mustCategoryByID(t, "ssh"),
+ },
+ confirmRestore: true,
+ confirmCompatible: false,
+ }
+
+ if err := runRestoreWorkflowWithUI(context.Background(), cfg, logger, "vtest", ui); err != nil {
+ t.Fatalf("runRestoreWorkflowWithUI error: %v", err)
+ }
+ if ui.confirmCompatibilityCalls != 0 {
+ t.Fatalf("confirmCompatibilityCalls=%d; want 0", ui.confirmCompatibilityCalls)
+ }
+}
+
+func TestRunRestoreWorkflow_CompatibilityWarnsOnArchiveMismatchDespiteManifest(t *testing.T) {
+ origRestoreFS := restoreFS
+ origRestoreCmd := restoreCmd
+ origRestoreSystem := restoreSystem
+ origCompatFS := compatFS
+ origPrepare := prepareRestoreBundleFunc
+ origAnalyze := analyzeRestoreArchiveFunc
+ origSafetyFS := safetyFS
+ t.Cleanup(func() {
+ restoreFS = origRestoreFS
+ restoreCmd = origRestoreCmd
+ restoreSystem = origRestoreSystem
+ compatFS = origCompatFS
+ prepareRestoreBundleFunc = origPrepare
+ analyzeRestoreArchiveFunc = origAnalyze
+ safetyFS = origSafetyFS
+ })
+
+ fakeFS := NewFakeFS()
+ t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) })
+ restoreFS = fakeFS
+ compatFS = fakeFS
+ safetyFS = fakeFS
+ restoreCmd = runOnlyRunner{}
+ restoreSystem = fakeSystemDetector{systemType: SystemTypePVE}
+
+ if err := fakeFS.AddFile("/usr/bin/qm", []byte("x")); err != nil {
+ t.Fatalf("fakeFS.AddFile: %v", err)
+ }
+
+ tmpTar := filepath.Join(t.TempDir(), "bundle.tar")
+ if err := writeTarFile(tmpTar, map[string]string{
+ "etc/ssh/sshd_config": "Port 22\n",
+ "etc/proxmox-backup/sync.cfg": "sync\n",
+ }); err != nil {
+ t.Fatalf("writeTarFile: %v", err)
+ }
+ tarBytes, err := os.ReadFile(tmpTar)
+ if err != nil {
+ t.Fatalf("ReadFile tar: %v", err)
+ }
+ if err := fakeFS.WriteFile("/bundle.tar", tarBytes, 0o640); err != nil {
+ t.Fatalf("fakeFS.WriteFile: %v", err)
+ }
+
+ prepareRestoreBundleFunc = stubPreparedRestoreBundle("/bundle.tar", &backup.Manifest{
+ CreatedAt: time.Unix(1700000000, 0),
+ ClusterMode: "standalone",
+ ProxmoxType: "pve",
+ ScriptVersion: "vtest",
+ })
+
+ logger := logging.New(types.LogLevelError, false)
+ cfg := &config.Config{BaseDir: "/base"}
+ ui := &fakeRestoreWorkflowUI{
+ mode: RestoreModeCustom,
+ categories: []Category{
+ mustCategoryByID(t, "ssh"),
+ },
+ confirmRestore: true,
+ confirmCompatible: true,
+ }
+
+ if err := runRestoreWorkflowWithUI(context.Background(), cfg, logger, "vtest", ui); err != nil {
+ t.Fatalf("runRestoreWorkflowWithUI error: %v", err)
+ }
+ if ui.confirmCompatibilityCalls != 1 {
+ t.Fatalf("confirmCompatibilityCalls=%d; want 1", ui.confirmCompatibilityCalls)
+ }
+ if ui.lastCompatibilityWarning == nil {
+ t.Fatalf("expected compatibility warning to be passed to UI")
+ }
+}
+
+func TestRunRestoreWorkflow_CompatibilityWarningStillRunsBeforeFullFallbackOnAnalysisError(t *testing.T) {
+ origRestoreFS := restoreFS
+ origRestoreCmd := restoreCmd
+ origRestoreSystem := restoreSystem
+ origCompatFS := compatFS
+ origPrepare := prepareRestoreBundleFunc
+ origAnalyze := analyzeRestoreArchiveFunc
+ origSafetyFS := safetyFS
+ t.Cleanup(func() {
+ restoreFS = origRestoreFS
+ restoreCmd = origRestoreCmd
+ restoreSystem = origRestoreSystem
+ compatFS = origCompatFS
+ prepareRestoreBundleFunc = origPrepare
+ analyzeRestoreArchiveFunc = origAnalyze
+ safetyFS = origSafetyFS
+ })
+
+ fakeFS := NewFakeFS()
+ t.Cleanup(func() { _ = os.RemoveAll(fakeFS.Root) })
+ restoreFS = fakeFS
+ compatFS = fakeFS
+ safetyFS = fakeFS
+ restoreCmd = runOnlyRunner{}
+ restoreSystem = fakeSystemDetector{systemType: SystemTypePVE}
+
+ if err := fakeFS.AddFile("/usr/bin/qm", []byte("x")); err != nil {
+ t.Fatalf("fakeFS.AddFile: %v", err)
+ }
+
+ tmpTar := filepath.Join(t.TempDir(), "bundle.tar")
+ if err := writeTarFile(tmpTar, map[string]string{
+ "etc/hosts": "127.0.0.1 localhost\n",
+ }); err != nil {
+ t.Fatalf("writeTarFile: %v", err)
+ }
+ tarBytes, err := os.ReadFile(tmpTar)
+ if err != nil {
+ t.Fatalf("ReadFile tar: %v", err)
+ }
+ if err := fakeFS.WriteFile("/bundle.tar", tarBytes, 0o640); err != nil {
+ t.Fatalf("fakeFS.WriteFile: %v", err)
+ }
+
+ prepareRestoreBundleFunc = stubPreparedRestoreBundle("/bundle.tar", &backup.Manifest{
+ CreatedAt: time.Unix(1700000000, 0),
+ ClusterMode: "standalone",
+ ProxmoxType: "pbs",
+ ScriptVersion: "vtest",
+ })
+ analyzeRestoreArchiveFunc = func(archivePath string, logger *logging.Logger) ([]Category, *RestoreDecisionInfo, error) {
+ return nil, nil, errors.New("boom")
+ }
+
+ logger := logging.New(types.LogLevelError, false)
+ cfg := &config.Config{BaseDir: "/base"}
+ ui := &fakeRestoreWorkflowUI{
+ confirmRestore: true,
+ confirmCompatible: true,
+ }
+
+ if err := runRestoreWorkflowWithUI(context.Background(), cfg, logger, "vtest", ui); err != nil {
+ t.Fatalf("runRestoreWorkflowWithUI error: %v", err)
+ }
+ if ui.confirmCompatibilityCalls != 1 {
+ t.Fatalf("confirmCompatibilityCalls=%d; want 1", ui.confirmCompatibilityCalls)
+ }
+ if ui.lastCompatibilityWarning == nil {
+ t.Fatalf("expected compatibility warning before fallback")
+ }
+ if _, err := fakeFS.ReadFile("/etc/hosts"); err != nil {
+ t.Fatalf("expected full restore fallback to extract /etc/hosts: %v", err)
+ }
+}
diff --git a/internal/orchestrator/restore_workflow_more_test.go b/internal/orchestrator/restore_workflow_more_test.go
index 41fa569..e60eb56 100644
--- a/internal/orchestrator/restore_workflow_more_test.go
+++ b/internal/orchestrator/restore_workflow_more_test.go
@@ -321,7 +321,8 @@ func TestRunRestoreWorkflow_IncompatibilityAndSafetyBackupFailureCanContinue(t *
tmpTar := filepath.Join(t.TempDir(), "bundle.tar")
if err := writeTarFile(tmpTar, map[string]string{
- "etc/hosts": "127.0.0.1 localhost\n",
+ "etc/hosts": "127.0.0.1 localhost\n",
+ "etc/proxmox-backup/sync.cfg": "sync\n",
}); err != nil {
t.Fatalf("writeTarFile: %v", err)
}
diff --git a/internal/orchestrator/restore_workflow_ui.go b/internal/orchestrator/restore_workflow_ui.go
index 9ce2057..48e5ab5 100644
--- a/internal/orchestrator/restore_workflow_ui.go
+++ b/internal/orchestrator/restore_workflow_ui.go
@@ -11,12 +11,26 @@ import (
"strings"
"time"
+ "github.com/tis24dev/proxsave/internal/backup"
"github.com/tis24dev/proxsave/internal/config"
"github.com/tis24dev/proxsave/internal/input"
"github.com/tis24dev/proxsave/internal/logging"
)
var prepareRestoreBundleFunc = prepareRestoreBundleWithUI
+var analyzeRestoreArchiveFunc = AnalyzeRestoreArchive
+
+func fallbackRestoreDecisionInfoFromManifest(manifest *backup.Manifest) *RestoreDecisionInfo {
+ info := &RestoreDecisionInfo{Source: RestoreDecisionSourceUnknown}
+ if manifest == nil {
+ return info
+ }
+
+ info.BackupType = DetectBackupType(manifest)
+ info.ClusterPayload = strings.EqualFold(strings.TrimSpace(manifest.ClusterMode), "cluster")
+ info.BackupHostname = strings.TrimSpace(manifest.Hostname)
+ return info
+}
func prepareRestoreBundleWithUI(ctx context.Context, cfg *config.Config, logger *logging.Logger, version string, ui RestoreWorkflowUI) (*decryptCandidate, *preparedBundle, error) {
candidate, err := selectBackupCandidateWithUI(ctx, ui, cfg, logger, false)
@@ -76,7 +90,19 @@ func runRestoreWorkflowWithUI(ctx context.Context, cfg *config.Config, logger *l
systemType := restoreSystem.DetectCurrentSystem()
logger.Info("Detected system type: %s", GetSystemTypeString(systemType))
- if warn := ValidateCompatibility(candidate.Manifest); warn != nil {
+ availableCategories, decisionInfo, err := analyzeRestoreArchiveFunc(prepared.ArchivePath, logger)
+ fallbackToFullRestore := false
+ if err != nil {
+ logger.Warning("Could not analyze categories: %v", err)
+ availableCategories = nil
+ decisionInfo = fallbackRestoreDecisionInfoFromManifest(candidate.Manifest)
+ fallbackToFullRestore = true
+ }
+ if decisionInfo == nil {
+ decisionInfo = &RestoreDecisionInfo{}
+ }
+
+ if warn := ValidateCompatibility(systemType, decisionInfo.BackupType); warn != nil {
logger.Warning("Compatibility check: %v", warn)
proceed, perr := ui.ConfirmCompatibility(ctx, warn)
if perr != nil {
@@ -86,11 +112,7 @@ func runRestoreWorkflowWithUI(ctx context.Context, cfg *config.Config, logger *l
return ErrRestoreAborted
}
}
-
- logger.Info("Analyzing backup contents...")
- availableCategories, err := AnalyzeBackupCategories(prepared.ArchivePath, logger)
- if err != nil {
- logger.Warning("Could not analyze categories: %v", err)
+ if fallbackToFullRestore {
logger.Info("Falling back to full restore mode")
return runFullRestoreWithUI(ctx, ui, candidate, prepared, destRoot, logger, cfg.DryRun)
}
@@ -127,7 +149,7 @@ func runRestoreWorkflowWithUI(ctx context.Context, cfg *config.Config, logger *l
}
}
- plan := PlanRestore(candidate.Manifest, selectedCategories, systemType, mode)
+ plan := PlanRestore(decisionInfo.ClusterPayload, selectedCategories, systemType, mode)
if plan.SystemType == SystemTypePBS &&
(plan.HasCategoryID("pbs_host") ||
@@ -145,9 +167,8 @@ func runRestoreWorkflowWithUI(ctx context.Context, cfg *config.Config, logger *l
logger.Info("PBS restore behavior: %s", behavior.DisplayName())
}
- clusterBackup := strings.EqualFold(strings.TrimSpace(candidate.Manifest.ClusterMode), "cluster")
- if plan.NeedsClusterRestore && clusterBackup {
- logger.Info("Backup marked as cluster node; enabling guarded restore options for pve_cluster")
+ if plan.NeedsClusterRestore && plan.ClusterBackup {
+ logger.Info("Cluster payload detected in backup; enabling guarded restore options for pve_cluster")
choice, promptErr := ui.SelectClusterRestoreMode(ctx)
if promptErr != nil {
return promptErr
@@ -168,8 +189,8 @@ func runRestoreWorkflowWithUI(ctx context.Context, cfg *config.Config, logger *l
if plan.HasCategoryID("pve_access_control") || plan.HasCategoryID("pbs_access_control") {
currentHost, hostErr := os.Hostname()
- if hostErr == nil && strings.TrimSpace(candidate.Manifest.Hostname) != "" && strings.TrimSpace(currentHost) != "" {
- backupHost := strings.TrimSpace(candidate.Manifest.Hostname)
+ if hostErr == nil && strings.TrimSpace(decisionInfo.BackupHostname) != "" && strings.TrimSpace(currentHost) != "" {
+ backupHost := strings.TrimSpace(decisionInfo.BackupHostname)
if !strings.EqualFold(strings.TrimSpace(currentHost), backupHost) {
logger.Warning("Access control/TFA: backup hostname=%s current hostname=%s; WebAuthn users may require re-enrollment if the UI origin (FQDN/port) changes", backupHost, currentHost)
}
diff --git a/internal/orchestrator/restore_workflow_ui_helpers_test.go b/internal/orchestrator/restore_workflow_ui_helpers_test.go
index f2a03724..ca1a8fe 100644
--- a/internal/orchestrator/restore_workflow_ui_helpers_test.go
+++ b/internal/orchestrator/restore_workflow_ui_helpers_test.go
@@ -35,6 +35,10 @@ type fakeRestoreWorkflowUI struct {
confirmActionErr error
repairNICNamesErr error
networkCommitErr error
+
+ confirmCompatibilityCalls int
+ clusterRestoreModeCalls int
+ lastCompatibilityWarning error
}
func (f *fakeRestoreWorkflowUI) RunTask(ctx context.Context, title, initialMessage string, run func(ctx context.Context, report ProgressReporter) error) error {
@@ -84,10 +88,13 @@ func (f *fakeRestoreWorkflowUI) ConfirmRestore(ctx context.Context) (bool, error
}
func (f *fakeRestoreWorkflowUI) ConfirmCompatibility(ctx context.Context, warning error) (bool, error) {
+ f.confirmCompatibilityCalls++
+ f.lastCompatibilityWarning = warning
return f.confirmCompatible, f.confirmCompatibleErr
}
func (f *fakeRestoreWorkflowUI) SelectClusterRestoreMode(ctx context.Context) (ClusterRestoreMode, error) {
+ f.clusterRestoreModeCalls++
return f.clusterMode, f.clusterModeErr
}
diff --git a/internal/orchestrator/selective.go b/internal/orchestrator/selective.go
index 4b05ea2..5b0b8af 100644
--- a/internal/orchestrator/selective.go
+++ b/internal/orchestrator/selective.go
@@ -27,40 +27,8 @@ type SelectiveRestoreConfig struct {
// AnalyzeBackupCategories detects which categories are available in the backup
func AnalyzeBackupCategories(archivePath string, logger *logging.Logger) (categories []Category, err error) {
- done := logging.DebugStart(logger, "analyze backup categories", "archive=%s", archivePath)
- defer func() { done(err) }()
- logger.Info("Analyzing backup categories...")
-
- // Open the archive and read all entry names
- file, err := restoreFS.Open(archivePath)
- if err != nil {
- return nil, fmt.Errorf("open archive: %w", err)
- }
- defer file.Close()
-
- // Create appropriate reader based on compression
- reader, err := createDecompressionReader(context.Background(), file, archivePath)
- if err != nil {
- return nil, err
- }
- defer func() {
- if closer, ok := reader.(interface{ Close() error }); ok {
- closer.Close()
- }
- }()
-
- tarReader := tar.NewReader(reader)
-
- archivePaths := collectArchivePaths(tarReader)
- logger.Debug("Found %d entries in archive", len(archivePaths))
-
- availableCategories := AnalyzeArchivePaths(archivePaths, GetAllCategories())
- for _, cat := range availableCategories {
- logger.Debug("Category available: %s (%s)", cat.ID, cat.Name)
- }
-
- logger.Info("Detected %d available categories", len(availableCategories))
- return availableCategories, nil
+ availableCategories, _, err := AnalyzeRestoreArchive(archivePath, logger)
+ return availableCategories, err
}
// AnalyzeArchivePaths determines available categories from the provided archive entries.
diff --git a/internal/pbs/namespaces.go b/internal/pbs/namespaces.go
index 345ac56..d168bef 100644
--- a/internal/pbs/namespaces.go
+++ b/internal/pbs/namespaces.go
@@ -46,6 +46,12 @@ func ListNamespaces(ctx context.Context, datastoreName, datastorePath string, io
return namespaces, true, nil
}
+// DiscoverNamespacesFromFilesystem skips the PBS CLI and infers namespaces
+// directly from the datastore filesystem layout.
+func DiscoverNamespacesFromFilesystem(ctx context.Context, datastorePath string, ioTimeout time.Duration) ([]Namespace, error) {
+ return discoverNamespacesFromFilesystem(ctx, datastorePath, ioTimeout)
+}
+
func listNamespacesViaCLI(ctx context.Context, datastore string) ([]Namespace, error) {
if err := ctx.Err(); err != nil {
return nil, err
diff --git a/internal/pbs/namespaces_test.go b/internal/pbs/namespaces_test.go
index ac358b9..f151cae 100644
--- a/internal/pbs/namespaces_test.go
+++ b/internal/pbs/namespaces_test.go
@@ -130,6 +130,21 @@ func TestDiscoverNamespacesFromFilesystem_Errors(t *testing.T) {
}
}
+func TestDiscoverNamespacesFromFilesystemExportedHelper(t *testing.T) {
+ tmpDir := t.TempDir()
+ mustMkdirAll(t, filepath.Join(tmpDir, "prod", "vm"))
+
+ namespaces, err := DiscoverNamespacesFromFilesystem(context.Background(), tmpDir, 0)
+ if err != nil {
+ t.Fatalf("DiscoverNamespacesFromFilesystem failed: %v", err)
+ }
+
+ got := namespacesToMap(namespaces)
+ if _, ok := got["prod"]; !ok {
+ t.Fatalf("expected exported helper to discover namespace, got %+v", namespaces)
+ }
+}
+
func TestListNamespaces_CLISuccess(t *testing.T) {
setExecCommandStub(t, "cli-success")
diff --git a/internal/safefs/safefs.go b/internal/safefs/safefs.go
index 36b001a..4639e34 100644
--- a/internal/safefs/safefs.go
+++ b/internal/safefs/safefs.go
@@ -14,6 +14,7 @@ var (
osStat = os.Stat
osReadDir = os.ReadDir
syscallStatfs = syscall.Statfs
+ fsOpLimiter = newOperationLimiter(32)
)
// ErrTimeout is a sentinel error used to classify filesystem operations that did not
@@ -56,100 +57,110 @@ func effectiveTimeout(ctx context.Context, timeout time.Duration) time.Duration
return timeout
}
-func Stat(ctx context.Context, path string, timeout time.Duration) (fs.FileInfo, error) {
- if err := ctx.Err(); err != nil {
- return nil, err
+// operationLimiter bounds the number of in-flight filesystem goroutines whose
+// callers may already have returned due to timeout/cancellation.
+type operationLimiter struct {
+ slots chan struct{}
+}
+
+func newOperationLimiter(capacity int) *operationLimiter {
+ if capacity < 1 {
+ capacity = 1
}
- timeout = effectiveTimeout(ctx, timeout)
- if timeout <= 0 {
- return osStat(path)
+ return &operationLimiter{
+ slots: make(chan struct{}, capacity),
}
+}
- type result struct {
- info fs.FileInfo
- err error
+func (l *operationLimiter) acquire(ctx context.Context, timer <-chan time.Time) error {
+ select {
+ case l.slots <- struct{}{}:
+ return nil
+ case <-ctx.Done():
+ return ctx.Err()
+ case <-timer:
+ return ErrTimeout
}
- ch := make(chan result, 1)
- go func() {
- info, err := osStat(path)
- ch <- result{info: info, err: err}
- }()
-
- timer := time.NewTimer(timeout)
- defer timer.Stop()
+}
+func (l *operationLimiter) release() {
select {
- case r := <-ch:
- return r.info, r.err
- case <-ctx.Done():
- return nil, ctx.Err()
- case <-timer.C:
- return nil, &TimeoutError{Op: "stat", Path: path, Timeout: timeout}
+ case <-l.slots:
+ default:
}
}
-func ReadDir(ctx context.Context, path string, timeout time.Duration) ([]os.DirEntry, error) {
+func (l *operationLimiter) inflight() int {
+ return len(l.slots)
+}
+
+func runLimited[T any](ctx context.Context, timeout time.Duration, timeoutErr *TimeoutError, run func() (T, error)) (T, error) {
+ var zero T
if err := ctx.Err(); err != nil {
- return nil, err
+ return zero, err
}
timeout = effectiveTimeout(ctx, timeout)
if timeout <= 0 {
- return osReadDir(path)
- }
-
- type result struct {
- entries []os.DirEntry
- err error
+ if err := ctx.Err(); err != nil {
+ if errors.Is(err, context.DeadlineExceeded) {
+ return zero, timeoutErr
+ }
+ return zero, err
+ }
+ return run()
}
- ch := make(chan result, 1)
- go func() {
- entries, err := osReadDir(path)
- ch <- result{entries: entries, err: err}
- }()
timer := time.NewTimer(timeout)
defer timer.Stop()
- select {
- case r := <-ch:
- return r.entries, r.err
- case <-ctx.Done():
- return nil, ctx.Err()
- case <-timer.C:
- return nil, &TimeoutError{Op: "readdir", Path: path, Timeout: timeout}
- }
-}
-
-func Statfs(ctx context.Context, path string, timeout time.Duration) (syscall.Statfs_t, error) {
- if err := ctx.Err(); err != nil {
- return syscall.Statfs_t{}, err
- }
- timeout = effectiveTimeout(ctx, timeout)
- if timeout <= 0 {
- var stat syscall.Statfs_t
- return stat, syscallStatfs(path, &stat)
+ limiter := fsOpLimiter
+ if err := limiter.acquire(ctx, timer.C); err != nil {
+ if errors.Is(err, ErrTimeout) {
+ return zero, timeoutErr
+ }
+ return zero, err
}
type result struct {
- stat syscall.Statfs_t
- err error
+ value T
+ err error
}
ch := make(chan result, 1)
go func() {
- var stat syscall.Statfs_t
- err := syscallStatfs(path, &stat)
- ch <- result{stat: stat, err: err}
+ defer limiter.release()
+ value, err := run()
+ ch <- result{value: value, err: err}
}()
- timer := time.NewTimer(timeout)
- defer timer.Stop()
-
select {
case r := <-ch:
- return r.stat, r.err
+ return r.value, r.err
case <-ctx.Done():
- return syscall.Statfs_t{}, ctx.Err()
+ return zero, ctx.Err()
case <-timer.C:
- return syscall.Statfs_t{}, &TimeoutError{Op: "statfs", Path: path, Timeout: timeout}
+ return zero, timeoutErr
}
}
+
+func Stat(ctx context.Context, path string, timeout time.Duration) (fs.FileInfo, error) {
+ stat := osStat
+ return runLimited(ctx, timeout, &TimeoutError{Op: "stat", Path: path, Timeout: effectiveTimeout(ctx, timeout)}, func() (fs.FileInfo, error) {
+ return stat(path)
+ })
+}
+
+func ReadDir(ctx context.Context, path string, timeout time.Duration) ([]os.DirEntry, error) {
+ readDir := osReadDir
+ return runLimited(ctx, timeout, &TimeoutError{Op: "readdir", Path: path, Timeout: effectiveTimeout(ctx, timeout)}, func() ([]os.DirEntry, error) {
+ return readDir(path)
+ })
+}
+
+func Statfs(ctx context.Context, path string, timeout time.Duration) (syscall.Statfs_t, error) {
+ statfs := syscallStatfs
+ return runLimited(ctx, timeout, &TimeoutError{Op: "statfs", Path: path, Timeout: effectiveTimeout(ctx, timeout)}, func() (syscall.Statfs_t, error) {
+ var stat syscall.Statfs_t
+ err := statfs(path, &stat)
+ return stat, err
+ })
+}
diff --git a/internal/safefs/safefs_test.go b/internal/safefs/safefs_test.go
index 30646ae..70745b5 100644
--- a/internal/safefs/safefs_test.go
+++ b/internal/safefs/safefs_test.go
@@ -4,17 +4,69 @@ import (
"context"
"errors"
"os"
+ "sync/atomic"
"syscall"
"testing"
"time"
)
+type stagedDeadlineContext struct {
+ deadline time.Time
+ done <-chan struct{}
+ errCalls int
+}
+
+func (c *stagedDeadlineContext) Deadline() (time.Time, bool) {
+ return c.deadline, true
+}
+
+func (c *stagedDeadlineContext) Done() <-chan struct{} {
+ return c.done
+}
+
+func (c *stagedDeadlineContext) Err() error {
+ c.errCalls++
+ if c.errCalls == 1 {
+ return nil
+ }
+ return context.DeadlineExceeded
+}
+
+func (c *stagedDeadlineContext) Value(any) any {
+ return nil
+}
+
+func waitForSignal(t *testing.T, ch <-chan struct{}, name string) {
+ t.Helper()
+ select {
+ case <-ch:
+ case <-time.After(500 * time.Millisecond):
+ t.Fatalf("timeout waiting for %s", name)
+ }
+}
+
+func registerBlockedOpCleanup(t *testing.T, name string, unblock chan struct{}, finished <-chan struct{}, restore func()) {
+ t.Helper()
+
+ t.Cleanup(restore)
+ t.Cleanup(func() {
+ close(unblock)
+ waitForSignal(t, finished, name)
+ })
+}
+
func TestStat_ReturnsTimeoutError(t *testing.T) {
prev := osStat
- defer func() { osStat = prev }()
+ unblock := make(chan struct{})
+ finished := make(chan struct{})
+ registerBlockedOpCleanup(t, "stat completion", unblock, finished, func() {
+ osStat = prev
+ })
osStat = func(string) (os.FileInfo, error) {
- select {}
+ <-unblock
+ close(finished)
+ return nil, os.ErrNotExist
}
start := time.Now()
@@ -29,10 +81,16 @@ func TestStat_ReturnsTimeoutError(t *testing.T) {
func TestReadDir_ReturnsTimeoutError(t *testing.T) {
prev := osReadDir
- defer func() { osReadDir = prev }()
+ unblock := make(chan struct{})
+ finished := make(chan struct{})
+ registerBlockedOpCleanup(t, "readdir completion", unblock, finished, func() {
+ osReadDir = prev
+ })
osReadDir = func(string) ([]os.DirEntry, error) {
- select {}
+ <-unblock
+ close(finished)
+ return nil, nil
}
start := time.Now()
@@ -47,10 +105,16 @@ func TestReadDir_ReturnsTimeoutError(t *testing.T) {
func TestStatfs_ReturnsTimeoutError(t *testing.T) {
prev := syscallStatfs
- defer func() { syscallStatfs = prev }()
+ unblock := make(chan struct{})
+ finished := make(chan struct{})
+ registerBlockedOpCleanup(t, "statfs completion", unblock, finished, func() {
+ syscallStatfs = prev
+ })
syscallStatfs = func(string, *syscall.Statfs_t) error {
- select {}
+ <-unblock
+ close(finished)
+ return nil
}
start := time.Now()
@@ -72,3 +136,65 @@ func TestStat_PropagatesContextCancellation(t *testing.T) {
t.Fatalf("Stat err = %v; want context.Canceled", err)
}
}
+
+func TestRunLimited_ReturnsTimeoutErrorWhenDeadlineExpiresBeforeNoTimeoutPath(t *testing.T) {
+ done := make(chan struct{})
+ close(done)
+ ctx := &stagedDeadlineContext{
+ deadline: time.Now().Add(-time.Millisecond),
+ done: done,
+ }
+
+ called := false
+ _, err := runLimited(ctx, 50*time.Millisecond, &TimeoutError{Op: "stat", Path: "/does/not/matter"}, func() (int, error) {
+ called = true
+ return 1, nil
+ })
+
+ if called {
+ t.Fatal("run called; want timeout before execution")
+ }
+ if err == nil || !errors.Is(err, ErrTimeout) {
+ t.Fatalf("runLimited err = %v; want timeout", err)
+ }
+}
+
+func TestStat_DoesNotSpawnPastLimiterCapacity(t *testing.T) {
+ prevStat := osStat
+ prevLimiter := fsOpLimiter
+ unblock := make(chan struct{})
+ finished := make(chan struct{})
+ registerBlockedOpCleanup(t, "limited stat completion", unblock, finished, func() {
+ osStat = prevStat
+ fsOpLimiter = prevLimiter
+ })
+
+ fsOpLimiter = newOperationLimiter(1)
+
+ var calls atomic.Int32
+ osStat = func(string) (os.FileInfo, error) {
+ calls.Add(1)
+ <-unblock
+ close(finished)
+ return nil, os.ErrNotExist
+ }
+
+ _, err := Stat(context.Background(), "/first", 25*time.Millisecond)
+ if err == nil || !errors.Is(err, ErrTimeout) {
+ t.Fatalf("first Stat err = %v; want timeout", err)
+ }
+ if got := calls.Load(); got != 1 {
+ t.Fatalf("calls after first timeout = %d; want 1", got)
+ }
+ if got := fsOpLimiter.inflight(); got != 1 {
+ t.Fatalf("inflight after first timeout = %d; want 1", got)
+ }
+
+ _, err = Stat(context.Background(), "/second", 25*time.Millisecond)
+ if err == nil || !errors.Is(err, ErrTimeout) {
+ t.Fatalf("second Stat err = %v; want timeout", err)
+ }
+ if got := calls.Load(); got != 1 {
+ t.Fatalf("calls after limiter saturation = %d; want 1", got)
+ }
+}