diff --git a/cmd/format/format.go b/cmd/format/format.go index 2efef587..31f59cf7 100644 --- a/cmd/format/format.go +++ b/cmd/format/format.go @@ -107,13 +107,13 @@ func Run(v *viper.Viper, statz *stats.Stats, cmd *cobra.Command, paths []string) cancel() }() - // parse the walk type - walkType, err := walk.TypeString(cfg.Walk) + // parse the walk selector + walkSelector, err := newWalkSelector(cfg) if err != nil { return fmt.Errorf("invalid walk type: %w", err) } - if walkType == walk.Stdin && len(paths) != 1 { + if walkSelector.IsBuiltin(walk.Stdin) && len(paths) != 1 { // check we have only received one path arg which we use for the file extension / matching to formatters return errors.New("exactly one path should be specified when using the --stdin flag") } @@ -125,7 +125,7 @@ func Run(v *viper.Viper, statz *stats.Stats, cmd *cobra.Command, paths []string) } // create a new walker for traversing the paths - walker, err := walk.NewCompositeReader(walkType, cfg.TreeRoot, paths, db, statz) + walker, err := walk.NewCompositeReader(walkSelector, cfg.TreeRoot, paths, db, statz) if err != nil { return fmt.Errorf("failed to create walker: %w", err) } @@ -205,3 +205,21 @@ func Run(v *viper.Viper, statz *stats.Stats, cmd *cobra.Command, paths []string) return nil } + +func newWalkSelector(cfg *config.Config) (walk.Selector, error) { + walkType, err := walk.TypeString(cfg.Walk) + if err == nil { + return walk.BuiltinSelector(walkType), nil + } + + walkerCfg, ok := cfg.WalkerConfigs[cfg.Walk] + if !ok { + return walk.Selector{}, fmt.Errorf("walker %v not found in config", cfg.Walk) + } + + return walk.CustomSelector(walk.CustomConfig{ + Name: cfg.Walk, + Command: walkerCfg.Command, + Options: walkerCfg.Options, + }), nil +} diff --git a/cmd/init/init.toml b/cmd/init/init.toml index 375686b3..73182b0f 100644 --- a/cmd/init/init.toml +++ b/cmd/init/init.toml @@ -48,9 +48,17 @@ # verbose = 2 # The method used to traverse the files within the tree root -# Currently, we support 'auto', 'git', 'jujutsu', or 'filesystem' +# Built-in values are 'auto', 'git', 'jujutsu', and 'filesystem' +# You can also set this to the name of a configured custom walker # Env $TREEFMT_WALK # walk = "filesystem" +# walk = "mywalker" + +# Custom walkers are configured with [walker.] +# They receive requested file and directory paths as positional args +# [walker.mywalker] +# command = "command-to-run" +# options = [] [formatter.mylanguage] # Command to execute diff --git a/cmd/root_test.go b/cmd/root_test.go index be520c62..efd42b11 100644 --- a/cmd/root_test.go +++ b/cmd/root_test.go @@ -1814,6 +1814,65 @@ func TestJujutsu(t *testing.T) { ) } +func TestCustomWalker(t *testing.T) { + tempDir := test.TempExamples(t) + configPath := filepath.Join(tempDir, "/treefmt.toml") + + test.ChangeWorkDir(t, tempDir) + + cfg := &config.Config{ + Walk: "mywalker", + WalkerConfigs: map[string]*config.Walker{ + "mywalker": { + Command: "bash", + Options: []string{ + "-c", + `if [ "$#" -eq 0 ]; then set -- go haskell; fi +for path in "$@"; do + case "$path" in + go) printf '%s\n' go/main.go go/go.mod ;; + haskell) printf '%s\n' haskell/Foo.hs ;; + *) printf '%s\n' "$path" ;; + esac +done`, + "custom-walker", + }, + }, + }, + FormatterConfigs: map[string]*config.Formatter{ + "echo": { + Command: "echo", // will not generate any underlying changes in the file + Includes: []string{"*"}, + }, + }, + } + + test.WriteConfig(t, configPath, cfg) + + treefmt(t, + withConfig(configPath, cfg), + withNoError(t), + withStats(t, map[stats.Type]int{ + stats.Traversed: 3, + stats.Matched: 3, + stats.Formatted: 3, + stats.Changed: 0, + }), + ) + + treefmt(t, + withArgs("--clear-cache", "go"), + withConfig(configPath, cfg), + withNoError(t), + withStats(t, map[stats.Type]int{ + stats.Traversed: 2, + stats.Matched: 2, + stats.Formatted: 2, + stats.Changed: 0, + }), + ) +} + func TestTreeRootCmd(t *testing.T) { as := require.New(t) diff --git a/config/config.go b/config/config.go index 4c97de11..310e1e87 100644 --- a/config/config.go +++ b/config/config.go @@ -45,6 +45,7 @@ type Config struct { Stdin bool `mapstructure:"stdin" toml:"-"` // not allowed in config FormatterConfigs map[string]*Formatter `mapstructure:"formatter" toml:"formatter,omitempty"` + WalkerConfigs map[string]*Walker `mapstructure:"walker" toml:"walker,omitempty"` Global struct { // Deprecated: Use Excludes @@ -68,6 +69,13 @@ type Formatter struct { NoPositionalArgSupport *bool `mapstructure:"no-positional-arg-support" toml:"no-positional-arg-support"` } +type Walker struct { + // Command is the command to invoke when walking the tree. + Command string `mapstructure:"command" toml:"command"` + // Options are an optional list of args to be passed to Command. + Options []string `mapstructure:"options,omitempty" toml:"options,omitempty"` +} + // SetFlags appends our flags to the provided flag set. // We have a flag matching most entries in Config, taking care to ensure the name matches the field name defined in the // mapstructure tag. @@ -165,6 +173,7 @@ func NewViper() (*viper.Viper, error) { v.SetEnvPrefix("treefmt") v.AutomaticEnv() v.SetEnvKeyReplacer(strings.NewReplacer("-", "_", ".", "_")) + v.SetDefault("walk", walk.Auto.String()) // unset some env variables that we don't want automatically applied if err := os.Unsetenv("TREEFMT_STDIN"); err != nil { @@ -233,6 +242,36 @@ func FromViper(v *viper.Viper) (*Config, error) { } } + for name, walkerCfg := range cfg.WalkerConfigs { + if !nameRegex.MatchString(name) { + return nil, fmt.Errorf( + "walker name %q is invalid, must be of the form %s", + name, nameRegex.String(), + ) + } + + if _, err := walk.TypeString(name); err == nil { + return nil, fmt.Errorf("walker name %q is reserved for a built-in walk type", name) + } + + if walkerCfg.Command == "" { + return nil, fmt.Errorf("walker %v has no command", name) + } + } + + if _, err := walk.TypeString(cfg.Walk); err != nil { + if !nameRegex.MatchString(cfg.Walk) { + return nil, fmt.Errorf( + "walk value %q is invalid, must be a built-in walk type or a walker name of the form %s", + cfg.Walk, nameRegex.String(), + ) + } + + if _, ok := cfg.WalkerConfigs[cfg.Walk]; !ok { + return nil, fmt.Errorf("walker %v not found in config", cfg.Walk) + } + } + // filter formatters based on provided names if len(cfg.Formatters) > 0 { filtered := make(map[string]*Formatter) diff --git a/config/config_test.go b/config/config_test.go index 218b87e7..6e527379 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -591,6 +591,117 @@ func TestWalk(t *testing.T) { checkValue("auto") } +func TestWalkers(t *testing.T) { + t.Run("configured walker", func(t *testing.T) { + as := require.New(t) + + cfg := &config.Config{ + Walk: "mywalker", + WalkerConfigs: map[string]*config.Walker{ + "mywalker": { + Command: "command-to-run", + Options: []string{ + "--foo", + "bar", + }, + }, + }, + } + + v, _ := newViper(t) + + readValue(t, v, cfg, func(cfg *config.Config) { + walker, ok := cfg.WalkerConfigs["mywalker"] + as.True(ok, "walker not found") + as.Equal("mywalker", cfg.Walk) + as.Equal("command-to-run", walker.Command) + as.Equal([]string{"--foo", "bar"}, walker.Options) + }) + }) + + t.Run("missing walker", func(t *testing.T) { + as := require.New(t) + + cfg := &config.Config{ + Walk: "mywalker", + } + + v, _ := newViper(t) + + readError(t, v, cfg, func(err error) { + as.ErrorContains(err, "walker mywalker not found in config") + }) + }) + + t.Run("empty command", func(t *testing.T) { + as := require.New(t) + + cfg := &config.Config{ + Walk: "mywalker", + WalkerConfigs: map[string]*config.Walker{ + "mywalker": {}, + }, + } + + v, _ := newViper(t) + + readError(t, v, cfg, func(err error) { + as.ErrorContains(err, "walker mywalker has no command") + }) + }) + + t.Run("invalid walker name", func(t *testing.T) { + as := require.New(t) + + cfg := &config.Config{ + Walk: "mywalker", + WalkerConfigs: map[string]*config.Walker{ + "my/walker": { + Command: "command-to-run", + }, + }, + } + + v, _ := newViper(t) + + readError(t, v, cfg, func(err error) { + as.ErrorContains(err, "walker name \"my/walker\" is invalid") + }) + }) + + t.Run("reserved walker name", func(t *testing.T) { + as := require.New(t) + + cfg := &config.Config{ + WalkerConfigs: map[string]*config.Walker{ + "git": { + Command: "command-to-run", + }, + }, + } + + v, _ := newViper(t) + + readError(t, v, cfg, func(err error) { + as.ErrorContains(err, "walker name \"git\" is reserved for a built-in walk type") + }) + }) + + t.Run("invalid walk value", func(t *testing.T) { + as := require.New(t) + + cfg := &config.Config{ + Walk: "my.walker", + } + + v, _ := newViper(t) + + readError(t, v, cfg, func(err error) { + as.ErrorContains(err, "walk value \"my.walker\" is invalid") + }) + }) +} + func TestWorkingDirectory(t *testing.T) { as := require.New(t) @@ -662,7 +773,9 @@ func TestStdin(t *testing.T) { func TestSampleConfigFile(t *testing.T) { as := require.New(t) - v := viper.New() + v, err := config.NewViper() + as.NoError(err, "failed to create viper") + v.SetConfigFile("../test/examples/treefmt.toml") as.NoError(v.ReadInConfig(), "failed to read config file") diff --git a/docs/site/getting-started/configure.md b/docs/site/getting-started/configure.md index 77f55adb..f683bc31 100644 --- a/docs/site/getting-started/configure.md +++ b/docs/site/getting-started/configure.md @@ -396,7 +396,8 @@ Set the verbosity level of logs: ### `walk` The method used to traverse the files within the tree root. -Currently, we support 'auto', 'git', 'jujutsu' or 'filesystem' +Built-in values are `auto`, `git`, `jujutsu`, and `filesystem`. +You can also set this to the name of a configured [custom walker](#walker-options). === "Flag" @@ -416,6 +417,14 @@ Currently, we support 'auto', 'git', 'jujutsu' or 'filesystem' walk = "filesystem" ``` + ```toml + walk = "mywalker" + + [walker.mywalker] + command = "command-to-run" + options = [] + ``` + ### `working-dir` Run as if `treefmt` was started in the specified working directory instead of the current working directory. @@ -433,6 +442,42 @@ Run as if `treefmt` was started in the specified working directory instead of th TREEFMT_WORKING_DIR=/tmp/foo treefmt ``` +## Walker Options + +Custom walkers are configured using a [table](https://toml.io/en/v1.0.0#table) entry in `treefmt.toml` of the form +`[walker.]`. +To use a custom walker, set the global [`walk`](#walk) option to the same name: + +```toml +walk = "mywalker" + +[walker.mywalker] +command = "command-to-run" +options = [] +``` + +### `command` + +The command to invoke when walking the tree. +`treefmt` runs the command from the tree root. + +When you pass file or directory paths to `treefmt`, `treefmt` passes those paths to the walker command as positional +arguments relative to the tree root. +A custom walker should use those arguments to reduce the paths it emits. +This improves performance when you format a subtree in a large repository. + +`treefmt` still filters the command output to the requested paths. + +The command must write path records to `stdout`. +Records may be separated by newlines or NUL bytes. +Each path must be relative to the tree root or an absolute path inside the tree root. +Paths containing newlines are unsupported. + +### `options` + +An optional list of args to be passed to `command`. +`treefmt` passes these args before any positional path filters. + ## Formatter Options Formatters are configured using a [table](https://toml.io/en/v1.0.0#table) entry in `treefmt.toml` of the form diff --git a/walk/custom.go b/walk/custom.go new file mode 100644 index 00000000..c20a1fff --- /dev/null +++ b/walk/custom.go @@ -0,0 +1,19 @@ +package walk + +import "github.com/numtide/treefmt/v2/stats" + +type CustomReader = PathStreamReader + +func NewCustomReader( + root string, + pathFilters []string, + statz *stats.Stats, + cfg CustomConfig, +) (*CustomReader, error) { + return NewPathStreamReader(root, statz, PathStreamConfig{ + Name: "custom walker " + cfg.Name, + Command: cfg.Command, + Options: cfg.Options, + PathFilters: pathFilters, + }) +} diff --git a/walk/custom_test.go b/walk/custom_test.go new file mode 100644 index 00000000..4e8e46c3 --- /dev/null +++ b/walk/custom_test.go @@ -0,0 +1,212 @@ +package walk_test + +import ( + "context" + "fmt" + "io" + "os" + "path/filepath" + "testing" + "time" + + "github.com/numtide/treefmt/v2/stats" + "github.com/numtide/treefmt/v2/test" + "github.com/numtide/treefmt/v2/walk" + "github.com/stretchr/testify/require" +) + +func TestCustomReader(t *testing.T) { + as := require.New(t) + + tempDir := test.TempExamples(t) + + as.NoError(os.Symlink(filepath.Join(tempDir, "go", "main.go"), filepath.Join(tempDir, "link-to-main"))) + + statz := stats.New() + reader, err := walk.NewCustomReader(tempDir, nil, &statz, walk.CustomConfig{ + Name: "myWalker", + Command: "bash", + Options: []string{ + "-c", + `for path in "$@"; do printf '%s\n' "$path"; done`, + "walker", + "go/main.go", + "go/go.mod", + "go", + "missing.txt", + "link-to-main", + }, + }) + as.NoError(err) + + files := make([]*walk.File, 8) + ctx, cancel := context.WithTimeout(t.Context(), time.Second) + n, err := reader.Read(ctx, files) + + cancel() + + as.Equal(2, n) + as.ErrorIs(err, io.EOF) + as.Equal("go/main.go", files[0].RelPath) + as.Equal("go/go.mod", files[1].RelPath) + as.Equal(2, statz.Value(stats.Traversed)) + as.NoError(reader.Close()) +} + +func TestCustomReaderSubpath(t *testing.T) { + as := require.New(t) + + tempDir := test.TempExamples(t) + + statz := stats.New() + reader, err := walk.NewCustomReader(tempDir, []string{"go"}, &statz, walk.CustomConfig{ + Name: "myWalker", + Command: "bash", + Options: []string{ + "-c", + "printf '%s\n' go/main.go haskell/Foo.hs go/go.mod", + }, + }) + as.NoError(err) + + files := make([]*walk.File, 8) + ctx, cancel := context.WithTimeout(t.Context(), time.Second) + n, err := reader.Read(ctx, files) + + cancel() + + as.Equal(2, n) + as.ErrorIs(err, io.EOF) + as.Equal("go/main.go", files[0].RelPath) + as.Equal("go/go.mod", files[1].RelPath) + as.Equal(2, statz.Value(stats.Traversed)) + as.NoError(reader.Close()) +} + +func TestCustomReaderPassesPathFilters(t *testing.T) { + as := require.New(t) + + tempDir := test.TempExamples(t) + + statz := stats.New() + reader, err := walk.NewCustomReader(tempDir, []string{"go"}, &statz, walk.CustomConfig{ + Name: "myWalker", + Command: "bash", + Options: []string{ + "-c", + `for path in "$@"; do printf '%s\n' "$path/main.go" "$path/go.mod"; done`, + "walker", + }, + }) + as.NoError(err) + + files := make([]*walk.File, 8) + ctx, cancel := context.WithTimeout(t.Context(), time.Second) + n, err := reader.Read(ctx, files) + + cancel() + + as.Equal(2, n) + as.ErrorIs(err, io.EOF) + as.Equal("go/main.go", files[0].RelPath) + as.Equal("go/go.mod", files[1].RelPath) + as.Equal(2, statz.Value(stats.Traversed)) + as.NoError(reader.Close()) +} + +func TestPathStreamReaderDelimiters(t *testing.T) { + as := require.New(t) + + tempDir := test.TempExamples(t) + + statz := stats.New() + reader, err := walk.NewPathStreamReader(tempDir, &statz, walk.PathStreamConfig{ + Name: "test walker", + Command: "bash", + Options: []string{ + "-c", + "printf 'go/main.go\\0\\nhaskell/Foo.hs\\r\\n'", + }, + }) + as.NoError(err) + + files := make([]*walk.File, 8) + ctx, cancel := context.WithTimeout(t.Context(), time.Second) + n, err := reader.Read(ctx, files) + + cancel() + + as.Equal(2, n) + as.ErrorIs(err, io.EOF) + as.Equal("go/main.go", files[0].RelPath) + as.Equal("haskell/Foo.hs", files[1].RelPath) + as.Equal(2, statz.Value(stats.Traversed)) + as.NoError(reader.Close()) +} + +func TestPathStreamReaderCloseUnblocks(t *testing.T) { + as := require.New(t) + + tempDir := t.TempDir() + + for i := range walk.BatchSize { + path := fmt.Sprintf("f%04d", i) + as.NoError(os.WriteFile(filepath.Join(tempDir, path), nil, 0o600)) + } + + statz := stats.New() + reader, err := walk.NewPathStreamReader(tempDir, &statz, walk.PathStreamConfig{ + Name: "test walker", + Command: "bash", + Options: []string{ + "-c", + `while true; do for path in "$@"; do printf '%s\n' "$path"; done; done`, + "walker", + }, + PathFilters: []string{"."}, + }) + as.NoError(err) + + done := make(chan error, 1) + + go func() { done <- reader.Close() }() + + select { + case err := <-done: + as.NoError(err) + case <-time.After(5 * time.Second): + t.Fatal("Close() did not return; walker command is deadlocked") + } +} + +func TestCustomReaderCommandFailure(t *testing.T) { + as := require.New(t) + + tempDir := test.TempExamples(t) + + statz := stats.New() + reader, err := walk.NewCustomReader(tempDir, nil, &statz, walk.CustomConfig{ + Name: "myWalker", + Command: "bash", + Options: []string{ + "-c", + "printf '%s\n' go/main.go; exit 7", + }, + }) + as.NoError(err) + + files := make([]*walk.File, 8) + ctx, cancel := context.WithTimeout(t.Context(), time.Second) + n, err := reader.Read(ctx, files) + + cancel() + + as.Equal(1, n) + as.ErrorIs(err, io.EOF) + as.Equal("go/main.go", files[0].RelPath) + + err = reader.Close() + as.Error(err) + as.ErrorContains(err, "failed to wait for custom walker myWalker command to complete") + as.ErrorContains(err, "exit status 7") +} diff --git a/walk/filesystem.go b/walk/filesystem.go index dc7b301c..67298c70 100644 --- a/walk/filesystem.go +++ b/walk/filesystem.go @@ -19,7 +19,7 @@ import ( type FilesystemReader struct { log *log.Logger root string - path string + paths []string batchSize int eg *errgroup.Group @@ -35,9 +35,24 @@ func (f *FilesystemReader) process() error { close(f.filesCh) }() - // f.path is relative to the root, so we create a fully qualified version + paths := f.paths + if len(paths) == 0 { + paths = []string{""} + } + + for _, path := range paths { + if err := f.processPath(path); err != nil { + return err + } + } + + return nil +} + +func (f *FilesystemReader) processPath(pathFilter string) error { + // pathFilter is relative to the root, so we create a fully qualified version // we also clean the path up in case there are any ../../ components etc. - path := filepath.Clean(filepath.Join(f.root, f.path)) + path := filepath.Clean(filepath.Join(f.root, pathFilter)) // ensure the path is within the root if !strings.HasPrefix(path, f.root) { @@ -130,7 +145,7 @@ func (f *FilesystemReader) Close() error { // and root. func NewFilesystemReader( root string, - path string, + paths []string, statz *stats.Stats, batchSize int, ) *FilesystemReader { @@ -140,7 +155,7 @@ func NewFilesystemReader( r := FilesystemReader{ log: log.WithPrefix("walk | filesystem"), root: root, - path: path, + paths: paths, batchSize: batchSize, eg: &eg, diff --git a/walk/filesystem_test.go b/walk/filesystem_test.go index 45d7d17b..16565b33 100644 --- a/walk/filesystem_test.go +++ b/walk/filesystem_test.go @@ -56,7 +56,7 @@ func TestFilesystemReaderCancellation(t *testing.T) { tempDir := test.TempExamples(t) statz := stats.New() - r := walk.NewFilesystemReader(tempDir, "", &statz, 1024) + r := walk.NewFilesystemReader(tempDir, nil, &statz, 1024) ctx, cancel := context.WithCancel(t.Context()) cancel() @@ -72,7 +72,7 @@ func TestFilesystemReader(t *testing.T) { tempDir := test.TempExamples(t) statz := stats.New() - r := walk.NewFilesystemReader(tempDir, "", &statz, 1024) + r := walk.NewFilesystemReader(tempDir, nil, &statz, 1024) count := 0 @@ -101,3 +101,25 @@ func TestFilesystemReader(t *testing.T) { as.Equal(0, statz.Value(stats.Formatted)) as.Equal(0, statz.Value(stats.Changed)) } + +func TestFilesystemReaderSubpaths(t *testing.T) { + as := require.New(t) + + tempDir := test.TempExamples(t) + statz := stats.New() + + r := walk.NewFilesystemReader(tempDir, []string{"go", "haskell/Foo.hs"}, &statz, 1024) + + ctx, cancel := context.WithTimeout(t.Context(), 100*time.Millisecond) + defer cancel() + + files := make([]*walk.File, 8) + n, err := r.Read(ctx, files) + + as.ErrorIs(err, io.EOF) + as.Equal(3, n) + as.Equal("go/go.mod", files[0].RelPath) + as.Equal("go/main.go", files[1].RelPath) + as.Equal("haskell/Foo.hs", files[2].RelPath) + as.Equal(3, statz.Value(stats.Traversed)) +} diff --git a/walk/git.go b/walk/git.go index 607ba524..d39fcc5f 100644 --- a/walk/git.go +++ b/walk/git.go @@ -1,143 +1,17 @@ package walk import ( - "bufio" - "context" "fmt" - "io" - "os" - "os/exec" - "path/filepath" - "strconv" - "strings" - "github.com/charmbracelet/log" "github.com/numtide/treefmt/v2/git" "github.com/numtide/treefmt/v2/stats" - "golang.org/x/sync/errgroup" ) -type GitReader struct { - root string - path string - - log *log.Logger - stats *stats.Stats - - eg *errgroup.Group - scanner *bufio.Scanner -} - -func (g *GitReader) Read(ctx context.Context, files []*File) (n int, err error) { - // ensure we record how many files we traversed - defer func() { - g.stats.Add(stats.Traversed, n) - }() - - nextFile := func() (string, error) { - for line := g.scanner.Text(); len(line) > 0; line = g.scanner.Text() { - lineSplit := strings.Split(line, "\t") - - var stage, file string - // Untracked files just show as ``, while tracked files show as ` ` - if len(lineSplit) == 1 { - stage, file = "", lineSplit[0] - } else { - stage, file = lineSplit[0], lineSplit[1] - } - - // 160000 is the mode for submodules, skip them because they are separate projects that may have their own - // formatting rules - if strings.HasPrefix(stage, "160000") { - g.scanner.Scan() - - continue - } - - if file[0] != '"' { - return file, nil - } - - unquoted, err := strconv.Unquote(file) - if err != nil { - return "", fmt.Errorf("failed to unquote file %s: %w", file, err) - } - - return unquoted, nil - } - - return "", io.EOF - } - -LOOP: - - for n < len(files) { - select { - // exit early if the context was cancelled - case <-ctx.Done(): - return n, ctx.Err() //nolint:wrapcheck - - default: - // read the next file - if g.scanner.Scan() { - entry, err := nextFile() - if err != nil { - return n, err - } - - path := filepath.Join(g.root, g.path, entry) - - g.log.Debugf("processing file: %s", path) - - info, err := os.Lstat(path) - - switch { - case os.IsNotExist(err): - // the underlying file might have been removed - g.log.Warnf( - "Path %s is in the worktree but appears to have been removed from the filesystem", path, - ) - - continue - case err != nil: - return n, fmt.Errorf("failed to stat %s: %w", path, err) - case info.Mode()&os.ModeSymlink == os.ModeSymlink: - // we skip reporting symlinks stored in Git, they should - // point to local files which we would list anyway. - continue - } - - files[n] = &File{ - Path: path, - RelPath: filepath.Join(g.path, entry), - Info: info, - } - - n++ - } else { - // nothing more to read - err = io.EOF - - break LOOP - } - } - } - - return n, err -} - -func (g *GitReader) Close() error { - err := g.eg.Wait() - if err != nil { - return fmt.Errorf("failed to wait for git command to complete: %w", err) - } - - return nil -} +type GitReader = PathStreamReader func NewGitReader( root string, - path string, + pathFilters []string, statz *stats.Stats, ) (*GitReader, error) { // check if the root is a git repository @@ -150,34 +24,10 @@ func NewGitReader( return nil, fmt.Errorf("%s is not a git repository", root) } - // create an errgroup for executing in the background - eg := &errgroup.Group{} - - // create a pipe to capture the command output - r, w := io.Pipe() - - // create a command which will execute from the specified sub path within root - cmd := exec.CommandContext( - context.Background(), - "git", "ls-files", "--cached", "--others", "--exclude-standard", "--stage", - ) - cmd.Dir = filepath.Join(root, path) - cmd.Stdout = w - - // execute the command in the background - eg.Go(func() error { - return w.CloseWithError(cmd.Run()) + return NewPathStreamReader(root, statz, PathStreamConfig{ + Name: "git", + Command: "git", + Options: []string{"ls-files", "-z", "--cached", "--others", "--exclude-standard", "--full-name", "--"}, + PathFilters: pathFilters, }) - - // create a new scanner for reading the output - scanner := bufio.NewScanner(r) - - return &GitReader{ - eg: eg, - root: root, - path: path, - stats: statz, - scanner: scanner, - log: log.WithPrefix("walk | git"), - }, nil } diff --git a/walk/git_test.go b/walk/git_test.go index ce4e3d29..00b4bbe0 100644 --- a/walk/git_test.go +++ b/walk/git_test.go @@ -35,7 +35,7 @@ func TestGitReader(t *testing.T) { // read empty worktree statz := stats.New() - reader, err := walk.NewGitReader(tempDir, "", &statz) + reader, err := walk.NewGitReader(tempDir, nil, &statz) as.NoError(err) files := make([]*walk.File, 34) @@ -84,7 +84,7 @@ func TestGitReader(t *testing.T) { as.NoError(cmd.Run(), "failed to add everything to the index") statz = stats.New() - reader, err = walk.NewGitReader(tempDir, "", &statz) + reader, err = walk.NewGitReader(tempDir, nil, &statz) as.NoError(err) count := 0 diff --git a/walk/jujutsu.go b/walk/jujutsu.go index 4ee7a195..b2566be1 100644 --- a/walk/jujutsu.go +++ b/walk/jujutsu.go @@ -1,109 +1,17 @@ package walk import ( - "bufio" - "context" "fmt" - "io" - "os" - "os/exec" - "path/filepath" - "github.com/charmbracelet/log" "github.com/numtide/treefmt/v2/jujutsu" "github.com/numtide/treefmt/v2/stats" - "golang.org/x/sync/errgroup" ) -type JujutsuReader struct { - root string - path string - - log *log.Logger - stats *stats.Stats - - eg *errgroup.Group - scanner *bufio.Scanner -} - -func (j *JujutsuReader) Read(ctx context.Context, files []*File) (n int, err error) { - // ensure we record how many files we traversed - defer func() { - j.stats.Add(stats.Traversed, n) - }() - - nextLine := func() string { - line := j.scanner.Text() - - return line - } - -LOOP: - - for n < len(files) { - select { - // exit early if the context was cancelled - case <-ctx.Done(): - return n, ctx.Err() //nolint:wrapcheck - - default: - // read the next file - if j.scanner.Scan() { - entry := nextLine() - - path := filepath.Join(j.root, entry) - - j.log.Debugf("processing file: %s", path) - - info, err := os.Lstat(path) - - switch { - case os.IsNotExist(err): - // the underlying file might have been removed - j.log.Warnf( - "Path %s is in the worktree but appears to have been removed from the filesystem", path, - ) - - continue - case err != nil: - return n, fmt.Errorf("failed to stat %s: %w", path, err) - case info.Mode()&os.ModeSymlink == os.ModeSymlink: - // we skip reporting symlinks stored in Jujutsu, they should - // point to local files which we would list anyway. - continue - } - - files[n] = &File{ - Path: path, - RelPath: entry, - Info: info, - } - - n++ - } else { - // nothing more to read - err = io.EOF - - break LOOP - } - } - } - - return n, err -} - -func (j *JujutsuReader) Close() error { - err := j.eg.Wait() - if err != nil { - return fmt.Errorf("failed to wait for jujutsu command to complete: %w", err) - } - - return nil -} +type JujutsuReader = PathStreamReader func NewJujutsuReader( root string, - path string, + pathFilters []string, statz *stats.Stats, ) (*JujutsuReader, error) { // check if the root is a jujutsu repository @@ -116,42 +24,13 @@ func NewJujutsuReader( return nil, fmt.Errorf("%s is not a jujutsu repository", root) } - // create an errgroup for async list task - eg := &errgroup.Group{} - - // create a pipe to capture the command output - r, w := io.Pipe() - - // create a command which will execute from root // --ignore-working-copy: Don't snapshot the working copy, and don't update it. This prevents that the user has to // enter a password for signing the commit. New files also won't be added to the index and not displayed in the // output. - // Add the subpath as a fileset displaying only files matching this prefix. If - // the subpath is empty ignore it since it interferes with the command - args := []string{"file", "list", "--ignore-working-copy"} - if path != "" { - args = append(args, path) - } - - // create the jj command - cmd := exec.CommandContext(context.Background(), "jj", args...) - cmd.Dir = root - cmd.Stdout = w - - // execute the command in the background - eg.Go(func() error { - return w.CloseWithError(cmd.Run()) + return NewPathStreamReader(root, statz, PathStreamConfig{ + Name: "jujutsu", + Command: "jj", + Options: []string{"file", "list", "--ignore-working-copy", "--"}, + PathFilters: pathFilters, }) - - // create a new scanner for reading the output - scanner := bufio.NewScanner(r) - - return &JujutsuReader{ - eg: eg, - root: root, - path: path, - stats: statz, - scanner: scanner, - log: log.WithPrefix("walk | jujutsu"), - }, nil } diff --git a/walk/jujutsu_test.go b/walk/jujutsu_test.go index c50b6f11..2d829f03 100644 --- a/walk/jujutsu_test.go +++ b/walk/jujutsu_test.go @@ -27,7 +27,7 @@ func TestJujutsuReader(t *testing.T) { // read empty worktree statz := stats.New() - reader, err := walk.NewJujutsuReader(tempDir, "", &statz) + reader, err := walk.NewJujutsuReader(tempDir, nil, &statz) as.NoError(err) files := make([]*walk.File, 33) // The number of files in `test/examples` used for testing @@ -47,7 +47,7 @@ func TestJujutsuReader(t *testing.T) { as.NoError(cmd.Run(), "failed to update the index") statz = stats.New() - reader, err = walk.NewJujutsuReader(tempDir, "", &statz) + reader, err = walk.NewJujutsuReader(tempDir, nil, &statz) as.NoError(err) count := 0 diff --git a/walk/path_stream.go b/walk/path_stream.go new file mode 100644 index 00000000..93e99065 --- /dev/null +++ b/walk/path_stream.go @@ -0,0 +1,283 @@ +package walk + +import ( + "bufio" + "context" + "errors" + "fmt" + "io" + "os" + "os/exec" + "path/filepath" + "strings" + "sync" + + "github.com/charmbracelet/log" + "github.com/numtide/treefmt/v2/stats" + "golang.org/x/sync/errgroup" + "mvdan.cc/sh/v3/expand" + "mvdan.cc/sh/v3/interp" +) + +type PathStreamConfig struct { + Name string + Command string + Options []string + PathFilters []string +} + +type PathStreamReader struct { + root string + cfg PathStreamConfig + + log *log.Logger + stats *stats.Stats + + cancel context.CancelFunc + eg *errgroup.Group + scanner *bufio.Scanner + + waitMu sync.Mutex + waitErr error + waitCanceled bool +} + +func (p *PathStreamReader) Read(ctx context.Context, files []*File) (n int, err error) { + // ensure we record how many files we traversed + defer func() { + p.stats.Add(stats.Traversed, n) + }() + +LOOP: + for n < len(files) { + select { + // exit early if the context was cancelled + case <-ctx.Done(): + return n, ctx.Err() //nolint:wrapcheck + + default: + if !p.scanner.Scan() { + if err := p.scanner.Err(); err != nil { + return n, fmt.Errorf("failed to read %s output: %w", p.cfg.Name, err) + } + + err = io.EOF + + break LOOP + } + + file, ok, err := p.file(p.scanner.Text()) + if err != nil { + return n, err + } else if !ok { + continue + } + + files[n] = file + n++ + } + } + + return n, err +} + +func (p *PathStreamReader) file(record string) (*File, bool, error) { + record = strings.TrimSuffix(record, "\r") + if record == "" { + return nil, false, nil + } + + relPath, err := p.relativePath(record) + if err != nil { + return nil, false, err + } + + if !containsAnyPath(p.cfg.PathFilters, relPath) { + return nil, false, nil + } + + path := filepath.Join(p.root, relPath) + + p.log.Debugf("processing file: %s", path) + + info, err := os.Lstat(path) + switch { + case os.IsNotExist(err): + p.log.Warnf( + "Path %s was emitted by %s but appears to have been removed from the filesystem", + path, p.cfg.Name, + ) + + return nil, false, nil + case err != nil: + return nil, false, fmt.Errorf("failed to stat %s: %w", path, err) + case info.IsDir() || info.Mode()&os.ModeSymlink == os.ModeSymlink: + return nil, false, nil + } + + return &File{ + Path: path, + RelPath: relPath, + Info: info, + }, true, nil +} + +func (p *PathStreamReader) relativePath(entry string) (string, error) { + entry = filepath.Clean(entry) + if filepath.IsAbs(entry) { + relPath, err := filepath.Rel(p.root, entry) + if err != nil { + return "", fmt.Errorf("failed to determine a relative path for %s: %w", entry, err) + } + + entry = relPath + } + + if entry == ".." || strings.HasPrefix(entry, ".."+string(os.PathSeparator)) || filepath.IsAbs(entry) { + return "", fmt.Errorf("%s emitted path %s outside the tree root %s", p.cfg.Name, entry, p.root) + } + + return entry, nil +} + +func (p *PathStreamReader) Close() error { + // Unblock the walker command if the caller stopped draining Read() before + // EOF. Without this, the command can block writing stdout while Close waits + // for it to exit. + p.cancel() + + err := p.eg.Wait() + if err != nil { + return fmt.Errorf("failed to wait for %s command to complete: %w", p.cfg.Name, err) + } + + p.waitMu.Lock() + defer p.waitMu.Unlock() + + if p.waitErr != nil && !errors.Is(p.waitErr, context.Canceled) && !p.waitCanceled { + return fmt.Errorf("failed to wait for %s command to complete: %w", p.cfg.Name, p.waitErr) + } + + return nil +} + +func splitPathRecord(data []byte, atEOF bool) (advance int, token []byte, err error) { + for i, b := range data { + if b == '\n' || b == 0 { + return i + 1, data[:i], nil + } + } + + if atEOF && len(data) > 0 { + return len(data), data, nil + } + + return 0, nil, nil +} + +func containsAnyPath(filters []string, path string) bool { + if len(filters) == 0 { + return true + } + + for _, filter := range filters { + if containsPath(filter, path) { + return true + } + } + + return false +} + +func containsPath(root string, path string) bool { + if root == "" || root == "." || path == root { + return true + } + + relPath, err := filepath.Rel(root, path) + if err != nil { + return false + } + + return relPath != ".." && + !strings.HasPrefix(relPath, ".."+string(os.PathSeparator)) && + !filepath.IsAbs(relPath) +} + +func NewPathStreamReader( + root string, + statz *stats.Stats, + cfg PathStreamConfig, +) (*PathStreamReader, error) { + env := expand.ListEnviron(os.Environ()...) + + executable, err := interp.LookPathDir(root, env, cfg.Command) + if err != nil { + return nil, fmt.Errorf("failed to find %s command %q: %w", cfg.Name, cfg.Command, err) + } + + eg := &errgroup.Group{} + + args := append([]string{}, cfg.Options...) + args = append(args, cfg.PathFilters...) + + ctx, cancel := context.WithCancel(context.Background()) + cmd := exec.CommandContext(ctx, executable, args...) + cmd.Dir = root + + // Don't use cmd.StdoutPipe here. Its docs say it is incorrect to call + // Cmd.Wait before all reads from the pipe have completed, and we wait for + // the command in a producer goroutine while Read drains stdout. + // See https://pkg.go.dev/os/exec#Cmd.StdoutPipe. + stdout, stdoutW := io.Pipe() + stderr, stderrW := io.Pipe() + + cmd.Stdout = stdoutW + cmd.Stderr = stderrW + + scanner := bufio.NewScanner(stdout) + scanner.Split(splitPathRecord) + + reader := &PathStreamReader{ + root: root, + cfg: cfg, + log: log.WithPrefix("walk | " + cfg.Name), + stats: statz, + cancel: cancel, + eg: eg, + scanner: scanner, + } + + eg.Go(func() error { + err := cmd.Run() + + reader.waitMu.Lock() + reader.waitErr = err + reader.waitCanceled = ctx.Err() != nil + reader.waitMu.Unlock() + + closeErr := stdoutW.Close() + if stderrCloseErr := stderrW.Close(); stderrCloseErr != nil && closeErr == nil { + closeErr = stderrCloseErr + } + + return closeErr + }) + + eg.Go(func() error { + l := log.WithPrefix("walk | " + cfg.Name + " | stderr") + + scanner := bufio.NewScanner(stderr) + for scanner.Scan() { + l.Debugf("%s", scanner.Text()) + } + + if err := scanner.Err(); err != nil { + return fmt.Errorf("failed to read %s stderr: %w", cfg.Name, err) + } + + return nil + }) + + return reader, nil +} diff --git a/walk/walk.go b/walk/walk.go index 00c6aa0a..57819f53 100644 --- a/walk/walk.go +++ b/walk/walk.go @@ -5,7 +5,6 @@ import ( "crypto/md5" //nolint:gosec "errors" "fmt" - "io" "io/fs" "os" "path/filepath" @@ -30,6 +29,43 @@ const ( const BatchSize = 1024 +type CustomConfig struct { + Name string + Command string + Options []string +} + +type selectorKind int + +const ( + builtinSelector selectorKind = iota + customSelector +) + +type Selector struct { + kind selectorKind + builtin Type + custom CustomConfig +} + +func BuiltinSelector(walkType Type) Selector { + return Selector{ + kind: builtinSelector, + builtin: walkType, + } +} + +func CustomSelector(cfg CustomConfig) Selector { + return Selector{ + kind: customSelector, + custom: cfg, + } +} + +func (s Selector) IsBuiltin(walkType Type) bool { + return s.kind == builtinSelector && s.builtin == walkType +} + type ReleaseFunc func(ctx context.Context) error // File represents a file object with its path, relative path, file info, and potential cache entry. @@ -150,59 +186,33 @@ type Reader interface { Close() error } -// CompositeReader combines multiple Readers into one. -// It iterates over the given readers, reading each until completion. -type CompositeReader struct { - idx int - current Reader - readers []Reader -} - -func (c *CompositeReader) Read(ctx context.Context, files []*File) (n int, err error) { - if c.current == nil { - // check if we have exhausted all the readers - if c.idx >= len(c.readers) { - return 0, io.EOF - } - - // if not, select the next reader - c.current = c.readers[c.idx] - c.idx++ - } - - // attempt a read - n, err = c.current.Read(ctx, files) - - // check if the current reader has been exhausted - if errors.Is(err, io.EOF) { - // reset the error if it's EOF - err = nil - // set the current reader to nil so we try to read from the next reader on the next call - c.current = nil - } else if err != nil { - err = fmt.Errorf("failed to read from current reader: %w", err) +//nolint:ireturn +func NewReader( + selector Selector, + root string, + pathFilters []string, + db *bolt.DB, + statz *stats.Stats, +) (Reader, error) { + reader, err := newUncachedReader(selector, root, pathFilters, statz) + if err != nil { + return nil, err } - // return the number of files read in this call and any error - return n, err -} - -func (c *CompositeReader) Close() error { - for _, reader := range c.readers { - if err := reader.Close(); err != nil { - return fmt.Errorf("failed to close reader: %w", err) - } + if db != nil { + // wrap with cached reader + // db will be null if --no-cache is enabled + reader, err = NewCachedReader(db, BatchSize, reader) } - return nil + return reader, err } //nolint:ireturn -func NewReader( - walkType Type, +func newUncachedReader( + selector Selector, root string, - path string, - db *bolt.DB, + pathFilters []string, statz *stats.Stats, ) (Reader, error) { var ( @@ -210,50 +220,52 @@ func NewReader( reader Reader ) - switch walkType { - case Auto: - // for now, we keep it simple and try git first, jujutsu second, and filesystem last - reader, err = NewReader(Git, root, path, db, statz) - if err != nil { - reader, err = NewReader(Jujutsu, root, path, db, statz) + switch selector.kind { + case customSelector: + reader, err = NewCustomReader(root, pathFilters, statz, selector.custom) + case builtinSelector: + switch selector.builtin { + case Auto: + // for now, we keep it simple and try git first, jujutsu second, and filesystem last + reader, err = newUncachedReader(BuiltinSelector(Git), root, pathFilters, statz) if err != nil { - reader, err = NewReader(Filesystem, root, path, db, statz) + reader, err = newUncachedReader(BuiltinSelector(Jujutsu), root, pathFilters, statz) + if err != nil { + reader, err = newUncachedReader(BuiltinSelector(Filesystem), root, pathFilters, statz) + } } - } - return reader, err - case Stdin: - return nil, errors.New("stdin walk type is not supported") - case Filesystem: - reader = NewFilesystemReader(root, path, statz, BatchSize) - case Git: - reader, err = NewGitReader(root, path, statz) - case Jujutsu: - reader, err = NewJujutsuReader(root, path, statz) + return reader, err + case Stdin: + return nil, errors.New("stdin walk type is not supported") + case Filesystem: + reader = NewFilesystemReader(root, pathFilters, statz, BatchSize) + case Git: + reader, err = NewGitReader(root, pathFilters, statz) + case Jujutsu: + reader, err = NewJujutsuReader(root, pathFilters, statz) + + default: + return nil, fmt.Errorf("unknown walk type: %v", selector.builtin) + } default: - return nil, fmt.Errorf("unknown walk type: %v", walkType) + return nil, fmt.Errorf("unknown selector type: %v", selector.kind) } if err != nil { return nil, err } - if db != nil { - // wrap with cached reader - // db will be null if --no-cache is enabled - reader, err = NewCachedReader(db, BatchSize, reader) - } - return reader, err } -// NewCompositeReader returns a composite reader for the `root` and all `paths`. It -// never follows symlinks. +// NewCompositeReader returns a reader for the `root` and all `paths`. It never +// follows symlinks. // //nolint:ireturn func NewCompositeReader( - walkType Type, + selector Selector, root string, paths []string, db *bolt.DB, @@ -271,13 +283,11 @@ func NewCompositeReader( // if no paths are provided we default to processing the tree root if len(paths) == 0 { - return NewReader(walkType, root, "", db, statz) + return NewReader(selector, root, nil, db, statz) } - readers := make([]Reader, len(paths)) - // check we have received 1 path for the stdin walk type - if walkType == Stdin { + if selector.IsBuiltin(Stdin) { if len(paths) != 1 { return nil, errors.New("stdin walk requires exactly one path") } @@ -291,13 +301,8 @@ func NewCompositeReader( return NewStdinReader(root, path, statz), nil } - // create a reader for each provided path - for idx, path := range paths { - var ( - err error - info os.FileInfo - ) - + pathFilters := make([]string, 0, len(paths)) + for _, path := range paths { resolvedPath, err := resolvePath(path) if err != nil { return nil, fmt.Errorf("error resolving path %s: %w", path, err) @@ -313,27 +318,14 @@ func NewCompositeReader( } // check the path exists - info, err = os.Lstat(resolvedPath) - if err != nil { + if _, err = os.Lstat(resolvedPath); err != nil { return nil, fmt.Errorf("failed to stat %s: %w", resolvedPath, err) } - if info.IsDir() { - // for directories, we honour the walk type as we traverse them - readers[idx], err = NewReader(walkType, root, relativePath, db, statz) - } else { - // for files, we enforce a simple filesystem read - readers[idx], err = NewReader(Filesystem, root, relativePath, db, statz) - } - - if err != nil { - return nil, fmt.Errorf("failed to create reader for %s: %w", relativePath, err) - } + pathFilters = append(pathFilters, relativePath) } - return &CompositeReader{ - readers: readers, - }, nil + return NewReader(selector, root, pathFilters, db, statz) } // Resolve a path to an absolute path, resolving any symlinks along the way.