diff --git a/packagehandlers/commonpackagehandler.go b/packagehandlers/commonpackagehandler.go index 9a874d521..844d9faaf 100644 --- a/packagehandlers/commonpackagehandler.go +++ b/packagehandlers/commonpackagehandler.go @@ -3,6 +3,7 @@ package packagehandlers import ( "fmt" "io/fs" + "os" "os/exec" "path/filepath" "regexp" @@ -107,6 +108,11 @@ func (cph *CommonPackageHandler) GetAllDescriptorFilesFullPaths(descriptorFilesS regexpPatternsCompilers = append(regexpPatternsCompilers, regexp.MustCompile(patternToExclude)) } + wd, err := os.Getwd() + if err != nil { + return + } + err = filepath.WalkDir(".", func(path string, d fs.DirEntry, innerErr error) error { if innerErr != nil { return fmt.Errorf("an error has occurred when attempting to access or traverse the file system: %w", innerErr) @@ -129,7 +135,12 @@ func (cph *CommonPackageHandler) GetAllDescriptorFilesFullPaths(descriptorFilesS if innerErr != nil { return fmt.Errorf("couldn't retrieve file's absolute path for './%s': %w", path, innerErr) } - descriptorFilesFullPaths = append(descriptorFilesFullPaths, absFilePath) + resolvedPath, validateErr := utils.ValidateFileWithinDir(absFilePath, wd) + if validateErr != nil { + log.Warn(fmt.Sprintf("skipping descriptor file '%s' as it resolves outside the project directory: %s", path, validateErr.Error())) + return nil + } + descriptorFilesFullPaths = append(descriptorFilesFullPaths, resolvedPath) } } return nil @@ -153,3 +164,4 @@ func GetVulnerabilityRegexCompiler(impactedName, impactedVersion, dependencyLine regexpCompleteFormat := fmt.Sprintf(strings.ToLower(dependencyLineFormat), regexpFitImpactedName, regexpFitImpactedVersion) return regexp.MustCompile(regexpCompleteFormat) } + diff --git a/packagehandlers/pythonpackagehandler.go b/packagehandlers/pythonpackagehandler.go index 3ff6569ac..92e5b6e87 100644 --- a/packagehandlers/pythonpackagehandler.go +++ b/packagehandlers/pythonpackagehandler.go @@ -77,7 +77,15 @@ func (py *PythonPackageHandler) handlePip(vulnDetails *utils.VulnerabilityDetail if fixedFile == "" { return fmt.Errorf("impacted package %s not found, fix failed", vulnDetails.ImpactedDependencyName) } - if err = os.WriteFile(py.pipRequirementsFile, []byte(fixedFile), 0600); err != nil { + wd, err := os.Getwd() + if err != nil { + return err + } + resolvedWritePath, err := utils.ValidateFileWithinDir(filepath.Join(wd, py.pipRequirementsFile), wd) + if err != nil { + return fmt.Errorf("wrong requirements file input '%s': %s", py.pipRequirementsFile, err.Error()) + } + if err = os.WriteFile(resolvedWritePath, []byte(fixedFile), 0600); err != nil { err = fmt.Errorf("an error occured while writing the fixed version of %s to the requirements file:\n%s", vulnDetails.SuggestedFixedVersion, err.Error()) } return @@ -112,10 +120,11 @@ func (py *PythonPackageHandler) tryReadRequirementFile(file string) (string, err return "", err } fullPath := filepath.Join(wd, file) - if !strings.HasPrefix(filepath.Clean(fullPath), wd) { - return "", errors.New("wrong requirements file input: " + fullPath) + resolvedPath, err := utils.ValidateFileWithinDir(fullPath, wd) + if err != nil { + return "", fmt.Errorf("wrong requirements file input '%s': %s", file, err.Error()) } - data, err := os.ReadFile(filepath.Clean(file)) + data, err := os.ReadFile(resolvedPath) if err != nil { return "", errors.New("an error occurred while attempting to read the requirements file:\n" + err.Error()) } diff --git a/utils/pathutils.go b/utils/pathutils.go new file mode 100644 index 000000000..628208afb --- /dev/null +++ b/utils/pathutils.go @@ -0,0 +1,38 @@ +package utils + +import ( + "fmt" + "os" + "path/filepath" + "strings" +) + +// ValidateFileWithinDir receives a file path and an allowed directory, resolves any symlinks in both, +// and verifies the real (on-disk) target still resides under the allowed directory. +// Returns the resolved absolute path on success, or an error if the file escapes the allowed boundary. +func ValidateFileWithinDir(filePath, allowedDir string) (string, error) { + absPath, err := filepath.Abs(filePath) + if err != nil { + return "", fmt.Errorf("couldn't get absolute path for '%s': %s", filePath, err.Error()) + } + + realPath, err := filepath.EvalSymlinks(absPath) + if err != nil { + return "", fmt.Errorf("couldn't resolve symlinks for '%s': %s", filePath, err.Error()) + } + + realAllowedDir, err := filepath.EvalSymlinks(allowedDir) + if err != nil { + return "", fmt.Errorf("couldn't resolve symlinks for allowed directory '%s': %s", allowedDir, err.Error()) + } + + cleanAllowedDir := filepath.Clean(realAllowedDir) + realPath = filepath.Clean(realPath) + + // The resolved path must either equal the allowed directory itself, or sit underneath it. + if realPath != cleanAllowedDir && !strings.HasPrefix(realPath, cleanAllowedDir+string(os.PathSeparator)) { + return "", fmt.Errorf("file '%s' resolves to '%s' which is outside the allowed directory '%s'", filePath, realPath, allowedDir) + } + + return realPath, nil +} diff --git a/utils/pathutils_test.go b/utils/pathutils_test.go new file mode 100644 index 000000000..7c6b1e345 --- /dev/null +++ b/utils/pathutils_test.go @@ -0,0 +1,125 @@ +package utils + +import ( + "os" + "path/filepath" + "runtime" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// evalTempDir creates a temporary directory and resolves any symlinks in its path. +// Some platforms use symlinked temp directories (e.g., /var -> /private/var on macOS), +// which would cause path comparisons to fail without resolving first. +func evalTempDir(t *testing.T) string { + t.Helper() + dir := t.TempDir() + resolved, err := filepath.EvalSymlinks(dir) + require.NoError(t, err) + return resolved +} + +func TestValidateFileWithinDir(t *testing.T) { + tmpDir := evalTempDir(t) + + regularFile := filepath.Join(tmpDir, "requirements.txt") + require.NoError(t, os.WriteFile(regularFile, []byte("requests==2.28.0"), 0600)) + + resolvedPath, err := ValidateFileWithinDir(regularFile, tmpDir) + assert.NoError(t, err) + assert.Equal(t, regularFile, resolvedPath) +} + +func TestValidateFileWithinDir_Subdirectory(t *testing.T) { + tmpDir := evalTempDir(t) + + subDir := filepath.Join(tmpDir, "subdir") + require.NoError(t, os.Mkdir(subDir, 0700)) + nestedFile := filepath.Join(subDir, "setup.py") + require.NoError(t, os.WriteFile(nestedFile, []byte("setup()"), 0600)) + + resolvedPath, err := ValidateFileWithinDir(nestedFile, tmpDir) + assert.NoError(t, err) + assert.Equal(t, nestedFile, resolvedPath) +} + +func TestValidateFileWithinDir_SymlinkEscape(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("Symlink tests are not supported on Windows") + } + + tmpDir := evalTempDir(t) + outsideDir := evalTempDir(t) + + outsideFile := filepath.Join(outsideDir, "secret.txt") + require.NoError(t, os.WriteFile(outsideFile, []byte("sensitive data"), 0600)) + + // Create a symlink inside the allowed dir that points outside + symlinkPath := filepath.Join(tmpDir, "requirements.txt") + require.NoError(t, os.Symlink(outsideFile, symlinkPath)) + + _, err := ValidateFileWithinDir(symlinkPath, tmpDir) + assert.Error(t, err) + assert.Contains(t, err.Error(), "outside the allowed directory") +} + +func TestValidateFileWithinDir_SymlinkWithinDir(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("Symlink tests are not supported on Windows") + } + + tmpDir := evalTempDir(t) + + realFile := filepath.Join(tmpDir, "real-requirements.txt") + require.NoError(t, os.WriteFile(realFile, []byte("requests==2.28.0"), 0600)) + + // A symlink that resolves to a file still inside the allowed dir should be accepted + symlinkPath := filepath.Join(tmpDir, "requirements.txt") + require.NoError(t, os.Symlink(realFile, symlinkPath)) + + resolvedPath, err := ValidateFileWithinDir(symlinkPath, tmpDir) + assert.NoError(t, err) + assert.Equal(t, realFile, resolvedPath) +} + +func TestValidateFileWithinDir_PathTraversal(t *testing.T) { + tmpDir := evalTempDir(t) + outsideDir := evalTempDir(t) + + outsideFile := filepath.Join(outsideDir, "secret.txt") + require.NoError(t, os.WriteFile(outsideFile, []byte("sensitive data"), 0600)) + + _, err := ValidateFileWithinDir(outsideFile, tmpDir) + assert.Error(t, err) + assert.Contains(t, err.Error(), "outside the allowed directory") +} + +func TestValidateFileWithinDir_NonexistentFile(t *testing.T) { + tmpDir := evalTempDir(t) + + _, err := ValidateFileWithinDir(filepath.Join(tmpDir, "nonexistent.txt"), tmpDir) + assert.Error(t, err) + assert.Contains(t, err.Error(), "couldn't resolve symlinks") +} + +func TestValidateFileWithinDir_DirSymlinkEscape(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("Symlink tests are not supported on Windows") + } + + tmpDir := evalTempDir(t) + outsideDir := evalTempDir(t) + + outsideFile := filepath.Join(outsideDir, "secret.txt") + require.NoError(t, os.WriteFile(outsideFile, []byte("sensitive data"), 0600)) + + // A symlinked directory pointing outside the workspace should be caught + symlinkDir := filepath.Join(tmpDir, "linked-dir") + require.NoError(t, os.Symlink(outsideDir, symlinkDir)) + + _, err := ValidateFileWithinDir(filepath.Join(symlinkDir, "secret.txt"), tmpDir) + assert.Error(t, err) + assert.Contains(t, err.Error(), "outside the allowed directory") +}