diff --git a/conn.go b/conn.go index 9562ffd4..03ab5233 100644 --- a/conn.go +++ b/conn.go @@ -1012,18 +1012,16 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) { c.reader.Close() c.reader = nil } - c.messageReader = nil c.readLength = 0 - for c.readErr == nil { frameType, err := c.advanceFrame() if err != nil { c.readErr = err break } - if frameType == TextMessage || frameType == BinaryMessage { + c.readErrCount = 0 // resets on success c.messageReader = &messageReader{c} c.reader = c.messageReader if c.readDecompress { @@ -1032,15 +1030,14 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) { return frameType, c.reader, nil } } - // Applications that do handle the error returned from this method spin in // tight loop on connection failure. To help application developers detect // this error, panic on repeated reads to the failed connection. c.readErrCount++ - if c.readErrCount >= 1000 { - panic("repeated read on failed websocket connection") + if c.readErrCount >= 2 { + panic("gorilla/websocket: application did not handle error from NextReader, error: " + + c.readErr.Error()) } - return noFrame, nil, c.readErr } diff --git a/conn_test.go b/conn_test.go index 28f5c4a3..5e247514 100644 --- a/conn_test.go +++ b/conn_test.go @@ -744,3 +744,45 @@ func TestFailedConnectionReadPanic(t *testing.T) { } t.Fatal("should not get here") } + +func TestNextReadErrCountResetOnSuccess(t *testing.T) { + // setup a pipe to write a real websocket frame + r, w := io.Pipe() + conn := newTestConn(r, w, false) // client connection + + // simulate accumulated errors from previous failures + conn.readErrCount = 1 + + // write a valid text frame from another goroutine + go func() { + serverConn := newTestConn(r, w, true) // server connection + serverConn.WriteMessage(TextMessage, []byte("Test")) + }() + + // successful read should reset the counter + _, _, err := conn.NextReader() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if conn.readErrCount != 0 { + t.Errorf("expected readErrCount to be 0 after success, got %d", conn.readErrCount) + } +} + +func TestNextReaderPanicsOnSecondFailedRead(t *testing.T) { + r, w := io.Pipe() + conn := newTestConn(r, w, false) // client connection + + // force the connection into a permanently failed state + conn.readErr = errors.New("simulated failure") + conn.readErrCount = 1 // already had 1 failed read + + defer func() { + r := recover() + if r == nil { + t.Fatal("expected panic on second read of failed connection, got none") + } + }() + + conn.NextReader() +}