Skip to content
82 changes: 57 additions & 25 deletions pipe/close_responsibility_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,22 @@ import (
// readCloseSpy records whether Close was called.
type readCloseSpy struct {
io.Reader
closed atomic.Bool
closeCount atomic.Uint32
}

func (r *readCloseSpy) Close() error {
r.closed.Store(true)
r.closeCount.Add(1)
return nil
}

// writeCloseSpy records whether Close was called.
type writeCloseSpy struct {
io.Writer
closed atomic.Bool
closeCount atomic.Uint32
}

func (w *writeCloseSpy) Close() error {
w.closed.Store(true)
w.closeCount.Add(1)
return nil
}

Expand Down Expand Up @@ -63,28 +63,52 @@ func TestGoStageHonorsStreamOwnership(t *testing.T) {
))
require.NoError(t, s.Wait())

assert.Equal(t, !tc.leaveIn, in.closed.Load(), "closing stdin=%v", !tc.leaveIn)
assert.Equal(t, !tc.leaveOut, out.closed.Load(), "closing stdout=%v", !tc.leaveOut)
if tc.leaveIn {
assert.EqualValues(t, 0, in.closeCount.Load(), "closing stdin=%v", !tc.leaveIn)
} else {
assert.EqualValues(t, 1, in.closeCount.Load(), "closing stdin=%v", !tc.leaveIn)
}
if tc.leaveOut {
assert.EqualValues(t, 0, out.closeCount.Load(), "closing stdout=%v", !tc.leaveOut)
} else {
assert.EqualValues(t, 1, out.closeCount.Load(), "closing stdout=%v", !tc.leaveOut)
}
})
}
}

func TestStreamConstructorsPreserveOwnershipAndDynamicType(t *testing.T) {
borrowedInput := strings.NewReader("borrowed")
assert.Same(t, borrowedInput, Input(borrowedInput).Reader())
assert.Nil(t, Input(borrowedInput).Closer())

ownedInput := &readCloseSpy{Reader: strings.NewReader("owned")}
assert.Same(t, ownedInput, ClosingInput(ownedInput).Reader())
assert.Same(t, ownedInput, ClosingInput(ownedInput).Closer())

borrowedOutput := &strings.Builder{}
assert.Same(t, borrowedOutput, Output(borrowedOutput).Writer())
assert.Nil(t, Output(borrowedOutput).Closer())

ownedOutput := &writeCloseSpy{Writer: io.Discard}
assert.Same(t, ownedOutput, ClosingOutput(ownedOutput).Writer())
assert.Same(t, ownedOutput, ClosingOutput(ownedOutput).Closer())
borrowedReader := &readCloseSpy{Reader: strings.NewReader("borrowed")}
borrowedInput := Input(borrowedReader)
assert.Same(t, borrowedReader, borrowedInput.Reader())
assert.NoError(t, borrowedInput.Close())
assert.EqualValues(t, 0, borrowedReader.closeCount.Load())
assert.NoError(t, borrowedInput.Close())
assert.EqualValues(t, 0, borrowedReader.closeCount.Load())

ownedReader := &readCloseSpy{Reader: strings.NewReader("owned")}
ownedInput := ClosingInput(ownedReader)
assert.Same(t, ownedReader, ownedInput.Reader())
assert.NoError(t, ownedInput.Close())
assert.EqualValues(t, 1, ownedReader.closeCount.Load())
assert.NoError(t, ownedInput.Close())
assert.EqualValues(t, 1, ownedReader.closeCount.Load())

borrowedWriter := &writeCloseSpy{Writer: &strings.Builder{}}
borrowedOutput := Output(borrowedWriter)
assert.Same(t, borrowedWriter, borrowedOutput.Writer())
assert.NoError(t, borrowedOutput.Close())
assert.EqualValues(t, 0, borrowedWriter.closeCount.Load())
assert.NoError(t, borrowedOutput.Close())
assert.EqualValues(t, 0, borrowedWriter.closeCount.Load())

ownedWriter := &writeCloseSpy{Writer: &writeCloseSpy{Writer: io.Discard}}
ownedOutput := ClosingOutput(ownedWriter)
assert.Same(t, ownedWriter, ownedOutput.Writer())
assert.NoError(t, ownedOutput.Close())
assert.EqualValues(t, 1, ownedWriter.closeCount.Load())
assert.NoError(t, ownedOutput.Close())
assert.EqualValues(t, 1, ownedWriter.closeCount.Load())
}

