diff --git a/libp2p/stream_muxer/yamux/yamux.py b/libp2p/stream_muxer/yamux/yamux.py index 8acc1a9ea..69b9dbce6 100644 --- a/libp2p/stream_muxer/yamux/yamux.py +++ b/libp2p/stream_muxer/yamux/yamux.py @@ -808,26 +808,59 @@ async def handle_incoming(self) -> None: self.stream_buffers[stream_id] = bytearray() self.stream_events[stream_id] = trio.Event() - # Read any data that came with the SYN frame - if length > 0: - try: - data = await read_exactly(self.secured_conn, length) - self.stream_buffers[stream_id].extend(data) - self.stream_events[stream_id].set() + if typ == TYPE_DATA: + # Read any data that came with the SYN frame + if length > 0: + try: + data = await read_exactly(self.secured_conn, length) + self.stream_buffers[stream_id].extend(data) + self.stream_events[stream_id].set() + logger.debug( + f"Read {length} bytes with SYN " + f"for stream {stream_id}" + ) + except IncompleteReadError as e: + logger.error( + "Incomplete read for SYN data on stream " + f"{stream_id}: {e}" + ) + # Mark stream as closed + stream.recv_closed = True + stream.closed = True + if stream_id in self.stream_events: + self.stream_events[stream_id].set() + elif typ == TYPE_WINDOW_UPDATE: + # For WINDOW_UPDATE, length is window increment, not payload + increment = length + async with stream.window_lock: logger.debug( - f"Read {length} bytes with SYN " - f"for stream {stream_id}" + f"Received window update with SYN for stream" + f"{self.peer_id}:{stream_id}," + f" increment: {increment}" ) - except IncompleteReadError as e: - logger.error( - "Incomplete read for SYN data on stream " - f"{stream_id}: {e}" - ) - # Mark stream as closed - stream.recv_closed = True + stream.send_window += increment + + # FIN and RST flags may be sent with SYN frames + if flags & FLAG_FIN: + logger.debug( + f"Received FIN for stream {self.peer_id}:" + f"{stream_id} with SYN, marking recv_closed" + ) + stream.recv_closed = True + if stream.send_closed: stream.closed = True - if stream_id in self.stream_events: - self.stream_events[stream_id].set() + # Wake up reader + self.stream_events[stream_id].set() + + if flags & FLAG_RST: + logger.debug( + f"Resetting stream {stream_id} for peer" + f"{self.peer_id} with SYN" + ) + stream.closed = True + stream.reset_received = True + # Wake up reader + self.stream_events[stream_id].set() ack_header = struct.pack( YAMUX_HEADER_FORMAT, diff --git a/tests/core/stream_muxer/test_yamux.py b/tests/core/stream_muxer/test_yamux.py index 8e0befc89..d17b3c569 100644 --- a/tests/core/stream_muxer/test_yamux.py +++ b/tests/core/stream_muxer/test_yamux.py @@ -847,3 +847,111 @@ async def test_incomplete_read_error_clean_close_detection(): assert legacy_error.received_bytes == 0 logging.debug("test_incomplete_read_error_clean_close_detection complete") + + +@pytest.mark.trio +async def test_yamux_syn_with_window_update(yamux_pair): + """ + Test that WINDOW_UPDATE|SYN frame is properly handled without reading payload. + This regression test ensures that WINDOW_UPDATE|SYN frames are NOT treated + as carrying payload bytes, fixing interop issues. + """ + logging.debug("Starting test_yamux_syn_with_window_update") + client_yamux, server_yamux = yamux_pair + + # Manually construct a WINDOW_UPDATE|SYN frame + window_increment = 1024 + stream_id = 11 # Client stream ID (odd number) + + # Create WINDOW_UPDATE header with SYN flag + # length is window increment, NOT payload length + header = struct.pack( + YAMUX_HEADER_FORMAT, + 0, # version + TYPE_WINDOW_UPDATE, + FLAG_SYN, + stream_id, + window_increment, + ) + + # Send header directly + await client_yamux.secured_conn.write(header) + logging.debug(f"Sent WINDOW_UPDATE|SYN with increment {window_increment}") + + # Server should accept the stream and NOT hang trying to read payload + with trio.move_on_after(2) as cancel_scope: + server_stream = await server_yamux.accept_stream() + + assert ( + not cancel_scope.cancelled_caught + ), "Server should have accepted the stream without hanging" + assert server_stream.stream_id == stream_id + + # Check if window increment was applied + # Initial window is 256KB by default + assert server_stream.send_window == 256 * 1024 + window_increment + + logging.debug("test_yamux_syn_with_window_update complete") + + +@pytest.mark.trio +async def test_yamux_syn_with_fin(yamux_pair): + """ + Test that DATA|SYN|FIN frame is properly handled (opens and half-closes). + """ + logging.debug("Starting test_yamux_syn_with_fin") + client_yamux, server_yamux = yamux_pair + + stream_id = 13 + header = struct.pack( + YAMUX_HEADER_FORMAT, + 0, + TYPE_DATA, + FLAG_SYN | FLAG_FIN, + stream_id, + 0, + ) + + await client_yamux.secured_conn.write(header) + + server_stream = await server_yamux.accept_stream() + assert server_stream.stream_id == stream_id + + # Should be recv_closed because of FIN + assert server_stream.recv_closed + + # Should be able to read 0 bytes (EOF) + with pytest.raises(MuxedStreamEOF): + await server_stream.read(1) + + logging.debug("test_yamux_syn_with_fin complete") + + +@pytest.mark.trio +async def test_yamux_syn_with_rst(yamux_pair): + """ + Test that DATA|SYN|RST frame is properly handled (opens and resets). + """ + logging.debug("Starting test_yamux_syn_with_rst") + client_yamux, server_yamux = yamux_pair + + stream_id = 15 + header = struct.pack( + YAMUX_HEADER_FORMAT, + 0, + TYPE_DATA, + FLAG_SYN | FLAG_RST, + stream_id, + 0, + ) + + await client_yamux.secured_conn.write(header) + + server_stream = await server_yamux.accept_stream() + assert server_stream.stream_id == stream_id + + # Should be closed and reset_received set + assert server_stream.closed + assert server_stream.reset_received + + logging.debug("test_yamux_syn_with_rst complete")