diff --git a/internal/ptree/ptree.go b/internal/ptree/ptree.go deleted file mode 100644 index 8ce22da..0000000 --- a/internal/ptree/ptree.go +++ /dev/null @@ -1,303 +0,0 @@ -// Package ptree contains utilities for dealing with Linux process trees. -package ptree - -import ( - "bytes" - "errors" - "io" - "os" - "strconv" - "strings" - "sync" -) - -const ( - // initialReadBufSize is the starting capacity of the buffer used for - // reading /proc files. /proc//status is typically ~1 KiB and the - // per-task "children" files are smaller, so 4 KiB covers the common - // case in a single read. - initialReadBufSize = 4 * 1024 -) - -var errNoRss = errors.New("RssAnon was not found") - -type ProcessTree struct { - path string -} - -func NewProcessTree(path string) ProcessTree { - return ProcessTree{ - path: path, - } -} - -// readBufPool reuses the byte buffer that holds the contents of a /proc file -// across calls to getProcessRSSAnon / walkChildrenFile, so that the -// per-poll work doesn't allocate (and then garbage-collect) a fresh buffer -// for every process in the tree. -var readBufPool = sync.Pool{ - New: func() any { - b := make([]byte, 0, initialReadBufSize) - return &b - }, -} - -// readProcFile reads all of path into a buffer borrowed from readBufPool. -// -// On success, the returned slice is only valid until bufPtr is returned to -// the pool, which the caller MUST do (typically with -// `defer readBufPool.Put(bufPtr)`, placed after the error check). -// -// On error, bufPtr is nil and the buffer has already been released, so the -// caller must not Put it back. -// -// Compared to os.ReadFile, this skips the (useless for /proc) Stat call -// used to pre-size the buffer and reuses the buffer across calls. It also -// returns the underlying *[]byte rather than a closure so the release path -// allocates nothing. -func readProcFile(path string) (data []byte, bufPtr *[]byte, err error) { - f, err := os.Open(path) - if err != nil { - return nil, nil, err - } - defer f.Close() - - bufPtr = readBufPool.Get().(*[]byte) - buf := (*bufPtr)[:0] - for { - if len(buf) == cap(buf) { - // Grow via append; this only allocates if the pooled - // buffer was too small. Subsequent calls will reuse - // the grown buffer because we store it back below. - buf = append(buf, 0)[:len(buf)] - } - n, rerr := f.Read(buf[len(buf):cap(buf)]) - buf = buf[:len(buf)+n] - if rerr == io.EOF { - break - } - if rerr != nil { - *bufPtr = buf - readBufPool.Put(bufPtr) - return nil, nil, rerr - } - if n == 0 { - // Defensive: io.Reader allows (0, nil) returns; treat - // as EOF rather than spinning. Real *os.File on /proc - // shouldn't hit this, but mocks or future runtime - // behavior might. - break - } - } - *bufPtr = buf - return buf, bufPtr, nil -} - -// Return the RSSAnon of a single process `pid`. -func (pt ProcessTree) GetProcessRSSAnon(pid int) (uint64, error) { - status := pt.path + "/" + strconv.Itoa(pid) + "/status" - data, bufPtr, err := readProcFile(status) - if os.IsNotExist(err) { - // process is already gone - return 0, nil - } - if err != nil { - return 0, err - } - defer readBufPool.Put(bufPtr) - - prefix := []byte("RssAnon:") - rest := data - for len(rest) > 0 { - var line []byte - if nl := bytes.IndexByte(rest, '\n'); nl >= 0 { - line, rest = rest[:nl], rest[nl+1:] - } else { - line, rest = rest, nil - } - // Fast prefix check before paying for the string conversion. - if !bytes.HasPrefix(line, prefix) { - continue - } - if rss, ok := ParseRSSAnon(string(line)); ok { - return rss, nil - } - } - return 0, errNoRss -} - -// Return the total RSS of the tree of processes rooted at `pid`. -// -// If the passed root pid is that of a kernel thread, as a special case, we -// return zero and no error. -// -// Errors encountered while walking the children are ignored, since it can -// change while traversing it. -func (pt ProcessTree) GetProcessTreeRSSAnon(pid int) (uint64, error) { - total, err := pt.GetProcessRSSAnon(pid) - if err != nil { - if err == errNoRss { - // these are typically kernel threads, which don't have an address space to measure - return 0, nil - } - return 0, err - } - - pt.WalkChildren(pid, func(pid int) { - mem, err := pt.GetProcessRSSAnon(pid) - if err != nil { - return - } - total += mem - }) - - return total, nil -} - -func (pt ProcessTree) WalkChildren(pid int, walkFn func(int)) { - pt.walkChildPids(pid, walkFn, map[int]bool{pid: true}) -} - -func (pt ProcessTree) walkChildPids(pid int, walkFn func(int), visited map[int]bool) { - // List the per-thread directories under /proc//task and read each - // task's "children" file directly. This avoids filepath.Glob, which - // would Stat every match on top of the readdir we already need. - taskDir := pt.path + "/" + strconv.Itoa(pid) + "/task" - entries, err := os.ReadDir(taskDir) - if err != nil { - return - } - - for _, entry := range entries { - // task/ should only contain numeric TID directories. Skip - // anything else defensively; this mirrors the implicit - // filtering that filepath.Glob("*/children") provided. - // A byte-range check avoids the error allocation that - // strconv.Atoi would incur for non-numeric names. - if !isAllDigits(entry.Name()) { - continue - } - pt.walkChildrenFile(taskDir+"/"+entry.Name()+"/children", walkFn, visited) - } -} - -func (pt ProcessTree) walkChildrenFile(filename string, walkFn func(int), visited map[int]bool) { - data, bufPtr, err := readProcFile(filename) - if err != nil { - return - } - defer readBufPool.Put(bufPtr) - - // children is a whitespace-separated list of decimal PIDs. Parse it in - // place to avoid the string(data) conversion and the []string allocated - // by strings.Fields. - i := 0 - for i < len(data) { - for i < len(data) && isASCIISpace(data[i]) { - i++ - } - if i >= len(data) { - return - } - pid := 0 - start := i - for i < len(data) && data[i] >= '0' && data[i] <= '9' { - pid = pid*10 + int(data[i]-'0') - i++ - } - if i == start { - // Not a digit; skip until next whitespace to stay in sync. - for i < len(data) && !isASCIISpace(data[i]) { - i++ - } - continue - } - if i-start > 10 { - // Realistic Linux PIDs fit in well under 10 digits - // (PID_MAX is 2^22). A longer digit run can't be a - // real PID and would risk silently overflowing the - // int accumulator, so skip it. - continue - } - if visited[pid] { - continue - } - walkFn(pid) - visited[pid] = true - pt.walkChildPids(pid, walkFn, visited) - } -} - -// parseRSSAnon parses an "RssAnon" line from /proc/*/status and returns the size. -// The entire line should be passed in, with or without the line ending. If the -// line looks like "RssAnon: 1234 kB", the byte size will be returned. If the -// line isn't parseable, (0, false) will be returned. -func ParseRSSAnon(s string) (uint64, bool) { - const prefix = "RssAnon:" - if !strings.HasPrefix(s, prefix) { - return 0, false - } - s = s[len(prefix):] - - // Optional whitespace before the number. - i := 0 - for i < len(s) && isASCIISpace(s[i]) { - i++ - } - - // One or more digits. - digitsStart := i - for i < len(s) && s[i] >= '0' && s[i] <= '9' { - i++ - } - if i == digitsStart { - return 0, false - } - kb, err := strconv.ParseUint(s[digitsStart:i], 10, 64) - if err != nil { - return 0, false - } - - // At least one whitespace between the number and "kB". - if i >= len(s) || !isASCIISpace(s[i]) { - return 0, false - } - for i < len(s) && isASCIISpace(s[i]) { - i++ - } - - // Literal "kB", then either end-of-string or whitespace. - const unit = "kB" - if !strings.HasPrefix(s[i:], unit) { - return 0, false - } - i += len(unit) - if i < len(s) && !isASCIISpace(s[i]) { - return 0, false - } - return kb * 1024, true -} - -// isASCIISpace matches the character class that Go's regexp engine uses for -// \s in non-Unicode mode: [\t\n\f\r ]. -func isASCIISpace(b byte) bool { - switch b { - case ' ', '\t', '\n', '\f', '\r': - return true - } - return false -} - -// isAllDigits reports whether s is non-empty and consists entirely of ASCII -// decimal digits. Used as a cheap allocation-free numeric-name filter. -func isAllDigits(s string) bool { - if len(s) == 0 { - return false - } - for i := 0; i < len(s); i++ { - if s[i] < '0' || s[i] > '9' { - return false - } - } - return true -} diff --git a/internal/ptree/ptree_linux.go b/internal/ptree/ptree_linux.go deleted file mode 100644 index 7019693..0000000 --- a/internal/ptree/ptree_linux.go +++ /dev/null @@ -1,21 +0,0 @@ -package ptree - -var DefaultProcessTree = ProcessTree{ - path: "/proc", -} - -// Walk the child processes of the specified root process. walkFn will be called -// for each child found. It will not be called for the root process. Any errors -// will be ignored, since they may be just a consequence of the process tree -// changing during traversal. -func WalkChildren(pid int, walkFn func(int)) { - DefaultProcessTree.WalkChildren(pid, walkFn) -} - -func GetProcessRSSAnon(pid int) (uint64, error) { - return DefaultProcessTree.GetProcessRSSAnon(pid) -} - -func GetProcessTreeRSSAnon(pid int) (uint64, error) { - return DefaultProcessTree.GetProcessTreeRSSAnon(pid) -} diff --git a/internal/ptree/ptree_test.go b/internal/ptree/ptree_test.go deleted file mode 100644 index 5c014c2..0000000 --- a/internal/ptree/ptree_test.go +++ /dev/null @@ -1,260 +0,0 @@ -package ptree_test - -import ( - "bytes" - "fmt" - "os" - "path/filepath" - "sort" - "strconv" - "testing" - - "github.com/github/go-pipe/v2/internal/ptree" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// writeStatus creates only what GetProcessRSSAnon reads: //status. -// If rssKB is zero, the RssAnon line is omitted (mimicking kernel threads). -func writeStatus(t testing.TB, root string, pid int, rssKB uint64) { - t.Helper() - pidDir := filepath.Join(root, strconv.Itoa(pid)) - require.NoError(t, os.MkdirAll(pidDir, 0o755)) - var status string - if rssKB > 0 { - status = fmt.Sprintf("Name:\tfake\nRssAnon:\t%d kB\nVmSize:\t1000 kB\n", rssKB) - } else { - status = "Name:\tfake\nVmSize:\t1000 kB\n" - } - require.NoError(t, os.WriteFile(filepath.Join(pidDir, "status"), []byte(status), 0o600)) -} - -// writeChildren creates //task//children containing the -// space-separated child pids. Only call this for processes that actually -// have children; getProcessTreeRSSAnon copes fine with the task/ directory -// being absent for leaves. -func writeChildren(t testing.TB, root string, pid int, children []int) { - t.Helper() - writeThreadChildren(t, root, pid, pid, children) -} - -// writeThreadChildren creates //task//children for a specific -// thread id, so tests can exercise multi-threaded processes. -func writeThreadChildren(t testing.TB, root string, pid, tid int, children []int) { - t.Helper() - taskDir := filepath.Join(root, strconv.Itoa(pid), "task", strconv.Itoa(tid)) - require.NoError(t, os.MkdirAll(taskDir, 0o755)) - var buf bytes.Buffer - for _, c := range children { - fmt.Fprintf(&buf, "%d ", c) - } - require.NoError(t, os.WriteFile(filepath.Join(taskDir, "children"), buf.Bytes(), 0o600)) -} - -func TestGetProcessRSSAnon(t *testing.T) { - const kb = 1024 - root := t.TempDir() - writeStatus(t, root, 100, 15032) - writeStatus(t, root, 101, 0) // kernel-thread-like: no RssAnon line - - pt := ptree.NewProcessTree(root) - - t.Run("reads RssAnon", func(t *testing.T) { - rss, err := pt.GetProcessRSSAnon(100) - require.NoError(t, err) - assert.Equal(t, uint64(15032*kb), rss) - }) - - t.Run("missing RssAnon line returns an error", func(t *testing.T) { - _, err := pt.GetProcessRSSAnon(101) - assert.Error(t, err) - }) - - t.Run("missing pid returns (0, nil)", func(t *testing.T) { - // A process that has already exited disappears from /proc; the function - // treats that as a non-error zero. - rss, err := pt.GetProcessRSSAnon(999) - require.NoError(t, err) - assert.Equal(t, uint64(0), rss) - }) -} - -func TestGetProcessTreeRSSAnon(t *testing.T) { - const kb = 1024 - - t.Run("leaf process returns its own RssAnon", func(t *testing.T) { - root := t.TempDir() - writeStatus(t, root, 100, 1000) - - pt := ptree.NewProcessTree(root) - - total, err := pt.GetProcessTreeRSSAnon(100) - require.NoError(t, err) - assert.Equal(t, uint64(1000*kb), total) - }) - - t.Run("sums root and descendants", func(t *testing.T) { - // 100 -> {101 -> 103, 102} - root := t.TempDir() - writeStatus(t, root, 100, 1000) - writeStatus(t, root, 101, 200) - writeStatus(t, root, 102, 50) - writeStatus(t, root, 103, 7) - writeChildren(t, root, 100, []int{101, 102}) - writeChildren(t, root, 101, []int{103}) - - pt := ptree.NewProcessTree(root) - - total, err := pt.GetProcessTreeRSSAnon(100) - require.NoError(t, err) - assert.Equal(t, uint64((1000+200+50+7)*kb), total) - }) - - t.Run("kernel-thread root returns (0, nil)", func(t *testing.T) { - // Root has no RssAnon line; the function maps errNoRss to (0, nil). - root := t.TempDir() - writeStatus(t, root, 100, 0) - - pt := ptree.NewProcessTree(root) - - total, err := pt.GetProcessTreeRSSAnon(100) - require.NoError(t, err) - assert.Equal(t, uint64(0), total) - }) -} - -func TestWalkChildren(t *testing.T) { - t.Run("walks all descendants", func(t *testing.T) { - // 100 -> {101, 102 -> 103}. Verifies the callback fires for - // every descendant (not just direct children) and is not - // invoked for the root. - root := t.TempDir() - writeChildren(t, root, 100, []int{101, 102}) - writeChildren(t, root, 102, []int{103}) - - pt := ptree.NewProcessTree(root) - - var seen []int - pt.WalkChildren(100, func(pid int) { seen = append(seen, pid) }) - sort.Ints(seen) - assert.Equal(t, []int{101, 102, 103}, seen) - }) - - t.Run("iterates every thread under task/ and dedups", func(t *testing.T) { - // 100 has two threads (100 and 200); each thread reports a - // different set of children, with 102 listed by both threads - // to exercise the visited dedup. - root := t.TempDir() - writeThreadChildren(t, root, 100, 100, []int{101, 102}) - writeThreadChildren(t, root, 100, 200, []int{102, 103}) - - pt := ptree.NewProcessTree(root) - - var seen []int - pt.WalkChildren(100, func(pid int) { seen = append(seen, pid) }) - sort.Ints(seen) - assert.Equal(t, []int{101, 102, 103}, seen) - }) -} - -// BenchmarkGetProcessTreeRSSAnon measures the cost of a single poll over a -// small process tree (a root plus a few direct children). -func BenchmarkGetProcessTreeRSSAnon(b *testing.B) { - const rootPid = 100 - root := b.TempDir() - writeStatus(b, root, rootPid, 1000) - children := []int{101, 102, 103} - writeChildren(b, root, rootPid, children) - for _, c := range children { - writeStatus(b, root, c, 200) - } - - pt := ptree.NewProcessTree(root) - - b.ReportAllocs() - b.ResetTimer() - for i := 0; i < b.N; i++ { - _, err := pt.GetProcessTreeRSSAnon(rootPid) - if err != nil { - b.Fatal(err) - } - } -} - -func TestParseRss(t *testing.T) { - const kb = 1024 - - okExamples := []struct { - input string - result uint64 - }{ - { - input: "RssAnon:\t 15032 kB", - result: 15032 * kb, - }, - { - input: "RssAnon:\t99915032 kB", - result: 99915032 * kb, - }, - { - input: "RssAnon:\t 1 kB", - result: kb, - }, - // Exactly what the kernel emits via SEQ_PUT_DEC: "RssAnon:\t" + - // 8-wide right-justified decimal + " kB\n". See fs/proc/task_mmu.c - // (task_mem). The trailing newline must be tolerated. - { - input: "RssAnon:\t 15032 kB\n", - result: 15032 * kb, - }, - // A value wider than the 8-char padding (no leading spaces). - { - input: "RssAnon:\t12345678 kB\n", - result: 12345678 * kb, - }, - { - input: "RssAnon:\t 0 kB\n", - result: 0, - }, - } - - for _, example := range okExamples { - rss, ok := ptree.ParseRSSAnon(example.input) - if assert.Truef(t, ok, "should be able to parse %q", example.input) { - assert.Equalf(t, example.result, rss, "value of %q", example.input) - } - } - - badExamples := []string{ - "", - "\n", - "RssAnon:\t 123", - "RssAnonn:\t 123 kB", - "RssAno:\t 123 kB", - "Blah:\t 123 kB", - "Blah:", - "123", - } - - for _, example := range badExamples { - _, ok := ptree.ParseRSSAnon(example) - assert.Falsef(t, ok, "should not be able to parse %q", example) - } -} - -func BenchmarkParseRss(b *testing.B) { - b.Run("match", func(b *testing.B) { - for i := 0; i < b.N; i++ { - rss, ok := ptree.ParseRSSAnon("RssAnon:\t 15032 kB") - require.True(b, ok) - require.EqualValues(b, 15032*1024, rss) - } - }) - - b.Run("no match", func(b *testing.B) { - for i := 0; i < b.N; i++ { - _, ok := ptree.ParseRSSAnon("Other:\t 15032 kB") - require.False(b, ok) - } - }) -} diff --git a/pipe/command.go b/pipe/command.go index 4dbe82b..ed12f9b 100644 --- a/pipe/command.go +++ b/pipe/command.go @@ -15,8 +15,6 @@ import ( "golang.org/x/sync/errgroup" ) -var errProcessInfoMissing = errors.New("cmd.Process is nil") - // commandStage is a pipeline `Stage` based on running an external // command and piping the data through its stdin and stdout. type commandStage struct { @@ -37,9 +35,16 @@ type commandStage struct { } var ( - _ Stage = (*commandStage)(nil) + _ Stage = (*commandStage)(nil) + _ processProvider = (*commandStage)(nil) ) +// processProvider is the hook external memory-watchers use to find the running +// process so they can sample its RSS. +type processProvider interface { + Process() *os.Process +} + // Command returns a pipeline `Stage` based on the specified external // `command`, run with the given command-line `args`. Its stdin and // stdout are handled as usual, and its stderr is collected and @@ -69,6 +74,10 @@ func (s *commandStage) Name() string { return s.name } +func (s *commandStage) Process() *os.Process { + return s.cmd.Process +} + func (s *commandStage) Requirements() StageRequirements { return StageRequirements{ StdinNeedsFile: true, diff --git a/pipe/command_linux.go b/pipe/command_linux.go deleted file mode 100644 index 987bffb..0000000 --- a/pipe/command_linux.go +++ /dev/null @@ -1,20 +0,0 @@ -//go:build linux - -package pipe - -import ( - "context" - - "github.com/github/go-pipe/v2/internal/ptree" -) - -// On linux, we can limit or observe memory usage in command stages. -var _ LimitableStage = (*commandStage)(nil) - -func (s *commandStage) GetRSSAnon(_ context.Context) (uint64, error) { - if s.cmd.Process == nil { - return 0, errProcessInfoMissing - } - - return ptree.GetProcessTreeRSSAnon(s.cmd.Process.Pid) -} diff --git a/pipe/env_stage.go b/pipe/env_stage.go new file mode 100644 index 0000000..99f5372 --- /dev/null +++ b/pipe/env_stage.go @@ -0,0 +1,59 @@ +package pipe + +import ( + "context" + "io" +) + +// WithExtraEnv returns a Stage that adds env to the environment seen by inner. +func WithExtraEnv(inner Stage, env []EnvVar) Stage { + stage := &stageWithExtraEnv{ + inner: inner, + env: env, + } + if processKiller, ok := inner.(processKiller); ok { + return &processStageWithExtraEnv{ + stageWithExtraEnv: stage, + processKiller: processKiller, + } + } + return stage +} + +type processKiller interface { + processProvider + Kill(error) +} + +type processStageWithExtraEnv struct { + *stageWithExtraEnv + processKiller +} + +type stageWithExtraEnv struct { + inner Stage + env []EnvVar +} + +func (s *stageWithExtraEnv) Name() string { + return s.inner.Name() + " (with extra env vars)" +} + +func (s *stageWithExtraEnv) Requirements() StageRequirements { + return s.inner.Requirements() +} + +func (s *stageWithExtraEnv) Start( + ctx context.Context, opts StageOptions, + stdin io.Reader, closeStdin bool, + stdout io.Writer, closeStdout bool, +) error { + opts.Vars = append(opts.Vars[:len(opts.Vars):len(opts.Vars)], func(_ context.Context, vars []EnvVar) []EnvVar { + return append(vars, s.env...) + }) + return s.inner.Start(ctx, opts, stdin, closeStdin, stdout, closeStdout) +} + +func (s *stageWithExtraEnv) Wait() error { + return s.inner.Wait() +} diff --git a/pipe/env_stage_test.go b/pipe/env_stage_test.go new file mode 100644 index 0000000..43117ae --- /dev/null +++ b/pipe/env_stage_test.go @@ -0,0 +1,196 @@ +package pipe + +import ( + "bytes" + "context" + "io" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func collectEnvVars(ctx context.Context, env Env) []EnvVar { + var vars []EnvVar + for _, fn := range env.Vars { + vars = fn(ctx, vars) + } + return vars +} + +func lastEnvValue(ctx context.Context, env Env, key string) (string, bool) { + var ( + value string + ok bool + ) + for _, envVar := range collectEnvVars(ctx, env) { + if envVar.Key == key { + value = envVar.Value + ok = true + } + } + return value, ok +} + +func TestWithExtraEnvAddsStageLocalVars(t *testing.T) { + t.Parallel() + + ctx := context.Background() + var firstStageVars []EnvVar + var secondStageVars []EnvVar + + p := New(WithEnvVar("PIPELINE", "present")) + p.Add( + WithExtraEnv( + Function("first", func(ctx context.Context, env Env, _ io.Reader, _ io.Writer) error { + firstStageVars = collectEnvVars(ctx, env) + return nil + }), + []EnvVar{{Key: "STAGE", Value: "first"}}, + ), + Function("second", func(ctx context.Context, env Env, _ io.Reader, _ io.Writer) error { + secondStageVars = collectEnvVars(ctx, env) + return nil + }), + ) + + require.NoError(t, p.Run(ctx)) + assert.Equal(t, []EnvVar{{Key: "PIPELINE", Value: "present"}, {Key: "STAGE", Value: "first"}}, firstStageVars) + assert.Equal(t, []EnvVar{{Key: "PIPELINE", Value: "present"}}, secondStageVars) +} + +func TestWithExtraEnvStageLocalVarsOverridePipelineVars(t *testing.T) { + t.Parallel() + + ctx := context.Background() + var firstStageValue string + var secondStageValue string + + p := New(WithEnvVar("STAGE", "pipeline")) + p.Add( + WithExtraEnv( + Function("first", func(ctx context.Context, env Env, _ io.Reader, _ io.Writer) error { + var ok bool + firstStageValue, ok = lastEnvValue(ctx, env, "STAGE") + require.True(t, ok) + return nil + }), + []EnvVar{{Key: "STAGE", Value: "stage-local"}}, + ), + Function("second", func(ctx context.Context, env Env, _ io.Reader, _ io.Writer) error { + var ok bool + secondStageValue, ok = lastEnvValue(ctx, env, "STAGE") + require.True(t, ok) + return nil + }), + ) + + require.NoError(t, p.Run(ctx)) + assert.Equal(t, "stage-local", firstStageValue) + assert.Equal(t, "pipeline", secondStageValue) +} + +func TestWithExtraEnvDoesNotShareVarsBackingArray(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + allowFirstStage := make(chan struct{}) + var firstStageVars []EnvVar + var secondStageVars []EnvVar + + baseVars := make([]AppendVars, 0, 4) + for _, env := range []EnvVar{ + {Key: "PIPELINE1", Value: "present"}, + {Key: "PIPELINE2", Value: "present"}, + {Key: "PIPELINE3", Value: "present"}, + } { + baseVars = append(baseVars, func(_ context.Context, vars []EnvVar) []EnvVar { + return append(vars, env) + }) + } + + p := New(func(p *Pipeline) { + p.env.Vars = baseVars + }) + p.Add( + WithExtraEnv( + Function("first", func(ctx context.Context, env Env, _ io.Reader, _ io.Writer) error { + select { + case <-allowFirstStage: + case <-ctx.Done(): + return ctx.Err() + } + firstStageVars = collectEnvVars(ctx, env) + return nil + }), + []EnvVar{{Key: "STAGE", Value: "first"}}, + ), + WithExtraEnv( + Function("second", func(ctx context.Context, env Env, _ io.Reader, _ io.Writer) error { + secondStageVars = collectEnvVars(ctx, env) + close(allowFirstStage) + return nil + }), + []EnvVar{{Key: "STAGE", Value: "second"}}, + ), + ) + + require.NoError(t, p.Run(ctx)) + + wantBase := []EnvVar{ + {Key: "PIPELINE1", Value: "present"}, + {Key: "PIPELINE2", Value: "present"}, + {Key: "PIPELINE3", Value: "present"}, + } + assert.Equal(t, append(append([]EnvVar(nil), wantBase...), EnvVar{Key: "STAGE", Value: "first"}), firstStageVars) + assert.Equal(t, append(append([]EnvVar(nil), wantBase...), EnvVar{Key: "STAGE", Value: "second"}), secondStageVars) +} + +func TestWithExtraEnvAddsCommandEnv(t *testing.T) { + t.Parallel() + + ctx := context.Background() + stdout := &bytes.Buffer{} + + p := New(WithStdout(stdout)) + p.Add(WithExtraEnv( + Command("sh", "-c", "printf %s \"$STAGE\""), + []EnvVar{{Key: "STAGE", Value: "command"}}, + )) + + require.NoError(t, p.Run(ctx)) + assert.Equal(t, "command", stdout.String()) +} + +func TestWithExtraEnvPreservesProcessHooks(t *testing.T) { + t.Parallel() + + stage := WithExtraEnv(Command("true"), nil) + assert.Implements(t, (*processKiller)(nil), stage) +} + +func TestWithExtraEnvDoesNotAddProcessHooks(t *testing.T) { + t.Parallel() + + inner := Function("inner", func(context.Context, Env, io.Reader, io.Writer) error { + return nil + }) + + stage := WithExtraEnv(inner, nil) + assert.NotImplements(t, (*processKiller)(nil), stage) +} + +func TestWithExtraEnvPreservesStageMetadata(t *testing.T) { + t.Parallel() + + inner := Function("inner", func(context.Context, Env, io.Reader, io.Writer) error { + return nil + }, ForbidStdin(), ForbidStdout()) + + stage := WithExtraEnv(inner, nil) + assert.Equal(t, "inner (with extra env vars)", stage.Name()) + assert.Equal(t, inner.Requirements(), stage.Requirements()) +} diff --git a/pipe/memorylimit.go b/pipe/memorylimit.go deleted file mode 100644 index 8a02400..0000000 --- a/pipe/memorylimit.go +++ /dev/null @@ -1,273 +0,0 @@ -package pipe - -import ( - "context" - "errors" - "fmt" - "io" - "sync" - "time" -) - -const memoryPollInterval = time.Second - -// ErrMemoryLimitExceeded is the error that will be used to kill a -// process, if necessary, from a MemoryWatch with WithMemoryLimit. -var ErrMemoryLimitExceeded = errors.New("memory limit exceeded") - -// LimitableStage is the superset of `Stage` that must be implemented -// by stages passed to MemoryWatch. -type LimitableStage interface { - Stage - - GetRSSAnon(context.Context) (uint64, error) - Kill(error) -} - -// MemoryWatchOption configures a MemoryWatch stage. -type MemoryWatchOption func(*memoryWatchStage) - -// WithMemoryLimit makes MemoryWatch kill the stage when its RSS exceeds -// byteLimit. -func WithMemoryLimit(byteLimit uint64) MemoryWatchOption { - return func(m *memoryWatchStage) { - m.limit = &byteLimit - m.nameSuffix = " with memory limit" - } -} - -// WithPeakUsageLogging makes MemoryWatch log the peak RSS when the stage -// exits. -func WithPeakUsageLogging() MemoryWatchOption { - return func(m *memoryWatchStage) { - m.observe = true - } -} - -// MemoryWatch watches the memory usage of the stage and reports via -// eventHandler. With WithMemoryLimit it kills the stage when the limit is -// exceeded; with WithPeakUsageLogging it logs the peak RSS when the stage -// exits. At least one of the two options is required. -// -// If the event handler panics while reporting the over-limit event, the -// stage is still killed. A panic in any other event-handler call (an -// RSS-read error, or the peak-usage report) is recovered via -// StageOptions.PanicHandler and the stage keeps running unmonitored; see -// StageOptions.PanicHandler. -func MemoryWatch(stage Stage, eventHandler func(e *Event), opts ...MemoryWatchOption) Stage { - limitableStage, ok := stage.(LimitableStage) - if !ok { - eventHandler(&Event{ - Command: stage.Name(), - Msg: "invalid pipe.MemoryWatch usage", - Err: fmt.Errorf("invalid pipe.MemoryWatch usage"), - }) - return stage - } - - m := memoryWatchStage{ - stage: limitableStage, - eventHandler: eventHandler, - } - for _, opt := range opts { - opt(&m) - } - - if m.limit == nil && !m.observe { - eventHandler(&Event{ - Command: stage.Name(), - Msg: "invalid pipe.MemoryWatch usage", - Err: fmt.Errorf( - "pipe.MemoryWatch requires WithMemoryLimit and/or WithPeakUsageLogging", - ), - }) - return stage - } - - return &m -} - -type memoryWatchStage struct { - nameSuffix string - stage LimitableStage - eventHandler func(e *Event) - - limit *uint64 // non-nil enables kill-at-limit - observe bool // log peak RSS when the stage exits - - maxRSS uint64 - samples int - errCount int - consecutiveErrors int - - cancel context.CancelFunc - wg sync.WaitGroup - watchErr error -} - -var _ LimitableStage = (*memoryWatchStage)(nil) - -func (m *memoryWatchStage) Name() string { - return m.stage.Name() + m.nameSuffix -} - -func (m *memoryWatchStage) Requirements() StageRequirements { - return m.stage.Requirements() -} - -func (m *memoryWatchStage) Start( - ctx context.Context, opts StageOptions, - stdin io.Reader, closeStdin bool, - stdout io.Writer, closeStdout bool, -) error { - if err := m.stage.Start(ctx, opts, stdin, closeStdin, stdout, closeStdout); err != nil { - return err - } - - m.monitor(ctx, opts.PanicHandler) - - return nil -} - -func (m *memoryWatchStage) Wait() error { - err := m.stage.Wait() - m.stopWatching() - if err == nil { - err = m.watchErr // non-nil if panicHandler() returned anything - } - return err -} - -func (m *memoryWatchStage) GetRSSAnon(ctx context.Context) (uint64, error) { - return m.stage.GetRSSAnon(ctx) -} - -func (m *memoryWatchStage) Kill(err error) { - m.stage.Kill(err) - m.stopWatching() -} - -// monitor starts up a goroutine that monitors the memory of `m`. If -// panicHandler is set, any panic that escapes the user-supplied event handler -// (via m.watch) is recovered. -func (m *memoryWatchStage) monitor(ctx context.Context, panicHandler StagePanicHandler) { - ctx, cancel := context.WithCancel(ctx) - m.cancel = cancel - m.wg.Add(1) - - go func() { - defer m.wg.Done() - - if panicHandler != nil { - defer func() { - if p := recover(); p != nil { - m.watchErr = panicHandler(p) - } - }() - } - - m.watch(ctx) - }() -} - -func (m *memoryWatchStage) stopWatching() { - m.cancel() - m.wg.Wait() -} - -// watch is a `memoryWatchFunc` that watches the memory usage of the -// specified `stage`. -func (m *memoryWatchStage) watch(ctx context.Context) { - t := time.NewTicker(memoryPollInterval) - defer t.Stop() - -watchLoop: - for { - select { - case <-ctx.Done(): - break watchLoop - case <-t.C: - if m.update(ctx) { - // The stage was killed. - break watchLoop - } - } - } - - if m.observe { - <-ctx.Done() - m.reportPeakUsage() - } -} - -// update samples the current memory usage and updates internal stats. -// Return true if the stage was killed for exceeding the memory limit. -func (m *memoryWatchStage) update(ctx context.Context) bool { - rss, err := m.stage.GetRSSAnon(ctx) - if err != nil { - m.handleGetRSSError(err) - return false - } - - m.consecutiveErrors = 0 - m.samples++ - if rss > m.maxRSS { - m.maxRSS = rss - } - - if m.limit != nil && rss >= *m.limit { - m.killStage(rss) - return true - } - - return false -} - -// handleGetRSSError deals with error `err` that happened when trying -// to get `stage`'s memory usage. -func (m *memoryWatchStage) handleGetRSSError(err error) { - if !errors.Is(err, errProcessInfoMissing) { - m.errCount++ - m.consecutiveErrors++ - if m.consecutiveErrors == 2 { - m.eventHandler(&Event{ - Command: m.stage.Name(), - Msg: "error getting RSS", - Err: err, - }) - } - } else { - m.consecutiveErrors = 0 - } -} - -// killStage kills the stage and reports and event saying what it did. -func (m *memoryWatchStage) killStage(rss uint64) { - // Guarantee the over-limit stage is killed even if - // the user's event handler panics. - defer m.stage.Kill(ErrMemoryLimitExceeded) - - m.eventHandler(&Event{ - Command: m.stage.Name(), - Msg: "stage exceeded allowed memory use", - Err: fmt.Errorf("stage exceeded allowed memory use"), - Context: map[string]any{ - "limit": *m.limit, - "used": rss, - }, - }) -} - -// reportPeakUsage sends an event reporting the peak usage that has -// been seen for `stage`. -func (m *memoryWatchStage) reportPeakUsage() { - m.eventHandler(&Event{ - Command: m.stage.Name(), - Msg: "peak memory usage", - Context: map[string]any{ - "max_rss_bytes": m.maxRSS, - "samples": m.samples, - "errors": m.errCount, - }, - }) -} diff --git a/pipe/memorylimit_panic_test.go b/pipe/memorylimit_panic_test.go deleted file mode 100644 index 1c0c32e..0000000 --- a/pipe/memorylimit_panic_test.go +++ /dev/null @@ -1,156 +0,0 @@ -package pipe - -import ( - "context" - "fmt" - "io" - "strings" - "testing" - "time" - - "github.com/stretchr/testify/assert" -) - -const memWatchPanicSentinel = "memwatch-panic-sentinel" - -// fakeLimitableStage is a minimal LimitableStage whose `GetRSSAnon()` -// method panics, and whose `Wait()` method returns after that panic -// has been issued. -type fakeLimitableStage struct { - done chan struct{} -} - -func (fakeLimitableStage) Name() string { return "fake" } -func (fakeLimitableStage) Requirements() StageRequirements { return StageRequirements{} } -func (fakeLimitableStage) Start( - context.Context, StageOptions, io.Reader, bool, io.Writer, bool, -) error { - return nil -} -func (stage fakeLimitableStage) Wait() error { - <-stage.done - return nil -} - -func (stage fakeLimitableStage) GetRSSAnon(context.Context) (uint64, error) { - close(stage.done) - panic(memWatchPanicSentinel) -} - -func (fakeLimitableStage) Kill(error) {} - -func panickingWatchStage() Stage { - stage := fakeLimitableStage{ - done: make(chan struct{}), - } - return MemoryWatch(stage, func(*Event) {}, WithMemoryLimit(1)) -} - -// TestMemoryWatchStagePanicWithHandlerSurfaced verifies that a panic -// escaping the memory-watch goroutine (where the user-supplied event -// handler runs) is recovered via the configured panic handler and -// surfaced as the stage's Wait error. -func TestMemoryWatchStagePanicWithHandlerSurfaced(t *testing.T) { - ms := panickingWatchStage() - opts := StageOptions{ - PanicHandler: func(p any) error { return fmt.Errorf("recovered: %v", p) }, - } - - if err := ms.Start(context.Background(), opts, nil, false, nil, false); err != nil { - t.Fatalf("Start returned unexpected error: %v", err) - } - - err := ms.Wait() - if err == nil { - t.Fatal("expected Wait to surface the recovered panic, got nil") - } - if !strings.Contains(err.Error(), memWatchPanicSentinel) { - t.Fatalf("expected error to mention %q, got: %v", memWatchPanicSentinel, err) - } -} - -// TestMemoryWatchStagePanicWithoutHandlerPropagates verifies that the -// memory-watch sampling path does not swallow a panic. The monitor -// goroutine only installs a recover when a handler is present (see -// memoryWatchStage.monitor), so we exercise update() directly and assert -// that it propagates the panic. update() is used rather than watch() so the -// assertion is synchronous and ticker-independent: a regression that stopped -// the panic would fail the test rather than hang on the ticker loop. -func TestMemoryWatchStagePanicWithoutHandlerPropagates(t *testing.T) { - limit := uint64(1) - mw := memoryWatchStage{ - stage: fakeLimitableStage{done: make(chan struct{})}, - eventHandler: func(*Event) {}, - limit: &limit, - } - - assert.PanicsWithValue(t, memWatchPanicSentinel, func() { - mw.update(context.Background()) - }) -} - -// killTrackingStage is a LimitableStage that reports an over-limit RSS -// and blocks in Wait until it is killed, recording that the kill -// happened. It lets a test assert that the memory limit is enforced. -type killTrackingStage struct { - killed chan struct{} - done chan struct{} -} - -func newKillTrackingStage() *killTrackingStage { - return &killTrackingStage{ - killed: make(chan struct{}), - done: make(chan struct{}), - } -} - -func (*killTrackingStage) Name() string { return "kill-tracking" } -func (*killTrackingStage) Requirements() StageRequirements { return StageRequirements{} } -func (*killTrackingStage) Start( - context.Context, StageOptions, io.Reader, bool, io.Writer, bool, -) error { - return nil -} -func (s *killTrackingStage) Wait() error { <-s.done; return ErrMemoryLimitExceeded } -func (*killTrackingStage) GetRSSAnon(context.Context) (uint64, error) { - return 1 << 30, nil -} - -func (s *killTrackingStage) Kill(error) { - select { - case <-s.killed: - // already killed - default: - close(s.killed) - close(s.done) - } -} - -// TestMemoryLimitKillsEvenIfEventHandlerPanics verifies that an over-limit -// stage is still killed (the limit enforced) even when the user's event -// handler panics and that panic is recovered by the configured handler. -// Without the kill being guaranteed during unwinding, the runaway stage -// would never be killed and Wait would hang. -func TestMemoryLimitKillsEvenIfEventHandlerPanics(t *testing.T) { - stage := newKillTrackingStage() - eventHandler := func(*Event) { panic(memWatchPanicSentinel) } - ms := MemoryWatch(stage, eventHandler, WithMemoryLimit(1)) - opts := StageOptions{ - PanicHandler: func(p any) error { return fmt.Errorf("recovered: %v", p) }, - } - - if err := ms.Start(context.Background(), opts, nil, false, nil, false); err != nil { - t.Fatalf("Start returned unexpected error: %v", err) - } - - select { - case <-stage.killed: - // expected: the limit was enforced despite the handler panic. - case <-time.After(5 * time.Second): - t.Fatal("over-limit stage was not killed after the event handler panicked") - } - - if err := ms.Wait(); err != ErrMemoryLimitExceeded { - t.Fatalf("Wait = %v, want %v", err, ErrMemoryLimitExceeded) - } -} diff --git a/pipe/memorylimit_test.go b/pipe/memorylimit_test.go deleted file mode 100644 index 4e46945..0000000 --- a/pipe/memorylimit_test.go +++ /dev/null @@ -1,278 +0,0 @@ -package pipe_test - -import ( - "bytes" - "context" - "fmt" - "io" - "log" - "os" - "strings" - "syscall" - "testing" - "time" - - "github.com/github/go-pipe/v2/pipe" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func LogEventHandler(logger *log.Logger) func(*pipe.Event) { - return func(e *pipe.Event) { - ctx := "" - for k, v := range e.Context { - ctx = fmt.Sprintf("%s,%s=%v", ctx, k, v) - } - - logger.Printf("Command %s failed with message %s and error %s. Context: %s", e.Command, e.Msg, e.Command, ctx) - } -} - -func TestMemoryObserverSimple(t *testing.T) { - t.Parallel() - rss := testMemoryObserver(t, 400, pipe.Command("less")) - require.Greater(t, rss, 400_000_000) -} - -func TestMemoryObserverTreeMem(t *testing.T) { - t.Parallel() - - // create a process tree like: - // /tmp/go-build3166037414/b001/pipe.test -test.paniconexit0 -test.timeout=10m0s -test.count=1 -test.run=Tree -test.v=true - // \_ head -c 3G - // \_ sh -c less; : - // \_ less - // so that MemoryObserver is watching the parent `sh` proc and doesn't detect less's mem usage. - // less should buffer whatever we send it via stdin, giving us some level of control over its - // memory usage. - rss := testMemoryObserver(t, 400, pipe.Command("sh", "-c", "less; :")) - require.Greater(t, rss, 400_000_000) -} - -func testMemoryObserver(t *testing.T, mbs int, stage pipe.Stage) int { - ctx := context.Background() - - stdinReader, stdinWriter := io.Pipe() - - devNull, err := os.OpenFile("/dev/null", os.O_WRONLY, 0) - require.NoError(t, err) - - buf := &bytes.Buffer{} - logger := log.New(buf, "testMemoryObserver", log.Ldate|log.Ltime) - - p := pipe.New(pipe.WithDir("/"), pipe.WithStdin(stdinReader), pipe.WithStdout(devNull)) - p.Add(pipe.MemoryWatch(stage, LogEventHandler(logger), pipe.WithPeakUsageLogging())) - require.NoError(t, p.Start(ctx)) - - // Write some nonsense data to less, but don't close stdin until we want it - // to exit. - var bytes [1_000_000]byte - for i := 0; i < mbs; i++ { - n, err := stdinWriter.Write(bytes[:]) - require.NoError(t, err) - require.Equal(t, len(bytes), n) - } - - // MemoryObserver polls every one second, so this should make sure we catch - // at least one. - time.Sleep(2 * time.Second) - - // Close stdin and wait for the pipeline to exit. - require.NoError(t, stdinWriter.Close()) - require.NoError(t, p.Wait()) - - return maxBytes(buf.String()) -} - -func maxBytes(s string) int { - idx := strings.Index(s, "max_rss_bytes=") - if idx < 0 { - return idx - } - var maxRSS int - n, err := fmt.Sscanf(s[idx:], "max_rss_bytes=%d", &maxRSS) - if n != 1 || err != nil { - return -1 - } - return maxRSS -} - -func TestMemoryLimitSimple(t *testing.T) { - t.Parallel() - msg, err := testMemoryLimit(t, 400, 10_000_000, pipe.Command("less")) - assert.Contains(t, msg, "exceeded allowed memory") - assert.Contains(t, msg, "limit=10000000") - require.ErrorContains(t, err, "memory limit exceeded") -} - -func TestMemoryLimitTreeMem(t *testing.T) { - t.Parallel() - msg, err := testMemoryLimit(t, 400, 10_000_000, pipe.Command("sh", "-c", "less; :")) - assert.Contains(t, msg, "exceeded allowed memory") - assert.Contains(t, msg, "limit=10000000") - require.ErrorContains(t, err, "memory limit exceeded") -} - -func TestMemoryLimitWithObserverSimple(t *testing.T) { - t.Parallel() - msg, err := testMemoryLimitWithObserver(t, 400, 10_000_000, pipe.Command("less")) - assert.Contains(t, msg, "exceeded allowed memory") - assert.Contains(t, msg, "limit=10000000") - require.ErrorContains(t, err, "memory limit exceeded") -} - -func TestMemoryLimitWithObserverTreeMem(t *testing.T) { - t.Parallel() - msg, err := testMemoryLimitWithObserver(t, 400, 10_000_000, pipe.Command("sh", "-c", "less; :")) - assert.Contains(t, msg, "exceeded allowed memory") - assert.Contains(t, msg, "limit=10000000") - require.ErrorContains(t, err, "memory limit exceeded") -} - -func TestMemoryLimitWithObserverLogsPeakOnKill(t *testing.T) { - t.Parallel() - msg, err := testMemoryLimitWithObserver(t, 400, 10_000_000, pipe.Command("less")) - assert.Contains(t, msg, "exceeded allowed memory") - assert.Contains(t, msg, "peak memory usage") - require.ErrorContains(t, err, "memory limit exceeded") -} - -func TestMemoryLimitWithObserverBelowLimit(t *testing.T) { - t.Parallel() - rss := testMemoryLimitWithObserverBelowLimit(t, 400, pipe.Command("less")) - require.Greater(t, rss, 400_000_000) -} - -func TestMemoryLimitWithObserverBelowLimitTreeMem(t *testing.T) { - t.Parallel() - rss := testMemoryLimitWithObserverBelowLimit(t, 400, pipe.Command("sh", "-c", "less; :")) - require.Greater(t, rss, 400_000_000) -} - -// testMemoryLimitWithObserverBelowLimit exercises the observer half of -// `MemoryLimitWithObserver` when the memory limit is never hit: with a -// 100GiB limit, less should never be killed, but the wrapper should -// still poll RSS and emit a "peak memory usage" event when the stage -// exits normally. Mirrors `testMemoryObserver` in structure — we hold -// stdin open across at least one poll interval so RSS samples are -// guaranteed to be taken before the stage is allowed to exit. -func testMemoryLimitWithObserverBelowLimit(t *testing.T, mbs int, stage pipe.Stage) int { - ctx := context.Background() - - stdinReader, stdinWriter := io.Pipe() - - devNull, err := os.OpenFile("/dev/null", os.O_WRONLY, 0) - require.NoError(t, err) - - buf := &bytes.Buffer{} - logger := log.New(buf, "testMemoryLimitWithObserverBelowLimit", log.Ldate|log.Ltime) - - p := pipe.New(pipe.WithDir("/"), pipe.WithStdin(stdinReader), pipe.WithStdout(devNull)) - p.Add(pipe.MemoryWatch(stage, LogEventHandler(logger), pipe.WithMemoryLimit(100*1024*1024*1024), pipe.WithPeakUsageLogging())) - require.NoError(t, p.Start(ctx)) - - var bytes [1_000_000]byte - for i := 0; i < mbs; i++ { - n, err := stdinWriter.Write(bytes[:]) - require.NoError(t, err) - require.Equal(t, len(bytes), n) - } - - // Wrapper polls once per second; sleep long enough to guarantee at - // least one sample is taken before the stage is allowed to exit. - time.Sleep(2 * time.Second) - - require.NoError(t, stdinWriter.Close()) - require.NoError(t, p.Wait()) - - output := buf.String() - assert.Contains(t, output, "peak memory usage") - - return maxBytes(output) -} - -func testMemoryLimit(t *testing.T, mbs int, limit uint64, stage pipe.Stage) (string, error) { - ctx := context.Background() - - devNull, err := os.OpenFile("/dev/null", os.O_WRONLY, 0) - require.NoError(t, err) - - buf := &bytes.Buffer{} - logger := log.New(buf, "testMemoryObserver", log.Ldate|log.Ltime) - - p := pipe.New(pipe.WithDir("/"), pipe.WithStdoutCloser(devNull)) - p.Add( - pipe.Function( - "write-to-less", - func(_ context.Context, _ pipe.Env, _ io.Reader, stdout io.Writer) error { - // Write some nonsense data to less. - var bytes [1_000_000]byte - for i := 0; i < mbs; i++ { - _, err := stdout.Write(bytes[:]) - if err != nil { - assert.ErrorIs(t, err, syscall.EPIPE) - return nil - } - } - - return nil - }, - ), - pipe.MemoryWatch(stage, LogEventHandler(logger), pipe.WithMemoryLimit(limit)), - ) - require.NoError(t, p.Start(ctx)) - - err = p.Wait() - - return buf.String(), err -} - -func testMemoryLimitWithObserver(t *testing.T, mbs int, limit uint64, stage pipe.Stage) (string, error) { - ctx := context.Background() - - devNull, err := os.OpenFile("/dev/null", os.O_WRONLY, 0) - require.NoError(t, err) - - buf := &bytes.Buffer{} - logger := log.New(buf, "testMemoryLimitWithObserver", log.Ldate|log.Ltime) - - p := pipe.New(pipe.WithDir("/"), pipe.WithStdoutCloser(devNull)) - p.Add( - pipe.Function( - "write-to-less", - func(_ context.Context, _ pipe.Env, _ io.Reader, stdout io.Writer) error { - var bytes [1_000_000]byte - for i := 0; i < mbs; i++ { - _, err := stdout.Write(bytes[:]) - if err != nil { - assert.ErrorIs(t, err, syscall.EPIPE) - return nil - } - } - return nil - }, - ), - pipe.MemoryWatch(stage, LogEventHandler(logger), pipe.WithMemoryLimit(limit), pipe.WithPeakUsageLogging()), - ) - require.NoError(t, p.Start(ctx)) - - err = p.Wait() - - return buf.String(), err -} - -// TestMemoryWatchRequiresAnOption verifies that MemoryWatch without -// WithMemoryLimit or WithPeakUsageLogging is rejected: it reports an -// invalid-usage event and returns the stage unwrapped (no watcher). -func TestMemoryWatchRequiresAnOption(t *testing.T) { - stage := pipe.Command("true") - - var events []*pipe.Event - got := pipe.MemoryWatch(stage, func(e *pipe.Event) { - events = append(events, e) - }) - - require.Same(t, stage, got, "expected the input stage returned unwrapped") - require.Len(t, events, 1) - require.Contains(t, events[0].Msg, "invalid pipe.MemoryWatch usage") -}