diff --git a/AGENTS.md b/AGENTS.md index 191593e..403ea63 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -240,8 +240,11 @@ Use in custom generator: ```go // tools/gendi/main.go func main() { + // Always-included passes passes := []di.Pass{&MyPass{}} - cmd.Run(flag.CommandLine, passes) + // Selectable passes (filtered by --enable-pass flag) + selectablePasses := []di.Pass{} + cmd.Run(flag.CommandLine, passes, selectablePasses) } ``` diff --git a/README.md b/README.md index 97cf56d..5cc58d2 100644 --- a/README.md +++ b/README.md @@ -134,6 +134,8 @@ Flags: --container string Container struct name (default: "Container") --strict Enable strict validation (default: true) --build-tags string Build tags for generated file + --enable-pass string + Enable an optional compiler pass (repeatable) --verbose Enable verbose logging ``` diff --git a/cmd/cli.go b/cmd/cli.go index 3802500..db55bb7 100644 --- a/cmd/cli.go +++ b/cmd/cli.go @@ -9,9 +9,17 @@ import ( di "github.com/asp24/gendi" "github.com/asp24/gendi/pipeline" "github.com/asp24/gendi/srcloc" + "github.com/asp24/gendi/stdlib" "github.com/asp24/gendi/yaml" ) +func BuiltinSelectablePasses() []di.Pass { + return []di.Pass{ + &stdlib.SLogPass{}, + &di.ExposeAllPass{}, + } +} + // WriteTargetFile writes data to the specified file path func WriteTargetFile(path string, data []byte) error { if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { @@ -56,8 +64,8 @@ func Generate(cfg Config, passes []di.Pass) error { return nil } -// Run executes the full gendi workflow with optional compiler passes -func Run(flags *flag.FlagSet, passes []di.Pass) error { +// Run executes the full gendi workflow with compiler passes +func Run(flags *flag.FlagSet, passes, selectablePasses []di.Pass) error { var cfg Config cfg.RegisterFlags(flags) @@ -65,7 +73,12 @@ func Run(flags *flag.FlagSet, passes []di.Pass) error { return fmt.Errorf("parse flags: %w", err) } - return Generate(cfg, passes) + resolvedPasses, err := cfg.resolvePasses(passes, selectablePasses) + if err != nil { + return fmt.Errorf("resolve passes: %w", err) + } + + return Generate(cfg, resolvedPasses) } func PrintErrorAndExit(err error) { @@ -79,6 +92,6 @@ func PrintErrorAndExit(err error) { os.Exit(1) } -func MustRun(flags *flag.FlagSet, passes []di.Pass) { - PrintErrorAndExit(Run(flags, passes)) +func MustRun(flags *flag.FlagSet, passes, selectablePasses []di.Pass) { + PrintErrorAndExit(Run(flags, passes, selectablePasses)) } diff --git a/cmd/config.go b/cmd/config.go index 7fce479..beae42b 100644 --- a/cmd/config.go +++ b/cmd/config.go @@ -4,13 +4,15 @@ import ( "flag" "fmt" + di "github.com/asp24/gendi" "github.com/asp24/gendi/pipeline" ) // Config holds CLI configuration type Config struct { - ConfigPath string - Options pipeline.Options + ConfigPath string + Options pipeline.Options + EnabledPasses map[string]struct{} } func (c *Config) RegisterFlags(flags *flag.FlagSet) { @@ -21,6 +23,57 @@ func (c *Config) RegisterFlags(flags *flag.FlagSet) { flags.BoolVar(&c.Options.Strict, "strict", true, "Enable strict validation") flags.StringVar(&c.Options.BuildTags, "build-tags", "", "Go build tags") flags.BoolVar(&c.Options.Verbose, "verbose", false, "Verbose logging") + flags.Var(&stringSetFlag{values: &c.EnabledPasses}, "enable-pass", "Enable a specific compiler pass (can be specified multiple times)") +} + +func (c *Config) validatePasses(selectablePasses []di.Pass) error { + known := make(map[string]struct{}, len(selectablePasses)) + for _, p := range selectablePasses { + known[p.Name()] = struct{}{} + } + + for name := range c.EnabledPasses { + if _, ok := known[name]; !ok { + return fmt.Errorf("--enable-pass: unknown pass %q", name) + } + } + + return nil +} + +func (c *Config) resolvePasses(passes, selectablePasses []di.Pass) ([]di.Pass, error) { + if err := c.validatePasses(selectablePasses); err != nil { + return nil, err + } + + result := make([]di.Pass, 0, len(passes)+len(selectablePasses)) + included := make(map[string]struct{}, len(passes)+len(selectablePasses)) + for _, p := range passes { + name := p.Name() + if _, ok := included[name]; ok { + continue + } + + result = append(result, p) + included[name] = struct{}{} + } + + for _, p := range selectablePasses { + name := p.Name() + if _, ok := included[name]; ok { + continue + } + + _, enabled := c.EnabledPasses[name] + if !enabled { + continue + } + + result = append(result, p) + included[name] = struct{}{} + } + + return result, nil } // Finalize validates and finalizes the configuration diff --git a/cmd/config_test.go b/cmd/config_test.go new file mode 100644 index 0000000..86e44a8 --- /dev/null +++ b/cmd/config_test.go @@ -0,0 +1,116 @@ +package cmd + +import ( + "testing" + + di "github.com/asp24/gendi" +) + +type testPass struct { + name string +} + +func (p *testPass) Name() string { return p.name } +func (p *testPass) Process(cfg *di.Config) (*di.Config, error) { return cfg, nil } + +func makePass(name string) *testPass { + return &testPass{name: name} +} + +func TestConfig_ResolvePasses(t *testing.T) { + cases := []struct { + name string + enabled map[string]struct{} + passes []di.Pass + selectablePasses []di.Pass + wantNames []string + }{ + { + name: "always-included pass is included", + passes: []di.Pass{makePass("a")}, + wantNames: []string{"a"}, + }, + { + name: "selectable pass without enable flag is excluded", + selectablePasses: []di.Pass{makePass("a")}, + wantNames: []string{}, + }, + { + name: "selectable pass with enable flag is included", + enabled: map[string]struct{}{"a": {}}, + selectablePasses: []di.Pass{makePass("a")}, + wantNames: []string{"a"}, + }, + { + name: "duplicate always-included pass name runs only once", + passes: []di.Pass{makePass("a"), makePass("a")}, + wantNames: []string{"a"}, + }, + { + name: "duplicate selectable pass name runs only once", + enabled: map[string]struct{}{"a": {}}, + selectablePasses: []di.Pass{makePass("a"), makePass("a")}, + wantNames: []string{"a"}, + }, + { + name: "always-included pass wins over selectable pass with same name", + enabled: map[string]struct{}{"a": {}}, + passes: []di.Pass{makePass("a")}, + selectablePasses: []di.Pass{makePass("a")}, + wantNames: []string{"a"}, + }, + { + name: "selectable passes are appended after always-included passes", + enabled: map[string]struct{}{"b": {}}, + passes: []di.Pass{makePass("a")}, + selectablePasses: []di.Pass{makePass("b")}, + wantNames: []string{"a", "b"}, + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + cfg := Config{EnabledPasses: tc.enabled} + result, err := cfg.resolvePasses(tc.passes, tc.selectablePasses) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(result) != len(tc.wantNames) { + t.Fatalf("got %d passes, want %d", len(result), len(tc.wantNames)) + } + for i, p := range result { + if p.Name() != tc.wantNames[i] { + t.Errorf("pass[%d] name = %q, want %q", i, p.Name(), tc.wantNames[i]) + } + } + }) + } +} + +func TestConfig_ResolvePasses_Errors(t *testing.T) { + cases := []struct { + name string + enabled map[string]struct{} + passes []di.Pass + selectablePasses []di.Pass + }{ + { + name: "unknown name in --enable-pass", + enabled: map[string]struct{}{"unknown": {}}, + selectablePasses: []di.Pass{makePass("foo")}, + }, + { + name: "always-included pass is not selectable", + enabled: map[string]struct{}{"foo": {}}, + passes: []di.Pass{makePass("foo")}, + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + cfg := Config{EnabledPasses: tc.enabled} + _, err := cfg.resolvePasses(tc.passes, tc.selectablePasses) + if err == nil { + t.Error("expected error, got nil") + } + }) + } +} diff --git a/cmd/gendi/main.go b/cmd/gendi/main.go index c5ffd36..6ee4317 100644 --- a/cmd/gendi/main.go +++ b/cmd/gendi/main.go @@ -7,5 +7,5 @@ import ( ) func main() { - cmd.MustRun(flag.CommandLine, nil) + cmd.MustRun(flag.CommandLine, nil, cmd.BuiltinSelectablePasses()) } diff --git a/cmd/string_set_flag.go b/cmd/string_set_flag.go new file mode 100644 index 0000000..cc25a7c --- /dev/null +++ b/cmd/string_set_flag.go @@ -0,0 +1,31 @@ +package cmd + +import ( + "sort" + "strings" +) + +// stringSetFlag is a flag.Value that collects multiple values into a set. +type stringSetFlag struct { + values *map[string]struct{} +} + +func (f *stringSetFlag) String() string { + if f.values == nil || *f.values == nil { + return "" + } + values := make([]string, 0, len(*f.values)) + for value := range *f.values { + values = append(values, value) + } + sort.Strings(values) + return strings.Join(values, ",") +} + +func (f *stringSetFlag) Set(s string) error { + if *f.values == nil { + *f.values = make(map[string]struct{}) + } + (*f.values)[s] = struct{}{} + return nil +} diff --git a/cmd/string_set_flag_test.go b/cmd/string_set_flag_test.go new file mode 100644 index 0000000..d3e9104 --- /dev/null +++ b/cmd/string_set_flag_test.go @@ -0,0 +1,39 @@ +package cmd + +import "testing" + +func TestStringSetFlag_String(t *testing.T) { + cases := []struct { + name string + m map[string]struct{} + want string + }{ + {"nil map", nil, ""}, + {"multiple values sorted", map[string]struct{}{"beta": {}, "alpha": {}, "gamma": {}}, "alpha,beta,gamma"}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + f := &stringSetFlag{values: &tc.m} + if got := f.String(); got != tc.want { + t.Errorf("got %q, want %q", got, tc.want) + } + }) + } +} + +func TestStringSetFlag_Set(t *testing.T) { + var m map[string]struct{} + f := &stringSetFlag{values: &m} + f.Set("a") + f.Set("b") + f.Set("a") // duplicate + want := map[string]struct{}{"a": {}, "b": {}} + if got := len(*f.values); got != len(want) { + t.Fatalf("len = %d, want %d", got, len(want)) + } + for k := range want { + if _, ok := (*f.values)[k]; !ok { + t.Errorf("missing key %q", k) + } + } +} diff --git a/doc/LLM.md b/doc/LLM.md index 4924661..a1b5453 100644 --- a/doc/LLM.md +++ b/doc/LLM.md @@ -8,7 +8,7 @@ Short, stable facts for tooling and assistants. go tool gendi --config=gendi.yaml --out=./di --pkg=di ``` -Flags: `--config`, `--out`, `--pkg`, `--container`, `--strict`, `--build-tags`, `--verbose`. +Flags: `--config`, `--out`, `--pkg`, `--container`, `--strict`, `--build-tags`, `--enable-pass`, `--verbose`. ## YAML Syntax diff --git a/doc/custom-passes.md b/doc/custom-passes.md index 2ce21a6..bb58dde 100644 --- a/doc/custom-passes.md +++ b/doc/custom-passes.md @@ -6,6 +6,7 @@ Compiler passes transform configuration before code generation, enabling project - [Overview](#overview) - [Pass Interface](#pass-interface) +- [Optional CLI Passes](#optional-cli-passes) - [Creating a Pass](#creating-a-pass) - [Building a Custom Generator](#building-a-custom-generator) - [Common Use Cases](#common-use-cases) @@ -66,6 +67,19 @@ type Pass interface { - Returns an error if transformation fails - **Must not modify the input config** (create a copy if needed) +## CLI Passes + +Custom generator binaries built with `cmd.Run` or `cmd.MustRun` register two types of passes: + +- **Always-included passes**: Passed as the first `passes` parameter, always run +- **Selectable passes**: Passed as the second `selectablePasses` parameter, filtered by `--enable-pass` flag + +Pass names come from `Name()`. If the same pass name is registered more than once, only the first included pass runs. + +`cmd.Run` validates pass flags before generation and returns an error if a name passed to `--enable-pass` does not match any registered selectable pass. + +Use `di.Pass` when calling `di.ApplyPasses`, `cmd.Generate`, `cmd.Run`, or `cmd.MustRun`. + ## Creating a Pass ### Basic Pass Structure @@ -130,7 +144,7 @@ func (p *MyPass) Process(cfg *di.Config) (*di.Config, error) { ```go func (p *MyPass) Process(cfg *di.Config) (*di.Config, error) { cfg.Services["new_service"] = di.Service{ - Constructor: &di.Constructor{ + Constructor: di.Constructor{ Func: "github.com/myapp.NewService", Args: []di.Argument{ {Kind: di.ArgLiteral, Literal: di.NewStringLiteral("value")}, @@ -190,14 +204,18 @@ import ( ) func main() { - // Define custom compiler passes + // Define always-included custom passes customPasses := []di.Pass{ &passes.AutoTagPass{}, + } + + // Define selectable passes (filtered by --enable-pass flag) + selectablePasses := []di.Pass{ &passes.ValidationPass{}, } // Run gendi with custom passes - if err := cmd.Run(flag.CommandLine, customPasses); err != nil { + if err := cmd.Run(flag.CommandLine, customPasses, selectablePasses); err != nil { fmt.Fprintf(os.Stderr, "%v\n", err) os.Exit(1) } @@ -210,9 +228,12 @@ func main() { # Build custom generator go build -o bin/gendi ./tools/gendi -# Run custom generator +# Run custom generator (AutoTagPass always runs) ./bin/gendi --config=gendi.yaml --out=./di --pkg=di +# Enable ValidationPass via flag +./bin/gendi --config=gendi.yaml --out=./di --pkg=di --enable-pass=validation + # Or use go run go run ./tools/gendi --config=gendi.yaml --out=./di --pkg=di ``` @@ -278,13 +299,11 @@ func (p *LoggingPass) Process(cfg *di.Config) (*di.Config, error) { } // Add logger as first argument - if svc.Constructor != nil { - svc.Constructor.Args = append( - []di.Argument{{Kind: di.ArgServiceRef, Value: "logger"}}, - svc.Constructor.Args..., - ) - cfg.Services[id] = svc - } + svc.Constructor.Args = append( + []di.Argument{{Kind: di.ArgServiceRef, Value: "logger"}}, + svc.Constructor.Args..., + ) + cfg.Services[id] = svc } return cfg, nil @@ -417,7 +436,7 @@ type AutoTagPass struct{} Return descriptive errors: ```go -if svc.Constructor == nil { +if svc.Constructor.Func == "" && svc.Constructor.Method == "" { return nil, fmt.Errorf( "service %q: missing constructor (required by auto-tag pass)", id, @@ -460,7 +479,7 @@ func TestAutoTagPass(t *testing.T) { cfg := &di.Config{ Services: map[string]di.Service{ "home.handler": { - Constructor: &di.Constructor{ + Constructor: di.Constructor{ Func: "app.NewHomeHandler", }, }, @@ -493,6 +512,8 @@ customPasses := []di.Pass{ } ``` +If these passes are registered with `cmd.Run`, use `[]di.Pass`. + ## Complete Example See [examples/custom-pass](../examples/custom-pass) for a production-ready implementation featuring: @@ -511,7 +532,7 @@ func (p *ChannelLoggerPass) Name() string { func (p *ChannelLoggerPass) Process(cfg *di.Config) (*di.Config, error) { for id, svc := range cfg.Services { // Only process method constructors - if svc.Constructor == nil || svc.Constructor.Method == "" { + if svc.Constructor.Method == "" { continue } @@ -587,7 +608,7 @@ type Config struct { // Service represents a service definition type Service struct { Type string - Constructor *Constructor + Constructor Constructor Alias string Shared bool Public bool diff --git a/doc/spec/cli.md b/doc/spec/cli.md index 45e4491..339c56a 100644 --- a/doc/spec/cli.md +++ b/doc/spec/cli.md @@ -16,6 +16,7 @@ gendi | `--container` | Container struct name | | `--strict` | Enable strict validation (default: true) | | `--build-tags` | Go build tags | +| `--enable-pass` | Enable a selectable compiler pass by name; repeat for multiple passes; errors on unknown name or if pass is not registered as selectable | | `--verbose` | Verbose logging | ## go:generate diff --git a/examples/custom-pass/README.md b/examples/custom-pass/README.md index b2ce855..49c2192 100644 --- a/examples/custom-pass/README.md +++ b/examples/custom-pass/README.md @@ -21,7 +21,6 @@ examples/custom-pass/ │ │ └── gendi.yaml # Service definitions │ └── di/ # Custom compiler passes │ ├── autotag_pass.go -│ └── slog_pass.go ``` ## Custom Passes @@ -68,16 +67,16 @@ This pass demonstrates **variadic function support** - `slog.Logger.With(args .. ### 1. Define Custom Passes -Implement the `di.Pass` interface: +Implement the `di.Pass` interface for project-specific behavior. This example defines `AutoTagPass` locally and reuses `stdlib.SLogPass` for structured logger wiring: ```go -type SLogPass struct{} +type AutoTagPass struct{} -func (s *SLogPass) Name() string { - return "slog" +func (p *AutoTagPass) Name() string { + return "auto-tag" } -func (s *SLogPass) Process(cfg *di.Config) (*di.Config, error) { +func (p *AutoTagPass) Process(cfg *di.Config) (*di.Config, error) { // Transform config and return modified version return cfg, nil } @@ -88,15 +87,13 @@ func (s *SLogPass) Process(cfg *di.Config) (*di.Config, error) { `tools/gendi/main.go`: ```go func main() { + // Always-included passes passes := []gendi.Pass{ &di.AutoTagPass{}, - &di.SLogPass{}, + &stdlib.SLogPass{}, } - if err := cmd.Run(flag.CommandLine, passes); err != nil { - fmt.Fprintf(os.Stderr, "%v\n", err) - os.Exit(1) - } + cmd.MustRun(flag.CommandLine, passes, nil) } ``` diff --git a/examples/custom-pass/tools/gendi/main.go b/examples/custom-pass/tools/gendi/main.go index 18c7c85..ad58909 100644 --- a/examples/custom-pass/tools/gendi/main.go +++ b/examples/custom-pass/tools/gendi/main.go @@ -6,14 +6,15 @@ import ( gendi "github.com/asp24/gendi" "github.com/asp24/gendi/cmd" "github.com/asp24/gendi/examples/custom-pass/internal/di" + "github.com/asp24/gendi/stdlib" ) func main() { // Register custom compiler passes passes := []gendi.Pass{ &di.AutoTagPass{}, - &di.SLogPass{}, + &stdlib.SLogPass{}, } - cmd.MustRun(flag.CommandLine, passes) + cmd.MustRun(flag.CommandLine, passes, nil) } diff --git a/integration/integration_test.go b/integration/integration_test.go index 7902194..6568dbf 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -63,7 +63,7 @@ func runEmbeddedTest(t *testing.T, testName string, expectedOutput string, wantC } // Compile the code - compileCmd := exec.Command("go", "build", "-o", "app") + compileCmd := exec.Command("go", "build", "-buildvcs=false", "-o", "app") compileCmd.Dir = tmpDir compileOutput, err := compileCmd.CombinedOutput() if err != nil { diff --git a/public_pass.go b/public_pass.go new file mode 100644 index 0000000..24db38c --- /dev/null +++ b/public_pass.go @@ -0,0 +1,25 @@ +package di + +// ExposeAllPass promotes every service to public, causing the generator to emit +// a public getter for each one. Intended for test containers that need direct +// access to all services regardless of how they are declared in the YAML config. +// +// Enable via: --enable-pass=expose-all +// +// Avoid using in production containers — it overrides explicit `public: false` +// declarations and disables unreachable-service pruning (all services become +// reachable roots), so every imported service gets a generated getter. +type ExposeAllPass struct{} + +func (p *ExposeAllPass) Name() string { + return "expose-all" +} + +func (p *ExposeAllPass) Process(cfg *Config) (*Config, error) { + for id, svc := range cfg.Services { + svc.Public = true + cfg.Services[id] = svc + } + + return cfg, nil +} diff --git a/stdlib/README.md b/stdlib/README.md index 5c40afa..586f67b 100644 --- a/stdlib/README.md +++ b/stdlib/README.md @@ -549,6 +549,45 @@ services: shared: true ``` +## Compiler Passes + +The stdlib package also provides compiler passes for use in custom generator binaries. + +### SLogPass + +**Pass name:** `slog` + +Automatically wires structured logging into services that follow the slog naming convention. Use it in a custom generator built with `cmd.Run` or `cmd.MustRun`: + +```go +import ( + "flag" + gendi "github.com/asp24/gendi" + "github.com/asp24/gendi/cmd" + "github.com/asp24/gendi/stdlib" +) + +func main() { + // Always-included passes + passes := []gendi.Pass{ + &stdlib.SLogPass{}, + } + cmd.MustRun(flag.CommandLine, passes, nil) +} +``` + +To make SLogPass selectable via `--enable-pass=slog`, put it in the second parameter: + +```go +func main() { + passes := []gendi.Pass{} + selectablePasses := []gendi.Pass{ + &stdlib.SLogPass{}, + } + cmd.MustRun(flag.CommandLine, passes, selectablePasses) +} +``` + ## See Also - [Configuration Reference](../doc/configuration.md) diff --git a/examples/custom-pass/internal/di/slog_pass.go b/stdlib/slog_pass.go similarity index 97% rename from examples/custom-pass/internal/di/slog_pass.go rename to stdlib/slog_pass.go index 56d434f..e6e6cef 100644 --- a/examples/custom-pass/internal/di/slog_pass.go +++ b/stdlib/slog_pass.go @@ -1,4 +1,4 @@ -package di +package stdlib import ( "strings" @@ -6,8 +6,7 @@ import ( di "github.com/asp24/gendi" ) -type SLogPass struct { -} +type SLogPass struct{} func (s *SLogPass) Name() string { return "slog"