Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 62 additions & 41 deletions pkg/graceful/graceful.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,81 +9,102 @@ import (
"syscall"
)

var (
terminationErrChan = make(chan terminationError, 1)
type terminationContextKey string

cancelNotify context.CancelFunc
var (
terminationKey = terminationContextKey("graceful_termination")
)

type termination struct {
ctx context.Context
cancel context.CancelFunc

errChan chan terminationError
}

// doWithError adds termination error and cancels context.
// It is safe for concurrent usage.
func (t *termination) doWithError(termErr terminationError) {
// Unblocking write: write err in channel if channel is empty, otherwise just go next.
select {
case t.errChan <- termErr:
default:
// just go next in non-blocking mode
}
// Cancel context if it is not cancelled yet.
t.cancel()
}

type terminationError struct {
err error
exitCode int
}

// WithTermination returns a copy of parent context that is marked done when SIGINT or SIGTERM received.
// WithTermination returns a termination that is marked done
// when SIGINT or SIGTERM received or Terminate() called.
func WithTermination(ctx context.Context) context.Context {
var notifyCtx context.Context
notifyCtx, cancelNotify = signal.NotifyContext(ctx, os.Interrupt, syscall.SIGTERM)
return notifyCtx
notifyCtx, cancelNotify := signal.NotifyContext(ctx, os.Interrupt, syscall.SIGTERM)
return context.WithValue(notifyCtx, terminationKey, &termination{
ctx: notifyCtx,
cancel: cancelNotify,
errChan: make(chan terminationError, 1),
})
}

// Terminate starts termination if not yet. It should be called after WithTermination().
// Terminate starts termination if not yet. ctx must be the context created WithTermination().
// It is safe for concurrent usage.
func Terminate(err error, exitCode int) {
termErr := terminationError{
err: err,
exitCode: exitCode,
func Terminate(ctx context.Context, err error, exitCode int) {
term, ok := ctx.Value(terminationKey).(*termination)
if !ok {
panic("context is not termination")
}

// Unblocking write: write err in channel if channel is empty, otherwise just go next.
select {
case terminationErrChan <- termErr:
default:
// just go next in non-blocking mode
}
term.doWithError(terminationError{
err: err,
exitCode: exitCode,
})
}

// If WithTermination() isn't called before we will have panic here.
if cancelNotify != nil {
cancelNotify()
}
// IsTerminationContext returns "true" if ctx is termination.
func IsTerminationContext(ctx context.Context) bool {
_, ok := ctx.Value(terminationKey).(*termination)
return ok
}

// IsTerminating returns true if termination is in progress. It is safe for concurrent usage.
// IsTerminating returns true if termination is in progress. ctx must be the context created WithTermination().
// It is safe for concurrent usage.
func IsTerminating(ctx context.Context) bool {
// Unblocking read
select {
case <-ctx.Done():
return true
default:
return false
}
term, ok := ctx.Value(terminationKey).(*termination)
// If Done is not yet closed, Err returns nil. If Done is closed, Err returns a non-nil error explaining why.
return ok && term.ctx.Err() != nil
}

type ShutdownErrorCallback func(err error, exitCode int)

// Shutdown handles termination using terminationCtx. It should be called after WithTermination().
// Shutdown handles termination using terminationCtx. ctx must be the context created WithTermination().
// If termination context is done, it ensures termination err using SIGTERM by default.
// If panic is happened, it translates the panic to termination err.
// If termination err is exists, it calls callback(msg, exitCode).
// Otherwise, it does nothing.
func Shutdown(ctx context.Context, callback ShutdownErrorCallback) {
// Unblocking read
select {
case <-ctx.Done():
// If ctx is done, we have to ensure termination err. We could use SIGTERM by default.
Terminate(errors.New("process terminated"), 143) // SIGTERM exit code
default:
// just go next in non-blocking mode
term, ok := ctx.Value(terminationKey).(*termination)
if !ok {
panic("context is not termination")
}

if IsTerminating(ctx) {
// Ensure termination err. We could use SIGTERM by default.
Terminate(ctx, errors.New("process terminated"), 143) // SIGTERM exit code
}

// If we have panic; we should translate it to termination err.
// Translate panic to termination err if needed.
if r := recover(); r != nil {
Terminate(fmt.Errorf("%v", r), 1)
Terminate(ctx, fmt.Errorf("%v", r), 1)
}

// Unblocking read
select {
case termErr := <-terminationErrChan:
case termErr := <-term.errChan:
// If termErr is exists, it calls the callback.
callback(termErr.err, termErr.exitCode)
default:
Expand Down
113 changes: 59 additions & 54 deletions pkg/graceful/graceful_test.go
Original file line number Diff line number Diff line change
@@ -1,60 +1,50 @@
package graceful

import (
"context"
"errors"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"os"
"os/signal"
"sync"
"syscall"
)

var _ = Describe("graceful core", func() {
Describe("WithTermination()", func() {
It("should return terminationCtx", func(ctx SpecContext) {
It("should return ctx with termination", func(ctx SpecContext) {
terminationCtx := WithTermination(ctx)
expectedCtx, stop := signal.NotifyContext(ctx, os.Interrupt)
Expect(terminationCtx).To(BeAssignableToTypeOf(expectedCtx))
stop()
signal.Reset(os.Interrupt, syscall.SIGTERM)
term, ok := terminationCtx.Value(terminationKey).(*termination)
Expect(ok).To(BeTrue())
term.cancel()
})
})
Describe("Terminate()", func() {
It("should not panic if called before WithTermination()", func() {
cancelNotify = nil
err0 := errors.New("some err")

Terminate(err0, 1)

Expect(terminationErrChan).To(Receive(Equal(terminationError{
err: err0,
exitCode: 1,
})))
Describe("Terminate()", func() {
It("should not panic if ctx has not termination", func(ctx SpecContext) {
Expect(func() {
Terminate(ctx, errors.New("some err"), 1)
}).To(PanicWith(MatchRegexp("context is not termination")))
})
It("should work for single usage", func(ctx SpecContext) {
_ = WithTermination(ctx)
terminationCtx := WithTermination(ctx)
err0 := errors.New("some err")
Terminate(err0, 1)
Expect(terminationErrChan).To(Receive(Equal(terminationError{
Terminate(terminationCtx, err0, 1)
Expect(terminationCtx.Value(terminationKey).(*termination).errChan).To(Receive(Equal(terminationError{
err: err0,
exitCode: 1,
})))
})
It("should do FIFO for sequential double usage", func(ctx SpecContext) {
_ = WithTermination(ctx)
terminationCtx := WithTermination(ctx)
err0 := errors.New("some err")
err1 := errors.New("another err")
Terminate(err0, 1)
Terminate(err1, 2)
Expect(terminationErrChan).To(Receive(Equal(terminationError{
Terminate(terminationCtx, err0, 1)
Terminate(terminationCtx, err1, 2)
Expect(terminationCtx.Value(terminationKey).(*termination).errChan).To(Receive(Equal(terminationError{
err: err0,
exitCode: 1,
})))
})
It("should be safe for concurrent usage", func(ctx SpecContext) {
_ = WithTermination(ctx)
terminationCtx := WithTermination(ctx)

err0 := errors.New("some err")
err1 := errors.New("another err")
Expand All @@ -63,63 +53,75 @@ var _ = Describe("graceful core", func() {
wg.Add(2)
go func() {
defer wg.Done()
Terminate(err0, 1)
Terminate(terminationCtx, err0, 1)
}()
go func() {
defer wg.Done()
Terminate(err1, 2)
Terminate(terminationCtx, err1, 2)
}()
wg.Wait()

Expect(terminationErrChan).To(HaveLen(1))
Expect(terminationCtx.Value(terminationKey).(*termination).errChan).To(HaveLen(1))
})
})
Describe("IsTerminationContext()", func() {
It("should return 'false' if ctx has not termination", func(ctx SpecContext) {
Expect(IsTerminationContext(ctx)).To(BeFalse())
})
It("should return 'true' if ctx has termination", func(ctx SpecContext) {
terminationCtx := WithTermination(ctx)
terminationCtx.Value(terminationKey).(*termination).cancel()
Expect(IsTerminationContext(terminationCtx)).To(BeTrue())
})
})
Describe("IsTerminating()", func() {
It("should return 'false' if ctx is not done", func(ctx SpecContext) {
It("should return 'false' if ctx has termination", func(ctx SpecContext) {
Expect(IsTerminating(ctx)).To(BeFalse())
})
It("should return 'true' if ctx is not done", func(ctx SpecContext) {
ctx0, cancel := context.WithCancel(ctx)
cancel()
Expect(IsTerminating(ctx0)).To(BeTrue())
It("should return 'false' if ctx has termination but it is not terminated yet", func(ctx SpecContext) {
terminationCtx := WithTermination(ctx)
Expect(IsTerminating(terminationCtx)).To(BeFalse())
terminationCtx.Value(terminationKey).(*termination).cancel()
})
It("should return 'true' if ctx has termination and it is terminated", func(ctx SpecContext) {
terminationCtx := WithTermination(ctx)
terminationCtx.Value(terminationKey).(*termination).cancel()
Expect(IsTerminating(terminationCtx)).To(BeTrue())
})
It("should return 'true' if ctx has termination and ctx is wrapped with another one", func(ctx SpecContext) {})
It("should be safe for concurrent usage", func(ctx SpecContext) {
terminationCtx := WithTermination(ctx)
terminationCtx.Value(terminationKey).(*termination).cancel()

wg := sync.WaitGroup{}
wg.Add(2)
go func() {
defer wg.Done()
Expect(IsTerminating(ctx)).To(BeFalse())
Expect(IsTerminating(terminationCtx)).To(BeTrue())
}()
go func() {
defer wg.Done()
Expect(IsTerminating(ctx)).To(BeFalse())
Expect(IsTerminating(terminationCtx)).To(BeTrue())
}()
wg.Wait()
})
})
Describe("Shutdown()", func() {
var spyCallback *spyCallbackHelper
BeforeEach(func() {
terminationErrChan = make(chan terminationError, 1)
spyCallback = &spyCallbackHelper{}
})
It("should not panic if called before WithTermination()", func(ctx SpecContext) {
cancelNotify = nil

ctx0, cancel := context.WithCancel(ctx)
cancel()

Shutdown(ctx0, spyCallback.Method)
It("should not panic if ctx has not termination", func(ctx SpecContext) {
Expect(func() {
Shutdown(ctx, spyCallback.Method)
}).To(PanicWith(MatchRegexp("context is not termination")))

Expect(spyCallback).To(Equal(&spyCallbackHelper{
callsCount: 1,
err: errors.New("process terminated"),
exitCode: 143,
}))
Expect(spyCallback).To(Equal(&spyCallbackHelper{}))
})
It("should ensure termination err if terminationCtx is done", func(ctx SpecContext) {
It("should ensure termination err if termination is in progress", func(ctx SpecContext) {
terminationCtx := WithTermination(ctx)
cancelNotify()
terminationCtx.Value(terminationKey).(*termination).cancel()

Shutdown(terminationCtx, spyCallback.Method)

Expect(spyCallback).To(Equal(&spyCallbackHelper{
Expand All @@ -142,14 +144,17 @@ var _ = Describe("graceful core", func() {
exitCode: 1,
}))
})
It("should do nothing if terminationCtx is not done and no panic", func(ctx SpecContext) {
Shutdown(WithTermination(ctx), spyCallback.Method)
It("should do nothing if termination is not in progress and no panic", func(ctx SpecContext) {
terminationCtx := WithTermination(ctx)
Shutdown(terminationCtx, spyCallback.Method)

Expect(spyCallback).To(Equal(&spyCallbackHelper{
callsCount: 0,
err: nil,
exitCode: 0,
}))

terminationCtx.Value(terminationKey).(*termination).cancel()
})
})
})
Expand Down