// TestCommandStageHonorsCloseStdin verifies that a command stage closes a
Expand All @@ -109,7 +133,11 @@ func TestCommandStageHonorsCloseStdin(t *testing.T) {
))
require.NoError(t, s.Wait())

assert.Equal(t, !leave, in.closed.Load(), "closing stdin=%v", !leave)
if leave {
assert.EqualValues(t, 0, in.closeCount.Load(), "closing stdin=%v", !leave)
} else {
assert.EqualValues(t, 1, in.closeCount.Load(), "closing stdin=%v", !leave)
}
})
}
}
Expand All @@ -136,19 +164,23 @@ func TestCommandStageHonorsCloseStdout(t *testing.T) {
))
require.NoError(t, s.Wait())

assert.Equal(t, !leave, out.closed.Load(), "closing stdout=%v", !leave)
if leave {
assert.EqualValues(t, 0, out.closeCount.Load(), "closing stdout=%v", !leave)
} else {
assert.EqualValues(t, 1, out.closeCount.Load(), "closing stdout=%v", !leave)
}
})
}
}

func inputForTest(r io.ReadCloser, closing bool) InputStream {
func inputForTest(r io.ReadCloser, closing bool) *InputStream {
if closing {
return ClosingInput(r)
}
return Input(r)
}

func outputForTest(w io.WriteCloser, closing bool) OutputStream {
func outputForTest(w io.WriteCloser, closing bool) *OutputStream {
if closing {
return ClosingOutput(w)
}
Expand Down
44 changes: 18 additions & 26 deletions pipe/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,10 @@ func (s *commandStage) Requirements() StageRequirements {

func (s *commandStage) Start(
ctx context.Context, opts StageOptions,
ins InputStream, outs OutputStream,
stdin *InputStream, stdout *OutputStream,
) error {
stdin := ins.Reader()
stdinCloser := ins.Closer()
stdout := outs.Writer()
stdoutCloser := outs.Closer()
r := stdin.Reader()
w := stdout.Writer()

if s.cmd.Dir == "" {
s.cmd.Dir = opts.Dir
Expand All @@ -105,18 +103,16 @@ func (s *commandStage) Start(
var earlyClosers []io.Closer

// See the type comment for `Stage` for the explanation of this closing behavior.
if stdin != nil {
s.cmd.Stdin = stdin
if r != nil {
s.cmd.Stdin = r
}

if stdinCloser != nil {
if _, ok := stdin.(*os.File); ok {
// We can close our copy as soon as the command has started
earlyClosers = append(earlyClosers, stdinCloser)
} else {
// We need to close `stdin`, but only after the command has finished
s.lateClosers = append(s.lateClosers, stdinCloser)
}
if _, ok := r.(*os.File); ok {
// We can close our copy as soon as the command has started
earlyClosers = append(earlyClosers, stdin)
} else {
// We need to close `stdin`, but only after the command has finished
s.lateClosers = append(s.lateClosers, stdin)
}

closeEarlyClosers := func() {
Expand All @@ -133,28 +129,24 @@ func (s *commandStage) Start(
_ = s.closeLateClosers()
}

if stdout != nil {
if f, ok := stdout.(*os.File); ok {
if w != nil {
if f, ok := w.(*os.File); ok {
s.cmd.Stdout = f
if stdoutCloser != nil {
earlyClosers = append(earlyClosers, stdoutCloser)
}
earlyClosers = append(earlyClosers, stdout)
} else {
if stdoutCloser != nil {
s.lateClosers = append(s.lateClosers, stdoutCloser)
}
s.lateClosers = append(s.lateClosers, stdout)
// Route the copy through our own pipe so we can use a
// pooled buffer rather than letting exec.Cmd allocate a
// fresh 32KB buffer for its internal io.Copy.
ec, err := s.setupPooledStdout(stdout)
ec, err := s.setupPooledStdout(w)
if err != nil {
cleanupOnStartFailure()
return err
}
earlyClosers = append(earlyClosers, ec)
}
} else if stdoutCloser != nil {
s.lateClosers = append(s.lateClosers, stdoutCloser)
} else {
s.lateClosers = append(s.lateClosers, stdout)
}

// If the caller hasn't arranged otherwise, read the command's
Expand Down
4 changes: 3 additions & 1 deletion pipe/command_stdout_fastpath_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,11 @@ func TestCommandStageStdoutFastPath(t *testing.T) {
cmd := exec.Command("true")
s := CommandStage("true", cmd).(*commandStage)

stdout := OutputStream{writer: f}
var stdout *OutputStream
if tc.closingStdout {
stdout = ClosingOutput(f)
} else {
stdout = Output(f)
}

require.NoError(t, s.Start(ctx, StageOptions{}, Input(nil), stdout))
Expand Down
2 changes: 1 addition & 1 deletion pipe/env_stage.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func (s *stageWithExtraEnv) Requirements() StageRequirements {

func (s *stageWithExtraEnv) Start(
ctx context.Context, opts StageOptions,
stdin InputStream, stdout OutputStream,
stdin *InputStream, stdout *OutputStream,
) error {
opts.Vars = append(opts.Vars[:len(opts.Vars):len(opts.Vars)], func(_ context.Context, vars []EnvVar) []EnvVar {
return append(vars, s.env...)
Expand Down
16 changes: 5 additions & 11 deletions pipe/function.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,17 +87,15 @@ func (s *goStage) Requirements() StageRequirements {

func (s *goStage) Start(
ctx context.Context, opts StageOptions,
stdin InputStream, stdout OutputStream,
stdin *InputStream, stdout *OutputStream,
) error {
r := stdin.Reader()
stdinCloser := stdin.Closer()
if r == nil {
// treat nil as empty input.
r = strings.NewReader("")
}

w := stdout.Writer()
stdoutCloser := stdout.Closer()
if w == nil {
// treat nil output as /dev/null
w = io.Discard
Expand All @@ -110,15 +108,11 @@ func (s *goStage) Start(
s.err = opts.PanicHandler(p)
}
}
if stdoutCloser != nil {
if err := stdoutCloser.Close(); err != nil && s.err == nil {
s.err = fmt.Errorf("error closing stdout for stage %q: %w", s.Name(), err)
}
if err := stdout.Close(); err != nil && s.err == nil {
s.err = fmt.Errorf("error closing stdout for stage %q: %w", s.Name(), err)
}
if stdinCloser != nil {
if err := stdinCloser.Close(); err != nil && s.err == nil {
s.err = fmt.Errorf("error closing stdin for stage %q: %w", s.Name(), err)
}
if err := stdin.Close(); err != nil && s.err == nil {
s.err = fmt.Errorf("error closing stdin for stage %q: %w", s.Name(), err)
}
close(s.done)
}()
Expand Down
6 changes: 3 additions & 3 deletions pipe/pipe_matching_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,12 @@ func (s *pipeSniffingStage) Requirements() pipe.StageRequirements {

func (s *pipeSniffingStage) Start(
_ context.Context, _ pipe.StageOptions,
stdin pipe.InputStream, stdout pipe.OutputStream,
stdin *pipe.InputStream, stdout *pipe.OutputStream,
) error {
s.stdin = stdin.Reader()
stdin.Close()
_ = stdin.Close()
s.stdout = stdout.Writer()
stdout.Close()
_ = stdout.Close()
return nil
}

Expand Down
Loading
Loading