diff --git a/cmd/root_test.go b/cmd/root_test.go index be520c62..5afce7f3 100644 --- a/cmd/root_test.go +++ b/cmd/root_test.go @@ -2205,6 +2205,36 @@ func TestStdin(t *testing.T) { }), ) + // Try a file that's more subtly outside of the project root. + os.Stdin = test.TempFile(t, "", "stdin", &contents) + + treefmt(t, + withArgs("--stdin", "foo/../../test.nix"), + withError(func(as *require.Assertions, err error) { + as.ErrorContains(err, "path foo/../../test.nix not inside the tree root "+tempDir) + }), + withStderr(func(out []byte) { + as.Contains(string(out), "Error: failed to create walker: path foo/../../test.nix not inside the tree root") + }), + ) + + // Try a file with a funky name, but is *not* outside of the project root. + os.Stdin = test.TempFile(t, "", "stdin", &contents) + treefmt(t, + withArgs("--stdin", "..lotsadots.nix"), + withNoError(t), + withStats(t, map[stats.Type]int{ + stats.Traversed: 1, + stats.Matched: 1, + stats.Formatted: 1, + stats.Changed: 1, + }), + withStdout(func(out []byte) { + as.Equal(`{ ...}: "hello" +`, string(out)) + }), + ) + // try some markdown instead contents = ` | col1 | col2 | @@ -2254,6 +2284,30 @@ help: as.Equal(`# print this message help: just --list --list-submodules --unsorted +`, string(out)) + }), + ) + + // Try from a subdirectory, using .. to get to a parent directory that is + // still inside the project root. + // + // Note: we called `test.ChangeWorkDir` at the start of the test, which + // will restore the working directory during test cleanup. + t.Chdir("go") + os.Stdin = test.TempFile(t, "", "stdin", &contents) + treefmt(t, + withArgs("--stdin", "../foo/justfile"), + withNoError(t), + withStats(t, map[stats.Type]int{ + stats.Traversed: 1, + stats.Matched: 1, + stats.Formatted: 1, + stats.Changed: 1, + }), + withStdout(func(out []byte) { + as.Equal(`# print this message +help: + just --list --list-submodules --unsorted `, string(out)) }), ) diff --git a/walk/walk.go b/walk/walk.go index 00c6aa0a..11a953dc 100644 --- a/walk/walk.go +++ b/walk/walk.go @@ -276,21 +276,6 @@ func NewCompositeReader( readers := make([]Reader, len(paths)) - // check we have received 1 path for the stdin walk type - if walkType == Stdin { - if len(paths) != 1 { - return nil, errors.New("stdin walk requires exactly one path") - } - - path := paths[0] - - if strings.HasPrefix(path, "..") { - return nil, fmt.Errorf("path %s not inside the tree root %s", path, root) - } - - return NewStdinReader(root, path, statz), nil - } - // create a reader for each provided path for idx, path := range paths { var ( @@ -308,13 +293,30 @@ func NewCompositeReader( return nil, fmt.Errorf("error computing relative path from %s to %s: %w", root, resolvedPath, err) } - if strings.HasPrefix(relativePath, "..") { + isInsideTreeRoot, err := isDescendant(path, root) + if err != nil { + return nil, fmt.Errorf("error checking if %s is inside the tree root %s", path, root) + } + + if !isInsideTreeRoot { return nil, fmt.Errorf("path %s not inside the tree root %s (relative path: %s)", path, root, relativePath) } + if walkType == Stdin { + if len(paths) != 1 { + return nil, errors.New("stdin walk requires exactly one path") + } + + return NewStdinReader(root, relativePath, statz), nil + } + // check the path exists info, err = os.Lstat(resolvedPath) if err != nil { + if os.IsNotExist(err) { + return nil, fmt.Errorf("path %s not found: %w", resolvedPath, err) + } + return nil, fmt.Errorf("failed to stat %s: %w", resolvedPath, err) } @@ -336,7 +338,8 @@ func NewCompositeReader( }, nil } -// Resolve a path to an absolute path, resolving any symlinks along the way. +// Resolve a path to an absolute path. +// Furthermore, if the path exists, any symlinks in its components are resolved. func resolvePath(path string) (string, error) { log.Debugf("Resolving path '%s'", path) @@ -347,8 +350,28 @@ func resolvePath(path string) (string, error) { resolvedPath, err := filepath.EvalSymlinks(absolutePath) if err != nil { - return "", fmt.Errorf("path %s not found: %w", absolutePath, err) + if os.IsNotExist(err) { + log.Debugf("Path '%s' does not exist, treating it as resolved", absolutePath) + + return absolutePath, nil + } + + return "", fmt.Errorf("error evaluating symlinks of %s: %w", absolutePath, err) } return resolvedPath, nil } + +func isDescendant(potentialChild string, potentialAncestor string) (bool, error) { + resolvedChild, err := resolvePath(potentialChild) + if err != nil { + return false, fmt.Errorf("error resolving %s: %w", potentialChild, err) + } + + relPath, err := filepath.Rel(potentialAncestor, resolvedChild) + if err != nil { + return false, fmt.Errorf("failed to compute relative path from %s to %s: %w", potentialAncestor, resolvedChild, err) + } + + return relPath != ".." && !strings.HasPrefix(relPath, ".."+string(os.PathSeparator)), nil +}