Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pipe/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ func (s *commandStage) Process() *os.Process {

func (s *commandStage) Requirements() StageRequirements {
return StageRequirements{
StdinNeedsFile: true,
StdoutNeedsFile: true,
Stdin: StreamPreferFile,
Stdout: StreamPreferFile,
}
}

Expand Down
64 changes: 42 additions & 22 deletions pipe/pipe_matching_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,10 @@ func writeCloser() io.WriteCloser {
}

func newPipeSniffingStage(
stdinNeedsFile bool, stdinExpectation ioExpectation,
stdoutNeedsFile bool, stdoutExpectation ioExpectation,
req pipe.StageRequirements, stdinExpectation, stdoutExpectation ioExpectation,
) *pipeSniffingStage {
return &pipeSniffingStage{
requirements: pipe.StageRequirements{
StdinNeedsFile: stdinNeedsFile,
StdoutNeedsFile: stdoutNeedsFile,
},
requirements: req,
expect: pipeExpectations{
stdin: stdinExpectation,
stdout: stdoutExpectation,
Expand All @@ -68,17 +64,23 @@ func newPipeSniffingFunc(
stdinExpectation, stdoutExpectation ioExpectation,
) *pipeSniffingStage {
return newPipeSniffingStage(
false, stdinExpectation,
false, stdoutExpectation,
pipe.StageRequirements{
Stdin: pipe.StreamAcceptAny,
Stdout: pipe.StreamAcceptAny,
},
stdinExpectation, stdoutExpectation,
)
}

func newPipeSniffingCmd(
stdinExpectation, stdoutExpectation ioExpectation,
) *pipeSniffingStage {
return newPipeSniffingStage(
true, stdinExpectation,
true, stdoutExpectation,
pipe.StageRequirements{
Stdin: pipe.StreamPreferFile,
Stdout: pipe.StreamPreferFile,
},
stdinExpectation, stdoutExpectation,
)
}

Expand Down Expand Up @@ -325,16 +327,25 @@ func TestPipeTypes(t *testing.T) {
opts: []pipe.Option{},
stages: []pipe.Stage{
newPipeSniffingStage(
false, expectNil,
false, expectOther,
pipe.StageRequirements{
Stdin: pipe.StreamAcceptAny,
Stdout: pipe.StreamAcceptAny,
},
expectNil, expectOther,
),
newPipeSniffingStage(
false, expectOther,
true, expectFile,
pipe.StageRequirements{
Stdin: pipe.StreamAcceptAny,
Stdout: pipe.StreamPreferFile,
},
expectOther, expectFile,
),
newPipeSniffingStage(
false, expectFile,
false, expectNil,
pipe.StageRequirements{
Stdin: pipe.StreamAcceptAny,
Stdout: pipe.StreamAcceptAny,
},
expectFile, expectNil,
),
},
},
Expand All @@ -343,16 +354,25 @@ func TestPipeTypes(t *testing.T) {
opts: []pipe.Option{},
stages: []pipe.Stage{
newPipeSniffingStage(
false, expectNil,
false, expectFile,
pipe.StageRequirements{
Stdin: pipe.StreamAcceptAny,
Stdout: pipe.StreamAcceptAny,
},
expectNil, expectFile,
),
newPipeSniffingStage(
true, expectFile,
false, expectOther,
pipe.StageRequirements{
Stdin: pipe.StreamPreferFile,
Stdout: pipe.StreamAcceptAny,
},
expectFile, expectOther,
),
newPipeSniffingStage(
false, expectOther,
false, expectNil,
pipe.StageRequirements{
Stdin: pipe.StreamAcceptAny,
Stdout: pipe.StreamAcceptAny,
},
expectOther, expectNil,
),
},
},
Expand Down
185 changes: 63 additions & 122 deletions pipe/pipeline.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"errors"
"fmt"
"io"
"os"
"sync/atomic"
)

Expand Down Expand Up @@ -218,50 +217,6 @@ func (p *Pipeline) AddWithIgnoredError(em ErrorMatcher, stages ...Stage) {
}
}

type stageStarter struct {
requirements StageRequirements
stdin *InputStream
stdout *OutputStream
}

func (requirement StreamRequirement) validate() error {
switch requirement {
case StreamOptional, StreamForbidden:
return nil
default:
return fmt.Errorf("invalid stream requirement %d", requirement)
}
}

func (requirements StageRequirements) validate(s Stage, stdinConnected, stdoutConnected bool) error {
if err := requirements.Stdin.validate(); err != nil {
return fmt.Errorf("stdin: %w", err)
}
if err := requirements.Stdout.validate(); err != nil {
return fmt.Errorf("stdout: %w", err)
}
if requirements.Stdin == StreamForbidden && stdinConnected {
return fmt.Errorf("stage %q forbids stdin, but stdin is connected", s.Name())
}
if requirements.Stdout == StreamForbidden && stdoutConnected {
return fmt.Errorf("stage %q forbids stdout, but stdout is connected", s.Name())
}
return nil
}

func (p *Pipeline) abortBeforeStart(s Stage, err error) error {
_ = p.stdout.Close()
p.cancel()
p.eventHandler(&Event{
Command: s.Name(),
Msg: "failed to start pipeline stage",
Err: err,
})
return fmt.Errorf(
"starting pipeline stage %q: %w", s.Name(), err,
)
}

func (p *Pipeline) stageOptions() StageOptions {
return StageOptions{Env: p.env, PanicHandler: p.panicHandler}
}
Expand Down Expand Up @@ -309,50 +264,67 @@ func (p *Pipeline) Start(ctx context.Context) error {
// We need to decide how to start the stages, especially what
// pipes to use to connect adjacent stages (`os.Pipe()` vs.
// `io.Pipe()`) based on the two stages' requirements.
stageStarters := make([]stageStarter, len(p.stages))
stageJoiners := make([]stageJoiner, len(p.stages)+1)

// Arrange for the input of the 0th stage to come from `p.stdin`:
stageJoiners[0].nextStdin = p.stdin

// Arrange for the output of the last stage to go to `p.stdout`:
stageJoiners[len(p.stages)].prevStdout = p.stdout

// closePipes closes all of the streams that are currently stored
// in the joiners. This should be called if startup fails. As we
// call `Stage.Start()` and pass that method streams, we clear
// them from the corresponding joiners to avoid closing them
// twice.
closePipes := func() {
for _, sj := range stageJoiners {
_ = sj.closePipe()
}
}

// Collect information about each stage's type and requirements:
// Store the stages in the joiners, and verify that the stages'
// requirements are well-formed:
for i, s := range p.stages {
stageStarters[i].requirements = s.Requirements()
stageJoiners[i].nextStage = s
stageJoiners[i+1].prevStage = s

err := stageStarters[i].requirements.validate(
s,
i > 0 || p.stdin != nil,
i < len(p.stages)-1 || p.stdout != nil,
)
if err != nil {
return p.abortBeforeStart(s, err)
// Make sure that the stage's requirements are well-formed:
requirements := s.Requirements()
if err := requirements.Stdin.Validate(); err != nil {
return fmt.Errorf("stdin: %w", err)
}
if err := requirements.Stdout.Validate(); err != nil {
return fmt.Errorf("stdout: %w", err)
}
}

if p.stdin != nil {
// Arrange for the input of the 0th stage to come from
// `p.stdin`:
stageStarters[0].stdin = p.stdin
}

if p.stdout != nil {
// Arrange for the output of the last stage to go to
// `p.stdout`:
stageStarters[len(p.stages)-1].stdout = p.stdout
// Create the "inner" pipes (i.e, all but the first and last
// `stageJoiners`):
for i := 1; i < len(stageJoiners)-1; i++ {
if err := stageJoiners[i].createPipe(); err != nil {
closePipes()
return err
}
}

// Clean up any processes and pipes that have been created. `i` is the
// index of the stage that failed to start. If the stage already received
// its streams, it owns any closing stream.
abort := func(i int, err error, closeFailedStageStdin bool) error {
// If the failing stage never received its stdin, close the pipe that
// the previous stage was writing to. That should cause it to exit
// even if it's not minding its context.
if closeFailedStageStdin {
_ = stageStarters[i].stdin.Close()
// Check that each of the stages' requirements are compatible with
// the pipes that we have created for them:
for i := range stageJoiners {
if err := stageJoiners[i].validate(); err != nil {
closePipes()
return err
}
}

// If stdout was supplied with WithStdoutCloser but the final stage
// was never started, then the pipeline still owns that closer.
if i < len(p.stages)-1 {
_ = p.stdout.Close()
}
// We're about to start up the stages, one by one. If something
// goes wrong during that process, this function should be called
// to kill any stages that have already been started and to close
// any pipes that have not yet been passed to a stage. `i` is the
// index of the stage that failed to start. If the stage already
// received its streams, it is responsible for closing them.
abort := func(i int, err error) error {
closePipes()

// Kill and wait for any stages that have been started
// already to finish:
Expand All @@ -370,51 +342,20 @@ func (p *Pipeline) Start(ctx context.Context) error {
)
}

// Loop over all but the last stage, starting them. By the time we
// get to a stage, its stdin will have already been determined,
// but we still need to figure out its stdout and set the stdin
// that will be used for the subsequent stage.
for i, s := range p.stages[:len(p.stages)-1] {
ss := &stageStarters[i]
nextSS := &stageStarters[i+1]

// We need to generate a pipe pair for this stage to use
// to communicate with its successor:
if ss.requirements.StdoutNeedsFile || nextSS.requirements.StdinNeedsFile {
// Use an OS-level pipe for the communication:
nextStdin, stdout, err := os.Pipe()
if err != nil {
return abort(i, err, true)
}
nextSS.stdin = ClosingInput(nextStdin)
ss.stdout = ClosingOutput(stdout)
} else {
nextStdin, stdout := io.Pipe()
nextSS.stdin = ClosingInput(nextStdin)
ss.stdout = ClosingOutput(stdout)
}
if err := s.Start(
ctx, p.stageOptions(),
ss.stdin, ss.stdout,
); err != nil {
_ = nextSS.stdin.Close()
return abort(i, err, false)
}
}
// Loop over all of the stages, starting them in order.
for i, s := range p.stages {
prevSJ := &stageJoiners[i]
nextSJ := &stageJoiners[i+1]

// The last stage needs special handling, because its stdout
// doesn't need to flow into another stage (it's already set in
// `ss.stdout` if it's needed).
{
i := len(p.stages) - 1
s := p.stages[i]
ss := &stageStarters[i]
err := s.Start(ctx, p.stageOptions(), prevSJ.nextStdin, nextSJ.prevStdout)

if err := s.Start(
ctx, p.stageOptions(),
ss.stdin, ss.stdout,
); err != nil {
return abort(i, err, false)
// Even if that stage failed to start, we are no longer
// responsible for closing its streams:
prevSJ.nextStdin = nil
nextSJ.prevStdout = nil

if err != nil {
return abort(i, err)
}
}

Expand Down
22 changes: 4 additions & 18 deletions pipe/stage.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,25 +155,11 @@ type StageOptions struct {
// StagePanicHandler is a function that handles panics in a pipeline's stages.
type StagePanicHandler func(p any) error

type StreamRequirement int

const (
// StreamOptional means the stream may be connected or nil.
StreamOptional StreamRequirement = iota

// StreamForbidden means the stream must be nil.
StreamForbidden
)

// StageRequirements describes what a Stage needs from the streams connected to
// its stdin and stdout. The zero value is correct for stages that are happy
// with arbitrary io.Reader/io.Writer streams, such as Function stages.
// StageRequirements describes what a Stage needs from the streams
// connected to its stdin and stdout. The zero value is correct for
// stages that are happy with arbitrary io.Reader/io.Writer streams,
// such as Function stages.
type StageRequirements struct {
Stdin StreamRequirement
Stdout StreamRequirement

// {Stdin,Stdout}NeedsFile indicate that, if stdio is connected, the
// stage requires it to be backed by an *os.File (a real file descriptor)
StdinNeedsFile bool
StdoutNeedsFile bool
}
Loading
Loading