diff --git a/internal/utils/archive/archive.go b/internal/utils/archive/archive.go new file mode 100644 index 00000000..37a07a57 --- /dev/null +++ b/internal/utils/archive/archive.go @@ -0,0 +1,423 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +// Package archive provides on-disk tar extraction and deterministic archive +// creation. +// +// [Extract] materializes entries into a destination directory via [os.Root], +// which confines every file, directory, and symlink operation to that +// directory — any entry path that would escape destDir is rejected by the +// runtime. Symlink *targets* (the contents of a symlink, not its own path) +// are not policed by [os.Root]; this package additionally rejects any symlink +// whose target is non-local via [filepath.IsLocal]. +// +// Archive creation is designed for reproducible builds: file ordering is +// lexicographic, timestamps are pinned to Unix epoch, and owner/group metadata +// is zeroed out. This matches the +// `tar --sort=name --mtime=@0 --owner=0 --group=0` convention used by source +// modification scripts. +package archive + +import ( + "archive/tar" + "compress/gzip" + "errors" + "fmt" + "io" + "log/slog" + "os" + "path/filepath" + "strings" + "time" + + "github.com/klauspost/compress/zstd" + "github.com/microsoft/azure-linux-dev-tools/internal/utils/defers" + "github.com/microsoft/azure-linux-dev-tools/internal/utils/fileperms" + "github.com/ulikunitz/xz" +) + +// Compression identifies the compression format of an archive. +type Compression int + +const ( + // CompressionNone indicates an uncompressed .tar archive. + CompressionNone Compression = iota + // CompressionGzip indicates gzip compression (.tar.gz or .tgz). + CompressionGzip + // CompressionXZ indicates xz compression (.tar.xz). + CompressionXZ + // CompressionZstd indicates zstandard compression (.tar.zst). + CompressionZstd +) + +// maxEntryBytes caps the decompressed size of any single regular-file entry +// extracted by [Extract]. This prevents a decompression-bomb archive from +// filling the destination filesystem. 10 GiB is well above any reasonable +// source archive entry but small enough to refuse pathological inputs. +// +// Declared as var rather than const so internal tests can override it +// without having to construct a >10 GiB fixture. +// +//nolint:gochecknoglobals // see comment above +var maxEntryBytes int64 = 10 << 30 + +// DetectCompression determines the compression type from the archive filename. +func DetectCompression(filename string) (Compression, error) { + lower := strings.ToLower(filename) + + switch { + case strings.HasSuffix(lower, ".tar.gz") || strings.HasSuffix(lower, ".tgz"): + return CompressionGzip, nil + case strings.HasSuffix(lower, ".tar.xz"): + return CompressionXZ, nil + case strings.HasSuffix(lower, ".tar.zst"): + return CompressionZstd, nil + case strings.HasSuffix(lower, ".tar"): + return CompressionNone, nil + default: + return CompressionNone, fmt.Errorf("unsupported archive format %#q", filename) + } +} + +// Extract reads a tar archive, decompresses it, and extracts all entries into +// destDir. Supported entry types are regular files, directories, and symlinks; +// other entry types are skipped. Entry paths are confined to destDir via +// [os.Root]: any path that would escape destDir is rejected by the runtime. +// Symlink targets are validated separately by this package — see the package +// doc for details. +func Extract(archivePath, destDir string, comp Compression) (err error) { + if err := os.MkdirAll(destDir, fileperms.PublicDir); err != nil { + return fmt.Errorf("creating destination %#q:\n%w", destDir, err) + } + + root, err := os.OpenRoot(destDir) + if err != nil { + return fmt.Errorf("opening destination root %#q:\n%w", destDir, err) + } + defer defers.HandleDeferError(root.Close, &err) + + file, err := os.Open(archivePath) + if err != nil { + return fmt.Errorf("opening archive %#q:\n%w", archivePath, err) + } + defer defers.HandleDeferError(file.Close, &err) + + decompressed, closer, err := newDecompressor(file, comp) + if err != nil { + return err + } + + if closer != nil { + defer defers.HandleDeferError(closer.Close, &err) + } + + tarReader := tar.NewReader(decompressed) + + for { + header, readErr := tarReader.Next() + if errors.Is(readErr, io.EOF) { + return nil + } + + if readErr != nil { + return fmt.Errorf("reading tar entry from %#q:\n%w", archivePath, readErr) + } + + if err := extractEntry(root, header, tarReader); err != nil { + return fmt.Errorf("extracting %#q from %#q:\n%w", header.Name, archivePath, err) + } + } +} + +// newDecompressor wraps reader in the chosen decompressor. For +// [CompressionNone] the returned closer is nil; otherwise it is the +// decompressor itself. +func newDecompressor(reader io.Reader, comp Compression) (io.Reader, io.Closer, error) { + switch comp { + case CompressionNone: + return reader, nil, nil + case CompressionGzip: + gzReader, err := gzip.NewReader(reader) + if err != nil { + return nil, nil, fmt.Errorf("creating gzip reader:\n%w", err) + } + + return gzReader, gzReader, nil + case CompressionXZ: + xzReader, err := xz.NewReader(reader) + if err != nil { + return nil, nil, fmt.Errorf("creating xz reader:\n%w", err) + } + + return xzReader, nil, nil + case CompressionZstd: + zstdReader, err := zstd.NewReader(reader) + if err != nil { + return nil, nil, fmt.Errorf("creating zstd reader:\n%w", err) + } + + return zstdReader, readerCloser{zstdReader.Close}, nil + default: + return nil, nil, fmt.Errorf("unsupported compression type %d", comp) + } +} + +// readerCloser adapts a no-error Close (such as [zstd.Decoder.Close]) to +// [io.Closer]. +type readerCloser struct { + close func() +} + +func (r readerCloser) Close() error { + r.close() + + return nil +} + +// CreateDeterministicArchive creates a new tar archive from the contents of sourceDir +// and writes it to archivePath on the OS filesystem, replacing any existing file. +// +// The output is deterministic: +// - File ordering is lexicographic (via [filepath.WalkDir]). +// - Timestamps are pinned to Unix epoch (1970-01-01 00:00:00 UTC). +// - Owner/group IDs and names are zeroed out. +// - Gzip output uses best compression with no OS or filename metadata. +// +// Symlink targets are recorded verbatim, including absolute or +// directory-escaping targets. [Extract] rejects non-local symlink targets +// (via [filepath.IsLocal]), so an archive produced here is not guaranteed to +// be extractable by [Extract] when the source tree contains such symlinks. +func CreateDeterministicArchive(archivePath, sourceDir string, comp Compression) (err error) { + file, err := os.Create(archivePath) + if err != nil { + return fmt.Errorf("opening archive for writing %#q:\n%w", archivePath, err) + } + defer defers.HandleDeferError(file.Close, &err) + + compressedWriter, compressedCloser, err := newCompressor(file, comp) + if err != nil { + return err + } + + if compressedCloser != nil { + defer defers.HandleDeferError(compressedCloser.Close, &err) + } + + tarWriter := tar.NewWriter(compressedWriter) + defer defers.HandleDeferError(tarWriter.Close, &err) + + epoch := deterministicEpoch() + + walkErr := filepath.WalkDir(sourceDir, func(path string, entry os.DirEntry, dirErr error) error { + if dirErr != nil { + return dirErr + } + + rel, relErr := filepath.Rel(sourceDir, path) + if relErr != nil { + return fmt.Errorf("computing relative path for %#q:\n%w", path, relErr) + } + + if rel == "." { + return nil + } + + return writeEntryDeterministic(tarWriter, path, rel, entry, epoch) + }) + if walkErr != nil { + return fmt.Errorf("walking directory for repacking:\n%w", walkErr) + } + + return nil +} + +func deterministicEpoch() time.Time { + return time.Unix(0, 0).UTC() +} + +func extractEntry(root *os.Root, header *tar.Header, tarReader io.Reader) error { + name := header.Name + + if header.Typeflag == tar.TypeDir { + if err := root.MkdirAll(name, fileperms.PublicDir); err != nil { + return fmt.Errorf("creating directory %#q:\n%w", name, err) + } + + return nil + } + + if err := root.MkdirAll(filepath.Dir(name), fileperms.PublicDir); err != nil { + return fmt.Errorf("creating parent for %#q:\n%w", name, err) + } + + switch header.Typeflag { + case tar.TypeSymlink: + // os.Root validates that the link's own path stays inside the root, but + // it stores the target verbatim. Reject non-local targets so that tools + // which later walk the extracted tree without os.Root cannot be led + // outside destDir. + if !filepath.IsLocal(header.Linkname) { + return fmt.Errorf("tar symlink %#q has non-local target %#q", name, header.Linkname) + } + + if err := root.Symlink(header.Linkname, name); err != nil { + return fmt.Errorf("creating symlink %#q -> %#q:\n%w", name, header.Linkname, err) + } + + return nil + case tar.TypeReg: + return extractRegularFile(root, header, tarReader) + default: + slog.Debug("Skipping unsupported tar entry type", "name", name, "typeflag", header.Typeflag) + + return nil + } +} + +func extractRegularFile(root *os.Root, header *tar.Header, src io.Reader) (err error) { + name := header.Name + + // gosec G115: header.Mode is the tar permission bits; mask to ModePerm + // (the bottom 9 bits) before passing to OpenFile. + mode := os.FileMode(header.Mode) & os.ModePerm //nolint:gosec + + // Reject up front if the declared header size is over the limit so we never + // open a destination file for a known-bad entry. + if header.Size > maxEntryBytes { + return fmt.Errorf("tar entry %#q declares size %d bytes, exceeds max of %d", name, header.Size, maxEntryBytes) + } + + outFile, err := root.OpenFile(name, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, mode) + if err != nil { + return fmt.Errorf("creating file %#q:\n%w", name, err) + } + defer defers.HandleDeferError(outFile.Close, &err) + + // Remove any partially written output on error so callers don't see a + // truncated file masquerading as a successful extraction. + defer func() { + if err != nil { + _ = root.Remove(name) + } + }() + + // Copy exactly header.Size bytes. CopyN returns a non-nil error (including + // io.EOF on a truncated archive) whenever it cannot satisfy the full count, + // so any error here means the entry is short or unreadable. + if _, copyErr := io.CopyN(outFile, src, header.Size); copyErr != nil { + return fmt.Errorf("writing file %#q:\n%w", name, copyErr) + } + + return nil +} + +func writeEntryDeterministic( + tarWriter *tar.Writer, path, rel string, entry os.DirEntry, epoch time.Time, +) (err error) { + info, err := entry.Info() + if err != nil { + return fmt.Errorf("stat %#q:\n%w", path, err) + } + + var linkTarget string + if info.Mode()&os.ModeSymlink != 0 { + linkTarget, err = os.Readlink(path) + if err != nil { + return fmt.Errorf("reading symlink %#q:\n%w", path, err) + } + } + + header, err := tar.FileInfoHeader(info, linkTarget) + if err != nil { + return fmt.Errorf("creating tar header for %#q:\n%w", path, err) + } + + header.Name = filepath.ToSlash(rel) + if info.IsDir() { + // tar convention: directory entry names end with '/'. tar.FileInfoHeader + // applies this, but our rel-based override above drops it. + header.Name += "/" + } + + header.Format = tar.FormatGNU + header.ModTime = epoch + header.AccessTime = time.Time{} + header.ChangeTime = time.Time{} + header.Uid = 0 + header.Gid = 0 + header.Uname = "" + header.Gname = "" + + if err := tarWriter.WriteHeader(header); err != nil { + return fmt.Errorf("writing tar header for %#q:\n%w", path, err) + } + + if !info.Mode().IsRegular() { + return nil + } + + sourceFile, openErr := os.Open(path) + if openErr != nil { + return fmt.Errorf("opening %#q for repack:\n%w", path, openErr) + } + defer defers.HandleDeferError(sourceFile.Close, &err) + + // Copy exactly header.Size bytes so the archive entry matches the header. + // Using io.Copy would write whatever the file contains at read time, which + // may exceed header.Size if the file grew between stat and open, causing + // tar.Writer to error mid-stream with a partially written entry. + if _, copyErr := io.CopyN(tarWriter, sourceFile, header.Size); copyErr != nil { + return fmt.Errorf("writing %#q to archive:\n%w", path, copyErr) + } + + return nil +} + +// newCompressor wraps writer in the chosen compression. For [CompressionNone] +// the returned closer is nil (no wrapper to flush or close); otherwise it is +// the compressor itself, which must be closed before the underlying writer to +// flush trailing bytes. +func newCompressor(writer io.Writer, comp Compression) (io.Writer, io.Closer, error) { + switch comp { + case CompressionNone: + return writer, nil, nil + case CompressionGzip: + gzWriter, gzErr := gzip.NewWriterLevel(writer, gzip.BestCompression) + if gzErr != nil { + return nil, nil, fmt.Errorf("creating gzip writer:\n%w", gzErr) + } + + // Pin every gzip header field that the writer would otherwise populate + // from the environment or input file, so two runs over identical inputs + // produce byte-identical output. ModTime matches the tar header epoch + // for consistency; OS is "unknown" (RFC 1952 §2.3.1) so output is + // independent of the host OS. + const gzipOSUnknown byte = 0xff + + gzWriter.Header = gzip.Header{ + Name: "", + Comment: "", + Extra: nil, + ModTime: deterministicEpoch(), + OS: gzipOSUnknown, + } + + return gzWriter, gzWriter, nil + case CompressionXZ: + xzWriter, err := xz.NewWriter(writer) + if err != nil { + return nil, nil, fmt.Errorf("creating xz writer:\n%w", err) + } + + return xzWriter, xzWriter, nil + case CompressionZstd: + zstdWriter, err := zstd.NewWriter(writer) + if err != nil { + return nil, nil, fmt.Errorf("creating zstd writer:\n%w", err) + } + + return zstdWriter, zstdWriter, nil + default: + return nil, nil, fmt.Errorf("unsupported compression type %d for writing", comp) + } +} diff --git a/internal/utils/archive/archive_internal_test.go b/internal/utils/archive/archive_internal_test.go new file mode 100644 index 00000000..d11d0faf --- /dev/null +++ b/internal/utils/archive/archive_internal_test.go @@ -0,0 +1,54 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +package archive + +import ( + "archive/tar" + "bytes" + "compress/gzip" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestExtract_OversizedEntryRejected(t *testing.T) { + previous := maxEntryBytes + maxEntryBytes = 5 + + t.Cleanup(func() { maxEntryBytes = previous }) + + tmpDir := t.TempDir() + archivePath := filepath.Join(tmpDir, "big.tar.gz") + extractDir := filepath.Join(tmpDir, "out") + require.NoError(t, os.MkdirAll(extractDir, 0o755)) + + var buf bytes.Buffer + + gzWriter := gzip.NewWriter(&buf) + tarWriter := tar.NewWriter(gzWriter) + + const content = "hello world" // 11 bytes > 5 + + require.NoError(t, tarWriter.WriteHeader(&tar.Header{ + Name: "pkg/huge.bin", + Typeflag: tar.TypeReg, + Mode: 0o644, + Size: int64(len(content)), + })) + _, err := tarWriter.Write([]byte(content)) + require.NoError(t, err) + require.NoError(t, tarWriter.Close()) + require.NoError(t, gzWriter.Close()) + require.NoError(t, os.WriteFile(archivePath, buf.Bytes(), 0o600)) + + err = Extract(archivePath, extractDir, CompressionGzip) + require.Error(t, err) + assert.Contains(t, err.Error(), "exceeds max of 5") + + _, statErr := os.Stat(filepath.Join(extractDir, "pkg", "huge.bin")) + assert.True(t, os.IsNotExist(statErr), "oversized entry must not leave a partial file (stat err=%v)", statErr) +} diff --git a/internal/utils/archive/archive_test.go b/internal/utils/archive/archive_test.go new file mode 100644 index 00000000..b41273c2 --- /dev/null +++ b/internal/utils/archive/archive_test.go @@ -0,0 +1,340 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +package archive_test + +import ( + "archive/tar" + "bytes" + "compress/gzip" + "errors" + "io" + "os" + "path/filepath" + "testing" + + "github.com/microsoft/azure-linux-dev-tools/internal/utils/archive" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDetectCompression(t *testing.T) { + tests := []struct { + filename string + expected archive.Compression + wantErr bool + }{ + {"pkg-1.0.tar.gz", archive.CompressionGzip, false}, + {"pkg-1.0.tgz", archive.CompressionGzip, false}, + {"pkg-1.0.tar.bz2", archive.CompressionNone, true}, + {"pkg-1.0.tar.xz", archive.CompressionXZ, false}, + {"pkg-1.0.tar.zst", archive.CompressionZstd, false}, + {"pkg-1.0.tar", archive.CompressionNone, false}, + {"pkg-1.0.zip", archive.CompressionNone, true}, + {"PKG-1.0.TAR.GZ", archive.CompressionGzip, false}, + } + + for _, testCase := range tests { + t.Run(testCase.filename, func(t *testing.T) { + comp, err := archive.DetectCompression(testCase.filename) + + if testCase.wantErr { + require.Error(t, err) + + return + } + + require.NoError(t, err) + assert.Equal(t, testCase.expected, comp) + }) + } +} + +func TestExtractAndRepack(t *testing.T) { + tmpDir := t.TempDir() + + archivePath := filepath.Join(tmpDir, "test.tar.gz") + extractDir := filepath.Join(tmpDir, "extracted") + repackDir := filepath.Join(tmpDir, "repacked") + + require.NoError(t, os.MkdirAll(extractDir, 0o755)) + require.NoError(t, os.MkdirAll(repackDir, 0o755)) + + createTestTarGz(t, archivePath, []testTarEntry{ + {name: "pkg-1.0/hello.txt", typeflag: tar.TypeReg, content: "hello world"}, + {name: "pkg-1.0/config.cfg", typeflag: tar.TypeReg, content: "key=value"}, + }) + + err := archive.Extract(archivePath, extractDir, archive.CompressionGzip) + require.NoError(t, err) + + content, readErr := os.ReadFile(filepath.Join(extractDir, "pkg-1.0", "hello.txt")) + require.NoError(t, readErr) + assert.Equal(t, "hello world", string(content)) + + repackPath := filepath.Join(tmpDir, "repacked.tar.gz") + + err = archive.CreateDeterministicArchive(repackPath, extractDir, archive.CompressionGzip) + require.NoError(t, err) + + err = archive.Extract(repackPath, repackDir, archive.CompressionGzip) + require.NoError(t, err) + + content, readErr = os.ReadFile(filepath.Join(repackDir, "pkg-1.0", "hello.txt")) + require.NoError(t, readErr) + assert.Equal(t, "hello world", string(content)) + + // Repack twice and verify byte-for-byte identical output. + repackPath2 := filepath.Join(tmpDir, "repacked2.tar.gz") + + err = archive.CreateDeterministicArchive(repackPath2, extractDir, archive.CompressionGzip) + require.NoError(t, err) + + data1, err := os.ReadFile(repackPath) + require.NoError(t, err) + data2, err := os.ReadFile(repackPath2) + require.NoError(t, err) + assert.Equal(t, data1, data2, "deterministic repack should produce identical output") +} + +func createTestTarGz(t *testing.T, path string, entries []testTarEntry) { + t.Helper() + + var buf bytes.Buffer + + gzWriter := gzip.NewWriter(&buf) + tarWriter := tar.NewWriter(gzWriter) + + for _, entry := range entries { + header := &tar.Header{ + Name: entry.name, + Typeflag: entry.typeflag, + } + + switch entry.typeflag { + case tar.TypeDir: + header.Mode = 0o755 + case tar.TypeReg: + header.Mode = 0o644 + header.Size = int64(len(entry.content)) + case tar.TypeSymlink: + header.Linkname = entry.linkname + } + + require.NoError(t, tarWriter.WriteHeader(header)) + + if entry.typeflag == tar.TypeReg && len(entry.content) > 0 { + _, writeErr := tarWriter.Write([]byte(entry.content)) + require.NoError(t, writeErr) + } + } + + require.NoError(t, tarWriter.Close()) + require.NoError(t, gzWriter.Close()) + require.NoError(t, os.WriteFile(path, buf.Bytes(), 0o600)) +} + +type testTarEntry struct { + name string + typeflag byte + content string + linkname string +} + +func TestRoundTrip_AllCompressions(t *testing.T) { + tests := []struct { + name string + ext string + comp archive.Compression + }{ + {"none", ".tar", archive.CompressionNone}, + {"gzip", ".tar.gz", archive.CompressionGzip}, + {"xz", ".tar.xz", archive.CompressionXZ}, + {"zstd", ".tar.zst", archive.CompressionZstd}, + } + + for _, testCase := range tests { + t.Run(testCase.name, func(t *testing.T) { + tmpDir := t.TempDir() + sourceDir := filepath.Join(tmpDir, "src") + extractDir := filepath.Join(tmpDir, "out") + require.NoError(t, os.MkdirAll(filepath.Join(sourceDir, "sub"), 0o755)) + require.NoError(t, os.WriteFile(filepath.Join(sourceDir, "a.txt"), []byte("alpha"), 0o600)) + require.NoError(t, os.WriteFile(filepath.Join(sourceDir, "sub", "b.txt"), []byte("beta"), 0o600)) + + archivePath := filepath.Join(tmpDir, "archive"+testCase.ext) + + require.NoError(t, archive.CreateDeterministicArchive(archivePath, sourceDir, testCase.comp)) + require.NoError(t, archive.Extract(archivePath, extractDir, testCase.comp)) + + got, err := os.ReadFile(filepath.Join(extractDir, "sub", "b.txt")) + require.NoError(t, err) + assert.Equal(t, "beta", string(got)) + + // Create a second archive over identical input and verify + // byte-for-byte determinism across all compression modes. + archivePath2 := filepath.Join(tmpDir, "archive2"+testCase.ext) + require.NoError(t, archive.CreateDeterministicArchive(archivePath2, sourceDir, testCase.comp)) + + data1, err := os.ReadFile(archivePath) + require.NoError(t, err) + data2, err := os.ReadFile(archivePath2) + require.NoError(t, err) + assert.Equal(t, data1, data2, "deterministic archive should produce identical output") + }) + } +} + +func TestUnsupportedCompression(t *testing.T) { + tmpDir := t.TempDir() + archivePath := filepath.Join(tmpDir, "archive.bin") + require.NoError(t, os.WriteFile(archivePath, []byte("dummy"), 0o600)) + + bogus := archive.Compression(99) + + err := archive.Extract(archivePath, tmpDir, bogus) + require.Error(t, err) + assert.Contains(t, err.Error(), "unsupported compression type") + + err = archive.CreateDeterministicArchive(filepath.Join(tmpDir, "out.bin"), tmpDir, bogus) + require.Error(t, err) + assert.Contains(t, err.Error(), "unsupported compression type") +} + +func TestCreateDeterministicArchive_PreservesSymlinks(t *testing.T) { + tmpDir := t.TempDir() + sourceDir := filepath.Join(tmpDir, "src") + externalDir := filepath.Join(tmpDir, "external") + + require.NoError(t, os.MkdirAll(sourceDir, 0o755)) + require.NoError(t, os.MkdirAll(externalDir, 0o755)) + + // A regular file inside the source tree, plus an external file whose + // contents must NOT end up embedded in the archive. + require.NoError(t, os.WriteFile(filepath.Join(sourceDir, "real.txt"), []byte("inside"), 0o600)) + + const externalContent = "must-not-be-archived" + + externalPath := filepath.Join(externalDir, "secret.txt") + require.NoError(t, os.WriteFile(externalPath, []byte(externalContent), 0o600)) + + // Symlink staying inside the source tree (relative target). + require.NoError(t, os.Symlink("real.txt", filepath.Join(sourceDir, "internal-link"))) + // Symlink pointing outside the source tree (absolute target). + require.NoError(t, os.Symlink(externalPath, filepath.Join(sourceDir, "external-link"))) + + archivePath := filepath.Join(tmpDir, "archive.tar") + require.NoError(t, archive.CreateDeterministicArchive(archivePath, sourceDir, archive.CompressionNone)) + + archiveBytes, err := os.ReadFile(archivePath) + require.NoError(t, err) + + type entryInfo struct { + header *tar.Header + content string + } + + entriesByName := map[string]entryInfo{} + + reader := tar.NewReader(bytes.NewReader(archiveBytes)) + + for { + header, readErr := reader.Next() + if errors.Is(readErr, io.EOF) { + break + } + + require.NoError(t, readErr) + + var content string + + if header.Typeflag == tar.TypeReg { + body, bodyErr := io.ReadAll(reader) + require.NoError(t, bodyErr) + + content = string(body) + } + + entriesByName[header.Name] = entryInfo{header: header, content: content} + } + + for name, entry := range entriesByName { + if entry.header.Typeflag == tar.TypeReg { + assert.NotContains(t, entry.content, externalContent, + "regular file entry %#q must not contain external content", name) + } + } + + internalEntry, found := entriesByName["internal-link"] + require.True(t, found, "internal symlink entry missing from archive") + assert.Equal(t, byte(tar.TypeSymlink), internalEntry.header.Typeflag) + assert.Equal(t, "real.txt", internalEntry.header.Linkname) + assert.Zero(t, internalEntry.header.Size, "symlink entries must not carry payload bytes") + + externalEntry, found := entriesByName["external-link"] + require.True(t, found, "external symlink entry missing from archive") + assert.Equal(t, byte(tar.TypeSymlink), externalEntry.header.Typeflag) + assert.Equal(t, externalPath, externalEntry.header.Linkname, + "external symlink target must be recorded verbatim, not dereferenced") + assert.Zero(t, externalEntry.header.Size) +} + +func TestExtract_SymlinkSafety(t *testing.T) { + tests := []struct { + name string + entries []testTarEntry + wantErr bool + }{ + { + name: "absolute symlink target rejected", + entries: []testTarEntry{ + {name: "evil", typeflag: tar.TypeSymlink, linkname: "/etc/passwd"}, + }, + wantErr: true, + }, + { + name: "relative symlink escaping destDir rejected", + entries: []testTarEntry{ + {name: "evil", typeflag: tar.TypeSymlink, linkname: "../../etc"}, + }, + wantErr: true, + }, + { + name: "entry name escaping destDir rejected", + entries: []testTarEntry{ + {name: "../escape.txt", typeflag: tar.TypeReg, content: "nope"}, + }, + wantErr: true, + }, + { + name: "valid internal symlink allowed", + entries: []testTarEntry{ + {name: "pkg/", typeflag: tar.TypeDir}, + {name: "pkg/real.txt", typeflag: tar.TypeReg, content: "hello"}, + {name: "pkg/link.txt", typeflag: tar.TypeSymlink, linkname: "real.txt"}, + }, + wantErr: false, + }, + } + + for _, testCase := range tests { + t.Run(testCase.name, func(t *testing.T) { + tmpDir := t.TempDir() + archivePath := filepath.Join(tmpDir, "test.tar.gz") + extractDir := filepath.Join(tmpDir, "extracted") + require.NoError(t, os.MkdirAll(extractDir, 0o755)) + + createTestTarGz(t, archivePath, testCase.entries) + + err := archive.Extract(archivePath, extractDir, archive.CompressionGzip) + + if testCase.wantErr { + require.Error(t, err) + + return + } + + require.NoError(t, err) + }) + } +}