Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions first.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ type First[T any] struct {
cancel context.CancelFunc
}

// WithContext initializes an instance of First with the given context.
// It also returns the context that the instance of First will use.
// After Wait() completes, the returned context will be canceled.
// The parent context (passed in to this function) will NOT be canceled.
func WithContext[T any](ctx context.Context) (*First[T], context.Context) {
var f First[T]

Expand Down Expand Up @@ -104,17 +108,18 @@ func (f *First[T]) Do(fn func() (T, error)) {
}()
}

// DoContext works like Do, except it accepts and provides a context.
// The FIRST context provided to DoContext will be used. The rest will be ignored.
// DoContext works like Do, except it provides a context.
// After the first Do or DoContext call completes, the ctx provided to all DoContext callbacks will be canceled.
// This is useful for canceling long-running tasks that should short-circuit when the first operation completes.
// You are allowed to mix DoContext and Do with a single call to Wait.
//
// If you want to provide a context, use first.WithContext(ctx)
//
// Example:
//
// var f first.First
// f, ctx := first.WithContext(ctx)
//
// f.DoContext(ctx, func(ctx context.Context) (*example, error) {
// f.DoContext(func(ctx context.Context) (*example, error) {
// // do some long-running task that requires context
// data, err := getFromDatabase(ctx)
// if err != nil {
Expand All @@ -124,8 +129,8 @@ func (f *First[T]) Do(fn func() (T, error)) {
// })
//
// data, err := f.Wait()
func (f *First[T]) DoContext(ctx context.Context, fn func(context.Context) (T, error)) {
f.init(ctx)
func (f *First[T]) DoContext(fn func(context.Context) (T, error)) {
f.init(context.Background())

go func() {
res, err := fn(f.context)
Expand Down
42 changes: 21 additions & 21 deletions first_context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ func TestContextFirstSecond(t *testing.T) {

ctx := context.Background()

var f first.First[*example]
f, _ := first.WithContext[*example](ctx)

f.DoContext(ctx, func(ctx context.Context) (*example, error) {
f.DoContext(func(ctx context.Context) (*example, error) {
defer wg.Done()

time.Sleep(10 * time.Millisecond)
Expand All @@ -31,7 +31,7 @@ func TestContextFirstSecond(t *testing.T) {
return &example{name: "one"}, nil
})

f.DoContext(ctx, func(ctx context.Context) (*example, error) {
f.DoContext(func(ctx context.Context) (*example, error) {
defer wg.Done()

if ctx.Err() != nil {
Expand Down Expand Up @@ -64,9 +64,9 @@ func TestContextFirstFirst(t *testing.T) {

ctx := context.Background()

var f first.First[*example]
f, _ := first.WithContext[*example](ctx)

f.DoContext(ctx, func(ctx context.Context) (*example, error) {
f.DoContext(func(ctx context.Context) (*example, error) {
defer wg.Done()

if ctx.Err() != nil {
Expand All @@ -76,7 +76,7 @@ func TestContextFirstFirst(t *testing.T) {
return &example{name: "one"}, nil
})

f.DoContext(ctx, func(ctx context.Context) (*example, error) {
f.DoContext(func(ctx context.Context) (*example, error) {
defer wg.Done()

time.Sleep(10 * time.Millisecond)
Expand Down Expand Up @@ -111,9 +111,9 @@ func TestContextFirstError(t *testing.T) {

ctx := context.Background()

var f first.First[*example]
f, _ := first.WithContext[*example](ctx)

f.DoContext(ctx, func(ctx context.Context) (*example, error) {
f.DoContext(func(ctx context.Context) (*example, error) {
defer wg.Done()

time.Sleep(10 * time.Millisecond)
Expand All @@ -125,7 +125,7 @@ func TestContextFirstError(t *testing.T) {
return &example{name: "one"}, nil
})

f.DoContext(ctx, func(ctx context.Context) (*example, error) {
f.DoContext(func(ctx context.Context) (*example, error) {
defer wg.Done()

if ctx.Err() != nil {
Expand Down Expand Up @@ -158,9 +158,9 @@ func TestContextErrors(t *testing.T) {

ctx := context.Background()

var f first.First[*example]
f, _ := first.WithContext[*example](ctx)

f.DoContext(ctx, func(ctx context.Context) (*example, error) {
f.DoContext(func(ctx context.Context) (*example, error) {
defer wg.Done()

time.Sleep(10 * time.Millisecond)
Expand All @@ -172,7 +172,7 @@ func TestContextErrors(t *testing.T) {
return nil, errOne
})

f.DoContext(ctx, func(ctx context.Context) (*example, error) {
f.DoContext(func(ctx context.Context) (*example, error) {
defer wg.Done()

if ctx.Err() != nil {
Expand Down Expand Up @@ -225,10 +225,10 @@ func TestContextFirstNone(t *testing.T) {
func TestContextFirstRoutines(t *testing.T) {
ctx := context.Background()

var f first.First[*example]
f, _ := first.WithContext[*example](ctx)

// Go ahead and set up one func to avoid the ErrNothingToWaitOn error.
f.DoContext(ctx, func(ctx context.Context) (*example, error) {
f.DoContext(func(ctx context.Context) (*example, error) {
if ctx.Err() != nil {
t.Fatalf("Unexpected context error: %s", ctx.Err())
}
Expand All @@ -243,7 +243,7 @@ func TestContextFirstRoutines(t *testing.T) {
go func() {
defer wg.Done()

f.DoContext(ctx, func(ctx context.Context) (*example, error) {
f.DoContext(func(ctx context.Context) (*example, error) {
time.Sleep(10 * time.Millisecond)

return &example{name: "two"}, nil
Expand Down Expand Up @@ -281,17 +281,17 @@ func TestRespectsCancel(t *testing.T) {

ctx, cancel := context.WithCancel(context.Background())

var f first.First[*example]
f, _ := first.WithContext[*example](ctx)

f.DoContext(ctx, func(ctx context.Context) (*example, error) {
f.DoContext(func(ctx context.Context) (*example, error) {
defer wg.Done()

wg.Wait()

return &example{name: "one"}, nil
})

f.DoContext(ctx, func(ctx context.Context) (*example, error) {
f.DoContext(func(ctx context.Context) (*example, error) {
defer wg.Done()

wg.Wait()
Expand All @@ -317,17 +317,17 @@ func TestRespectsWithContextCancel(t *testing.T) {

ctx, cancel := context.WithCancel(context.Background())

f, ctx := first.WithContext[*example](ctx)
f, _ := first.WithContext[*example](ctx)

f.DoContext(ctx, func(ctx context.Context) (*example, error) {
f.DoContext(func(ctx context.Context) (*example, error) {
defer wg.Done()

wg.Wait()

return &example{name: "one"}, nil
})

f.DoContext(ctx, func(ctx context.Context) (*example, error) {
f.DoContext(func(ctx context.Context) (*example, error) {
defer wg.Done()

wg.Wait()
Expand Down
6 changes: 3 additions & 3 deletions first_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,16 +87,16 @@ func ExampleFirst_DoContext() {

wg.Add(2)

var f first.First[*example]
f, _ := first.WithContext[*example](ctx)

f.DoContext(ctx, func(ctx context.Context) (*example, error) {
f.DoContext(func(ctx context.Context) (*example, error) {
defer wg.Done()
time.Sleep(1 * time.Millisecond)
fmt.Printf("2 ctx=%s\n", ctx.Err())
return nil, errors.New("oops 2")
})

f.DoContext(ctx, func(ctx context.Context) (*example, error) {
f.DoContext(func(ctx context.Context) (*example, error) {
defer wg.Done()
fmt.Printf("1 ctx=%s\n", ctx.Err())
return &example{name: "one"}, nil
Expand Down