From a8336782f3d4941a7e0858e89a51cf06e55a59ab Mon Sep 17 00:00:00 2001 From: Tobias Mayer Date: Wed, 15 Apr 2026 10:49:57 +0200 Subject: [PATCH 01/18] Document custom walker configuration Add documentation for selecting a custom walker with the global walk option and defining its command and options under [walker.]. Document that custom walker commands run from the tree root, emit one relative path per line, and do not need to handle path arguments. --- cmd/init/init.toml | 9 +++++- docs/site/getting-started/configure.md | 40 +++++++++++++++++++++++++- 2 files changed, 47 insertions(+), 2 deletions(-) diff --git a/cmd/init/init.toml b/cmd/init/init.toml index 375686b3..d7079872 100644 --- a/cmd/init/init.toml +++ b/cmd/init/init.toml @@ -48,9 +48,16 @@ # verbose = 2 # The method used to traverse the files within the tree root -# Currently, we support 'auto', 'git', 'jujutsu', or 'filesystem' +# Built-in values are 'auto', 'git', 'jujutsu', and 'filesystem' +# You can also set this to the name of a configured custom walker # Env $TREEFMT_WALK # walk = "filesystem" +# walk = "myWalker" + +# Custom walkers are configured with [walker.] +# [walker.myWalker] +# command = "command-to-run" +# options = [] [formatter.mylanguage] # Command to execute diff --git a/docs/site/getting-started/configure.md b/docs/site/getting-started/configure.md index 77f55adb..f4e91fa3 100644 --- a/docs/site/getting-started/configure.md +++ b/docs/site/getting-started/configure.md @@ -396,7 +396,8 @@ Set the verbosity level of logs: ### `walk` The method used to traverse the files within the tree root. -Currently, we support 'auto', 'git', 'jujutsu' or 'filesystem' +Built-in values are `auto`, `git`, `jujutsu`, and `filesystem`. +You can also set this to the name of a configured [custom walker](#walker-options). === "Flag" @@ -416,6 +417,14 @@ Currently, we support 'auto', 'git', 'jujutsu' or 'filesystem' walk = "filesystem" ``` + ```toml + walk = "myWalker" + + [walker.myWalker] + command = "command-to-run" + options = [] + ``` + ### `working-dir` Run as if `treefmt` was started in the specified working directory instead of the current working directory. @@ -433,6 +442,35 @@ Run as if `treefmt` was started in the specified working directory instead of th TREEFMT_WORKING_DIR=/tmp/foo treefmt ``` +## Walker Options + +Custom walkers are configured using a [table](https://toml.io/en/v1.0.0#table) entry in `treefmt.toml` of the form +`[walker.]`. +To use a custom walker, set the global [`walk`](#walk) option to the same name: + +```toml +walk = "myWalker" + +[walker.myWalker] +command = "command-to-run" +options = [] +``` + +### `command` + +The command to invoke when walking the tree. +`treefmt` runs the command from the tree root. +The command must write one path per line to `stdout`. +Each path must be relative to the tree root. + +When you pass directory paths to `treefmt`, the walker command still runs for the tree root. +`treefmt` filters the command output to the requested directories. +The walker command doesn't need to implement path argument handling. + +### `options` + +An optional list of args to be passed to `command`. + ## Formatter Options Formatters are configured using a [table](https://toml.io/en/v1.0.0#table) entry in `treefmt.toml` of the form From d99461c3ae2398054018874339e05fcd28c691f3 Mon Sep 17 00:00:00 2001 From: Tobias Mayer Date: Wed, 15 Apr 2026 11:07:34 +0200 Subject: [PATCH 02/18] Add custom walker configuration Parse [walker.] tables into the configuration model and validate walker names, commands, and walk values. Accept documented camelCase walker names even when Viper normalizes table keys. --- config/config.go | 53 ++++++++++++++++++++ config/config_test.go | 111 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 164 insertions(+) diff --git a/config/config.go b/config/config.go index 4c97de11..83aa09da 100644 --- a/config/config.go +++ b/config/config.go @@ -45,6 +45,7 @@ type Config struct { Stdin bool `mapstructure:"stdin" toml:"-"` // not allowed in config FormatterConfigs map[string]*Formatter `mapstructure:"formatter" toml:"formatter,omitempty"` + WalkerConfigs map[string]*Walker `mapstructure:"walker" toml:"walker,omitempty"` Global struct { // Deprecated: Use Excludes @@ -68,6 +69,13 @@ type Formatter struct { NoPositionalArgSupport *bool `mapstructure:"no-positional-arg-support" toml:"no-positional-arg-support"` } +type Walker struct { + // Command is the command to invoke when walking the tree. + Command string `mapstructure:"command" toml:"command"` + // Options are an optional list of args to be passed to Command. + Options []string `mapstructure:"options,omitempty" toml:"options,omitempty"` +} + // SetFlags appends our flags to the provided flag set. // We have a flag matching most entries in Config, taking care to ensure the name matches the field name defined in the // mapstructure tag. @@ -200,6 +208,10 @@ func FromViper(v *viper.Viper) (*Config, error) { return nil, fmt.Errorf("failed to unmarshal config: %w", err) } + if cfg.Walk == "" { + cfg.Walk = walk.Auto.String() + } + // resolve the working directory to an absolute path cfg.WorkingDirectory, err = filepath.Abs(cfg.WorkingDirectory) if err != nil { @@ -233,6 +245,47 @@ func FromViper(v *viper.Viper) (*Config, error) { } } + for name, walkerCfg := range cfg.WalkerConfigs { + if !nameRegex.MatchString(name) { + return nil, fmt.Errorf( + "walker name %q is invalid, must be of the form %s", + name, nameRegex.String(), + ) + } + + if _, err := walk.TypeString(name); err == nil { + return nil, fmt.Errorf("walker name %q is reserved for a built-in walk type", name) + } + + if walkerCfg.Command == "" { + return nil, fmt.Errorf("walker %v has no command", name) + } + } + + if _, err := walk.TypeString(cfg.Walk); err != nil { + if !nameRegex.MatchString(cfg.Walk) { + return nil, fmt.Errorf( + "walk value %q is invalid, must be a built-in walk type or a walker name of the form %s", + cfg.Walk, nameRegex.String(), + ) + } + + if _, ok := cfg.WalkerConfigs[cfg.Walk]; !ok { + for name, walkerCfg := range cfg.WalkerConfigs { + if strings.EqualFold(name, cfg.Walk) { + cfg.WalkerConfigs[cfg.Walk] = walkerCfg + ok = true + + break + } + } + + if !ok { + return nil, fmt.Errorf("walker %v not found in config", cfg.Walk) + } + } + } + // filter formatters based on provided names if len(cfg.Formatters) > 0 { filtered := make(map[string]*Formatter) diff --git a/config/config_test.go b/config/config_test.go index 218b87e7..f019e0f8 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -591,6 +591,117 @@ func TestWalk(t *testing.T) { checkValue("auto") } +func TestWalkers(t *testing.T) { + t.Run("configured walker", func(t *testing.T) { + as := require.New(t) + + cfg := &config.Config{ + Walk: "myWalker", + WalkerConfigs: map[string]*config.Walker{ + "myWalker": { + Command: "command-to-run", + Options: []string{ + "--foo", + "bar", + }, + }, + }, + } + + v, _ := newViper(t) + + readValue(t, v, cfg, func(cfg *config.Config) { + walker, ok := cfg.WalkerConfigs["myWalker"] + as.True(ok, "walker not found") + as.Equal("myWalker", cfg.Walk) + as.Equal("command-to-run", walker.Command) + as.Equal([]string{"--foo", "bar"}, walker.Options) + }) + }) + + t.Run("missing walker", func(t *testing.T) { + as := require.New(t) + + cfg := &config.Config{ + Walk: "myWalker", + } + + v, _ := newViper(t) + + readError(t, v, cfg, func(err error) { + as.ErrorContains(err, "walker myWalker not found in config") + }) + }) + + t.Run("empty command", func(t *testing.T) { + as := require.New(t) + + cfg := &config.Config{ + Walk: "myWalker", + WalkerConfigs: map[string]*config.Walker{ + "myWalker": {}, + }, + } + + v, _ := newViper(t) + + readError(t, v, cfg, func(err error) { + as.ErrorContains(err, "walker mywalker has no command") + }) + }) + + t.Run("invalid walker name", func(t *testing.T) { + as := require.New(t) + + cfg := &config.Config{ + Walk: "myWalker", + WalkerConfigs: map[string]*config.Walker{ + "my/walker": { + Command: "command-to-run", + }, + }, + } + + v, _ := newViper(t) + + readError(t, v, cfg, func(err error) { + as.ErrorContains(err, "walker name \"my/walker\" is invalid") + }) + }) + + t.Run("reserved walker name", func(t *testing.T) { + as := require.New(t) + + cfg := &config.Config{ + WalkerConfigs: map[string]*config.Walker{ + "git": { + Command: "command-to-run", + }, + }, + } + + v, _ := newViper(t) + + readError(t, v, cfg, func(err error) { + as.ErrorContains(err, "walker name \"git\" is reserved for a built-in walk type") + }) + }) + + t.Run("invalid walk value", func(t *testing.T) { + as := require.New(t) + + cfg := &config.Config{ + Walk: "my.walker", + } + + v, _ := newViper(t) + + readError(t, v, cfg, func(err error) { + as.ErrorContains(err, "walk value \"my.walker\" is invalid") + }) + }) +} + func TestWorkingDirectory(t *testing.T) { as := require.New(t) From 3ecacdc2955acd626c75cfd4b5e7dc01637dffdd Mon Sep 17 00:00:00 2001 From: Tobias Mayer Date: Wed, 15 Apr 2026 11:10:58 +0200 Subject: [PATCH 03/18] Add custom walker reader Introduce a command-backed walker reader that runs from the tree root, reads one path per stdout line, filters configured subpaths inside treefmt, and converts accepted paths into walk.File values. Route custom walk names through the format command without changing built-in walkers. --- cmd/format/format.go | 26 +++++- walk/custom.go | 214 +++++++++++++++++++++++++++++++++++++++++++ walk/custom_test.go | 119 ++++++++++++++++++++++++ walk/walk.go | 79 ++++++++++------ 4 files changed, 408 insertions(+), 30 deletions(-) create mode 100644 walk/custom.go create mode 100644 walk/custom_test.go diff --git a/cmd/format/format.go b/cmd/format/format.go index 2efef587..1e9957ae 100644 --- a/cmd/format/format.go +++ b/cmd/format/format.go @@ -107,13 +107,13 @@ func Run(v *viper.Viper, statz *stats.Stats, cmd *cobra.Command, paths []string) cancel() }() - // parse the walk type - walkType, err := walk.TypeString(cfg.Walk) + // parse the walk selector + walkSelector, err := newWalkSelector(cfg) if err != nil { return fmt.Errorf("invalid walk type: %w", err) } - if walkType == walk.Stdin && len(paths) != 1 { + if walkSelector.Custom == nil && walkSelector.Type == walk.Stdin && len(paths) != 1 { // check we have only received one path arg which we use for the file extension / matching to formatters return errors.New("exactly one path should be specified when using the --stdin flag") } @@ -125,7 +125,7 @@ func Run(v *viper.Viper, statz *stats.Stats, cmd *cobra.Command, paths []string) } // create a new walker for traversing the paths - walker, err := walk.NewCompositeReader(walkType, cfg.TreeRoot, paths, db, statz) + walker, err := walk.NewCompositeReader(walkSelector, cfg.TreeRoot, paths, db, statz) if err != nil { return fmt.Errorf("failed to create walker: %w", err) } @@ -205,3 +205,21 @@ func Run(v *viper.Viper, statz *stats.Stats, cmd *cobra.Command, paths []string) return nil } + +func newWalkSelector(cfg *config.Config) (walk.Selector, error) { + walkType, err := walk.TypeString(cfg.Walk) + if err == nil { + return walk.BuiltinSelector(walkType), nil + } + + walkerCfg, ok := cfg.WalkerConfigs[cfg.Walk] + if !ok { + return walk.Selector{}, fmt.Errorf("walker %v not found in config", cfg.Walk) + } + + return walk.CustomSelector(walk.CustomConfig{ + Name: cfg.Walk, + Command: walkerCfg.Command, + Options: walkerCfg.Options, + }), nil +} diff --git a/walk/custom.go b/walk/custom.go new file mode 100644 index 00000000..a9898d5e --- /dev/null +++ b/walk/custom.go @@ -0,0 +1,214 @@ +package walk + +import ( + "bufio" + "context" + "fmt" + "io" + "os" + "os/exec" + "path/filepath" + "strings" + + "github.com/charmbracelet/log" + "github.com/numtide/treefmt/v2/stats" + "golang.org/x/sync/errgroup" + "mvdan.cc/sh/v3/expand" + "mvdan.cc/sh/v3/interp" +) + +type CustomReader struct { + root string + path string + cfg CustomConfig + + log *log.Logger + stats *stats.Stats + + eg *errgroup.Group + scanner *bufio.Scanner +} + +func (c *CustomReader) Read(ctx context.Context, files []*File) (n int, err error) { + // ensure we record how many files we traversed + defer func() { + c.stats.Add(stats.Traversed, n) + }() + +LOOP: + for n < len(files) { + select { + // exit early if the context was cancelled + case <-ctx.Done(): + return n, ctx.Err() //nolint:wrapcheck + + default: + if !c.scanner.Scan() { + if err := c.scanner.Err(); err != nil { + return n, fmt.Errorf("failed to read custom walker output: %w", err) + } + + err = io.EOF + + break LOOP + } + + file, ok, err := c.file(c.scanner.Text()) + if err != nil { + return n, err + } else if !ok { + continue + } + + files[n] = file + n++ + } + } + + return n, err +} + +func (c *CustomReader) file(line string) (*File, bool, error) { + line = strings.TrimSuffix(line, "\r") + if line == "" { + return nil, false, nil + } + + relPath, err := c.relativePath(line) + if err != nil { + return nil, false, err + } + + if !containsPath(c.path, relPath) { + return nil, false, nil + } + + path := filepath.Join(c.root, relPath) + + c.log.Debugf("processing file: %s", path) + + info, err := os.Lstat(path) + switch { + case os.IsNotExist(err): + c.log.Warnf( + "Path %s was emitted by custom walker %s but appears to have been removed from the filesystem", + path, c.cfg.Name, + ) + + return nil, false, nil + case err != nil: + return nil, false, fmt.Errorf("failed to stat %s: %w", path, err) + case info.IsDir() || info.Mode()&os.ModeSymlink == os.ModeSymlink: + return nil, false, nil + } + + return &File{ + Path: path, + RelPath: relPath, + Info: info, + }, true, nil +} + +func (c *CustomReader) relativePath(entry string) (string, error) { + entry = filepath.Clean(entry) + if filepath.IsAbs(entry) { + relPath, err := filepath.Rel(c.root, entry) + if err != nil { + return "", fmt.Errorf("failed to determine a relative path for %s: %w", entry, err) + } + + entry = relPath + } + + if entry == ".." || strings.HasPrefix(entry, ".."+string(os.PathSeparator)) || filepath.IsAbs(entry) { + return "", fmt.Errorf("custom walker %s emitted path %s outside the tree root %s", c.cfg.Name, entry, c.root) + } + + return entry, nil +} + +func (c *CustomReader) Close() error { + err := c.eg.Wait() + if err != nil { + return fmt.Errorf("failed to wait for custom walker %s command to complete: %w", c.cfg.Name, err) + } + + return nil +} + +func containsPath(root string, path string) bool { + if root == "" || path == root { + return true + } + + relPath, err := filepath.Rel(root, path) + if err != nil { + return false + } + + return relPath != ".." && + !strings.HasPrefix(relPath, ".."+string(os.PathSeparator)) && + !filepath.IsAbs(relPath) +} + +func NewCustomReader( + root string, + path string, + statz *stats.Stats, + cfg CustomConfig, +) (*CustomReader, error) { + env := expand.ListEnviron(os.Environ()...) + + executable, err := interp.LookPathDir(root, env, cfg.Command) + if err != nil { + return nil, fmt.Errorf("failed to find custom walker %s command %q: %w", cfg.Name, cfg.Command, err) + } + + eg := &errgroup.Group{} + + cmd := exec.CommandContext(context.Background(), executable, cfg.Options...) //nolint:gosec + cmd.Dir = root + + stdout, err := cmd.StdoutPipe() + if err != nil { + return nil, fmt.Errorf("failed to create stdout pipe for custom walker %s: %w", cfg.Name, err) + } + + stderr, err := cmd.StderrPipe() + if err != nil { + return nil, fmt.Errorf("failed to create stderr pipe for custom walker %s: %w", cfg.Name, err) + } + + if err := cmd.Start(); err != nil { + return nil, fmt.Errorf("failed to start custom walker %s: %w", cfg.Name, err) + } + + eg.Go(func() error { + return cmd.Wait() //nolint:wrapcheck + }) + + eg.Go(func() error { + l := log.WithPrefix("walk | custom | " + cfg.Name + " | stderr") + + scanner := bufio.NewScanner(stderr) + for scanner.Scan() { + l.Debugf("%s", scanner.Text()) + } + + if err := scanner.Err(); err != nil { + return fmt.Errorf("failed to read custom walker %s stderr: %w", cfg.Name, err) + } + + return nil + }) + + return &CustomReader{ + root: root, + path: path, + cfg: cfg, + log: log.WithPrefix("walk | custom | " + cfg.Name), + stats: statz, + eg: eg, + scanner: bufio.NewScanner(stdout), + }, nil +} diff --git a/walk/custom_test.go b/walk/custom_test.go new file mode 100644 index 00000000..89d17099 --- /dev/null +++ b/walk/custom_test.go @@ -0,0 +1,119 @@ +package walk_test + +import ( + "context" + "io" + "os" + "path/filepath" + "testing" + "time" + + "github.com/numtide/treefmt/v2/stats" + "github.com/numtide/treefmt/v2/test" + "github.com/numtide/treefmt/v2/walk" + "github.com/stretchr/testify/require" +) + +func TestCustomReader(t *testing.T) { + as := require.New(t) + + tempDir := test.TempExamples(t) + + walkerPath := filepath.Join(tempDir, "walker") + as.NoError(os.WriteFile(walkerPath, []byte(`#!/usr/bin/env sh +for path in "$@"; do + printf '%s\n' "$path" +done +`), 0o700)) + + as.NoError(os.Symlink(filepath.Join(tempDir, "go", "main.go"), filepath.Join(tempDir, "link-to-main"))) + + statz := stats.New() + reader, err := walk.NewCustomReader(tempDir, "", &statz, walk.CustomConfig{ + Name: "myWalker", + Command: "./walker", + Options: []string{ + "go/main.go", + "go/go.mod", + "go", + "missing.txt", + "link-to-main", + }, + }) + as.NoError(err) + + files := make([]*walk.File, 8) + ctx, cancel := context.WithTimeout(t.Context(), time.Second) + n, err := reader.Read(ctx, files) + cancel() + + as.Equal(2, n) + as.ErrorIs(err, io.EOF) + as.Equal("go/main.go", files[0].RelPath) + as.Equal("go/go.mod", files[1].RelPath) + as.Equal(2, statz.Value(stats.Traversed)) + as.NoError(reader.Close()) +} + +func TestCustomReaderSubpath(t *testing.T) { + as := require.New(t) + + tempDir := test.TempExamples(t) + + walkerPath := filepath.Join(tempDir, "walker") + as.NoError(os.WriteFile(walkerPath, []byte(`#!/usr/bin/env sh +printf '%s\n' go/main.go haskell/Foo.hs go/go.mod +`), 0o700)) + + statz := stats.New() + reader, err := walk.NewCustomReader(tempDir, "go", &statz, walk.CustomConfig{ + Name: "myWalker", + Command: "./walker", + }) + as.NoError(err) + + files := make([]*walk.File, 8) + ctx, cancel := context.WithTimeout(t.Context(), time.Second) + n, err := reader.Read(ctx, files) + cancel() + + as.Equal(2, n) + as.ErrorIs(err, io.EOF) + as.Equal("go/main.go", files[0].RelPath) + as.Equal("go/go.mod", files[1].RelPath) + as.Equal(2, statz.Value(stats.Traversed)) + as.NoError(reader.Close()) +} + +func TestCustomReaderCommandFailure(t *testing.T) { + as := require.New(t) + + tempDir := test.TempExamples(t) + + walkerPath := filepath.Join(tempDir, "walker") + as.NoError(os.WriteFile(walkerPath, []byte(`#!/usr/bin/env sh +printf '%s\n' go/main.go +exit 7 +`), 0o700)) + + statz := stats.New() + reader, err := walk.NewCustomReader(tempDir, "", &statz, walk.CustomConfig{ + Name: "myWalker", + Command: "./walker", + }) + as.NoError(err) + + files := make([]*walk.File, 8) + ctx, cancel := context.WithTimeout(t.Context(), time.Second) + n, err := reader.Read(ctx, files) + cancel() + + as.Equal(1, n) + as.ErrorIs(err, io.EOF) + as.Equal("go/main.go", files[0].RelPath) + + err = reader.Close() + as.Error(err) + as.ErrorContains(err, "failed to wait for custom walker myWalker command to complete") + as.ErrorContains(err, "exit status 7") +} diff --git a/walk/walk.go b/walk/walk.go index 00c6aa0a..57812fcb 100644 --- a/walk/walk.go +++ b/walk/walk.go @@ -30,6 +30,29 @@ const ( const BatchSize = 1024 +type CustomConfig struct { + Name string + Command string + Options []string +} + +type Selector struct { + Type Type + Custom *CustomConfig +} + +func BuiltinSelector(walkType Type) Selector { + return Selector{ + Type: walkType, + } +} + +func CustomSelector(cfg CustomConfig) Selector { + return Selector{ + Custom: &cfg, + } +} + type ReleaseFunc func(ctx context.Context) error // File represents a file object with its path, relative path, file info, and potential cache entry. @@ -199,7 +222,7 @@ func (c *CompositeReader) Close() error { //nolint:ireturn func NewReader( - walkType Type, + selector Selector, root string, path string, db *bolt.DB, @@ -210,29 +233,33 @@ func NewReader( reader Reader ) - switch walkType { - case Auto: - // for now, we keep it simple and try git first, jujutsu second, and filesystem last - reader, err = NewReader(Git, root, path, db, statz) - if err != nil { - reader, err = NewReader(Jujutsu, root, path, db, statz) + if selector.Custom != nil { + reader, err = NewCustomReader(root, path, statz, *selector.Custom) + } else { + switch selector.Type { + case Auto: + // for now, we keep it simple and try git first, jujutsu second, and filesystem last + reader, err = NewReader(BuiltinSelector(Git), root, path, db, statz) if err != nil { - reader, err = NewReader(Filesystem, root, path, db, statz) + reader, err = NewReader(BuiltinSelector(Jujutsu), root, path, db, statz) + if err != nil { + reader, err = NewReader(BuiltinSelector(Filesystem), root, path, db, statz) + } } - } - return reader, err - case Stdin: - return nil, errors.New("stdin walk type is not supported") - case Filesystem: - reader = NewFilesystemReader(root, path, statz, BatchSize) - case Git: - reader, err = NewGitReader(root, path, statz) - case Jujutsu: - reader, err = NewJujutsuReader(root, path, statz) - - default: - return nil, fmt.Errorf("unknown walk type: %v", walkType) + return reader, err + case Stdin: + return nil, errors.New("stdin walk type is not supported") + case Filesystem: + reader = NewFilesystemReader(root, path, statz, BatchSize) + case Git: + reader, err = NewGitReader(root, path, statz) + case Jujutsu: + reader, err = NewJujutsuReader(root, path, statz) + + default: + return nil, fmt.Errorf("unknown walk type: %v", selector.Type) + } } if err != nil { @@ -253,7 +280,7 @@ func NewReader( // //nolint:ireturn func NewCompositeReader( - walkType Type, + selector Selector, root string, paths []string, db *bolt.DB, @@ -271,13 +298,13 @@ func NewCompositeReader( // if no paths are provided we default to processing the tree root if len(paths) == 0 { - return NewReader(walkType, root, "", db, statz) + return NewReader(selector, root, "", db, statz) } readers := make([]Reader, len(paths)) // check we have received 1 path for the stdin walk type - if walkType == Stdin { + if selector.Custom == nil && selector.Type == Stdin { if len(paths) != 1 { return nil, errors.New("stdin walk requires exactly one path") } @@ -320,10 +347,10 @@ func NewCompositeReader( if info.IsDir() { // for directories, we honour the walk type as we traverse them - readers[idx], err = NewReader(walkType, root, relativePath, db, statz) + readers[idx], err = NewReader(selector, root, relativePath, db, statz) } else { // for files, we enforce a simple filesystem read - readers[idx], err = NewReader(Filesystem, root, relativePath, db, statz) + readers[idx], err = NewReader(BuiltinSelector(Filesystem), root, relativePath, db, statz) } if err != nil { From cb50b1812b32967a83353217c3384758b8298574 Mon Sep 17 00:00:00 2001 From: Tobias Mayer Date: Wed, 15 Apr 2026 11:11:45 +0200 Subject: [PATCH 04/18] Test custom walker integration Cover selecting a custom walker from treefmt.toml, passing configured walker options, and filtering walker output when treefmt is invoked with a directory path argument. --- cmd/root_test.go | 61 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/cmd/root_test.go b/cmd/root_test.go index be520c62..1a58b7b2 100644 --- a/cmd/root_test.go +++ b/cmd/root_test.go @@ -1814,6 +1814,67 @@ func TestJujutsu(t *testing.T) { ) } +func TestCustomWalker(t *testing.T) { + as := require.New(t) + + tempDir := test.TempExamples(t) + configPath := filepath.Join(tempDir, "/treefmt.toml") + + test.ChangeWorkDir(t, tempDir) + + walkerPath := filepath.Join(tempDir, "custom-walker") + as.NoError(os.WriteFile(walkerPath, []byte(`#!/usr/bin/env sh +for path in "$@"; do + printf '%s\n' "$path" +done +`), 0o700)) + + cfg := &config.Config{ + Walk: "myWalker", + WalkerConfigs: map[string]*config.Walker{ + "myWalker": { + Command: "./custom-walker", + Options: []string{ + "go/main.go", + "go/go.mod", + "haskell/Foo.hs", + }, + }, + }, + FormatterConfigs: map[string]*config.Formatter{ + "echo": { + Command: "echo", // will not generate any underlying changes in the file + Includes: []string{"*"}, + }, + }, + } + + test.WriteConfig(t, configPath, cfg) + + treefmt(t, + withConfig(configPath, cfg), + withNoError(t), + withStats(t, map[stats.Type]int{ + stats.Traversed: 3, + stats.Matched: 3, + stats.Formatted: 3, + stats.Changed: 0, + }), + ) + + treefmt(t, + withArgs("-c", "go"), + withConfig(configPath, cfg), + withNoError(t), + withStats(t, map[stats.Type]int{ + stats.Traversed: 2, + stats.Matched: 2, + stats.Formatted: 2, + stats.Changed: 0, + }), + ) +} + func TestTreeRootCmd(t *testing.T) { as := require.New(t) From 232df26aeb5447244338ce7246ebb61106b69d13 Mon Sep 17 00:00:00 2001 From: Tobias Mayer Date: Wed, 15 Apr 2026 11:14:58 +0200 Subject: [PATCH 05/18] Make custom walker tests sandbox-compatible Use bash -c for custom walker test commands instead of temporary scripts with /usr/bin/env shebangs, so the tests also run inside the Nix build sandbox. --- cmd/root_test.go | 14 ++++---------- walk/custom_test.go | 35 ++++++++++++++--------------------- 2 files changed, 18 insertions(+), 31 deletions(-) diff --git a/cmd/root_test.go b/cmd/root_test.go index 1a58b7b2..ac572df7 100644 --- a/cmd/root_test.go +++ b/cmd/root_test.go @@ -1815,26 +1815,20 @@ func TestJujutsu(t *testing.T) { } func TestCustomWalker(t *testing.T) { - as := require.New(t) - tempDir := test.TempExamples(t) configPath := filepath.Join(tempDir, "/treefmt.toml") test.ChangeWorkDir(t, tempDir) - walkerPath := filepath.Join(tempDir, "custom-walker") - as.NoError(os.WriteFile(walkerPath, []byte(`#!/usr/bin/env sh -for path in "$@"; do - printf '%s\n' "$path" -done -`), 0o700)) - cfg := &config.Config{ Walk: "myWalker", WalkerConfigs: map[string]*config.Walker{ "myWalker": { - Command: "./custom-walker", + Command: "bash", Options: []string{ + "-c", + `for path in "$@"; do printf '%s\n' "$path"; done`, + "custom-walker", "go/main.go", "go/go.mod", "haskell/Foo.hs", diff --git a/walk/custom_test.go b/walk/custom_test.go index 89d17099..bf988c40 100644 --- a/walk/custom_test.go +++ b/walk/custom_test.go @@ -19,20 +19,16 @@ func TestCustomReader(t *testing.T) { tempDir := test.TempExamples(t) - walkerPath := filepath.Join(tempDir, "walker") - as.NoError(os.WriteFile(walkerPath, []byte(`#!/usr/bin/env sh -for path in "$@"; do - printf '%s\n' "$path" -done -`), 0o700)) - as.NoError(os.Symlink(filepath.Join(tempDir, "go", "main.go"), filepath.Join(tempDir, "link-to-main"))) statz := stats.New() reader, err := walk.NewCustomReader(tempDir, "", &statz, walk.CustomConfig{ Name: "myWalker", - Command: "./walker", + Command: "bash", Options: []string{ + "-c", + `for path in "$@"; do printf '%s\n' "$path"; done`, + "walker", "go/main.go", "go/go.mod", "go", @@ -60,15 +56,14 @@ func TestCustomReaderSubpath(t *testing.T) { tempDir := test.TempExamples(t) - walkerPath := filepath.Join(tempDir, "walker") - as.NoError(os.WriteFile(walkerPath, []byte(`#!/usr/bin/env sh -printf '%s\n' go/main.go haskell/Foo.hs go/go.mod -`), 0o700)) - statz := stats.New() reader, err := walk.NewCustomReader(tempDir, "go", &statz, walk.CustomConfig{ Name: "myWalker", - Command: "./walker", + Command: "bash", + Options: []string{ + "-c", + "printf '%s\n' go/main.go haskell/Foo.hs go/go.mod", + }, }) as.NoError(err) @@ -90,16 +85,14 @@ func TestCustomReaderCommandFailure(t *testing.T) { tempDir := test.TempExamples(t) - walkerPath := filepath.Join(tempDir, "walker") - as.NoError(os.WriteFile(walkerPath, []byte(`#!/usr/bin/env sh -printf '%s\n' go/main.go -exit 7 -`), 0o700)) - statz := stats.New() reader, err := walk.NewCustomReader(tempDir, "", &statz, walk.CustomConfig{ Name: "myWalker", - Command: "./walker", + Command: "bash", + Options: []string{ + "-c", + "printf '%s\n' go/main.go; exit 7", + }, }) as.NoError(err) From 5f6451007fd4d1b510cacefd92d85536b8be1bf7 Mon Sep 17 00:00:00 2001 From: Tobias Mayer Date: Fri, 15 May 2026 08:55:27 +0200 Subject: [PATCH 06/18] Fix custom walker pipe handling Make the custom walker own its stdout and stderr pipes instead of combining StdoutPipe with a concurrent Wait call. This avoids surfacing a spurious file-closed read error at process exit and preserves command failures for Close(). --- walk/custom.go | 56 ++++++++++++++++++++++++++++++++------------------ 1 file changed, 36 insertions(+), 20 deletions(-) diff --git a/walk/custom.go b/walk/custom.go index a9898d5e..8351ce42 100644 --- a/walk/custom.go +++ b/walk/custom.go @@ -9,6 +9,7 @@ import ( "os/exec" "path/filepath" "strings" + "sync" "github.com/charmbracelet/log" "github.com/numtide/treefmt/v2/stats" @@ -27,6 +28,9 @@ type CustomReader struct { eg *errgroup.Group scanner *bufio.Scanner + + waitMu sync.Mutex + waitErr error } func (c *CustomReader) Read(ctx context.Context, files []*File) (n int, err error) { @@ -133,6 +137,13 @@ func (c *CustomReader) Close() error { return fmt.Errorf("failed to wait for custom walker %s command to complete: %w", c.cfg.Name, err) } + c.waitMu.Lock() + defer c.waitMu.Unlock() + + if c.waitErr != nil { + return fmt.Errorf("failed to wait for custom walker %s command to complete: %w", c.cfg.Name, c.waitErr) + } + return nil } @@ -169,22 +180,35 @@ func NewCustomReader( cmd := exec.CommandContext(context.Background(), executable, cfg.Options...) //nolint:gosec cmd.Dir = root - stdout, err := cmd.StdoutPipe() - if err != nil { - return nil, fmt.Errorf("failed to create stdout pipe for custom walker %s: %w", cfg.Name, err) - } + stdout, stdoutW := io.Pipe() + stderr, stderrW := io.Pipe() - stderr, err := cmd.StderrPipe() - if err != nil { - return nil, fmt.Errorf("failed to create stderr pipe for custom walker %s: %w", cfg.Name, err) - } + cmd.Stdout = stdoutW + cmd.Stderr = stderrW - if err := cmd.Start(); err != nil { - return nil, fmt.Errorf("failed to start custom walker %s: %w", cfg.Name, err) + reader := &CustomReader{ + root: root, + path: path, + cfg: cfg, + log: log.WithPrefix("walk | custom | " + cfg.Name), + stats: statz, + eg: eg, + scanner: bufio.NewScanner(stdout), } eg.Go(func() error { - return cmd.Wait() //nolint:wrapcheck + err := cmd.Run() + + reader.waitMu.Lock() + reader.waitErr = err + reader.waitMu.Unlock() + + closeErr := stdoutW.Close() + if stderrCloseErr := stderrW.Close(); stderrCloseErr != nil && closeErr == nil { + closeErr = stderrCloseErr + } + + return closeErr }) eg.Go(func() error { @@ -202,13 +226,5 @@ func NewCustomReader( return nil }) - return &CustomReader{ - root: root, - path: path, - cfg: cfg, - log: log.WithPrefix("walk | custom | " + cfg.Name), - stats: statz, - eg: eg, - scanner: bufio.NewScanner(stdout), - }, nil + return reader, nil } From 36f66753b6d3208481bd5d14113c786f750c4d01 Mon Sep 17 00:00:00 2001 From: Tobias Mayer Date: Fri, 15 May 2026 08:58:44 +0200 Subject: [PATCH 07/18] Fix custom walker lint issues Refactor custom walker configuration lookup to satisfy nesting limits, remove an unused gosec suppression, and adjust test spacing to satisfy golangci-lint in the Nix check derivation. --- config/config.go | 30 ++++++++++++++++++------------ walk/custom.go | 2 +- walk/custom_test.go | 3 +++ 3 files changed, 22 insertions(+), 13 deletions(-) diff --git a/config/config.go b/config/config.go index 83aa09da..34a29f04 100644 --- a/config/config.go +++ b/config/config.go @@ -270,20 +270,16 @@ func FromViper(v *viper.Viper) (*Config, error) { ) } - if _, ok := cfg.WalkerConfigs[cfg.Walk]; !ok { - for name, walkerCfg := range cfg.WalkerConfigs { - if strings.EqualFold(name, cfg.Walk) { - cfg.WalkerConfigs[cfg.Walk] = walkerCfg - ok = true - - break - } - } + walkerCfg, ok := cfg.WalkerConfigs[cfg.Walk] + if !ok { + walkerCfg, ok = findWalkerConfig(cfg.WalkerConfigs, cfg.Walk) + } - if !ok { - return nil, fmt.Errorf("walker %v not found in config", cfg.Walk) - } + if !ok { + return nil, fmt.Errorf("walker %v not found in config", cfg.Walk) } + + cfg.WalkerConfigs[cfg.Walk] = walkerCfg } // filter formatters based on provided names @@ -592,3 +588,13 @@ func fileExists(path string) bool { return fi.Mode().IsRegular() } + +func findWalkerConfig(walkers map[string]*Walker, name string) (*Walker, bool) { + for walkerName, walkerCfg := range walkers { + if strings.EqualFold(walkerName, name) { + return walkerCfg, true + } + } + + return nil, false +} diff --git a/walk/custom.go b/walk/custom.go index 8351ce42..13024607 100644 --- a/walk/custom.go +++ b/walk/custom.go @@ -177,7 +177,7 @@ func NewCustomReader( eg := &errgroup.Group{} - cmd := exec.CommandContext(context.Background(), executable, cfg.Options...) //nolint:gosec + cmd := exec.CommandContext(context.Background(), executable, cfg.Options...) cmd.Dir = root stdout, stdoutW := io.Pipe() diff --git a/walk/custom_test.go b/walk/custom_test.go index bf988c40..5d1f2183 100644 --- a/walk/custom_test.go +++ b/walk/custom_test.go @@ -41,6 +41,7 @@ func TestCustomReader(t *testing.T) { files := make([]*walk.File, 8) ctx, cancel := context.WithTimeout(t.Context(), time.Second) n, err := reader.Read(ctx, files) + cancel() as.Equal(2, n) @@ -70,6 +71,7 @@ func TestCustomReaderSubpath(t *testing.T) { files := make([]*walk.File, 8) ctx, cancel := context.WithTimeout(t.Context(), time.Second) n, err := reader.Read(ctx, files) + cancel() as.Equal(2, n) @@ -99,6 +101,7 @@ func TestCustomReaderCommandFailure(t *testing.T) { files := make([]*walk.File, 8) ctx, cancel := context.WithTimeout(t.Context(), time.Second) n, err := reader.Read(ctx, files) + cancel() as.Equal(1, n) From 4258265ebdbcfc70038967efa405e90eca228ad0 Mon Sep 17 00:00:00 2001 From: Tobias Mayer Date: Sun, 31 May 2026 11:29:29 +0200 Subject: [PATCH 08/18] Document custom walker path filters --- cmd/init/init.toml | 1 + docs/site/getting-started/configure.md | 17 ++++++++++++----- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/cmd/init/init.toml b/cmd/init/init.toml index d7079872..f5019d9c 100644 --- a/cmd/init/init.toml +++ b/cmd/init/init.toml @@ -55,6 +55,7 @@ # walk = "myWalker" # Custom walkers are configured with [walker.] +# They receive requested file and directory paths as positional args # [walker.myWalker] # command = "command-to-run" # options = [] diff --git a/docs/site/getting-started/configure.md b/docs/site/getting-started/configure.md index f4e91fa3..cf3b5c43 100644 --- a/docs/site/getting-started/configure.md +++ b/docs/site/getting-started/configure.md @@ -460,16 +460,23 @@ options = [] The command to invoke when walking the tree. `treefmt` runs the command from the tree root. -The command must write one path per line to `stdout`. -Each path must be relative to the tree root. -When you pass directory paths to `treefmt`, the walker command still runs for the tree root. -`treefmt` filters the command output to the requested directories. -The walker command doesn't need to implement path argument handling. +When you pass file or directory paths to `treefmt`, `treefmt` passes those paths to the walker command as positional +arguments relative to the tree root. +A custom walker should use those arguments to reduce the paths it emits. +This improves performance when you format a subtree in a large repository. + +`treefmt` still filters the command output to the requested paths. + +The command must write path records to `stdout`. +Records may be separated by newlines or NUL bytes. +Each path must be relative to the tree root or an absolute path inside the tree root. +Paths containing newlines are unsupported. ### `options` An optional list of args to be passed to `command`. +`treefmt` passes these args before any positional path filters. ## Formatter Options From 7b717d2bb3fb586d83a7b0aadfaf6a8b552ccb5c Mon Sep 17 00:00:00 2001 From: Tobias Mayer Date: Sun, 31 May 2026 11:29:29 +0200 Subject: [PATCH 09/18] Use explicit walker selectors --- cmd/format/format.go | 2 +- walk/walk.go | 37 +++++++++++++++++++++++++++---------- 2 files changed, 28 insertions(+), 11 deletions(-) diff --git a/cmd/format/format.go b/cmd/format/format.go index 1e9957ae..213e95fe 100644 --- a/cmd/format/format.go +++ b/cmd/format/format.go @@ -113,7 +113,7 @@ func Run(v *viper.Viper, statz *stats.Stats, cmd *cobra.Command, paths []string) return fmt.Errorf("invalid walk type: %w", err) } - if walkSelector.Custom == nil && walkSelector.Type == walk.Stdin && len(paths) != 1 { + if walkSelector.Type == walk.Builtin && walkSelector.WalkType == walk.Stdin && len(paths) != 1 { // check we have only received one path arg which we use for the file extension / matching to formatters return errors.New("exactly one path should be specified when using the --stdin flag") } diff --git a/walk/walk.go b/walk/walk.go index 57812fcb..1bcc461c 100644 --- a/walk/walk.go +++ b/walk/walk.go @@ -36,20 +36,28 @@ type CustomConfig struct { Options []string } +const ( + Builtin Type = iota + 1000 + Custom +) + type Selector struct { - Type Type - Custom *CustomConfig + Type Type + WalkType Type + Custom CustomConfig } func BuiltinSelector(walkType Type) Selector { return Selector{ - Type: walkType, + Type: Builtin, + WalkType: walkType, } } func CustomSelector(cfg CustomConfig) Selector { return Selector{ - Custom: &cfg, + Type: Custom, + Custom: cfg, } } @@ -233,10 +241,11 @@ func NewReader( reader Reader ) - if selector.Custom != nil { - reader, err = NewCustomReader(root, path, statz, *selector.Custom) - } else { - switch selector.Type { + switch selector.Type { + case Custom: + reader, err = NewCustomReader(root, path, statz, selector.Custom) + case Builtin: + switch selector.WalkType { case Auto: // for now, we keep it simple and try git first, jujutsu second, and filesystem last reader, err = NewReader(BuiltinSelector(Git), root, path, db, statz) @@ -257,9 +266,17 @@ func NewReader( case Jujutsu: reader, err = NewJujutsuReader(root, path, statz) + case Builtin, Custom: + return nil, fmt.Errorf("invalid built-in walk type: %v", selector.WalkType) + default: - return nil, fmt.Errorf("unknown walk type: %v", selector.Type) + return nil, fmt.Errorf("unknown walk type: %v", selector.WalkType) } + case Auto, Stdin, Filesystem, Git, Jujutsu: + return nil, fmt.Errorf("invalid selector type: %v", selector.Type) + + default: + return nil, fmt.Errorf("unknown selector type: %v", selector.Type) } if err != nil { @@ -304,7 +321,7 @@ func NewCompositeReader( readers := make([]Reader, len(paths)) // check we have received 1 path for the stdin walk type - if selector.Custom == nil && selector.Type == Stdin { + if selector.Type == Builtin && selector.WalkType == Stdin { if len(paths) != 1 { return nil, errors.New("stdin walk requires exactly one path") } From f3b518fcfb24d9c510c2ba0512aa252790d834e3 Mon Sep 17 00:00:00 2001 From: Tobias Mayer Date: Sun, 31 May 2026 11:30:14 +0200 Subject: [PATCH 10/18] Add shared path stream walker --- walk/custom.go | 227 ++---------------------------------- walk/custom_test.go | 67 ++++++++++- walk/path_stream.go | 274 ++++++++++++++++++++++++++++++++++++++++++++ walk/walk.go | 7 +- 4 files changed, 352 insertions(+), 223 deletions(-) create mode 100644 walk/path_stream.go diff --git a/walk/custom.go b/walk/custom.go index 13024607..c20a1fff 100644 --- a/walk/custom.go +++ b/walk/custom.go @@ -1,230 +1,19 @@ package walk -import ( - "bufio" - "context" - "fmt" - "io" - "os" - "os/exec" - "path/filepath" - "strings" - "sync" +import "github.com/numtide/treefmt/v2/stats" - "github.com/charmbracelet/log" - "github.com/numtide/treefmt/v2/stats" - "golang.org/x/sync/errgroup" - "mvdan.cc/sh/v3/expand" - "mvdan.cc/sh/v3/interp" -) - -type CustomReader struct { - root string - path string - cfg CustomConfig - - log *log.Logger - stats *stats.Stats - - eg *errgroup.Group - scanner *bufio.Scanner - - waitMu sync.Mutex - waitErr error -} - -func (c *CustomReader) Read(ctx context.Context, files []*File) (n int, err error) { - // ensure we record how many files we traversed - defer func() { - c.stats.Add(stats.Traversed, n) - }() - -LOOP: - for n < len(files) { - select { - // exit early if the context was cancelled - case <-ctx.Done(): - return n, ctx.Err() //nolint:wrapcheck - - default: - if !c.scanner.Scan() { - if err := c.scanner.Err(); err != nil { - return n, fmt.Errorf("failed to read custom walker output: %w", err) - } - - err = io.EOF - - break LOOP - } - - file, ok, err := c.file(c.scanner.Text()) - if err != nil { - return n, err - } else if !ok { - continue - } - - files[n] = file - n++ - } - } - - return n, err -} - -func (c *CustomReader) file(line string) (*File, bool, error) { - line = strings.TrimSuffix(line, "\r") - if line == "" { - return nil, false, nil - } - - relPath, err := c.relativePath(line) - if err != nil { - return nil, false, err - } - - if !containsPath(c.path, relPath) { - return nil, false, nil - } - - path := filepath.Join(c.root, relPath) - - c.log.Debugf("processing file: %s", path) - - info, err := os.Lstat(path) - switch { - case os.IsNotExist(err): - c.log.Warnf( - "Path %s was emitted by custom walker %s but appears to have been removed from the filesystem", - path, c.cfg.Name, - ) - - return nil, false, nil - case err != nil: - return nil, false, fmt.Errorf("failed to stat %s: %w", path, err) - case info.IsDir() || info.Mode()&os.ModeSymlink == os.ModeSymlink: - return nil, false, nil - } - - return &File{ - Path: path, - RelPath: relPath, - Info: info, - }, true, nil -} - -func (c *CustomReader) relativePath(entry string) (string, error) { - entry = filepath.Clean(entry) - if filepath.IsAbs(entry) { - relPath, err := filepath.Rel(c.root, entry) - if err != nil { - return "", fmt.Errorf("failed to determine a relative path for %s: %w", entry, err) - } - - entry = relPath - } - - if entry == ".." || strings.HasPrefix(entry, ".."+string(os.PathSeparator)) || filepath.IsAbs(entry) { - return "", fmt.Errorf("custom walker %s emitted path %s outside the tree root %s", c.cfg.Name, entry, c.root) - } - - return entry, nil -} - -func (c *CustomReader) Close() error { - err := c.eg.Wait() - if err != nil { - return fmt.Errorf("failed to wait for custom walker %s command to complete: %w", c.cfg.Name, err) - } - - c.waitMu.Lock() - defer c.waitMu.Unlock() - - if c.waitErr != nil { - return fmt.Errorf("failed to wait for custom walker %s command to complete: %w", c.cfg.Name, c.waitErr) - } - - return nil -} - -func containsPath(root string, path string) bool { - if root == "" || path == root { - return true - } - - relPath, err := filepath.Rel(root, path) - if err != nil { - return false - } - - return relPath != ".." && - !strings.HasPrefix(relPath, ".."+string(os.PathSeparator)) && - !filepath.IsAbs(relPath) -} +type CustomReader = PathStreamReader func NewCustomReader( root string, - path string, + pathFilters []string, statz *stats.Stats, cfg CustomConfig, ) (*CustomReader, error) { - env := expand.ListEnviron(os.Environ()...) - - executable, err := interp.LookPathDir(root, env, cfg.Command) - if err != nil { - return nil, fmt.Errorf("failed to find custom walker %s command %q: %w", cfg.Name, cfg.Command, err) - } - - eg := &errgroup.Group{} - - cmd := exec.CommandContext(context.Background(), executable, cfg.Options...) - cmd.Dir = root - - stdout, stdoutW := io.Pipe() - stderr, stderrW := io.Pipe() - - cmd.Stdout = stdoutW - cmd.Stderr = stderrW - - reader := &CustomReader{ - root: root, - path: path, - cfg: cfg, - log: log.WithPrefix("walk | custom | " + cfg.Name), - stats: statz, - eg: eg, - scanner: bufio.NewScanner(stdout), - } - - eg.Go(func() error { - err := cmd.Run() - - reader.waitMu.Lock() - reader.waitErr = err - reader.waitMu.Unlock() - - closeErr := stdoutW.Close() - if stderrCloseErr := stderrW.Close(); stderrCloseErr != nil && closeErr == nil { - closeErr = stderrCloseErr - } - - return closeErr + return NewPathStreamReader(root, statz, PathStreamConfig{ + Name: "custom walker " + cfg.Name, + Command: cfg.Command, + Options: cfg.Options, + PathFilters: pathFilters, }) - - eg.Go(func() error { - l := log.WithPrefix("walk | custom | " + cfg.Name + " | stderr") - - scanner := bufio.NewScanner(stderr) - for scanner.Scan() { - l.Debugf("%s", scanner.Text()) - } - - if err := scanner.Err(); err != nil { - return fmt.Errorf("failed to read custom walker %s stderr: %w", cfg.Name, err) - } - - return nil - }) - - return reader, nil } diff --git a/walk/custom_test.go b/walk/custom_test.go index 5d1f2183..40ff6842 100644 --- a/walk/custom_test.go +++ b/walk/custom_test.go @@ -22,7 +22,7 @@ func TestCustomReader(t *testing.T) { as.NoError(os.Symlink(filepath.Join(tempDir, "go", "main.go"), filepath.Join(tempDir, "link-to-main"))) statz := stats.New() - reader, err := walk.NewCustomReader(tempDir, "", &statz, walk.CustomConfig{ + reader, err := walk.NewCustomReader(tempDir, nil, &statz, walk.CustomConfig{ Name: "myWalker", Command: "bash", Options: []string{ @@ -58,7 +58,7 @@ func TestCustomReaderSubpath(t *testing.T) { tempDir := test.TempExamples(t) statz := stats.New() - reader, err := walk.NewCustomReader(tempDir, "go", &statz, walk.CustomConfig{ + reader, err := walk.NewCustomReader(tempDir, []string{"go"}, &statz, walk.CustomConfig{ Name: "myWalker", Command: "bash", Options: []string{ @@ -82,13 +82,74 @@ func TestCustomReaderSubpath(t *testing.T) { as.NoError(reader.Close()) } +func TestCustomReaderPassesPathFilters(t *testing.T) { + as := require.New(t) + + tempDir := test.TempExamples(t) + + statz := stats.New() + reader, err := walk.NewCustomReader(tempDir, []string{"go"}, &statz, walk.CustomConfig{ + Name: "myWalker", + Command: "bash", + Options: []string{ + "-c", + `for path in "$@"; do printf '%s\n' "$path/main.go" "$path/go.mod"; done`, + "walker", + }, + }) + as.NoError(err) + + files := make([]*walk.File, 8) + ctx, cancel := context.WithTimeout(t.Context(), time.Second) + n, err := reader.Read(ctx, files) + + cancel() + + as.Equal(2, n) + as.ErrorIs(err, io.EOF) + as.Equal("go/main.go", files[0].RelPath) + as.Equal("go/go.mod", files[1].RelPath) + as.Equal(2, statz.Value(stats.Traversed)) + as.NoError(reader.Close()) +} + +func TestPathStreamReaderDelimiters(t *testing.T) { + as := require.New(t) + + tempDir := test.TempExamples(t) + + statz := stats.New() + reader, err := walk.NewPathStreamReader(tempDir, &statz, walk.PathStreamConfig{ + Name: "test walker", + Command: "bash", + Options: []string{ + "-c", + "printf 'go/main.go\\0\\nhaskell/Foo.hs\\r\\n'", + }, + }) + as.NoError(err) + + files := make([]*walk.File, 8) + ctx, cancel := context.WithTimeout(t.Context(), time.Second) + n, err := reader.Read(ctx, files) + + cancel() + + as.Equal(2, n) + as.ErrorIs(err, io.EOF) + as.Equal("go/main.go", files[0].RelPath) + as.Equal("haskell/Foo.hs", files[1].RelPath) + as.Equal(2, statz.Value(stats.Traversed)) + as.NoError(reader.Close()) +} + func TestCustomReaderCommandFailure(t *testing.T) { as := require.New(t) tempDir := test.TempExamples(t) statz := stats.New() - reader, err := walk.NewCustomReader(tempDir, "", &statz, walk.CustomConfig{ + reader, err := walk.NewCustomReader(tempDir, nil, &statz, walk.CustomConfig{ Name: "myWalker", Command: "bash", Options: []string{ diff --git a/walk/path_stream.go b/walk/path_stream.go new file mode 100644 index 00000000..148b0234 --- /dev/null +++ b/walk/path_stream.go @@ -0,0 +1,274 @@ +package walk + +import ( + "bufio" + "context" + "fmt" + "io" + "os" + "os/exec" + "path/filepath" + "strings" + "sync" + + "github.com/charmbracelet/log" + "github.com/numtide/treefmt/v2/stats" + "golang.org/x/sync/errgroup" + "mvdan.cc/sh/v3/expand" + "mvdan.cc/sh/v3/interp" +) + +type PathStreamConfig struct { + Name string + Command string + Options []string + PathFilters []string +} + +type PathStreamReader struct { + root string + filters []string + cfg PathStreamConfig + + log *log.Logger + stats *stats.Stats + + eg *errgroup.Group + scanner *bufio.Scanner + + waitMu sync.Mutex + waitErr error +} + +func (p *PathStreamReader) Read(ctx context.Context, files []*File) (n int, err error) { + // ensure we record how many files we traversed + defer func() { + p.stats.Add(stats.Traversed, n) + }() + +LOOP: + for n < len(files) { + select { + // exit early if the context was cancelled + case <-ctx.Done(): + return n, ctx.Err() //nolint:wrapcheck + + default: + if !p.scanner.Scan() { + if err := p.scanner.Err(); err != nil { + return n, fmt.Errorf("failed to read %s output: %w", p.cfg.Name, err) + } + + err = io.EOF + + break LOOP + } + + file, ok, err := p.file(p.scanner.Text()) + if err != nil { + return n, err + } else if !ok { + continue + } + + files[n] = file + n++ + } + } + + return n, err +} + +func (p *PathStreamReader) file(record string) (*File, bool, error) { + record = strings.TrimSuffix(record, "\r") + if record == "" { + return nil, false, nil + } + + relPath, err := p.relativePath(record) + if err != nil { + return nil, false, err + } + + if !containsAnyPath(p.filters, relPath) { + return nil, false, nil + } + + path := filepath.Join(p.root, relPath) + + p.log.Debugf("processing file: %s", path) + + info, err := os.Lstat(path) + switch { + case os.IsNotExist(err): + p.log.Warnf( + "Path %s was emitted by %s but appears to have been removed from the filesystem", + path, p.cfg.Name, + ) + + return nil, false, nil + case err != nil: + return nil, false, fmt.Errorf("failed to stat %s: %w", path, err) + case info.IsDir() || info.Mode()&os.ModeSymlink == os.ModeSymlink: + return nil, false, nil + } + + return &File{ + Path: path, + RelPath: relPath, + Info: info, + }, true, nil +} + +func (p *PathStreamReader) relativePath(entry string) (string, error) { + entry = filepath.Clean(entry) + if filepath.IsAbs(entry) { + relPath, err := filepath.Rel(p.root, entry) + if err != nil { + return "", fmt.Errorf("failed to determine a relative path for %s: %w", entry, err) + } + + entry = relPath + } + + if entry == ".." || strings.HasPrefix(entry, ".."+string(os.PathSeparator)) || filepath.IsAbs(entry) { + return "", fmt.Errorf("%s emitted path %s outside the tree root %s", p.cfg.Name, entry, p.root) + } + + return entry, nil +} + +func (p *PathStreamReader) Close() error { + err := p.eg.Wait() + if err != nil { + return fmt.Errorf("failed to wait for %s command to complete: %w", p.cfg.Name, err) + } + + p.waitMu.Lock() + defer p.waitMu.Unlock() + + if p.waitErr != nil { + return fmt.Errorf("failed to wait for %s command to complete: %w", p.cfg.Name, p.waitErr) + } + + return nil +} + +func splitPathRecord(data []byte, atEOF bool) (advance int, token []byte, err error) { + for i, b := range data { + if b == '\n' || b == 0 { + return i + 1, data[:i], nil + } + } + + if atEOF && len(data) > 0 { + return len(data), data, nil + } + + return 0, nil, nil +} + +func containsAnyPath(filters []string, path string) bool { + if len(filters) == 0 { + return true + } + + for _, filter := range filters { + if containsPath(filter, path) { + return true + } + } + + return false +} + +func containsPath(root string, path string) bool { + if root == "" || root == "." || path == root { + return true + } + + relPath, err := filepath.Rel(root, path) + if err != nil { + return false + } + + return relPath != ".." && + !strings.HasPrefix(relPath, ".."+string(os.PathSeparator)) && + !filepath.IsAbs(relPath) +} + +func NewPathStreamReader( + root string, + statz *stats.Stats, + cfg PathStreamConfig, +) (*PathStreamReader, error) { + env := expand.ListEnviron(os.Environ()...) + + executable, err := interp.LookPathDir(root, env, cfg.Command) + if err != nil { + return nil, fmt.Errorf("failed to find %s command %q: %w", cfg.Name, cfg.Command, err) + } + + eg := &errgroup.Group{} + + args := append([]string{}, cfg.Options...) + args = append(args, cfg.PathFilters...) + + cmd := exec.CommandContext(context.Background(), executable, args...) + cmd.Dir = root + + // Don't use cmd.StdoutPipe here. Its docs say it is incorrect to call + // Cmd.Wait before all reads from the pipe have completed, and we wait for + // the command in a producer goroutine while Read drains stdout. + // See https://pkg.go.dev/os/exec#Cmd.StdoutPipe. + stdout, stdoutW := io.Pipe() + stderr, stderrW := io.Pipe() + + cmd.Stdout = stdoutW + cmd.Stderr = stderrW + + scanner := bufio.NewScanner(stdout) + scanner.Split(splitPathRecord) + + reader := &PathStreamReader{ + root: root, + filters: cfg.PathFilters, + cfg: cfg, + log: log.WithPrefix("walk | " + cfg.Name), + stats: statz, + eg: eg, + scanner: scanner, + } + + eg.Go(func() error { + err := cmd.Run() + + reader.waitMu.Lock() + reader.waitErr = err + reader.waitMu.Unlock() + + closeErr := stdoutW.Close() + if stderrCloseErr := stderrW.Close(); stderrCloseErr != nil && closeErr == nil { + closeErr = stderrCloseErr + } + + return closeErr + }) + + eg.Go(func() error { + l := log.WithPrefix("walk | " + cfg.Name + " | stderr") + + scanner := bufio.NewScanner(stderr) + for scanner.Scan() { + l.Debugf("%s", scanner.Text()) + } + + if err := scanner.Err(); err != nil { + return fmt.Errorf("failed to read %s stderr: %w", cfg.Name, err) + } + + return nil + }) + + return reader, nil +} diff --git a/walk/walk.go b/walk/walk.go index 1bcc461c..1950e972 100644 --- a/walk/walk.go +++ b/walk/walk.go @@ -243,7 +243,12 @@ func NewReader( switch selector.Type { case Custom: - reader, err = NewCustomReader(root, path, statz, selector.Custom) + var pathFilters []string + if path != "" { + pathFilters = []string{path} + } + + reader, err = NewCustomReader(root, pathFilters, statz, selector.Custom) case Builtin: switch selector.WalkType { case Auto: From 6fc73bac25771e3ebd550135386be75c73e18ee1 Mon Sep 17 00:00:00 2001 From: Tobias Mayer Date: Sun, 31 May 2026 11:31:25 +0200 Subject: [PATCH 11/18] Use path streams for built-in walkers --- cmd/root_test.go | 12 ++-- walk/git.go | 164 ++----------------------------------------- walk/git_test.go | 4 +- walk/jujutsu.go | 135 ++--------------------------------- walk/jujutsu_test.go | 4 +- walk/walk.go | 65 +++++++---------- 6 files changed, 53 insertions(+), 331 deletions(-) diff --git a/cmd/root_test.go b/cmd/root_test.go index ac572df7..79cf6970 100644 --- a/cmd/root_test.go +++ b/cmd/root_test.go @@ -1827,11 +1827,15 @@ func TestCustomWalker(t *testing.T) { Command: "bash", Options: []string{ "-c", - `for path in "$@"; do printf '%s\n' "$path"; done`, + `if [ "$#" -eq 0 ]; then set -- go haskell; fi +for path in "$@"; do + case "$path" in + go) printf '%s\n' go/main.go go/go.mod ;; + haskell) printf '%s\n' haskell/Foo.hs ;; + *) printf '%s\n' "$path" ;; + esac +done`, "custom-walker", - "go/main.go", - "go/go.mod", - "haskell/Foo.hs", }, }, }, diff --git a/walk/git.go b/walk/git.go index 607ba524..d39fcc5f 100644 --- a/walk/git.go +++ b/walk/git.go @@ -1,143 +1,17 @@ package walk import ( - "bufio" - "context" "fmt" - "io" - "os" - "os/exec" - "path/filepath" - "strconv" - "strings" - "github.com/charmbracelet/log" "github.com/numtide/treefmt/v2/git" "github.com/numtide/treefmt/v2/stats" - "golang.org/x/sync/errgroup" ) -type GitReader struct { - root string - path string - - log *log.Logger - stats *stats.Stats - - eg *errgroup.Group - scanner *bufio.Scanner -} - -func (g *GitReader) Read(ctx context.Context, files []*File) (n int, err error) { - // ensure we record how many files we traversed - defer func() { - g.stats.Add(stats.Traversed, n) - }() - - nextFile := func() (string, error) { - for line := g.scanner.Text(); len(line) > 0; line = g.scanner.Text() { - lineSplit := strings.Split(line, "\t") - - var stage, file string - // Untracked files just show as ``, while tracked files show as ` ` - if len(lineSplit) == 1 { - stage, file = "", lineSplit[0] - } else { - stage, file = lineSplit[0], lineSplit[1] - } - - // 160000 is the mode for submodules, skip them because they are separate projects that may have their own - // formatting rules - if strings.HasPrefix(stage, "160000") { - g.scanner.Scan() - - continue - } - - if file[0] != '"' { - return file, nil - } - - unquoted, err := strconv.Unquote(file) - if err != nil { - return "", fmt.Errorf("failed to unquote file %s: %w", file, err) - } - - return unquoted, nil - } - - return "", io.EOF - } - -LOOP: - - for n < len(files) { - select { - // exit early if the context was cancelled - case <-ctx.Done(): - return n, ctx.Err() //nolint:wrapcheck - - default: - // read the next file - if g.scanner.Scan() { - entry, err := nextFile() - if err != nil { - return n, err - } - - path := filepath.Join(g.root, g.path, entry) - - g.log.Debugf("processing file: %s", path) - - info, err := os.Lstat(path) - - switch { - case os.IsNotExist(err): - // the underlying file might have been removed - g.log.Warnf( - "Path %s is in the worktree but appears to have been removed from the filesystem", path, - ) - - continue - case err != nil: - return n, fmt.Errorf("failed to stat %s: %w", path, err) - case info.Mode()&os.ModeSymlink == os.ModeSymlink: - // we skip reporting symlinks stored in Git, they should - // point to local files which we would list anyway. - continue - } - - files[n] = &File{ - Path: path, - RelPath: filepath.Join(g.path, entry), - Info: info, - } - - n++ - } else { - // nothing more to read - err = io.EOF - - break LOOP - } - } - } - - return n, err -} - -func (g *GitReader) Close() error { - err := g.eg.Wait() - if err != nil { - return fmt.Errorf("failed to wait for git command to complete: %w", err) - } - - return nil -} +type GitReader = PathStreamReader func NewGitReader( root string, - path string, + pathFilters []string, statz *stats.Stats, ) (*GitReader, error) { // check if the root is a git repository @@ -150,34 +24,10 @@ func NewGitReader( return nil, fmt.Errorf("%s is not a git repository", root) } - // create an errgroup for executing in the background - eg := &errgroup.Group{} - - // create a pipe to capture the command output - r, w := io.Pipe() - - // create a command which will execute from the specified sub path within root - cmd := exec.CommandContext( - context.Background(), - "git", "ls-files", "--cached", "--others", "--exclude-standard", "--stage", - ) - cmd.Dir = filepath.Join(root, path) - cmd.Stdout = w - - // execute the command in the background - eg.Go(func() error { - return w.CloseWithError(cmd.Run()) + return NewPathStreamReader(root, statz, PathStreamConfig{ + Name: "git", + Command: "git", + Options: []string{"ls-files", "-z", "--cached", "--others", "--exclude-standard", "--full-name", "--"}, + PathFilters: pathFilters, }) - - // create a new scanner for reading the output - scanner := bufio.NewScanner(r) - - return &GitReader{ - eg: eg, - root: root, - path: path, - stats: statz, - scanner: scanner, - log: log.WithPrefix("walk | git"), - }, nil } diff --git a/walk/git_test.go b/walk/git_test.go index ce4e3d29..00b4bbe0 100644 --- a/walk/git_test.go +++ b/walk/git_test.go @@ -35,7 +35,7 @@ func TestGitReader(t *testing.T) { // read empty worktree statz := stats.New() - reader, err := walk.NewGitReader(tempDir, "", &statz) + reader, err := walk.NewGitReader(tempDir, nil, &statz) as.NoError(err) files := make([]*walk.File, 34) @@ -84,7 +84,7 @@ func TestGitReader(t *testing.T) { as.NoError(cmd.Run(), "failed to add everything to the index") statz = stats.New() - reader, err = walk.NewGitReader(tempDir, "", &statz) + reader, err = walk.NewGitReader(tempDir, nil, &statz) as.NoError(err) count := 0 diff --git a/walk/jujutsu.go b/walk/jujutsu.go index 4ee7a195..b2566be1 100644 --- a/walk/jujutsu.go +++ b/walk/jujutsu.go @@ -1,109 +1,17 @@ package walk import ( - "bufio" - "context" "fmt" - "io" - "os" - "os/exec" - "path/filepath" - "github.com/charmbracelet/log" "github.com/numtide/treefmt/v2/jujutsu" "github.com/numtide/treefmt/v2/stats" - "golang.org/x/sync/errgroup" ) -type JujutsuReader struct { - root string - path string - - log *log.Logger - stats *stats.Stats - - eg *errgroup.Group - scanner *bufio.Scanner -} - -func (j *JujutsuReader) Read(ctx context.Context, files []*File) (n int, err error) { - // ensure we record how many files we traversed - defer func() { - j.stats.Add(stats.Traversed, n) - }() - - nextLine := func() string { - line := j.scanner.Text() - - return line - } - -LOOP: - - for n < len(files) { - select { - // exit early if the context was cancelled - case <-ctx.Done(): - return n, ctx.Err() //nolint:wrapcheck - - default: - // read the next file - if j.scanner.Scan() { - entry := nextLine() - - path := filepath.Join(j.root, entry) - - j.log.Debugf("processing file: %s", path) - - info, err := os.Lstat(path) - - switch { - case os.IsNotExist(err): - // the underlying file might have been removed - j.log.Warnf( - "Path %s is in the worktree but appears to have been removed from the filesystem", path, - ) - - continue - case err != nil: - return n, fmt.Errorf("failed to stat %s: %w", path, err) - case info.Mode()&os.ModeSymlink == os.ModeSymlink: - // we skip reporting symlinks stored in Jujutsu, they should - // point to local files which we would list anyway. - continue - } - - files[n] = &File{ - Path: path, - RelPath: entry, - Info: info, - } - - n++ - } else { - // nothing more to read - err = io.EOF - - break LOOP - } - } - } - - return n, err -} - -func (j *JujutsuReader) Close() error { - err := j.eg.Wait() - if err != nil { - return fmt.Errorf("failed to wait for jujutsu command to complete: %w", err) - } - - return nil -} +type JujutsuReader = PathStreamReader func NewJujutsuReader( root string, - path string, + pathFilters []string, statz *stats.Stats, ) (*JujutsuReader, error) { // check if the root is a jujutsu repository @@ -116,42 +24,13 @@ func NewJujutsuReader( return nil, fmt.Errorf("%s is not a jujutsu repository", root) } - // create an errgroup for async list task - eg := &errgroup.Group{} - - // create a pipe to capture the command output - r, w := io.Pipe() - - // create a command which will execute from root // --ignore-working-copy: Don't snapshot the working copy, and don't update it. This prevents that the user has to // enter a password for signing the commit. New files also won't be added to the index and not displayed in the // output. - // Add the subpath as a fileset displaying only files matching this prefix. If - // the subpath is empty ignore it since it interferes with the command - args := []string{"file", "list", "--ignore-working-copy"} - if path != "" { - args = append(args, path) - } - - // create the jj command - cmd := exec.CommandContext(context.Background(), "jj", args...) - cmd.Dir = root - cmd.Stdout = w - - // execute the command in the background - eg.Go(func() error { - return w.CloseWithError(cmd.Run()) + return NewPathStreamReader(root, statz, PathStreamConfig{ + Name: "jujutsu", + Command: "jj", + Options: []string{"file", "list", "--ignore-working-copy", "--"}, + PathFilters: pathFilters, }) - - // create a new scanner for reading the output - scanner := bufio.NewScanner(r) - - return &JujutsuReader{ - eg: eg, - root: root, - path: path, - stats: statz, - scanner: scanner, - log: log.WithPrefix("walk | jujutsu"), - }, nil } diff --git a/walk/jujutsu_test.go b/walk/jujutsu_test.go index c50b6f11..2d829f03 100644 --- a/walk/jujutsu_test.go +++ b/walk/jujutsu_test.go @@ -27,7 +27,7 @@ func TestJujutsuReader(t *testing.T) { // read empty worktree statz := stats.New() - reader, err := walk.NewJujutsuReader(tempDir, "", &statz) + reader, err := walk.NewJujutsuReader(tempDir, nil, &statz) as.NoError(err) files := make([]*walk.File, 33) // The number of files in `test/examples` used for testing @@ -47,7 +47,7 @@ func TestJujutsuReader(t *testing.T) { as.NoError(cmd.Run(), "failed to update the index") statz = stats.New() - reader, err = walk.NewJujutsuReader(tempDir, "", &statz) + reader, err = walk.NewJujutsuReader(tempDir, nil, &statz) as.NoError(err) count := 0 diff --git a/walk/walk.go b/walk/walk.go index 1950e972..4e66dc77 100644 --- a/walk/walk.go +++ b/walk/walk.go @@ -232,7 +232,7 @@ func (c *CompositeReader) Close() error { func NewReader( selector Selector, root string, - path string, + pathFilters []string, db *bolt.DB, statz *stats.Stats, ) (Reader, error) { @@ -243,21 +243,16 @@ func NewReader( switch selector.Type { case Custom: - var pathFilters []string - if path != "" { - pathFilters = []string{path} - } - reader, err = NewCustomReader(root, pathFilters, statz, selector.Custom) case Builtin: switch selector.WalkType { case Auto: // for now, we keep it simple and try git first, jujutsu second, and filesystem last - reader, err = NewReader(BuiltinSelector(Git), root, path, db, statz) + reader, err = NewReader(BuiltinSelector(Git), root, pathFilters, db, statz) if err != nil { - reader, err = NewReader(BuiltinSelector(Jujutsu), root, path, db, statz) + reader, err = NewReader(BuiltinSelector(Jujutsu), root, pathFilters, db, statz) if err != nil { - reader, err = NewReader(BuiltinSelector(Filesystem), root, path, db, statz) + reader, err = NewReader(BuiltinSelector(Filesystem), root, pathFilters, db, statz) } } @@ -265,11 +260,25 @@ func NewReader( case Stdin: return nil, errors.New("stdin walk type is not supported") case Filesystem: - reader = NewFilesystemReader(root, path, statz, BatchSize) + switch len(pathFilters) { + case 0: + reader = NewFilesystemReader(root, "", statz, BatchSize) + case 1: + reader = NewFilesystemReader(root, pathFilters[0], statz, BatchSize) + default: + readers := make([]Reader, len(pathFilters)) + for idx, pathFilter := range pathFilters { + readers[idx] = NewFilesystemReader(root, pathFilter, statz, BatchSize) + } + + reader = &CompositeReader{ + readers: readers, + } + } case Git: - reader, err = NewGitReader(root, path, statz) + reader, err = NewGitReader(root, pathFilters, statz) case Jujutsu: - reader, err = NewJujutsuReader(root, path, statz) + reader, err = NewJujutsuReader(root, pathFilters, statz) case Builtin, Custom: return nil, fmt.Errorf("invalid built-in walk type: %v", selector.WalkType) @@ -320,11 +329,9 @@ func NewCompositeReader( // if no paths are provided we default to processing the tree root if len(paths) == 0 { - return NewReader(selector, root, "", db, statz) + return NewReader(selector, root, nil, db, statz) } - readers := make([]Reader, len(paths)) - // check we have received 1 path for the stdin walk type if selector.Type == Builtin && selector.WalkType == Stdin { if len(paths) != 1 { @@ -340,13 +347,8 @@ func NewCompositeReader( return NewStdinReader(root, path, statz), nil } - // create a reader for each provided path - for idx, path := range paths { - var ( - err error - info os.FileInfo - ) - + pathFilters := make([]string, 0, len(paths)) + for _, path := range paths { resolvedPath, err := resolvePath(path) if err != nil { return nil, fmt.Errorf("error resolving path %s: %w", path, err) @@ -362,27 +364,14 @@ func NewCompositeReader( } // check the path exists - info, err = os.Lstat(resolvedPath) - if err != nil { + if _, err = os.Lstat(resolvedPath); err != nil { return nil, fmt.Errorf("failed to stat %s: %w", resolvedPath, err) } - if info.IsDir() { - // for directories, we honour the walk type as we traverse them - readers[idx], err = NewReader(selector, root, relativePath, db, statz) - } else { - // for files, we enforce a simple filesystem read - readers[idx], err = NewReader(BuiltinSelector(Filesystem), root, relativePath, db, statz) - } - - if err != nil { - return nil, fmt.Errorf("failed to create reader for %s: %w", relativePath, err) - } + pathFilters = append(pathFilters, relativePath) } - return &CompositeReader{ - readers: readers, - }, nil + return NewReader(selector, root, pathFilters, db, statz) } // Resolve a path to an absolute path, resolving any symlinks along the way. From 59274e1f5d367a6adba839decb1364b20a882f20 Mon Sep 17 00:00:00 2001 From: Tobias Mayer Date: Sun, 31 May 2026 19:01:41 +0200 Subject: [PATCH 12/18] Unblock path stream walkers on close --- walk/custom_test.go | 36 ++++++++++++++++++++++++++++++++++++ walk/path_stream.go | 13 +++++++++++-- 2 files changed, 47 insertions(+), 2 deletions(-) diff --git a/walk/custom_test.go b/walk/custom_test.go index 40ff6842..4e8e46c3 100644 --- a/walk/custom_test.go +++ b/walk/custom_test.go @@ -2,6 +2,7 @@ package walk_test import ( "context" + "fmt" "io" "os" "path/filepath" @@ -143,6 +144,41 @@ func TestPathStreamReaderDelimiters(t *testing.T) { as.NoError(reader.Close()) } +func TestPathStreamReaderCloseUnblocks(t *testing.T) { + as := require.New(t) + + tempDir := t.TempDir() + + for i := range walk.BatchSize { + path := fmt.Sprintf("f%04d", i) + as.NoError(os.WriteFile(filepath.Join(tempDir, path), nil, 0o600)) + } + + statz := stats.New() + reader, err := walk.NewPathStreamReader(tempDir, &statz, walk.PathStreamConfig{ + Name: "test walker", + Command: "bash", + Options: []string{ + "-c", + `while true; do for path in "$@"; do printf '%s\n' "$path"; done; done`, + "walker", + }, + PathFilters: []string{"."}, + }) + as.NoError(err) + + done := make(chan error, 1) + + go func() { done <- reader.Close() }() + + select { + case err := <-done: + as.NoError(err) + case <-time.After(5 * time.Second): + t.Fatal("Close() did not return; walker command is deadlocked") + } +} + func TestCustomReaderCommandFailure(t *testing.T) { as := require.New(t) diff --git a/walk/path_stream.go b/walk/path_stream.go index 148b0234..8743a510 100644 --- a/walk/path_stream.go +++ b/walk/path_stream.go @@ -3,6 +3,7 @@ package walk import ( "bufio" "context" + "errors" "fmt" "io" "os" @@ -33,6 +34,7 @@ type PathStreamReader struct { log *log.Logger stats *stats.Stats + cancel context.CancelFunc eg *errgroup.Group scanner *bufio.Scanner @@ -139,6 +141,11 @@ func (p *PathStreamReader) relativePath(entry string) (string, error) { } func (p *PathStreamReader) Close() error { + // Unblock the walker command if the caller stopped draining Read() before + // EOF. Without this, the command can block writing stdout while Close waits + // for it to exit. + p.cancel() + err := p.eg.Wait() if err != nil { return fmt.Errorf("failed to wait for %s command to complete: %w", p.cfg.Name, err) @@ -147,7 +154,7 @@ func (p *PathStreamReader) Close() error { p.waitMu.Lock() defer p.waitMu.Unlock() - if p.waitErr != nil { + if p.waitErr != nil && !errors.Is(p.waitErr, context.Canceled) { return fmt.Errorf("failed to wait for %s command to complete: %w", p.cfg.Name, p.waitErr) } @@ -214,7 +221,8 @@ func NewPathStreamReader( args := append([]string{}, cfg.Options...) args = append(args, cfg.PathFilters...) - cmd := exec.CommandContext(context.Background(), executable, args...) + ctx, cancel := context.WithCancel(context.Background()) + cmd := exec.CommandContext(ctx, executable, args...) cmd.Dir = root // Don't use cmd.StdoutPipe here. Its docs say it is incorrect to call @@ -236,6 +244,7 @@ func NewPathStreamReader( cfg: cfg, log: log.WithPrefix("walk | " + cfg.Name), stats: statz, + cancel: cancel, eg: eg, scanner: scanner, } From f0d32deb4637ae4800e922044e5c90c27d5b959f Mon Sep 17 00:00:00 2001 From: Tobias Mayer Date: Sun, 31 May 2026 19:39:46 +0200 Subject: [PATCH 13/18] Address custom walker review comments Remove the case-insensitive custom walker lookup and rely on exact configured names. Move the default walk value into NewViper so FromViper no longer silently rewrites an empty walk setting. Use lowercase custom walker names in examples and tests, and use --clear-cache in the custom walker CLI test. --- cmd/init/init.toml | 4 ++-- cmd/root_test.go | 6 +++--- config/config.go | 24 ++---------------------- config/config_test.go | 22 ++++++++++++---------- docs/site/getting-started/configure.md | 8 ++++---- 5 files changed, 23 insertions(+), 41 deletions(-) diff --git a/cmd/init/init.toml b/cmd/init/init.toml index f5019d9c..73182b0f 100644 --- a/cmd/init/init.toml +++ b/cmd/init/init.toml @@ -52,11 +52,11 @@ # You can also set this to the name of a configured custom walker # Env $TREEFMT_WALK # walk = "filesystem" -# walk = "myWalker" +# walk = "mywalker" # Custom walkers are configured with [walker.] # They receive requested file and directory paths as positional args -# [walker.myWalker] +# [walker.mywalker] # command = "command-to-run" # options = [] diff --git a/cmd/root_test.go b/cmd/root_test.go index 79cf6970..efd42b11 100644 --- a/cmd/root_test.go +++ b/cmd/root_test.go @@ -1821,9 +1821,9 @@ func TestCustomWalker(t *testing.T) { test.ChangeWorkDir(t, tempDir) cfg := &config.Config{ - Walk: "myWalker", + Walk: "mywalker", WalkerConfigs: map[string]*config.Walker{ - "myWalker": { + "mywalker": { Command: "bash", Options: []string{ "-c", @@ -1861,7 +1861,7 @@ done`, ) treefmt(t, - withArgs("-c", "go"), + withArgs("--clear-cache", "go"), withConfig(configPath, cfg), withNoError(t), withStats(t, map[stats.Type]int{ diff --git a/config/config.go b/config/config.go index 34a29f04..310e1e87 100644 --- a/config/config.go +++ b/config/config.go @@ -173,6 +173,7 @@ func NewViper() (*viper.Viper, error) { v.SetEnvPrefix("treefmt") v.AutomaticEnv() v.SetEnvKeyReplacer(strings.NewReplacer("-", "_", ".", "_")) + v.SetDefault("walk", walk.Auto.String()) // unset some env variables that we don't want automatically applied if err := os.Unsetenv("TREEFMT_STDIN"); err != nil { @@ -208,10 +209,6 @@ func FromViper(v *viper.Viper) (*Config, error) { return nil, fmt.Errorf("failed to unmarshal config: %w", err) } - if cfg.Walk == "" { - cfg.Walk = walk.Auto.String() - } - // resolve the working directory to an absolute path cfg.WorkingDirectory, err = filepath.Abs(cfg.WorkingDirectory) if err != nil { @@ -270,16 +267,9 @@ func FromViper(v *viper.Viper) (*Config, error) { ) } - walkerCfg, ok := cfg.WalkerConfigs[cfg.Walk] - if !ok { - walkerCfg, ok = findWalkerConfig(cfg.WalkerConfigs, cfg.Walk) - } - - if !ok { + if _, ok := cfg.WalkerConfigs[cfg.Walk]; !ok { return nil, fmt.Errorf("walker %v not found in config", cfg.Walk) } - - cfg.WalkerConfigs[cfg.Walk] = walkerCfg } // filter formatters based on provided names @@ -588,13 +578,3 @@ func fileExists(path string) bool { return fi.Mode().IsRegular() } - -func findWalkerConfig(walkers map[string]*Walker, name string) (*Walker, bool) { - for walkerName, walkerCfg := range walkers { - if strings.EqualFold(walkerName, name) { - return walkerCfg, true - } - } - - return nil, false -} diff --git a/config/config_test.go b/config/config_test.go index f019e0f8..6e527379 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -596,9 +596,9 @@ func TestWalkers(t *testing.T) { as := require.New(t) cfg := &config.Config{ - Walk: "myWalker", + Walk: "mywalker", WalkerConfigs: map[string]*config.Walker{ - "myWalker": { + "mywalker": { Command: "command-to-run", Options: []string{ "--foo", @@ -611,9 +611,9 @@ func TestWalkers(t *testing.T) { v, _ := newViper(t) readValue(t, v, cfg, func(cfg *config.Config) { - walker, ok := cfg.WalkerConfigs["myWalker"] + walker, ok := cfg.WalkerConfigs["mywalker"] as.True(ok, "walker not found") - as.Equal("myWalker", cfg.Walk) + as.Equal("mywalker", cfg.Walk) as.Equal("command-to-run", walker.Command) as.Equal([]string{"--foo", "bar"}, walker.Options) }) @@ -623,13 +623,13 @@ func TestWalkers(t *testing.T) { as := require.New(t) cfg := &config.Config{ - Walk: "myWalker", + Walk: "mywalker", } v, _ := newViper(t) readError(t, v, cfg, func(err error) { - as.ErrorContains(err, "walker myWalker not found in config") + as.ErrorContains(err, "walker mywalker not found in config") }) }) @@ -637,9 +637,9 @@ func TestWalkers(t *testing.T) { as := require.New(t) cfg := &config.Config{ - Walk: "myWalker", + Walk: "mywalker", WalkerConfigs: map[string]*config.Walker{ - "myWalker": {}, + "mywalker": {}, }, } @@ -654,7 +654,7 @@ func TestWalkers(t *testing.T) { as := require.New(t) cfg := &config.Config{ - Walk: "myWalker", + Walk: "mywalker", WalkerConfigs: map[string]*config.Walker{ "my/walker": { Command: "command-to-run", @@ -773,7 +773,9 @@ func TestStdin(t *testing.T) { func TestSampleConfigFile(t *testing.T) { as := require.New(t) - v := viper.New() + v, err := config.NewViper() + as.NoError(err, "failed to create viper") + v.SetConfigFile("../test/examples/treefmt.toml") as.NoError(v.ReadInConfig(), "failed to read config file") diff --git a/docs/site/getting-started/configure.md b/docs/site/getting-started/configure.md index cf3b5c43..f683bc31 100644 --- a/docs/site/getting-started/configure.md +++ b/docs/site/getting-started/configure.md @@ -418,9 +418,9 @@ You can also set this to the name of a configured [custom walker](#walker-option ``` ```toml - walk = "myWalker" + walk = "mywalker" - [walker.myWalker] + [walker.mywalker] command = "command-to-run" options = [] ``` @@ -449,9 +449,9 @@ Custom walkers are configured using a [table](https://toml.io/en/v1.0.0#table) e To use a custom walker, set the global [`walk`](#walk) option to the same name: ```toml -walk = "myWalker" +walk = "mywalker" -[walker.myWalker] +[walker.mywalker] command = "command-to-run" options = [] ``` From 94c0a56178a9034664e790380cb3df3b3b62a5a6 Mon Sep 17 00:00:00 2001 From: Tobias Mayer Date: Sun, 31 May 2026 20:07:20 +0200 Subject: [PATCH 14/18] Simplify walker selector state Represent selector variants separately from the built-in walker enum. Keep selector fields private and expose IsBuiltin for callers that need to test for stdin. --- cmd/format/format.go | 2 +- walk/walk.go | 45 ++++++++++++++++++++++---------------------- 2 files changed, 24 insertions(+), 23 deletions(-) diff --git a/cmd/format/format.go b/cmd/format/format.go index 213e95fe..31f59cf7 100644 --- a/cmd/format/format.go +++ b/cmd/format/format.go @@ -113,7 +113,7 @@ func Run(v *viper.Viper, statz *stats.Stats, cmd *cobra.Command, paths []string) return fmt.Errorf("invalid walk type: %w", err) } - if walkSelector.Type == walk.Builtin && walkSelector.WalkType == walk.Stdin && len(paths) != 1 { + if walkSelector.IsBuiltin(walk.Stdin) && len(paths) != 1 { // check we have only received one path arg which we use for the file extension / matching to formatters return errors.New("exactly one path should be specified when using the --stdin flag") } diff --git a/walk/walk.go b/walk/walk.go index 4e66dc77..bbd75ed0 100644 --- a/walk/walk.go +++ b/walk/walk.go @@ -36,31 +36,37 @@ type CustomConfig struct { Options []string } +type selectorKind int + const ( - Builtin Type = iota + 1000 - Custom + builtinSelector selectorKind = iota + customSelector ) type Selector struct { - Type Type - WalkType Type - Custom CustomConfig + kind selectorKind + builtin Type + custom CustomConfig } func BuiltinSelector(walkType Type) Selector { return Selector{ - Type: Builtin, - WalkType: walkType, + kind: builtinSelector, + builtin: walkType, } } func CustomSelector(cfg CustomConfig) Selector { return Selector{ - Type: Custom, - Custom: cfg, + kind: customSelector, + custom: cfg, } } +func (s Selector) IsBuiltin(walkType Type) bool { + return s.kind == builtinSelector && s.builtin == walkType +} + type ReleaseFunc func(ctx context.Context) error // File represents a file object with its path, relative path, file info, and potential cache entry. @@ -241,11 +247,11 @@ func NewReader( reader Reader ) - switch selector.Type { - case Custom: - reader, err = NewCustomReader(root, pathFilters, statz, selector.Custom) - case Builtin: - switch selector.WalkType { + switch selector.kind { + case customSelector: + reader, err = NewCustomReader(root, pathFilters, statz, selector.custom) + case builtinSelector: + switch selector.builtin { case Auto: // for now, we keep it simple and try git first, jujutsu second, and filesystem last reader, err = NewReader(BuiltinSelector(Git), root, pathFilters, db, statz) @@ -280,17 +286,12 @@ func NewReader( case Jujutsu: reader, err = NewJujutsuReader(root, pathFilters, statz) - case Builtin, Custom: - return nil, fmt.Errorf("invalid built-in walk type: %v", selector.WalkType) - default: - return nil, fmt.Errorf("unknown walk type: %v", selector.WalkType) + return nil, fmt.Errorf("unknown walk type: %v", selector.builtin) } - case Auto, Stdin, Filesystem, Git, Jujutsu: - return nil, fmt.Errorf("invalid selector type: %v", selector.Type) default: - return nil, fmt.Errorf("unknown selector type: %v", selector.Type) + return nil, fmt.Errorf("unknown selector type: %v", selector.kind) } if err != nil { @@ -333,7 +334,7 @@ func NewCompositeReader( } // check we have received 1 path for the stdin walk type - if selector.Type == Builtin && selector.WalkType == Stdin { + if selector.IsBuiltin(Stdin) { if len(paths) != 1 { return nil, errors.New("stdin walk requires exactly one path") } From 5b81b3b814ddada45a154314f415ff4dca216e12 Mon Sep 17 00:00:00 2001 From: Tobias Mayer Date: Sun, 31 May 2026 20:07:54 +0200 Subject: [PATCH 15/18] Split walker reader construction Build the uncached reader in a helper and apply the cache wrapper once in NewReader. Keep auto fallback inside uncached construction so trying built-in walkers does not recurse through cache setup. --- walk/walk.go | 33 ++++++++++++++++++++++++--------- 1 file changed, 24 insertions(+), 9 deletions(-) diff --git a/walk/walk.go b/walk/walk.go index bbd75ed0..8d1cf295 100644 --- a/walk/walk.go +++ b/walk/walk.go @@ -241,6 +241,27 @@ func NewReader( pathFilters []string, db *bolt.DB, statz *stats.Stats, +) (Reader, error) { + reader, err := newUncachedReader(selector, root, pathFilters, statz) + if err != nil { + return nil, err + } + + if db != nil { + // wrap with cached reader + // db will be null if --no-cache is enabled + reader, err = NewCachedReader(db, BatchSize, reader) + } + + return reader, err +} + +//nolint:ireturn +func newUncachedReader( + selector Selector, + root string, + pathFilters []string, + statz *stats.Stats, ) (Reader, error) { var ( err error @@ -254,11 +275,11 @@ func NewReader( switch selector.builtin { case Auto: // for now, we keep it simple and try git first, jujutsu second, and filesystem last - reader, err = NewReader(BuiltinSelector(Git), root, pathFilters, db, statz) + reader, err = newUncachedReader(BuiltinSelector(Git), root, pathFilters, statz) if err != nil { - reader, err = NewReader(BuiltinSelector(Jujutsu), root, pathFilters, db, statz) + reader, err = newUncachedReader(BuiltinSelector(Jujutsu), root, pathFilters, statz) if err != nil { - reader, err = NewReader(BuiltinSelector(Filesystem), root, pathFilters, db, statz) + reader, err = newUncachedReader(BuiltinSelector(Filesystem), root, pathFilters, statz) } } @@ -298,12 +319,6 @@ func NewReader( return nil, err } - if db != nil { - // wrap with cached reader - // db will be null if --no-cache is enabled - reader, err = NewCachedReader(db, BatchSize, reader) - } - return reader, err } From f23148253b7957316c8736805e9a12ffce0e73d6 Mon Sep 17 00:00:00 2001 From: Tobias Mayer Date: Sun, 31 May 2026 20:08:16 +0200 Subject: [PATCH 16/18] Avoid duplicate path stream filter state Use PathStreamConfig.PathFilters directly when post-filtering emitted paths. Remove the redundant filters field from PathStreamReader. --- walk/path_stream.go | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/walk/path_stream.go b/walk/path_stream.go index 8743a510..e449395a 100644 --- a/walk/path_stream.go +++ b/walk/path_stream.go @@ -27,9 +27,8 @@ type PathStreamConfig struct { } type PathStreamReader struct { - root string - filters []string - cfg PathStreamConfig + root string + cfg PathStreamConfig log *log.Logger stats *stats.Stats @@ -92,7 +91,7 @@ func (p *PathStreamReader) file(record string) (*File, bool, error) { return nil, false, err } - if !containsAnyPath(p.filters, relPath) { + if !containsAnyPath(p.cfg.PathFilters, relPath) { return nil, false, nil } @@ -240,7 +239,6 @@ func NewPathStreamReader( reader := &PathStreamReader{ root: root, - filters: cfg.PathFilters, cfg: cfg, log: log.WithPrefix("walk | " + cfg.Name), stats: statz, From 69c5a225974714c5b9d211e617626057b4255397 Mon Sep 17 00:00:00 2001 From: Tobias Mayer Date: Sun, 31 May 2026 20:08:36 +0200 Subject: [PATCH 17/18] Let filesystem walker handle multiple paths Store filesystem path filters as a slice and walk them sequentially in the filesystem reader. Remove the generic CompositeReader that was only used to fan out multiple filesystem paths. Add coverage for multi-path filesystem walking. --- walk/filesystem.go | 25 ++++++++++++--- walk/filesystem_test.go | 26 ++++++++++++++-- walk/walk.go | 68 ++--------------------------------------- 3 files changed, 47 insertions(+), 72 deletions(-) diff --git a/walk/filesystem.go b/walk/filesystem.go index dc7b301c..67298c70 100644 --- a/walk/filesystem.go +++ b/walk/filesystem.go @@ -19,7 +19,7 @@ import ( type FilesystemReader struct { log *log.Logger root string - path string + paths []string batchSize int eg *errgroup.Group @@ -35,9 +35,24 @@ func (f *FilesystemReader) process() error { close(f.filesCh) }() - // f.path is relative to the root, so we create a fully qualified version + paths := f.paths + if len(paths) == 0 { + paths = []string{""} + } + + for _, path := range paths { + if err := f.processPath(path); err != nil { + return err + } + } + + return nil +} + +func (f *FilesystemReader) processPath(pathFilter string) error { + // pathFilter is relative to the root, so we create a fully qualified version // we also clean the path up in case there are any ../../ components etc. - path := filepath.Clean(filepath.Join(f.root, f.path)) + path := filepath.Clean(filepath.Join(f.root, pathFilter)) // ensure the path is within the root if !strings.HasPrefix(path, f.root) { @@ -130,7 +145,7 @@ func (f *FilesystemReader) Close() error { // and root. func NewFilesystemReader( root string, - path string, + paths []string, statz *stats.Stats, batchSize int, ) *FilesystemReader { @@ -140,7 +155,7 @@ func NewFilesystemReader( r := FilesystemReader{ log: log.WithPrefix("walk | filesystem"), root: root, - path: path, + paths: paths, batchSize: batchSize, eg: &eg, diff --git a/walk/filesystem_test.go b/walk/filesystem_test.go index 45d7d17b..16565b33 100644 --- a/walk/filesystem_test.go +++ b/walk/filesystem_test.go @@ -56,7 +56,7 @@ func TestFilesystemReaderCancellation(t *testing.T) { tempDir := test.TempExamples(t) statz := stats.New() - r := walk.NewFilesystemReader(tempDir, "", &statz, 1024) + r := walk.NewFilesystemReader(tempDir, nil, &statz, 1024) ctx, cancel := context.WithCancel(t.Context()) cancel() @@ -72,7 +72,7 @@ func TestFilesystemReader(t *testing.T) { tempDir := test.TempExamples(t) statz := stats.New() - r := walk.NewFilesystemReader(tempDir, "", &statz, 1024) + r := walk.NewFilesystemReader(tempDir, nil, &statz, 1024) count := 0 @@ -101,3 +101,25 @@ func TestFilesystemReader(t *testing.T) { as.Equal(0, statz.Value(stats.Formatted)) as.Equal(0, statz.Value(stats.Changed)) } + +func TestFilesystemReaderSubpaths(t *testing.T) { + as := require.New(t) + + tempDir := test.TempExamples(t) + statz := stats.New() + + r := walk.NewFilesystemReader(tempDir, []string{"go", "haskell/Foo.hs"}, &statz, 1024) + + ctx, cancel := context.WithTimeout(t.Context(), 100*time.Millisecond) + defer cancel() + + files := make([]*walk.File, 8) + n, err := r.Read(ctx, files) + + as.ErrorIs(err, io.EOF) + as.Equal(3, n) + as.Equal("go/go.mod", files[0].RelPath) + as.Equal("go/main.go", files[1].RelPath) + as.Equal("haskell/Foo.hs", files[2].RelPath) + as.Equal(3, statz.Value(stats.Traversed)) +} diff --git a/walk/walk.go b/walk/walk.go index 8d1cf295..57819f53 100644 --- a/walk/walk.go +++ b/walk/walk.go @@ -5,7 +5,6 @@ import ( "crypto/md5" //nolint:gosec "errors" "fmt" - "io" "io/fs" "os" "path/filepath" @@ -187,53 +186,6 @@ type Reader interface { Close() error } -// CompositeReader combines multiple Readers into one. -// It iterates over the given readers, reading each until completion. -type CompositeReader struct { - idx int - current Reader - readers []Reader -} - -func (c *CompositeReader) Read(ctx context.Context, files []*File) (n int, err error) { - if c.current == nil { - // check if we have exhausted all the readers - if c.idx >= len(c.readers) { - return 0, io.EOF - } - - // if not, select the next reader - c.current = c.readers[c.idx] - c.idx++ - } - - // attempt a read - n, err = c.current.Read(ctx, files) - - // check if the current reader has been exhausted - if errors.Is(err, io.EOF) { - // reset the error if it's EOF - err = nil - // set the current reader to nil so we try to read from the next reader on the next call - c.current = nil - } else if err != nil { - err = fmt.Errorf("failed to read from current reader: %w", err) - } - - // return the number of files read in this call and any error - return n, err -} - -func (c *CompositeReader) Close() error { - for _, reader := range c.readers { - if err := reader.Close(); err != nil { - return fmt.Errorf("failed to close reader: %w", err) - } - } - - return nil -} - //nolint:ireturn func NewReader( selector Selector, @@ -287,21 +239,7 @@ func newUncachedReader( case Stdin: return nil, errors.New("stdin walk type is not supported") case Filesystem: - switch len(pathFilters) { - case 0: - reader = NewFilesystemReader(root, "", statz, BatchSize) - case 1: - reader = NewFilesystemReader(root, pathFilters[0], statz, BatchSize) - default: - readers := make([]Reader, len(pathFilters)) - for idx, pathFilter := range pathFilters { - readers[idx] = NewFilesystemReader(root, pathFilter, statz, BatchSize) - } - - reader = &CompositeReader{ - readers: readers, - } - } + reader = NewFilesystemReader(root, pathFilters, statz, BatchSize) case Git: reader, err = NewGitReader(root, pathFilters, statz) case Jujutsu: @@ -322,8 +260,8 @@ func newUncachedReader( return reader, err } -// NewCompositeReader returns a composite reader for the `root` and all `paths`. It -// never follows symlinks. +// NewCompositeReader returns a reader for the `root` and all `paths`. It never +// follows symlinks. // //nolint:ireturn func NewCompositeReader( From 63f77641e70542c34e55520341c81c3a31da0fb8 Mon Sep 17 00:00:00 2001 From: Tobias Mayer Date: Mon, 1 Jun 2026 10:09:24 +0200 Subject: [PATCH 18/18] Handle signal exits after walker close Record whether the walker command returned after its context was canceled. Treat those wait errors as expected during Close so platforms that report cancellation as signal termination do not fail the unblock test. --- walk/path_stream.go | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/walk/path_stream.go b/walk/path_stream.go index e449395a..93e99065 100644 --- a/walk/path_stream.go +++ b/walk/path_stream.go @@ -37,8 +37,9 @@ type PathStreamReader struct { eg *errgroup.Group scanner *bufio.Scanner - waitMu sync.Mutex - waitErr error + waitMu sync.Mutex + waitErr error + waitCanceled bool } func (p *PathStreamReader) Read(ctx context.Context, files []*File) (n int, err error) { @@ -153,7 +154,7 @@ func (p *PathStreamReader) Close() error { p.waitMu.Lock() defer p.waitMu.Unlock() - if p.waitErr != nil && !errors.Is(p.waitErr, context.Canceled) { + if p.waitErr != nil && !errors.Is(p.waitErr, context.Canceled) && !p.waitCanceled { return fmt.Errorf("failed to wait for %s command to complete: %w", p.cfg.Name, p.waitErr) } @@ -252,6 +253,7 @@ func NewPathStreamReader( reader.waitMu.Lock() reader.waitErr = err + reader.waitCanceled = ctx.Err() != nil reader.waitMu.Unlock() closeErr := stdoutW.Close()