diff --git a/first.go b/first.go index e6133aa..4adff7f 100644 --- a/first.go +++ b/first.go @@ -46,6 +46,10 @@ type First[T any] struct { cancel context.CancelFunc } +// WithContext initializes an instance of First with the given context. +// It also returns the context that the instance of First will use. +// After Wait() completes, the returned context will be canceled. +// The parent context (passed in to this function) will NOT be canceled. func WithContext[T any](ctx context.Context) (*First[T], context.Context) { var f First[T] @@ -104,17 +108,18 @@ func (f *First[T]) Do(fn func() (T, error)) { }() } -// DoContext works like Do, except it accepts and provides a context. -// The FIRST context provided to DoContext will be used. The rest will be ignored. +// DoContext works like Do, except it provides a context. // After the first Do or DoContext call completes, the ctx provided to all DoContext callbacks will be canceled. // This is useful for canceling long-running tasks that should short-circuit when the first operation completes. // You are allowed to mix DoContext and Do with a single call to Wait. // +// If you want to provide a context, use first.WithContext(ctx) +// // Example: // -// var f first.First +// f, ctx := first.WithContext(ctx) // -// f.DoContext(ctx, func(ctx context.Context) (*example, error) { +// f.DoContext(func(ctx context.Context) (*example, error) { // // do some long-running task that requires context // data, err := getFromDatabase(ctx) // if err != nil { @@ -124,8 +129,8 @@ func (f *First[T]) Do(fn func() (T, error)) { // }) // // data, err := f.Wait() -func (f *First[T]) DoContext(ctx context.Context, fn func(context.Context) (T, error)) { - f.init(ctx) +func (f *First[T]) DoContext(fn func(context.Context) (T, error)) { + f.init(context.Background()) go func() { res, err := fn(f.context) diff --git a/first_context_test.go b/first_context_test.go index 25615df..9b25d4f 100644 --- a/first_context_test.go +++ b/first_context_test.go @@ -17,9 +17,9 @@ func TestContextFirstSecond(t *testing.T) { ctx := context.Background() - var f first.First[*example] + f, _ := first.WithContext[*example](ctx) - f.DoContext(ctx, func(ctx context.Context) (*example, error) { + f.DoContext(func(ctx context.Context) (*example, error) { defer wg.Done() time.Sleep(10 * time.Millisecond) @@ -31,7 +31,7 @@ func TestContextFirstSecond(t *testing.T) { return &example{name: "one"}, nil }) - f.DoContext(ctx, func(ctx context.Context) (*example, error) { + f.DoContext(func(ctx context.Context) (*example, error) { defer wg.Done() if ctx.Err() != nil { @@ -64,9 +64,9 @@ func TestContextFirstFirst(t *testing.T) { ctx := context.Background() - var f first.First[*example] + f, _ := first.WithContext[*example](ctx) - f.DoContext(ctx, func(ctx context.Context) (*example, error) { + f.DoContext(func(ctx context.Context) (*example, error) { defer wg.Done() if ctx.Err() != nil { @@ -76,7 +76,7 @@ func TestContextFirstFirst(t *testing.T) { return &example{name: "one"}, nil }) - f.DoContext(ctx, func(ctx context.Context) (*example, error) { + f.DoContext(func(ctx context.Context) (*example, error) { defer wg.Done() time.Sleep(10 * time.Millisecond) @@ -111,9 +111,9 @@ func TestContextFirstError(t *testing.T) { ctx := context.Background() - var f first.First[*example] + f, _ := first.WithContext[*example](ctx) - f.DoContext(ctx, func(ctx context.Context) (*example, error) { + f.DoContext(func(ctx context.Context) (*example, error) { defer wg.Done() time.Sleep(10 * time.Millisecond) @@ -125,7 +125,7 @@ func TestContextFirstError(t *testing.T) { return &example{name: "one"}, nil }) - f.DoContext(ctx, func(ctx context.Context) (*example, error) { + f.DoContext(func(ctx context.Context) (*example, error) { defer wg.Done() if ctx.Err() != nil { @@ -158,9 +158,9 @@ func TestContextErrors(t *testing.T) { ctx := context.Background() - var f first.First[*example] + f, _ := first.WithContext[*example](ctx) - f.DoContext(ctx, func(ctx context.Context) (*example, error) { + f.DoContext(func(ctx context.Context) (*example, error) { defer wg.Done() time.Sleep(10 * time.Millisecond) @@ -172,7 +172,7 @@ func TestContextErrors(t *testing.T) { return nil, errOne }) - f.DoContext(ctx, func(ctx context.Context) (*example, error) { + f.DoContext(func(ctx context.Context) (*example, error) { defer wg.Done() if ctx.Err() != nil { @@ -225,10 +225,10 @@ func TestContextFirstNone(t *testing.T) { func TestContextFirstRoutines(t *testing.T) { ctx := context.Background() - var f first.First[*example] + f, _ := first.WithContext[*example](ctx) // Go ahead and set up one func to avoid the ErrNothingToWaitOn error. - f.DoContext(ctx, func(ctx context.Context) (*example, error) { + f.DoContext(func(ctx context.Context) (*example, error) { if ctx.Err() != nil { t.Fatalf("Unexpected context error: %s", ctx.Err()) } @@ -243,7 +243,7 @@ func TestContextFirstRoutines(t *testing.T) { go func() { defer wg.Done() - f.DoContext(ctx, func(ctx context.Context) (*example, error) { + f.DoContext(func(ctx context.Context) (*example, error) { time.Sleep(10 * time.Millisecond) return &example{name: "two"}, nil @@ -281,9 +281,9 @@ func TestRespectsCancel(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) - var f first.First[*example] + f, _ := first.WithContext[*example](ctx) - f.DoContext(ctx, func(ctx context.Context) (*example, error) { + f.DoContext(func(ctx context.Context) (*example, error) { defer wg.Done() wg.Wait() @@ -291,7 +291,7 @@ func TestRespectsCancel(t *testing.T) { return &example{name: "one"}, nil }) - f.DoContext(ctx, func(ctx context.Context) (*example, error) { + f.DoContext(func(ctx context.Context) (*example, error) { defer wg.Done() wg.Wait() @@ -317,9 +317,9 @@ func TestRespectsWithContextCancel(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) - f, ctx := first.WithContext[*example](ctx) + f, _ := first.WithContext[*example](ctx) - f.DoContext(ctx, func(ctx context.Context) (*example, error) { + f.DoContext(func(ctx context.Context) (*example, error) { defer wg.Done() wg.Wait() @@ -327,7 +327,7 @@ func TestRespectsWithContextCancel(t *testing.T) { return &example{name: "one"}, nil }) - f.DoContext(ctx, func(ctx context.Context) (*example, error) { + f.DoContext(func(ctx context.Context) (*example, error) { defer wg.Done() wg.Wait() diff --git a/first_example_test.go b/first_example_test.go index f9deb99..c986744 100644 --- a/first_example_test.go +++ b/first_example_test.go @@ -87,16 +87,16 @@ func ExampleFirst_DoContext() { wg.Add(2) - var f first.First[*example] + f, _ := first.WithContext[*example](ctx) - f.DoContext(ctx, func(ctx context.Context) (*example, error) { + f.DoContext(func(ctx context.Context) (*example, error) { defer wg.Done() time.Sleep(1 * time.Millisecond) fmt.Printf("2 ctx=%s\n", ctx.Err()) return nil, errors.New("oops 2") }) - f.DoContext(ctx, func(ctx context.Context) (*example, error) { + f.DoContext(func(ctx context.Context) (*example, error) { defer wg.Done() fmt.Printf("1 ctx=%s\n", ctx.Err()) return &example{name: "one"}, nil