diff --git a/pkg/streamer/streamer.go b/pkg/streamer/streamer.go index 6f67f84..2ec1f8a 100644 --- a/pkg/streamer/streamer.go +++ b/pkg/streamer/streamer.go @@ -234,6 +234,22 @@ func StopTimer(timer *time.Timer) { } } +// flushCh returns all data currently queued in channel +func flushCh(ch <-chan []byte) []byte { + res := []byte{} + for { + select { + case readData, ok := <-ch: + if !ok { + return res + } + res = append(res, readData...) + default: + return res + } + } +} + // GenericReadX reads from readCh till expr matched, exceeded time or read more than size. // Returns error if nothing was read during readTimeout or ctx was Done // readSize - maximum read size @@ -251,6 +267,7 @@ func GenericReadX(ctx context.Context, inBuffer []byte, readCh chan []byte, read select { case <-ctx.Done(): StopTimer(maxDurationTimeout) + buffer = append(buffer, flushCh(readCh)...) return nil, buffer, buffer[len(inBuffer):], multierr.Combine(ctx.Err(), ThrowReadTimeoutException(GetLastBytes(buffer, readSize))) default: } @@ -282,6 +299,7 @@ func GenericReadX(ctx context.Context, inBuffer []byte, readCh chan []byte, read case <-ctx.Done(): StopTimer(readIterTimeout) StopTimer(maxDurationTimeout) + buffer = append(buffer, flushCh(readCh)...) return nil, buffer, buffer[len(inBuffer):], multierr.Combine(ctx.Err(), ThrowReadTimeoutException(GetLastBytes(buffer, readSize))) case readData, ok := <-readCh: StopTimer(readIterTimeout) @@ -312,6 +330,7 @@ func GenericReadX(ctx context.Context, inBuffer []byte, readCh chan []byte, read return NewReadXRes(Timeout, buffer, nil, []byte{}), buffer, buffer[len(inBuffer):], nil case <-readIterTimeout.C: StopTimer(maxDurationTimeout) + buffer = append(buffer, flushCh(readCh)...) return nil, buffer, buffer[len(inBuffer):], ThrowReadTimeoutException(GetLastBytes(buffer, readSize)) } } diff --git a/pkg/streamer/streamer_test.go b/pkg/streamer/streamer_test.go index ba3f32a..87d377e 100644 --- a/pkg/streamer/streamer_test.go +++ b/pkg/streamer/streamer_test.go @@ -2,10 +2,12 @@ package streamer import ( "context" + "errors" "testing" "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/annetutil/gnetcli/pkg/expr" ) @@ -83,6 +85,23 @@ func setupChan(data []byte) chan []byte { return ch } +func TestGenericReadXCtxDoneFlushChannel(t *testing.T) { + ch := make(chan []byte, 1) + ch <- []byte("pending in channel") + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + pat := expr.NewSimpleExpr().FromPattern("never-matches") + _, _, _, err := GenericReadX(ctx, nil, ch, 4096, time.Second, pat, 0, 0) + require.Error(t, err) + + var rt *ReadTimeoutException + require.True(t, errors.As(err, &rt)) + require.Equal(t, "pending in channel", string(rt.LastRead)) + require.Empty(t, ch, "channel should be drained") +} + func TestGenericSplitBytes(t *testing.T) { a, b := splitBytes([]byte("1234"), 2) assert.Equal(t, []byte("12"), a)