Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,11 @@ Use in custom generator:
```go
// tools/gendi/main.go
func main() {
// Always-included passes
passes := []di.Pass{&MyPass{}}
cmd.Run(flag.CommandLine, passes)
// Selectable passes (filtered by --enable-pass flag)
selectablePasses := []di.Pass{}
cmd.Run(flag.CommandLine, passes, selectablePasses)
}
```

Expand Down
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ Flags:
--container string Container struct name (default: "Container")
--strict Enable strict validation (default: true)
--build-tags string Build tags for generated file
--enable-pass string
Enable an optional compiler pass (repeatable)
--verbose Enable verbose logging
```

Expand Down
23 changes: 18 additions & 5 deletions cmd/cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,17 @@ import (
di "github.com/asp24/gendi"
"github.com/asp24/gendi/pipeline"
"github.com/asp24/gendi/srcloc"
"github.com/asp24/gendi/stdlib"
"github.com/asp24/gendi/yaml"
)

func BuiltinSelectablePasses() []di.Pass {
return []di.Pass{
&stdlib.SLogPass{},
&di.ExposeAllPass{},
}
}

// WriteTargetFile writes data to the specified file path
func WriteTargetFile(path string, data []byte) error {
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
Expand Down Expand Up @@ -56,16 +64,21 @@ func Generate(cfg Config, passes []di.Pass) error {
return nil
}

// Run executes the full gendi workflow with optional compiler passes
func Run(flags *flag.FlagSet, passes []di.Pass) error {
// Run executes the full gendi workflow with compiler passes
func Run(flags *flag.FlagSet, passes, selectablePasses []di.Pass) error {
var cfg Config
cfg.RegisterFlags(flags)

if err := flags.Parse(os.Args[1:]); err != nil {
return fmt.Errorf("parse flags: %w", err)
}

return Generate(cfg, passes)
resolvedPasses, err := cfg.resolvePasses(passes, selectablePasses)
if err != nil {
return fmt.Errorf("resolve passes: %w", err)
}

return Generate(cfg, resolvedPasses)
}

func PrintErrorAndExit(err error) {
Expand All @@ -79,6 +92,6 @@ func PrintErrorAndExit(err error) {
os.Exit(1)
}

func MustRun(flags *flag.FlagSet, passes []di.Pass) {
PrintErrorAndExit(Run(flags, passes))
func MustRun(flags *flag.FlagSet, passes, selectablePasses []di.Pass) {
PrintErrorAndExit(Run(flags, passes, selectablePasses))
}
57 changes: 55 additions & 2 deletions cmd/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@ import (
"flag"
"fmt"

di "github.com/asp24/gendi"
"github.com/asp24/gendi/pipeline"
)

// Config holds CLI configuration
type Config struct {
ConfigPath string
Options pipeline.Options
ConfigPath string
Options pipeline.Options
EnabledPasses map[string]struct{}
}

func (c *Config) RegisterFlags(flags *flag.FlagSet) {
Expand All @@ -21,6 +23,57 @@ func (c *Config) RegisterFlags(flags *flag.FlagSet) {
flags.BoolVar(&c.Options.Strict, "strict", true, "Enable strict validation")
flags.StringVar(&c.Options.BuildTags, "build-tags", "", "Go build tags")
flags.BoolVar(&c.Options.Verbose, "verbose", false, "Verbose logging")
flags.Var(&stringSetFlag{values: &c.EnabledPasses}, "enable-pass", "Enable a specific compiler pass (can be specified multiple times)")
}

func (c *Config) validatePasses(selectablePasses []di.Pass) error {
known := make(map[string]struct{}, len(selectablePasses))
for _, p := range selectablePasses {
known[p.Name()] = struct{}{}
}

for name := range c.EnabledPasses {
if _, ok := known[name]; !ok {
return fmt.Errorf("--enable-pass: unknown pass %q", name)
}
}

return nil
}

func (c *Config) resolvePasses(passes, selectablePasses []di.Pass) ([]di.Pass, error) {
if err := c.validatePasses(selectablePasses); err != nil {
return nil, err
}

result := make([]di.Pass, 0, len(passes)+len(selectablePasses))
included := make(map[string]struct{}, len(passes)+len(selectablePasses))
for _, p := range passes {
name := p.Name()
if _, ok := included[name]; ok {
continue
}

result = append(result, p)
included[name] = struct{}{}
}

for _, p := range selectablePasses {
name := p.Name()
if _, ok := included[name]; ok {
continue
}

_, enabled := c.EnabledPasses[name]
if !enabled {
continue
}

result = append(result, p)
included[name] = struct{}{}
}

return result, nil
}

// Finalize validates and finalizes the configuration
Expand Down
116 changes: 116 additions & 0 deletions cmd/config_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
package cmd

import (
"testing"

di "github.com/asp24/gendi"
)

type testPass struct {
name string
}

func (p *testPass) Name() string { return p.name }
func (p *testPass) Process(cfg *di.Config) (*di.Config, error) { return cfg, nil }

func makePass(name string) *testPass {
return &testPass{name: name}
}

func TestConfig_ResolvePasses(t *testing.T) {
cases := []struct {
name string
enabled map[string]struct{}
passes []di.Pass
selectablePasses []di.Pass
wantNames []string
}{
{
name: "always-included pass is included",
passes: []di.Pass{makePass("a")},
wantNames: []string{"a"},
},
{
name: "selectable pass without enable flag is excluded",
selectablePasses: []di.Pass{makePass("a")},
wantNames: []string{},
},
{
name: "selectable pass with enable flag is included",
enabled: map[string]struct{}{"a": {}},
selectablePasses: []di.Pass{makePass("a")},
wantNames: []string{"a"},
},
{
name: "duplicate always-included pass name runs only once",
passes: []di.Pass{makePass("a"), makePass("a")},
wantNames: []string{"a"},
},
{
name: "duplicate selectable pass name runs only once",
enabled: map[string]struct{}{"a": {}},
selectablePasses: []di.Pass{makePass("a"), makePass("a")},
wantNames: []string{"a"},
},
{
name: "always-included pass wins over selectable pass with same name",
enabled: map[string]struct{}{"a": {}},
passes: []di.Pass{makePass("a")},
selectablePasses: []di.Pass{makePass("a")},
wantNames: []string{"a"},
},
{
name: "selectable passes are appended after always-included passes",
enabled: map[string]struct{}{"b": {}},
passes: []di.Pass{makePass("a")},
selectablePasses: []di.Pass{makePass("b")},
wantNames: []string{"a", "b"},
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
cfg := Config{EnabledPasses: tc.enabled}
result, err := cfg.resolvePasses(tc.passes, tc.selectablePasses)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result) != len(tc.wantNames) {
t.Fatalf("got %d passes, want %d", len(result), len(tc.wantNames))
}
for i, p := range result {
if p.Name() != tc.wantNames[i] {
t.Errorf("pass[%d] name = %q, want %q", i, p.Name(), tc.wantNames[i])
}
}
})
}
}

func TestConfig_ResolvePasses_Errors(t *testing.T) {
cases := []struct {
name string
enabled map[string]struct{}
passes []di.Pass
selectablePasses []di.Pass
}{
{
name: "unknown name in --enable-pass",
enabled: map[string]struct{}{"unknown": {}},
selectablePasses: []di.Pass{makePass("foo")},
},
{
name: "always-included pass is not selectable",
enabled: map[string]struct{}{"foo": {}},
passes: []di.Pass{makePass("foo")},
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
cfg := Config{EnabledPasses: tc.enabled}
_, err := cfg.resolvePasses(tc.passes, tc.selectablePasses)
if err == nil {
t.Error("expected error, got nil")
}
})
}
}
2 changes: 1 addition & 1 deletion cmd/gendi/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ import (
)

func main() {
cmd.MustRun(flag.CommandLine, nil)
cmd.MustRun(flag.CommandLine, nil, cmd.BuiltinSelectablePasses())
}
31 changes: 31 additions & 0 deletions cmd/string_set_flag.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package cmd

import (
"sort"
"strings"
)

// stringSetFlag is a flag.Value that collects multiple values into a set.
type stringSetFlag struct {
values *map[string]struct{}
}

func (f *stringSetFlag) String() string {
if f.values == nil || *f.values == nil {
return ""
}
values := make([]string, 0, len(*f.values))
for value := range *f.values {
values = append(values, value)
}
sort.Strings(values)
return strings.Join(values, ",")
}

func (f *stringSetFlag) Set(s string) error {
if *f.values == nil {
*f.values = make(map[string]struct{})
}
(*f.values)[s] = struct{}{}
return nil
}
39 changes: 39 additions & 0 deletions cmd/string_set_flag_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package cmd

import "testing"

func TestStringSetFlag_String(t *testing.T) {
cases := []struct {
name string
m map[string]struct{}
want string
}{
{"nil map", nil, ""},
{"multiple values sorted", map[string]struct{}{"beta": {}, "alpha": {}, "gamma": {}}, "alpha,beta,gamma"},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
f := &stringSetFlag{values: &tc.m}
if got := f.String(); got != tc.want {
t.Errorf("got %q, want %q", got, tc.want)
}
})
}
}

func TestStringSetFlag_Set(t *testing.T) {
var m map[string]struct{}
f := &stringSetFlag{values: &m}
f.Set("a")
f.Set("b")
f.Set("a") // duplicate
want := map[string]struct{}{"a": {}, "b": {}}
if got := len(*f.values); got != len(want) {
t.Fatalf("len = %d, want %d", got, len(want))
}
for k := range want {
if _, ok := (*f.values)[k]; !ok {
t.Errorf("missing key %q", k)
}
}
}
2 changes: 1 addition & 1 deletion doc/LLM.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Short, stable facts for tooling and assistants.
go tool gendi --config=gendi.yaml --out=./di --pkg=di
```

Flags: `--config`, `--out`, `--pkg`, `--container`, `--strict`, `--build-tags`, `--verbose`.
Flags: `--config`, `--out`, `--pkg`, `--container`, `--strict`, `--build-tags`, `--enable-pass`, `--verbose`.

## YAML Syntax

Expand Down
Loading