From be3cf9d0fdcdb722dcc0d46a8503e54c408f4da7 Mon Sep 17 00:00:00 2001 From: Jason Lunz Date: Fri, 5 Jun 2026 22:11:56 +0200 Subject: [PATCH 1/6] Remove MemoryWatch from go-pipe Move the memory-watching API and Linux process-tree RSS reader out of go-pipe so the library no longer owns gitrpcd-specific memory policy. Expose commandStage.Process as the minimal hook needed by external memory watchers, while keeping commandStage.Kill for process-group teardown and context cancellation. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- internal/ptree/ptree.go | 303 --------------------------------- internal/ptree/ptree_linux.go | 21 --- internal/ptree/ptree_test.go | 260 ---------------------------- pipe/command.go | 15 +- pipe/command_linux.go | 20 --- pipe/memorylimit.go | 273 ----------------------------- pipe/memorylimit_panic_test.go | 156 ----------------- pipe/memorylimit_test.go | 278 ------------------------------ 8 files changed, 12 insertions(+), 1314 deletions(-) delete mode 100644 internal/ptree/ptree.go delete mode 100644 internal/ptree/ptree_linux.go delete mode 100644 internal/ptree/ptree_test.go delete mode 100644 pipe/command_linux.go delete mode 100644 pipe/memorylimit.go delete mode 100644 pipe/memorylimit_panic_test.go delete mode 100644 pipe/memorylimit_test.go diff --git a/internal/ptree/ptree.go b/internal/ptree/ptree.go deleted file mode 100644 index 8ce22da..0000000 --- a/internal/ptree/ptree.go +++ /dev/null @@ -1,303 +0,0 @@ -// Package ptree contains utilities for dealing with Linux process trees. -package ptree - -import ( - "bytes" - "errors" - "io" - "os" - "strconv" - "strings" - "sync" -) - -const ( - // initialReadBufSize is the starting capacity of the buffer used for - // reading /proc files. /proc//status is typically ~1 KiB and the - // per-task "children" files are smaller, so 4 KiB covers the common - // case in a single read. - initialReadBufSize = 4 * 1024 -) - -var errNoRss = errors.New("RssAnon was not found") - -type ProcessTree struct { - path string -} - -func NewProcessTree(path string) ProcessTree { - return ProcessTree{ - path: path, - } -} - -// readBufPool reuses the byte buffer that holds the contents of a /proc file -// across calls to getProcessRSSAnon / walkChildrenFile, so that the -// per-poll work doesn't allocate (and then garbage-collect) a fresh buffer -// for every process in the tree. -var readBufPool = sync.Pool{ - New: func() any { - b := make([]byte, 0, initialReadBufSize) - return &b - }, -} - -// readProcFile reads all of path into a buffer borrowed from readBufPool. -// -// On success, the returned slice is only valid until bufPtr is returned to -// the pool, which the caller MUST do (typically with -// `defer readBufPool.Put(bufPtr)`, placed after the error check). -// -// On error, bufPtr is nil and the buffer has already been released, so the -// caller must not Put it back. -// -// Compared to os.ReadFile, this skips the (useless for /proc) Stat call -// used to pre-size the buffer and reuses the buffer across calls. It also -// returns the underlying *[]byte rather than a closure so the release path -// allocates nothing. -func readProcFile(path string) (data []byte, bufPtr *[]byte, err error) { - f, err := os.Open(path) - if err != nil { - return nil, nil, err - } - defer f.Close() - - bufPtr = readBufPool.Get().(*[]byte) - buf := (*bufPtr)[:0] - for { - if len(buf) == cap(buf) { - // Grow via append; this only allocates if the pooled - // buffer was too small. Subsequent calls will reuse - // the grown buffer because we store it back below. - buf = append(buf, 0)[:len(buf)] - } - n, rerr := f.Read(buf[len(buf):cap(buf)]) - buf = buf[:len(buf)+n] - if rerr == io.EOF { - break - } - if rerr != nil { - *bufPtr = buf - readBufPool.Put(bufPtr) - return nil, nil, rerr - } - if n == 0 { - // Defensive: io.Reader allows (0, nil) returns; treat - // as EOF rather than spinning. Real *os.File on /proc - // shouldn't hit this, but mocks or future runtime - // behavior might. - break - } - } - *bufPtr = buf - return buf, bufPtr, nil -} - -// Return the RSSAnon of a single process `pid`. -func (pt ProcessTree) GetProcessRSSAnon(pid int) (uint64, error) { - status := pt.path + "/" + strconv.Itoa(pid) + "/status" - data, bufPtr, err := readProcFile(status) - if os.IsNotExist(err) { - // process is already gone - return 0, nil - } - if err != nil { - return 0, err - } - defer readBufPool.Put(bufPtr) - - prefix := []byte("RssAnon:") - rest := data - for len(rest) > 0 { - var line []byte - if nl := bytes.IndexByte(rest, '\n'); nl >= 0 { - line, rest = rest[:nl], rest[nl+1:] - } else { - line, rest = rest, nil - } - // Fast prefix check before paying for the string conversion. - if !bytes.HasPrefix(line, prefix) { - continue - } - if rss, ok := ParseRSSAnon(string(line)); ok { - return rss, nil - } - } - return 0, errNoRss -} - -// Return the total RSS of the tree of processes rooted at `pid`. -// -// If the passed root pid is that of a kernel thread, as a special case, we -// return zero and no error. -// -// Errors encountered while walking the children are ignored, since it can -// change while traversing it. -func (pt ProcessTree) GetProcessTreeRSSAnon(pid int) (uint64, error) { - total, err := pt.GetProcessRSSAnon(pid) - if err != nil { - if err == errNoRss { - // these are typically kernel threads, which don't have an address space to measure - return 0, nil - } - return 0, err - } - - pt.WalkChildren(pid, func(pid int) { - mem, err := pt.GetProcessRSSAnon(pid) - if err != nil { - return - } - total += mem - }) - - return total, nil -} - -func (pt ProcessTree) WalkChildren(pid int, walkFn func(int)) { - pt.walkChildPids(pid, walkFn, map[int]bool{pid: true}) -} - -func (pt ProcessTree) walkChildPids(pid int, walkFn func(int), visited map[int]bool) { - // List the per-thread directories under /proc//task and read each - // task's "children" file directly. This avoids filepath.Glob, which - // would Stat every match on top of the readdir we already need. - taskDir := pt.path + "/" + strconv.Itoa(pid) + "/task" - entries, err := os.ReadDir(taskDir) - if err != nil { - return - } - - for _, entry := range entries { - // task/ should only contain numeric TID directories. Skip - // anything else defensively; this mirrors the implicit - // filtering that filepath.Glob("*/children") provided. - // A byte-range check avoids the error allocation that - // strconv.Atoi would incur for non-numeric names. - if !isAllDigits(entry.Name()) { - continue - } - pt.walkChildrenFile(taskDir+"/"+entry.Name()+"/children", walkFn, visited) - } -} - -func (pt ProcessTree) walkChildrenFile(filename string, walkFn func(int), visited map[int]bool) { - data, bufPtr, err := readProcFile(filename) - if err != nil { - return - } - defer readBufPool.Put(bufPtr) - - // children is a whitespace-separated list of decimal PIDs. Parse it in - // place to avoid the string(data) conversion and the []string allocated - // by strings.Fields. - i := 0 - for i < len(data) { - for i < len(data) && isASCIISpace(data[i]) { - i++ - } - if i >= len(data) { - return - } - pid := 0 - start := i - for i < len(data) && data[i] >= '0' && data[i] <= '9' { - pid = pid*10 + int(data[i]-'0') - i++ - } - if i == start { - // Not a digit; skip until next whitespace to stay in sync. - for i < len(data) && !isASCIISpace(data[i]) { - i++ - } - continue - } - if i-start > 10 { - // Realistic Linux PIDs fit in well under 10 digits - // (PID_MAX is 2^22). A longer digit run can't be a - // real PID and would risk silently overflowing the - // int accumulator, so skip it. - continue - } - if visited[pid] { - continue - } - walkFn(pid) - visited[pid] = true - pt.walkChildPids(pid, walkFn, visited) - } -} - -// parseRSSAnon parses an "RssAnon" line from /proc/*/status and returns the size. -// The entire line should be passed in, with or without the line ending. If the -// line looks like "RssAnon: 1234 kB", the byte size will be returned. If the -// line isn't parseable, (0, false) will be returned. -func ParseRSSAnon(s string) (uint64, bool) { - const prefix = "RssAnon:" - if !strings.HasPrefix(s, prefix) { - return 0, false - } - s = s[len(prefix):] - - // Optional whitespace before the number. - i := 0 - for i < len(s) && isASCIISpace(s[i]) { - i++ - } - - // One or more digits. - digitsStart := i - for i < len(s) && s[i] >= '0' && s[i] <= '9' { - i++ - } - if i == digitsStart { - return 0, false - } - kb, err := strconv.ParseUint(s[digitsStart:i], 10, 64) - if err != nil { - return 0, false - } - - // At least one whitespace between the number and "kB". - if i >= len(s) || !isASCIISpace(s[i]) { - return 0, false - } - for i < len(s) && isASCIISpace(s[i]) { - i++ - } - - // Literal "kB", then either end-of-string or whitespace. - const unit = "kB" - if !strings.HasPrefix(s[i:], unit) { - return 0, false - } - i += len(unit) - if i < len(s) && !isASCIISpace(s[i]) { - return 0, false - } - return kb * 1024, true -} - -// isASCIISpace matches the character class that Go's regexp engine uses for -// \s in non-Unicode mode: [\t\n\f\r ]. -func isASCIISpace(b byte) bool { - switch b { - case ' ', '\t', '\n', '\f', '\r': - return true - } - return false -} - -// isAllDigits reports whether s is non-empty and consists entirely of ASCII -// decimal digits. Used as a cheap allocation-free numeric-name filter. -func isAllDigits(s string) bool { - if len(s) == 0 { - return false - } - for i := 0; i < len(s); i++ { - if s[i] < '0' || s[i] > '9' { - return false - } - } - return true -} diff --git a/internal/ptree/ptree_linux.go b/internal/ptree/ptree_linux.go deleted file mode 100644 index 7019693..0000000 --- a/internal/ptree/ptree_linux.go +++ /dev/null @@ -1,21 +0,0 @@ -package ptree - -var DefaultProcessTree = ProcessTree{ - path: "/proc", -} - -// Walk the child processes of the specified root process. walkFn will be called -// for each child found. It will not be called for the root process. Any errors -// will be ignored, since they may be just a consequence of the process tree -// changing during traversal. -func WalkChildren(pid int, walkFn func(int)) { - DefaultProcessTree.WalkChildren(pid, walkFn) -} - -func GetProcessRSSAnon(pid int) (uint64, error) { - return DefaultProcessTree.GetProcessRSSAnon(pid) -} - -func GetProcessTreeRSSAnon(pid int) (uint64, error) { - return DefaultProcessTree.GetProcessTreeRSSAnon(pid) -} diff --git a/internal/ptree/ptree_test.go b/internal/ptree/ptree_test.go deleted file mode 100644 index 5c014c2..0000000 --- a/internal/ptree/ptree_test.go +++ /dev/null @@ -1,260 +0,0 @@ -package ptree_test - -import ( - "bytes" - "fmt" - "os" - "path/filepath" - "sort" - "strconv" - "testing" - - "github.com/github/go-pipe/v2/internal/ptree" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// writeStatus creates only what GetProcessRSSAnon reads: //status. -// If rssKB is zero, the RssAnon line is omitted (mimicking kernel threads). -func writeStatus(t testing.TB, root string, pid int, rssKB uint64) { - t.Helper() - pidDir := filepath.Join(root, strconv.Itoa(pid)) - require.NoError(t, os.MkdirAll(pidDir, 0o755)) - var status string - if rssKB > 0 { - status = fmt.Sprintf("Name:\tfake\nRssAnon:\t%d kB\nVmSize:\t1000 kB\n", rssKB) - } else { - status = "Name:\tfake\nVmSize:\t1000 kB\n" - } - require.NoError(t, os.WriteFile(filepath.Join(pidDir, "status"), []byte(status), 0o600)) -} - -// writeChildren creates //task//children containing the -// space-separated child pids. Only call this for processes that actually -// have children; getProcessTreeRSSAnon copes fine with the task/ directory -// being absent for leaves. -func writeChildren(t testing.TB, root string, pid int, children []int) { - t.Helper() - writeThreadChildren(t, root, pid, pid, children) -} - -// writeThreadChildren creates //task//children for a specific -// thread id, so tests can exercise multi-threaded processes. -func writeThreadChildren(t testing.TB, root string, pid, tid int, children []int) { - t.Helper() - taskDir := filepath.Join(root, strconv.Itoa(pid), "task", strconv.Itoa(tid)) - require.NoError(t, os.MkdirAll(taskDir, 0o755)) - var buf bytes.Buffer - for _, c := range children { - fmt.Fprintf(&buf, "%d ", c) - } - require.NoError(t, os.WriteFile(filepath.Join(taskDir, "children"), buf.Bytes(), 0o600)) -} - -func TestGetProcessRSSAnon(t *testing.T) { - const kb = 1024 - root := t.TempDir() - writeStatus(t, root, 100, 15032) - writeStatus(t, root, 101, 0) // kernel-thread-like: no RssAnon line - - pt := ptree.NewProcessTree(root) - - t.Run("reads RssAnon", func(t *testing.T) { - rss, err := pt.GetProcessRSSAnon(100) - require.NoError(t, err) - assert.Equal(t, uint64(15032*kb), rss) - }) - - t.Run("missing RssAnon line returns an error", func(t *testing.T) { - _, err := pt.GetProcessRSSAnon(101) - assert.Error(t, err) - }) - - t.Run("missing pid returns (0, nil)", func(t *testing.T) { - // A process that has already exited disappears from /proc; the function - // treats that as a non-error zero. - rss, err := pt.GetProcessRSSAnon(999) - require.NoError(t, err) - assert.Equal(t, uint64(0), rss) - }) -} - -func TestGetProcessTreeRSSAnon(t *testing.T) { - const kb = 1024 - - t.Run("leaf process returns its own RssAnon", func(t *testing.T) { - root := t.TempDir() - writeStatus(t, root, 100, 1000) - - pt := ptree.NewProcessTree(root) - - total, err := pt.GetProcessTreeRSSAnon(100) - require.NoError(t, err) - assert.Equal(t, uint64(1000*kb), total) - }) - - t.Run("sums root and descendants", func(t *testing.T) { - // 100 -> {101 -> 103, 102} - root := t.TempDir() - writeStatus(t, root, 100, 1000) - writeStatus(t, root, 101, 200) - writeStatus(t, root, 102, 50) - writeStatus(t, root, 103, 7) - writeChildren(t, root, 100, []int{101, 102}) - writeChildren(t, root, 101, []int{103}) - - pt := ptree.NewProcessTree(root) - - total, err := pt.GetProcessTreeRSSAnon(100) - require.NoError(t, err) - assert.Equal(t, uint64((1000+200+50+7)*kb), total) - }) - - t.Run("kernel-thread root returns (0, nil)", func(t *testing.T) { - // Root has no RssAnon line; the function maps errNoRss to (0, nil). - root := t.TempDir() - writeStatus(t, root, 100, 0) - - pt := ptree.NewProcessTree(root) - - total, err := pt.GetProcessTreeRSSAnon(100) - require.NoError(t, err) - assert.Equal(t, uint64(0), total) - }) -} - -func TestWalkChildren(t *testing.T) { - t.Run("walks all descendants", func(t *testing.T) { - // 100 -> {101, 102 -> 103}. Verifies the callback fires for - // every descendant (not just direct children) and is not - // invoked for the root. - root := t.TempDir() - writeChildren(t, root, 100, []int{101, 102}) - writeChildren(t, root, 102, []int{103}) - - pt := ptree.NewProcessTree(root) - - var seen []int - pt.WalkChildren(100, func(pid int) { seen = append(seen, pid) }) - sort.Ints(seen) - assert.Equal(t, []int{101, 102, 103}, seen) - }) - - t.Run("iterates every thread under task/ and dedups", func(t *testing.T) { - // 100 has two threads (100 and 200); each thread reports a - // different set of children, with 102 listed by both threads - // to exercise the visited dedup. - root := t.TempDir() - writeThreadChildren(t, root, 100, 100, []int{101, 102}) - writeThreadChildren(t, root, 100, 200, []int{102, 103}) - - pt := ptree.NewProcessTree(root) - - var seen []int - pt.WalkChildren(100, func(pid int) { seen = append(seen, pid) }) - sort.Ints(seen) - assert.Equal(t, []int{101, 102, 103}, seen) - }) -} - -// BenchmarkGetProcessTreeRSSAnon measures the cost of a single poll over a -// small process tree (a root plus a few direct children). -func BenchmarkGetProcessTreeRSSAnon(b *testing.B) { - const rootPid = 100 - root := b.TempDir() - writeStatus(b, root, rootPid, 1000) - children := []int{101, 102, 103} - writeChildren(b, root, rootPid, children) - for _, c := range children { - writeStatus(b, root, c, 200) - } - - pt := ptree.NewProcessTree(root) - - b.ReportAllocs() - b.ResetTimer() - for i := 0; i < b.N; i++ { - _, err := pt.GetProcessTreeRSSAnon(rootPid) - if err != nil { - b.Fatal(err) - } - } -} - -func TestParseRss(t *testing.T) { - const kb = 1024 - - okExamples := []struct { - input string - result uint64 - }{ - { - input: "RssAnon:\t 15032 kB", - result: 15032 * kb, - }, - { - input: "RssAnon:\t99915032 kB", - result: 99915032 * kb, - }, - { - input: "RssAnon:\t 1 kB", - result: kb, - }, - // Exactly what the kernel emits via SEQ_PUT_DEC: "RssAnon:\t" + - // 8-wide right-justified decimal + " kB\n". See fs/proc/task_mmu.c - // (task_mem). The trailing newline must be tolerated. - { - input: "RssAnon:\t 15032 kB\n", - result: 15032 * kb, - }, - // A value wider than the 8-char padding (no leading spaces). - { - input: "RssAnon:\t12345678 kB\n", - result: 12345678 * kb, - }, - { - input: "RssAnon:\t 0 kB\n", - result: 0, - }, - } - - for _, example := range okExamples { - rss, ok := ptree.ParseRSSAnon(example.input) - if assert.Truef(t, ok, "should be able to parse %q", example.input) { - assert.Equalf(t, example.result, rss, "value of %q", example.input) - } - } - - badExamples := []string{ - "", - "\n", - "RssAnon:\t 123", - "RssAnonn:\t 123 kB", - "RssAno:\t 123 kB", - "Blah:\t 123 kB", - "Blah:", - "123", - } - - for _, example := range badExamples { - _, ok := ptree.ParseRSSAnon(example) - assert.Falsef(t, ok, "should not be able to parse %q", example) - } -} - -func BenchmarkParseRss(b *testing.B) { - b.Run("match", func(b *testing.B) { - for i := 0; i < b.N; i++ { - rss, ok := ptree.ParseRSSAnon("RssAnon:\t 15032 kB") - require.True(b, ok) - require.EqualValues(b, 15032*1024, rss) - } - }) - - b.Run("no match", func(b *testing.B) { - for i := 0; i < b.N; i++ { - _, ok := ptree.ParseRSSAnon("Other:\t 15032 kB") - require.False(b, ok) - } - }) -} diff --git a/pipe/command.go b/pipe/command.go index 4dbe82b..ed12f9b 100644 --- a/pipe/command.go +++ b/pipe/command.go @@ -15,8 +15,6 @@ import ( "golang.org/x/sync/errgroup" ) -var errProcessInfoMissing = errors.New("cmd.Process is nil") - // commandStage is a pipeline `Stage` based on running an external // command and piping the data through its stdin and stdout. type commandStage struct { @@ -37,9 +35,16 @@ type commandStage struct { } var ( - _ Stage = (*commandStage)(nil) + _ Stage = (*commandStage)(nil) + _ processProvider = (*commandStage)(nil) ) +// processProvider is the hook external memory-watchers use to find the running +// process so they can sample its RSS. +type processProvider interface { + Process() *os.Process +} + // Command returns a pipeline `Stage` based on the specified external // `command`, run with the given command-line `args`. Its stdin and // stdout are handled as usual, and its stderr is collected and @@ -69,6 +74,10 @@ func (s *commandStage) Name() string { return s.name } +func (s *commandStage) Process() *os.Process { + return s.cmd.Process +} + func (s *commandStage) Requirements() StageRequirements { return StageRequirements{ StdinNeedsFile: true, diff --git a/pipe/command_linux.go b/pipe/command_linux.go deleted file mode 100644 index 987bffb..0000000 --- a/pipe/command_linux.go +++ /dev/null @@ -1,20 +0,0 @@ -//go:build linux - -package pipe - -import ( - "context" - - "github.com/github/go-pipe/v2/internal/ptree" -) - -// On linux, we can limit or observe memory usage in command stages. -var _ LimitableStage = (*commandStage)(nil) - -func (s *commandStage) GetRSSAnon(_ context.Context) (uint64, error) { - if s.cmd.Process == nil { - return 0, errProcessInfoMissing - } - - return ptree.GetProcessTreeRSSAnon(s.cmd.Process.Pid) -} diff --git a/pipe/memorylimit.go b/pipe/memorylimit.go deleted file mode 100644 index 8a02400..0000000 --- a/pipe/memorylimit.go +++ /dev/null @@ -1,273 +0,0 @@ -package pipe - -import ( - "context" - "errors" - "fmt" - "io" - "sync" - "time" -) - -const memoryPollInterval = time.Second - -// ErrMemoryLimitExceeded is the error that will be used to kill a -// process, if necessary, from a MemoryWatch with WithMemoryLimit. -var ErrMemoryLimitExceeded = errors.New("memory limit exceeded") - -// LimitableStage is the superset of `Stage` that must be implemented -// by stages passed to MemoryWatch. -type LimitableStage interface { - Stage - - GetRSSAnon(context.Context) (uint64, error) - Kill(error) -} - -// MemoryWatchOption configures a MemoryWatch stage. -type MemoryWatchOption func(*memoryWatchStage) - -// WithMemoryLimit makes MemoryWatch kill the stage when its RSS exceeds -// byteLimit. -func WithMemoryLimit(byteLimit uint64) MemoryWatchOption { - return func(m *memoryWatchStage) { - m.limit = &byteLimit - m.nameSuffix = " with memory limit" - } -} - -// WithPeakUsageLogging makes MemoryWatch log the peak RSS when the stage -// exits. -func WithPeakUsageLogging() MemoryWatchOption { - return func(m *memoryWatchStage) { - m.observe = true - } -} - -// MemoryWatch watches the memory usage of the stage and reports via -// eventHandler. With WithMemoryLimit it kills the stage when the limit is -// exceeded; with WithPeakUsageLogging it logs the peak RSS when the stage -// exits. At least one of the two options is required. -// -// If the event handler panics while reporting the over-limit event, the -// stage is still killed. A panic in any other event-handler call (an -// RSS-read error, or the peak-usage report) is recovered via -// StageOptions.PanicHandler and the stage keeps running unmonitored; see -// StageOptions.PanicHandler. -func MemoryWatch(stage Stage, eventHandler func(e *Event), opts ...MemoryWatchOption) Stage { - limitableStage, ok := stage.(LimitableStage) - if !ok { - eventHandler(&Event{ - Command: stage.Name(), - Msg: "invalid pipe.MemoryWatch usage", - Err: fmt.Errorf("invalid pipe.MemoryWatch usage"), - }) - return stage - } - - m := memoryWatchStage{ - stage: limitableStage, - eventHandler: eventHandler, - } - for _, opt := range opts { - opt(&m) - } - - if m.limit == nil && !m.observe { - eventHandler(&Event{ - Command: stage.Name(), - Msg: "invalid pipe.MemoryWatch usage", - Err: fmt.Errorf( - "pipe.MemoryWatch requires WithMemoryLimit and/or WithPeakUsageLogging", - ), - }) - return stage - } - - return &m -} - -type memoryWatchStage struct { - nameSuffix string - stage LimitableStage - eventHandler func(e *Event) - - limit *uint64 // non-nil enables kill-at-limit - observe bool // log peak RSS when the stage exits - - maxRSS uint64 - samples int - errCount int - consecutiveErrors int - - cancel context.CancelFunc - wg sync.WaitGroup - watchErr error -} - -var _ LimitableStage = (*memoryWatchStage)(nil) - -func (m *memoryWatchStage) Name() string { - return m.stage.Name() + m.nameSuffix -} - -func (m *memoryWatchStage) Requirements() StageRequirements { - return m.stage.Requirements() -} - -func (m *memoryWatchStage) Start( - ctx context.Context, opts StageOptions, - stdin io.Reader, closeStdin bool, - stdout io.Writer, closeStdout bool, -) error { - if err := m.stage.Start(ctx, opts, stdin, closeStdin, stdout, closeStdout); err != nil { - return err - } - - m.monitor(ctx, opts.PanicHandler) - - return nil -} - -func (m *memoryWatchStage) Wait() error { - err := m.stage.Wait() - m.stopWatching() - if err == nil { - err = m.watchErr // non-nil if panicHandler() returned anything - } - return err -} - -func (m *memoryWatchStage) GetRSSAnon(ctx context.Context) (uint64, error) { - return m.stage.GetRSSAnon(ctx) -} - -func (m *memoryWatchStage) Kill(err error) { - m.stage.Kill(err) - m.stopWatching() -} - -// monitor starts up a goroutine that monitors the memory of `m`. If -// panicHandler is set, any panic that escapes the user-supplied event handler -// (via m.watch) is recovered. -func (m *memoryWatchStage) monitor(ctx context.Context, panicHandler StagePanicHandler) { - ctx, cancel := context.WithCancel(ctx) - m.cancel = cancel - m.wg.Add(1) - - go func() { - defer m.wg.Done() - - if panicHandler != nil { - defer func() { - if p := recover(); p != nil { - m.watchErr = panicHandler(p) - } - }() - } - - m.watch(ctx) - }() -} - -func (m *memoryWatchStage) stopWatching() { - m.cancel() - m.wg.Wait() -} - -// watch is a `memoryWatchFunc` that watches the memory usage of the -// specified `stage`. -func (m *memoryWatchStage) watch(ctx context.Context) { - t := time.NewTicker(memoryPollInterval) - defer t.Stop() - -watchLoop: - for { - select { - case <-ctx.Done(): - break watchLoop - case <-t.C: - if m.update(ctx) { - // The stage was killed. - break watchLoop - } - } - } - - if m.observe { - <-ctx.Done() - m.reportPeakUsage() - } -} - -// update samples the current memory usage and updates internal stats. -// Return true if the stage was killed for exceeding the memory limit. -func (m *memoryWatchStage) update(ctx context.Context) bool { - rss, err := m.stage.GetRSSAnon(ctx) - if err != nil { - m.handleGetRSSError(err) - return false - } - - m.consecutiveErrors = 0 - m.samples++ - if rss > m.maxRSS { - m.maxRSS = rss - } - - if m.limit != nil && rss >= *m.limit { - m.killStage(rss) - return true - } - - return false -} - -// handleGetRSSError deals with error `err` that happened when trying -// to get `stage`'s memory usage. -func (m *memoryWatchStage) handleGetRSSError(err error) { - if !errors.Is(err, errProcessInfoMissing) { - m.errCount++ - m.consecutiveErrors++ - if m.consecutiveErrors == 2 { - m.eventHandler(&Event{ - Command: m.stage.Name(), - Msg: "error getting RSS", - Err: err, - }) - } - } else { - m.consecutiveErrors = 0 - } -} - -// killStage kills the stage and reports and event saying what it did. -func (m *memoryWatchStage) killStage(rss uint64) { - // Guarantee the over-limit stage is killed even if - // the user's event handler panics. - defer m.stage.Kill(ErrMemoryLimitExceeded) - - m.eventHandler(&Event{ - Command: m.stage.Name(), - Msg: "stage exceeded allowed memory use", - Err: fmt.Errorf("stage exceeded allowed memory use"), - Context: map[string]any{ - "limit": *m.limit, - "used": rss, - }, - }) -} - -// reportPeakUsage sends an event reporting the peak usage that has -// been seen for `stage`. -func (m *memoryWatchStage) reportPeakUsage() { - m.eventHandler(&Event{ - Command: m.stage.Name(), - Msg: "peak memory usage", - Context: map[string]any{ - "max_rss_bytes": m.maxRSS, - "samples": m.samples, - "errors": m.errCount, - }, - }) -} diff --git a/pipe/memorylimit_panic_test.go b/pipe/memorylimit_panic_test.go deleted file mode 100644 index 1c0c32e..0000000 --- a/pipe/memorylimit_panic_test.go +++ /dev/null @@ -1,156 +0,0 @@ -package pipe - -import ( - "context" - "fmt" - "io" - "strings" - "testing" - "time" - - "github.com/stretchr/testify/assert" -) - -const memWatchPanicSentinel = "memwatch-panic-sentinel" - -// fakeLimitableStage is a minimal LimitableStage whose `GetRSSAnon()` -// method panics, and whose `Wait()` method returns after that panic -// has been issued. -type fakeLimitableStage struct { - done chan struct{} -} - -func (fakeLimitableStage) Name() string { return "fake" } -func (fakeLimitableStage) Requirements() StageRequirements { return StageRequirements{} } -func (fakeLimitableStage) Start( - context.Context, StageOptions, io.Reader, bool, io.Writer, bool, -) error { - return nil -} -func (stage fakeLimitableStage) Wait() error { - <-stage.done - return nil -} - -func (stage fakeLimitableStage) GetRSSAnon(context.Context) (uint64, error) { - close(stage.done) - panic(memWatchPanicSentinel) -} - -func (fakeLimitableStage) Kill(error) {} - -func panickingWatchStage() Stage { - stage := fakeLimitableStage{ - done: make(chan struct{}), - } - return MemoryWatch(stage, func(*Event) {}, WithMemoryLimit(1)) -} - -// TestMemoryWatchStagePanicWithHandlerSurfaced verifies that a panic -// escaping the memory-watch goroutine (where the user-supplied event -// handler runs) is recovered via the configured panic handler and -// surfaced as the stage's Wait error. -func TestMemoryWatchStagePanicWithHandlerSurfaced(t *testing.T) { - ms := panickingWatchStage() - opts := StageOptions{ - PanicHandler: func(p any) error { return fmt.Errorf("recovered: %v", p) }, - } - - if err := ms.Start(context.Background(), opts, nil, false, nil, false); err != nil { - t.Fatalf("Start returned unexpected error: %v", err) - } - - err := ms.Wait() - if err == nil { - t.Fatal("expected Wait to surface the recovered panic, got nil") - } - if !strings.Contains(err.Error(), memWatchPanicSentinel) { - t.Fatalf("expected error to mention %q, got: %v", memWatchPanicSentinel, err) - } -} - -// TestMemoryWatchStagePanicWithoutHandlerPropagates verifies that the -// memory-watch sampling path does not swallow a panic. The monitor -// goroutine only installs a recover when a handler is present (see -// memoryWatchStage.monitor), so we exercise update() directly and assert -// that it propagates the panic. update() is used rather than watch() so the -// assertion is synchronous and ticker-independent: a regression that stopped -// the panic would fail the test rather than hang on the ticker loop. -func TestMemoryWatchStagePanicWithoutHandlerPropagates(t *testing.T) { - limit := uint64(1) - mw := memoryWatchStage{ - stage: fakeLimitableStage{done: make(chan struct{})}, - eventHandler: func(*Event) {}, - limit: &limit, - } - - assert.PanicsWithValue(t, memWatchPanicSentinel, func() { - mw.update(context.Background()) - }) -} - -// killTrackingStage is a LimitableStage that reports an over-limit RSS -// and blocks in Wait until it is killed, recording that the kill -// happened. It lets a test assert that the memory limit is enforced. -type killTrackingStage struct { - killed chan struct{} - done chan struct{} -} - -func newKillTrackingStage() *killTrackingStage { - return &killTrackingStage{ - killed: make(chan struct{}), - done: make(chan struct{}), - } -} - -func (*killTrackingStage) Name() string { return "kill-tracking" } -func (*killTrackingStage) Requirements() StageRequirements { return StageRequirements{} } -func (*killTrackingStage) Start( - context.Context, StageOptions, io.Reader, bool, io.Writer, bool, -) error { - return nil -} -func (s *killTrackingStage) Wait() error { <-s.done; return ErrMemoryLimitExceeded } -func (*killTrackingStage) GetRSSAnon(context.Context) (uint64, error) { - return 1 << 30, nil -} - -func (s *killTrackingStage) Kill(error) { - select { - case <-s.killed: - // already killed - default: - close(s.killed) - close(s.done) - } -} - -// TestMemoryLimitKillsEvenIfEventHandlerPanics verifies that an over-limit -// stage is still killed (the limit enforced) even when the user's event -// handler panics and that panic is recovered by the configured handler. -// Without the kill being guaranteed during unwinding, the runaway stage -// would never be killed and Wait would hang. -func TestMemoryLimitKillsEvenIfEventHandlerPanics(t *testing.T) { - stage := newKillTrackingStage() - eventHandler := func(*Event) { panic(memWatchPanicSentinel) } - ms := MemoryWatch(stage, eventHandler, WithMemoryLimit(1)) - opts := StageOptions{ - PanicHandler: func(p any) error { return fmt.Errorf("recovered: %v", p) }, - } - - if err := ms.Start(context.Background(), opts, nil, false, nil, false); err != nil { - t.Fatalf("Start returned unexpected error: %v", err) - } - - select { - case <-stage.killed: - // expected: the limit was enforced despite the handler panic. - case <-time.After(5 * time.Second): - t.Fatal("over-limit stage was not killed after the event handler panicked") - } - - if err := ms.Wait(); err != ErrMemoryLimitExceeded { - t.Fatalf("Wait = %v, want %v", err, ErrMemoryLimitExceeded) - } -} diff --git a/pipe/memorylimit_test.go b/pipe/memorylimit_test.go deleted file mode 100644 index 4e46945..0000000 --- a/pipe/memorylimit_test.go +++ /dev/null @@ -1,278 +0,0 @@ -package pipe_test - -import ( - "bytes" - "context" - "fmt" - "io" - "log" - "os" - "strings" - "syscall" - "testing" - "time" - - "github.com/github/go-pipe/v2/pipe" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func LogEventHandler(logger *log.Logger) func(*pipe.Event) { - return func(e *pipe.Event) { - ctx := "" - for k, v := range e.Context { - ctx = fmt.Sprintf("%s,%s=%v", ctx, k, v) - } - - logger.Printf("Command %s failed with message %s and error %s. Context: %s", e.Command, e.Msg, e.Command, ctx) - } -} - -func TestMemoryObserverSimple(t *testing.T) { - t.Parallel() - rss := testMemoryObserver(t, 400, pipe.Command("less")) - require.Greater(t, rss, 400_000_000) -} - -func TestMemoryObserverTreeMem(t *testing.T) { - t.Parallel() - - // create a process tree like: - // /tmp/go-build3166037414/b001/pipe.test -test.paniconexit0 -test.timeout=10m0s -test.count=1 -test.run=Tree -test.v=true - // \_ head -c 3G - // \_ sh -c less; : - // \_ less - // so that MemoryObserver is watching the parent `sh` proc and doesn't detect less's mem usage. - // less should buffer whatever we send it via stdin, giving us some level of control over its - // memory usage. - rss := testMemoryObserver(t, 400, pipe.Command("sh", "-c", "less; :")) - require.Greater(t, rss, 400_000_000) -} - -func testMemoryObserver(t *testing.T, mbs int, stage pipe.Stage) int { - ctx := context.Background() - - stdinReader, stdinWriter := io.Pipe() - - devNull, err := os.OpenFile("/dev/null", os.O_WRONLY, 0) - require.NoError(t, err) - - buf := &bytes.Buffer{} - logger := log.New(buf, "testMemoryObserver", log.Ldate|log.Ltime) - - p := pipe.New(pipe.WithDir("/"), pipe.WithStdin(stdinReader), pipe.WithStdout(devNull)) - p.Add(pipe.MemoryWatch(stage, LogEventHandler(logger), pipe.WithPeakUsageLogging())) - require.NoError(t, p.Start(ctx)) - - // Write some nonsense data to less, but don't close stdin until we want it - // to exit. - var bytes [1_000_000]byte - for i := 0; i < mbs; i++ { - n, err := stdinWriter.Write(bytes[:]) - require.NoError(t, err) - require.Equal(t, len(bytes), n) - } - - // MemoryObserver polls every one second, so this should make sure we catch - // at least one. - time.Sleep(2 * time.Second) - - // Close stdin and wait for the pipeline to exit. - require.NoError(t, stdinWriter.Close()) - require.NoError(t, p.Wait()) - - return maxBytes(buf.String()) -} - -func maxBytes(s string) int { - idx := strings.Index(s, "max_rss_bytes=") - if idx < 0 { - return idx - } - var maxRSS int - n, err := fmt.Sscanf(s[idx:], "max_rss_bytes=%d", &maxRSS) - if n != 1 || err != nil { - return -1 - } - return maxRSS -} - -func TestMemoryLimitSimple(t *testing.T) { - t.Parallel() - msg, err := testMemoryLimit(t, 400, 10_000_000, pipe.Command("less")) - assert.Contains(t, msg, "exceeded allowed memory") - assert.Contains(t, msg, "limit=10000000") - require.ErrorContains(t, err, "memory limit exceeded") -} - -func TestMemoryLimitTreeMem(t *testing.T) { - t.Parallel() - msg, err := testMemoryLimit(t, 400, 10_000_000, pipe.Command("sh", "-c", "less; :")) - assert.Contains(t, msg, "exceeded allowed memory") - assert.Contains(t, msg, "limit=10000000") - require.ErrorContains(t, err, "memory limit exceeded") -} - -func TestMemoryLimitWithObserverSimple(t *testing.T) { - t.Parallel() - msg, err := testMemoryLimitWithObserver(t, 400, 10_000_000, pipe.Command("less")) - assert.Contains(t, msg, "exceeded allowed memory") - assert.Contains(t, msg, "limit=10000000") - require.ErrorContains(t, err, "memory limit exceeded") -} - -func TestMemoryLimitWithObserverTreeMem(t *testing.T) { - t.Parallel() - msg, err := testMemoryLimitWithObserver(t, 400, 10_000_000, pipe.Command("sh", "-c", "less; :")) - assert.Contains(t, msg, "exceeded allowed memory") - assert.Contains(t, msg, "limit=10000000") - require.ErrorContains(t, err, "memory limit exceeded") -} - -func TestMemoryLimitWithObserverLogsPeakOnKill(t *testing.T) { - t.Parallel() - msg, err := testMemoryLimitWithObserver(t, 400, 10_000_000, pipe.Command("less")) - assert.Contains(t, msg, "exceeded allowed memory") - assert.Contains(t, msg, "peak memory usage") - require.ErrorContains(t, err, "memory limit exceeded") -} - -func TestMemoryLimitWithObserverBelowLimit(t *testing.T) { - t.Parallel() - rss := testMemoryLimitWithObserverBelowLimit(t, 400, pipe.Command("less")) - require.Greater(t, rss, 400_000_000) -} - -func TestMemoryLimitWithObserverBelowLimitTreeMem(t *testing.T) { - t.Parallel() - rss := testMemoryLimitWithObserverBelowLimit(t, 400, pipe.Command("sh", "-c", "less; :")) - require.Greater(t, rss, 400_000_000) -} - -// testMemoryLimitWithObserverBelowLimit exercises the observer half of -// `MemoryLimitWithObserver` when the memory limit is never hit: with a -// 100GiB limit, less should never be killed, but the wrapper should -// still poll RSS and emit a "peak memory usage" event when the stage -// exits normally. Mirrors `testMemoryObserver` in structure — we hold -// stdin open across at least one poll interval so RSS samples are -// guaranteed to be taken before the stage is allowed to exit. -func testMemoryLimitWithObserverBelowLimit(t *testing.T, mbs int, stage pipe.Stage) int { - ctx := context.Background() - - stdinReader, stdinWriter := io.Pipe() - - devNull, err := os.OpenFile("/dev/null", os.O_WRONLY, 0) - require.NoError(t, err) - - buf := &bytes.Buffer{} - logger := log.New(buf, "testMemoryLimitWithObserverBelowLimit", log.Ldate|log.Ltime) - - p := pipe.New(pipe.WithDir("/"), pipe.WithStdin(stdinReader), pipe.WithStdout(devNull)) - p.Add(pipe.MemoryWatch(stage, LogEventHandler(logger), pipe.WithMemoryLimit(100*1024*1024*1024), pipe.WithPeakUsageLogging())) - require.NoError(t, p.Start(ctx)) - - var bytes [1_000_000]byte - for i := 0; i < mbs; i++ { - n, err := stdinWriter.Write(bytes[:]) - require.NoError(t, err) - require.Equal(t, len(bytes), n) - } - - // Wrapper polls once per second; sleep long enough to guarantee at - // least one sample is taken before the stage is allowed to exit. - time.Sleep(2 * time.Second) - - require.NoError(t, stdinWriter.Close()) - require.NoError(t, p.Wait()) - - output := buf.String() - assert.Contains(t, output, "peak memory usage") - - return maxBytes(output) -} - -func testMemoryLimit(t *testing.T, mbs int, limit uint64, stage pipe.Stage) (string, error) { - ctx := context.Background() - - devNull, err := os.OpenFile("/dev/null", os.O_WRONLY, 0) - require.NoError(t, err) - - buf := &bytes.Buffer{} - logger := log.New(buf, "testMemoryObserver", log.Ldate|log.Ltime) - - p := pipe.New(pipe.WithDir("/"), pipe.WithStdoutCloser(devNull)) - p.Add( - pipe.Function( - "write-to-less", - func(_ context.Context, _ pipe.Env, _ io.Reader, stdout io.Writer) error { - // Write some nonsense data to less. - var bytes [1_000_000]byte - for i := 0; i < mbs; i++ { - _, err := stdout.Write(bytes[:]) - if err != nil { - assert.ErrorIs(t, err, syscall.EPIPE) - return nil - } - } - - return nil - }, - ), - pipe.MemoryWatch(stage, LogEventHandler(logger), pipe.WithMemoryLimit(limit)), - ) - require.NoError(t, p.Start(ctx)) - - err = p.Wait() - - return buf.String(), err -} - -func testMemoryLimitWithObserver(t *testing.T, mbs int, limit uint64, stage pipe.Stage) (string, error) { - ctx := context.Background() - - devNull, err := os.OpenFile("/dev/null", os.O_WRONLY, 0) - require.NoError(t, err) - - buf := &bytes.Buffer{} - logger := log.New(buf, "testMemoryLimitWithObserver", log.Ldate|log.Ltime) - - p := pipe.New(pipe.WithDir("/"), pipe.WithStdoutCloser(devNull)) - p.Add( - pipe.Function( - "write-to-less", - func(_ context.Context, _ pipe.Env, _ io.Reader, stdout io.Writer) error { - var bytes [1_000_000]byte - for i := 0; i < mbs; i++ { - _, err := stdout.Write(bytes[:]) - if err != nil { - assert.ErrorIs(t, err, syscall.EPIPE) - return nil - } - } - return nil - }, - ), - pipe.MemoryWatch(stage, LogEventHandler(logger), pipe.WithMemoryLimit(limit), pipe.WithPeakUsageLogging()), - ) - require.NoError(t, p.Start(ctx)) - - err = p.Wait() - - return buf.String(), err -} - -// TestMemoryWatchRequiresAnOption verifies that MemoryWatch without -// WithMemoryLimit or WithPeakUsageLogging is rejected: it reports an -// invalid-usage event and returns the stage unwrapped (no watcher). -func TestMemoryWatchRequiresAnOption(t *testing.T) { - stage := pipe.Command("true") - - var events []*pipe.Event - got := pipe.MemoryWatch(stage, func(e *pipe.Event) { - events = append(events, e) - }) - - require.Same(t, stage, got, "expected the input stage returned unwrapped") - require.Len(t, events, 1) - require.Contains(t, events[0].Msg, "invalid pipe.MemoryWatch usage") -} From 5af021e18c3ee9cb209b7b1af111977eee4e8255 Mon Sep 17 00:00:00 2001 From: Jason Lunz Date: Fri, 5 Jun 2026 22:25:57 +0200 Subject: [PATCH 2/6] Add stage-scoped environment wrapper Add WithExtraEnv so callers can append environment variables for one wrapped stage without changing the pipeline-wide environment. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pipe/env_stage.go | 42 +++++++++++++++++++++++++++ pipe/env_stage_test.go | 65 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 107 insertions(+) create mode 100644 pipe/env_stage.go create mode 100644 pipe/env_stage_test.go diff --git a/pipe/env_stage.go b/pipe/env_stage.go new file mode 100644 index 0000000..60348ab --- /dev/null +++ b/pipe/env_stage.go @@ -0,0 +1,42 @@ +package pipe + +import ( + "context" + "io" +) + +// WithExtraEnv returns a Stage that adds env to the environment seen by inner. +func WithExtraEnv(inner Stage, env []EnvVar) Stage { + return &stageWithExtraEnv{ + inner: inner, + env: env, + } +} + +type stageWithExtraEnv struct { + inner Stage + env []EnvVar +} + +func (s *stageWithExtraEnv) Name() string { + return s.inner.Name() + " (with extra env vars)" +} + +func (s *stageWithExtraEnv) Requirements() StageRequirements { + return s.inner.Requirements() +} + +func (s *stageWithExtraEnv) Start( + ctx context.Context, opts StageOptions, + stdin io.Reader, closeStdin bool, + stdout io.Writer, closeStdout bool, +) error { + opts.Vars = append(opts.Vars, func(_ context.Context, vars []EnvVar) []EnvVar { + return append(vars, s.env...) + }) + return s.inner.Start(ctx, opts, stdin, closeStdin, stdout, closeStdout) +} + +func (s *stageWithExtraEnv) Wait() error { + return s.inner.Wait() +} diff --git a/pipe/env_stage_test.go b/pipe/env_stage_test.go new file mode 100644 index 0000000..efd50a0 --- /dev/null +++ b/pipe/env_stage_test.go @@ -0,0 +1,65 @@ +package pipe + +import ( + "context" + "io" + "reflect" + "testing" +) + +func collectEnvVars(ctx context.Context, env Env) []EnvVar { + var vars []EnvVar + for _, fn := range env.Vars { + vars = fn(ctx, vars) + } + return vars +} + +func TestWithExtraEnvAddsStageLocalVars(t *testing.T) { + ctx := context.Background() + var firstStageVars []EnvVar + var secondStageVars []EnvVar + + p := New(WithEnvVar("PIPELINE", "present")) + p.Add( + WithExtraEnv( + Function("first", func(ctx context.Context, env Env, _ io.Reader, _ io.Writer) error { + firstStageVars = collectEnvVars(ctx, env) + return nil + }), + []EnvVar{ + {Key: "STAGE", Value: "first"}, + }, + ), + Function("second", func(ctx context.Context, env Env, _ io.Reader, _ io.Writer) error { + secondStageVars = collectEnvVars(ctx, env) + return nil + }), + ) + + if err := p.Run(ctx); err != nil { + t.Fatal(err) + } + + if want := []EnvVar{{Key: "PIPELINE", Value: "present"}, {Key: "STAGE", Value: "first"}}; !reflect.DeepEqual(firstStageVars, want) { + t.Fatalf("first stage vars = %#v, want %#v", firstStageVars, want) + } + if want := []EnvVar{{Key: "PIPELINE", Value: "present"}}; !reflect.DeepEqual(secondStageVars, want) { + t.Fatalf("second stage vars = %#v, want %#v", secondStageVars, want) + } +} + +func TestWithExtraEnvPreservesStageMetadata(t *testing.T) { + inner := Function("inner", func(context.Context, Env, io.Reader, io.Writer) error { + return nil + }, ForbidStdin(), ForbidStdout()) + + stage := WithExtraEnv(inner, nil) + + if got, want := stage.Name(), "inner (with extra env vars)"; got != want { + t.Fatalf("Name() = %q, want %q", got, want) + } + if got, want := stage.Requirements(), inner.Requirements(); got != want { + t.Fatalf("Requirements() = %#v, want %#v", got, want) + } +} From ffcc348fcc20572412ae1ad635ff1013828865ce Mon Sep 17 00:00:00 2001 From: Jason Lunz Date: Mon, 8 Jun 2026 13:30:27 +0200 Subject: [PATCH 3/6] Fix stage env option slice aliasing WithExtraEnv appended to the StageOptions Env.Vars slice directly. The StageOptions value is copied per stage, but the slice backing array can still be shared with the pipeline-level environment hooks, so a wrapped stage could overwrite another stage's stage-local hook when the shared slice had spare capacity. Fix, and add tests. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pipe/env_stage.go | 2 +- pipe/env_stage_test.go | 84 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 85 insertions(+), 1 deletion(-) diff --git a/pipe/env_stage.go b/pipe/env_stage.go index 60348ab..bee791b 100644 --- a/pipe/env_stage.go +++ b/pipe/env_stage.go @@ -31,7 +31,7 @@ func (s *stageWithExtraEnv) Start( stdin io.Reader, closeStdin bool, stdout io.Writer, closeStdout bool, ) error { - opts.Vars = append(opts.Vars, func(_ context.Context, vars []EnvVar) []EnvVar { + opts.Vars = append(opts.Vars[:len(opts.Vars):len(opts.Vars)], func(_ context.Context, vars []EnvVar) []EnvVar { return append(vars, s.env...) }) return s.inner.Start(ctx, opts, stdin, closeStdin, stdout, closeStdout) diff --git a/pipe/env_stage_test.go b/pipe/env_stage_test.go index efd50a0..0bc4873 100644 --- a/pipe/env_stage_test.go +++ b/pipe/env_stage_test.go @@ -1,10 +1,12 @@ package pipe import ( + "bytes" "context" "io" "reflect" "testing" + "time" ) func collectEnvVars(ctx context.Context, env Env) []EnvVar { @@ -49,6 +51,88 @@ func TestWithExtraEnvAddsStageLocalVars(t *testing.T) { } } +func TestWithExtraEnvDoesNotShareVarsBackingArray(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + allowFirstStage := make(chan struct{}) + var firstStageVars []EnvVar + var secondStageVars []EnvVar + + baseVars := make([]AppendVars, 0, 4) + for _, env := range []EnvVar{ + {Key: "PIPELINE1", Value: "present"}, + {Key: "PIPELINE2", Value: "present"}, + {Key: "PIPELINE3", Value: "present"}, + } { + env := env + baseVars = append(baseVars, func(_ context.Context, vars []EnvVar) []EnvVar { + return append(vars, env) + }) + } + + p := New(func(p *Pipeline) { + p.env.Vars = baseVars + }) + p.Add( + WithExtraEnv( + Function("first", func(ctx context.Context, env Env, _ io.Reader, _ io.Writer) error { + select { + case <-allowFirstStage: + case <-ctx.Done(): + return ctx.Err() + } + firstStageVars = collectEnvVars(ctx, env) + return nil + }), + []EnvVar{{Key: "STAGE", Value: "first"}}, + ), + WithExtraEnv( + Function("second", func(ctx context.Context, env Env, _ io.Reader, _ io.Writer) error { + secondStageVars = collectEnvVars(ctx, env) + close(allowFirstStage) + return nil + }), + []EnvVar{{Key: "STAGE", Value: "second"}}, + ), + ) + + if err := p.Run(ctx); err != nil { + t.Fatal(err) + } + + wantBase := []EnvVar{ + {Key: "PIPELINE1", Value: "present"}, + {Key: "PIPELINE2", Value: "present"}, + {Key: "PIPELINE3", Value: "present"}, + } + if want := append(append([]EnvVar(nil), wantBase...), EnvVar{Key: "STAGE", Value: "first"}); !reflect.DeepEqual(firstStageVars, want) { + t.Fatalf("first stage vars = %#v, want %#v", firstStageVars, want) + } + if want := append(append([]EnvVar(nil), wantBase...), EnvVar{Key: "STAGE", Value: "second"}); !reflect.DeepEqual(secondStageVars, want) { + t.Fatalf("second stage vars = %#v, want %#v", secondStageVars, want) + } +} + +func TestWithExtraEnvAddsCommandEnv(t *testing.T) { + ctx := context.Background() + stdout := &bytes.Buffer{} + + p := New(WithStdout(stdout)) + p.Add(WithExtraEnv( + Command("sh", "-c", "printf %s \"$STAGE\""), + []EnvVar{{Key: "STAGE", Value: "command"}}, + )) + + if err := p.Run(ctx); err != nil { + t.Fatal(err) + } + + if got, want := stdout.String(), "command"; got != want { + t.Fatalf("stdout = %q, want %q", got, want) + } +} + func TestWithExtraEnvPreservesStageMetadata(t *testing.T) { inner := Function("inner", func(context.Context, Env, io.Reader, io.Writer) error { return nil From f89e7736aa7ba73657fce7af4dfdf5128ea7cb46 Mon Sep 17 00:00:00 2001 From: Jason Lunz Date: Mon, 8 Jun 2026 13:40:10 +0200 Subject: [PATCH 4/6] env_stage.go: Preserve command hooks through stage env wrapper WithExtraEnv wraps the inner stage, so it would otherwise hide optional command hooks such as Process and Kill. Wrap the inner stage in a way that faithfully preserves whether both are present. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pipe/env_stage.go | 19 ++++++++++++++++++- pipe/env_stage_test.go | 20 ++++++++++++++++++++ 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/pipe/env_stage.go b/pipe/env_stage.go index bee791b..99f5372 100644 --- a/pipe/env_stage.go +++ b/pipe/env_stage.go @@ -7,10 +7,27 @@ import ( // WithExtraEnv returns a Stage that adds env to the environment seen by inner. func WithExtraEnv(inner Stage, env []EnvVar) Stage { - return &stageWithExtraEnv{ + stage := &stageWithExtraEnv{ inner: inner, env: env, } + if processKiller, ok := inner.(processKiller); ok { + return &processStageWithExtraEnv{ + stageWithExtraEnv: stage, + processKiller: processKiller, + } + } + return stage +} + +type processKiller interface { + processProvider + Kill(error) +} + +type processStageWithExtraEnv struct { + *stageWithExtraEnv + processKiller } type stageWithExtraEnv struct { diff --git a/pipe/env_stage_test.go b/pipe/env_stage_test.go index 0bc4873..f232d8a 100644 --- a/pipe/env_stage_test.go +++ b/pipe/env_stage_test.go @@ -133,6 +133,26 @@ func TestWithExtraEnvAddsCommandEnv(t *testing.T) { } } +func TestWithExtraEnvPreservesProcessHooks(t *testing.T) { + stage := WithExtraEnv(Command("true"), nil) + + if _, ok := stage.(processKiller); !ok { + t.Fatal("WithExtraEnv(Command(...)) does not implement processKiller") + } +} + +func TestWithExtraEnvDoesNotAddProcessHooks(t *testing.T) { + inner := Function("inner", func(context.Context, Env, io.Reader, io.Writer) error { + return nil + }) + + stage := WithExtraEnv(inner, nil) + + if _, ok := stage.(processKiller); ok { + t.Fatal("WithExtraEnv(Function(...)) unexpectedly implements processKiller") + } +} + func TestWithExtraEnvPreservesStageMetadata(t *testing.T) { inner := Function("inner", func(context.Context, Env, io.Reader, io.Writer) error { return nil From 1ae6d3b00e59b0bc85dc8f29714189229183e841 Mon Sep 17 00:00:00 2001 From: Jason Lunz Date: Mon, 8 Jun 2026 13:46:18 +0200 Subject: [PATCH 5/6] lint fix: copyloopvar --- pipe/env_stage_test.go | 1 - 1 file changed, 1 deletion(-) diff --git a/pipe/env_stage_test.go b/pipe/env_stage_test.go index f232d8a..6d996e1 100644 --- a/pipe/env_stage_test.go +++ b/pipe/env_stage_test.go @@ -65,7 +65,6 @@ func TestWithExtraEnvDoesNotShareVarsBackingArray(t *testing.T) { {Key: "PIPELINE2", Value: "present"}, {Key: "PIPELINE3", Value: "present"}, } { - env := env baseVars = append(baseVars, func(_ context.Context, vars []EnvVar) []EnvVar { return append(vars, env) }) From 145c7205b81efc021edb97179a9dfdeb95448ccd Mon Sep 17 00:00:00 2001 From: Jason Lunz Date: Mon, 8 Jun 2026 15:15:06 +0200 Subject: [PATCH 6/6] Normalize env stage tests Use testify assertions consistently in env_stage_test and add coverage for stage-local env values overriding pipeline values without leaking to later stages. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pipe/env_stage_test.go | 116 +++++++++++++++++++++++++---------------- 1 file changed, 72 insertions(+), 44 deletions(-) diff --git a/pipe/env_stage_test.go b/pipe/env_stage_test.go index 6d996e1..43117ae 100644 --- a/pipe/env_stage_test.go +++ b/pipe/env_stage_test.go @@ -4,9 +4,11 @@ import ( "bytes" "context" "io" - "reflect" "testing" "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func collectEnvVars(ctx context.Context, env Env) []EnvVar { @@ -17,7 +19,23 @@ func collectEnvVars(ctx context.Context, env Env) []EnvVar { return vars } +func lastEnvValue(ctx context.Context, env Env, key string) (string, bool) { + var ( + value string + ok bool + ) + for _, envVar := range collectEnvVars(ctx, env) { + if envVar.Key == key { + value = envVar.Value + ok = true + } + } + return value, ok +} + func TestWithExtraEnvAddsStageLocalVars(t *testing.T) { + t.Parallel() + ctx := context.Background() var firstStageVars []EnvVar var secondStageVars []EnvVar @@ -29,9 +47,7 @@ func TestWithExtraEnvAddsStageLocalVars(t *testing.T) { firstStageVars = collectEnvVars(ctx, env) return nil }), - []EnvVar{ - {Key: "STAGE", Value: "first"}, - }, + []EnvVar{{Key: "STAGE", Value: "first"}}, ), Function("second", func(ctx context.Context, env Env, _ io.Reader, _ io.Writer) error { secondStageVars = collectEnvVars(ctx, env) @@ -39,19 +55,45 @@ func TestWithExtraEnvAddsStageLocalVars(t *testing.T) { }), ) - if err := p.Run(ctx); err != nil { - t.Fatal(err) - } + require.NoError(t, p.Run(ctx)) + assert.Equal(t, []EnvVar{{Key: "PIPELINE", Value: "present"}, {Key: "STAGE", Value: "first"}}, firstStageVars) + assert.Equal(t, []EnvVar{{Key: "PIPELINE", Value: "present"}}, secondStageVars) +} - if want := []EnvVar{{Key: "PIPELINE", Value: "present"}, {Key: "STAGE", Value: "first"}}; !reflect.DeepEqual(firstStageVars, want) { - t.Fatalf("first stage vars = %#v, want %#v", firstStageVars, want) - } - if want := []EnvVar{{Key: "PIPELINE", Value: "present"}}; !reflect.DeepEqual(secondStageVars, want) { - t.Fatalf("second stage vars = %#v, want %#v", secondStageVars, want) - } +func TestWithExtraEnvStageLocalVarsOverridePipelineVars(t *testing.T) { + t.Parallel() + + ctx := context.Background() + var firstStageValue string + var secondStageValue string + + p := New(WithEnvVar("STAGE", "pipeline")) + p.Add( + WithExtraEnv( + Function("first", func(ctx context.Context, env Env, _ io.Reader, _ io.Writer) error { + var ok bool + firstStageValue, ok = lastEnvValue(ctx, env, "STAGE") + require.True(t, ok) + return nil + }), + []EnvVar{{Key: "STAGE", Value: "stage-local"}}, + ), + Function("second", func(ctx context.Context, env Env, _ io.Reader, _ io.Writer) error { + var ok bool + secondStageValue, ok = lastEnvValue(ctx, env, "STAGE") + require.True(t, ok) + return nil + }), + ) + + require.NoError(t, p.Run(ctx)) + assert.Equal(t, "stage-local", firstStageValue) + assert.Equal(t, "pipeline", secondStageValue) } func TestWithExtraEnvDoesNotShareVarsBackingArray(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() @@ -96,24 +138,20 @@ func TestWithExtraEnvDoesNotShareVarsBackingArray(t *testing.T) { ), ) - if err := p.Run(ctx); err != nil { - t.Fatal(err) - } + require.NoError(t, p.Run(ctx)) wantBase := []EnvVar{ {Key: "PIPELINE1", Value: "present"}, {Key: "PIPELINE2", Value: "present"}, {Key: "PIPELINE3", Value: "present"}, } - if want := append(append([]EnvVar(nil), wantBase...), EnvVar{Key: "STAGE", Value: "first"}); !reflect.DeepEqual(firstStageVars, want) { - t.Fatalf("first stage vars = %#v, want %#v", firstStageVars, want) - } - if want := append(append([]EnvVar(nil), wantBase...), EnvVar{Key: "STAGE", Value: "second"}); !reflect.DeepEqual(secondStageVars, want) { - t.Fatalf("second stage vars = %#v, want %#v", secondStageVars, want) - } + assert.Equal(t, append(append([]EnvVar(nil), wantBase...), EnvVar{Key: "STAGE", Value: "first"}), firstStageVars) + assert.Equal(t, append(append([]EnvVar(nil), wantBase...), EnvVar{Key: "STAGE", Value: "second"}), secondStageVars) } func TestWithExtraEnvAddsCommandEnv(t *testing.T) { + t.Parallel() + ctx := context.Background() stdout := &bytes.Buffer{} @@ -123,46 +161,36 @@ func TestWithExtraEnvAddsCommandEnv(t *testing.T) { []EnvVar{{Key: "STAGE", Value: "command"}}, )) - if err := p.Run(ctx); err != nil { - t.Fatal(err) - } - - if got, want := stdout.String(), "command"; got != want { - t.Fatalf("stdout = %q, want %q", got, want) - } + require.NoError(t, p.Run(ctx)) + assert.Equal(t, "command", stdout.String()) } func TestWithExtraEnvPreservesProcessHooks(t *testing.T) { - stage := WithExtraEnv(Command("true"), nil) + t.Parallel() - if _, ok := stage.(processKiller); !ok { - t.Fatal("WithExtraEnv(Command(...)) does not implement processKiller") - } + stage := WithExtraEnv(Command("true"), nil) + assert.Implements(t, (*processKiller)(nil), stage) } func TestWithExtraEnvDoesNotAddProcessHooks(t *testing.T) { + t.Parallel() + inner := Function("inner", func(context.Context, Env, io.Reader, io.Writer) error { return nil }) stage := WithExtraEnv(inner, nil) - - if _, ok := stage.(processKiller); ok { - t.Fatal("WithExtraEnv(Function(...)) unexpectedly implements processKiller") - } + assert.NotImplements(t, (*processKiller)(nil), stage) } func TestWithExtraEnvPreservesStageMetadata(t *testing.T) { + t.Parallel() + inner := Function("inner", func(context.Context, Env, io.Reader, io.Writer) error { return nil }, ForbidStdin(), ForbidStdout()) stage := WithExtraEnv(inner, nil) - - if got, want := stage.Name(), "inner (with extra env vars)"; got != want { - t.Fatalf("Name() = %q, want %q", got, want) - } - if got, want := stage.Requirements(), inner.Requirements(); got != want { - t.Fatalf("Requirements() = %#v, want %#v", got, want) - } + assert.Equal(t, "inner (with extra env vars)", stage.Name()) + assert.Equal(t, inner.Requirements(), stage.Requirements()) }