diff --git a/extra/multihash-spec b/extra/multihash-spec deleted file mode 160000 index b43ec1026..000000000 --- a/extra/multihash-spec +++ /dev/null @@ -1 +0,0 @@ -Subproject commit b43ec1026a610fa87878e53b3daecf3a14b3ef6f diff --git a/extra/py-multihash b/extra/py-multihash deleted file mode 160000 index dfae0dd7a..000000000 --- a/extra/py-multihash +++ /dev/null @@ -1 +0,0 @@ -Subproject commit dfae0dd7a66e0f5a0346d0297e03582443297b9c diff --git a/extra/pymultihash b/extra/pymultihash deleted file mode 160000 index 215298fa2..000000000 --- a/extra/pymultihash +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 215298fa2faa55027384d1f22519229d0918cfb0 diff --git a/libp2p/security/noise/io.py b/libp2p/security/noise/io.py index 18fbbcd5c..d1460ed13 100644 --- a/libp2p/security/noise/io.py +++ b/libp2p/security/noise/io.py @@ -22,6 +22,9 @@ MAX_NOISE_MESSAGE_LEN = 2 ** (8 * SIZE_NOISE_MESSAGE_LEN) - 1 SIZE_NOISE_MESSAGE_BODY_LEN = 2 MAX_NOISE_MESSAGE_BODY_LEN = MAX_NOISE_MESSAGE_LEN - SIZE_NOISE_MESSAGE_BODY_LEN +# Max plaintext per Noise message: 65535 - 16 bytes Poly1305 MAC overhead. +# Matches go-libp2p's MaxPlaintextLength in p2p/security/noise/rw.go. +MAX_PLAINTEXT_LENGTH = MAX_NOISE_MESSAGE_LEN - 16 BYTE_ORDER = "big" # | Noise packet | @@ -53,14 +56,26 @@ def __init__(self, conn: IRawConnection, noise_state: NoiseState) -> None: self.noise_state = noise_state async def write_msg(self, msg: bytes, prefix_encoded: bool = False) -> None: - logger.debug(f"Noise write_msg: encrypting {len(msg)} bytes") - data_encrypted = self.encrypt(msg) - if prefix_encoded: - # Manually add the prefix if needed - data_encrypted = self.prefix + data_encrypted - logger.debug(f"Noise write_msg: writing {len(data_encrypted)} encrypted bytes") - await self.read_writer.write_msg(data_encrypted) - logger.debug("Noise write_msg: write completed successfully") + # Chunk large messages to stay within the Noise 65535-byte transport + # message limit, matching go-libp2p's noise/rw.go Write() approach. + if len(msg) <= MAX_PLAINTEXT_LENGTH: + # Fast path: single message (covers handshake and small writes) + data_encrypted = self.encrypt(msg) + if prefix_encoded: + data_encrypted = self.prefix + data_encrypted + await self.read_writer.write_msg(data_encrypted) + else: + # Slow path: chunk into multiple Noise messages + total = len(msg) + written = 0 + while written < total: + end = min(written + MAX_PLAINTEXT_LENGTH, total) + chunk = msg[written:end] + data_encrypted = self.encrypt(chunk) + if prefix_encoded and written == 0: + data_encrypted = self.prefix + data_encrypted + await self.read_writer.write_msg(data_encrypted) + written = end async def read_msg(self, prefix_encoded: bool = False) -> bytes: logger.debug("Noise read_msg: reading encrypted message") diff --git a/libp2p/security/secure_session.py b/libp2p/security/secure_session.py index 29a970507..604e5fc6f 100644 --- a/libp2p/security/secure_session.py +++ b/libp2p/security/secure_session.py @@ -94,24 +94,47 @@ async def read(self, n: int | None = None) -> bytes: return b"" data_from_buffer = self._drain(n) - if len(data_from_buffer) > 0: + if n is None and len(data_from_buffer) > 0: return data_from_buffer - msg = await self.conn.read_msg() + if n is None: + msg = await self.conn.read_msg() - # If underlying connection returned empty bytes, treat as closed - # and raise to signal that reads after close are invalid. - if msg == b"": - raise Exception("Connection closed") + # If underlying connection returned empty bytes, treat as closed + # and raise to signal that reads after close are invalid. + if msg == b"": + raise Exception("Connection closed") - if n is None: return msg - if n < len(msg): - self._fill(msg) - return self._drain(n) - else: - return msg + if len(data_from_buffer) == n: + return data_from_buffer + + result = bytearray(data_from_buffer) + while len(result) < n: + needed = n - len(result) + drained = self._drain(needed) + if drained: + result.extend(drained) + continue + + msg = await self.conn.read_msg() + + # If the connection closes after a partial read, return the bytes + # we already assembled. This preserves the stream-read behavior + # expected by higher layers. + if msg == b"": + if result: + return bytes(result) + raise Exception("Connection closed") + + if len(msg) <= needed: + result.extend(msg) + else: + result.extend(msg[:needed]) + self._fill(msg[needed:]) + + return bytes(result) async def write(self, data: bytes) -> None: await self.conn.write_msg(data) diff --git a/libp2p/stream_muxer/yamux/yamux.py b/libp2p/stream_muxer/yamux/yamux.py index 190c46bff..f5e57614e 100644 --- a/libp2p/stream_muxer/yamux/yamux.py +++ b/libp2p/stream_muxer/yamux/yamux.py @@ -70,6 +70,9 @@ # Network byte order: version (B), type (B), flags (H), stream_id (I), length (I) YAMUX_HEADER_FORMAT = "!BBHII" DEFAULT_WINDOW_SIZE = 256 * 1024 +MAX_WINDOW_SIZE = 16 * 1024 * 1024 # 16 MB max receive window (matches go-yamux) +MAX_MESSAGE_SIZE = 64 * 1024 # 64KB max frame payload, matches go-yamux default +RTT_MEASURE_INTERVAL = 30 # seconds between RTT measurements GO_AWAY_NORMAL = 0x0 GO_AWAY_PROTOCOL_ERROR = 0x1 @@ -77,6 +80,9 @@ class YamuxStream(IMuxedStream): + target_recv_window: int + epoch_start: float + def __init__(self, stream_id: int, conn: "Yamux", is_initiator: bool) -> None: self.stream_id = stream_id self.conn = conn @@ -89,6 +95,8 @@ def __init__(self, stream_id: int, conn: "Yamux", is_initiator: bool) -> None: self.send_window = DEFAULT_WINDOW_SIZE self.recv_window = DEFAULT_WINDOW_SIZE self.window_lock = trio.Lock() + self.target_recv_window = DEFAULT_WINDOW_SIZE # grows up to MAX_WINDOW_SIZE + self.epoch_start = 0.0 # trio.current_time() of last window update self.rw_lock = ReadWriteLock() self.close_lock = trio.Lock() @@ -106,54 +114,58 @@ async def __aexit__( await self.close() async def write(self, data: bytes) -> None: - async with self.rw_lock.write_lock(): - if self.send_closed: - raise MuxedStreamError("Stream is closed for sending") - - # Flow control: Check if we have enough send window - total_len = len(data) - sent = 0 - logger.debug(f"Stream {self.stream_id}: Starts writing {total_len} bytes ") - while sent < total_len: - # Wait for available window with timeout - timeout = False - async with self.window_lock: - if self.send_window == 0: - logger.debug( - f"Stream {self.stream_id}: " - "Window is zero, waiting for update" - ) - # Release lock and wait with timeout - self.window_lock.release() - # To avoid re-acquiring the lock immediately, - with trio.move_on_after(5.0) as cancel_scope: - while self.send_window == 0 and not self.closed: - await trio.sleep(0.01) - # If we timed out, cancel the scope - timeout = cancel_scope.cancelled_caught - # Re-acquire lock - await self.window_lock.acquire() - - # If we timed out waiting for window update, raise an error - if timeout: - raise MuxedStreamError( - "Timed out waiting for window update after 5 seconds." - ) - - if self.closed: - raise MuxedStreamError("Stream is closed") + if self.send_closed: + raise MuxedStreamError("Stream is closed for sending") + + total_len = len(data) + sent = 0 + logger.debug(f"Stream {self.stream_id}: Starts writing {total_len} bytes ") + while sent < total_len: + frame: bytes | None = None + while frame is None: + async with self.rw_lock.write_lock(): + if self.send_closed: + raise MuxedStreamError("Stream is closed for sending") + async with self.window_lock: + if self.closed: + raise MuxedStreamError("Stream is closed") + if self.send_window > 0: + to_send = min( + self.send_window, + MAX_MESSAGE_SIZE - HEADER_SIZE, + total_len - sent, + ) + chunk = data[sent : sent + to_send] + self.send_window -= to_send + header = struct.pack( + YAMUX_HEADER_FORMAT, + 0, + TYPE_DATA, + 0, + self.stream_id, + len(chunk), + ) + frame = header + chunk - # Calculate how much we can send now - to_send = min(self.send_window, total_len - sent) - chunk = data[sent : sent + to_send] - self.send_window -= to_send + if frame is not None: + break - # Send the data - header = struct.pack( - YAMUX_HEADER_FORMAT, 0, TYPE_DATA, 0, self.stream_id, len(chunk) + logger.debug( + f"Stream {self.stream_id}: Window is zero, waiting for update" + ) + with trio.move_on_after(5.0) as cancel_scope: + while True: + async with self.window_lock: + if self.send_window > 0 or self.closed: + break + await trio.sleep(0.01) + if cancel_scope.cancelled_caught: + raise MuxedStreamError( + "Timed out waiting for window update after 5 seconds." ) - await self.conn.secured_conn.write(header + chunk) - sent += to_send + + await self.conn._write_frame(frame) + sent += len(frame) - HEADER_SIZE async def send_window_update(self, increment: int, skip_lock: bool = False) -> None: """ @@ -162,9 +174,10 @@ async def send_window_update(self, increment: int, skip_lock: bool = False) -> N param:increment: The amount to increment the window size by. If None, uses the difference between DEFAULT_WINDOW_SIZE and current receive window. - param:skip_lock (bool): If True, skips acquiring window_lock. - This should only be used when calling from a context - that already holds the lock. + param:skip_lock (bool): Unused (retained for API compatibility). + + Never hold ``window_lock`` across this await — inbound WINDOW_UPDATE handling + also needs ``window_lock`` to adjust ``send_window`` concurrently. Note: This method gracefully handles connection closure errors. If the connection is closed (e.g., peer closed WebSocket immediately @@ -192,7 +205,7 @@ async def _do_window_update() -> None: increment, ) try: - await self.conn.secured_conn.write(header) + await self.conn._write_frame(header) except ConnectionClosedError as e: # Typed exception from transports (e.g., WebSocket) that # properly signal connection closure — handle gracefully. @@ -228,11 +241,60 @@ async def _do_window_update() -> None: ) raise - if skip_lock: - await _do_window_update() - else: - async with self.window_lock: - await _do_window_update() + # Never hold window_lock across _write_frame: inbound WINDOW_UPDATE handlers + # need window_lock concurrently to adjust send_window. + _ = skip_lock + await _do_window_update() + + async def _auto_tune_and_send_window_update(self: "YamuxStream") -> None: + """ + Auto-tune receive window size based on RTT and send window update. + + Ports go-yamux's two-pass GrowTo + sendWindowUpdate logic: + - Pass 1: GrowTo(current_target) — restore window to current target + - Auto-tune: if within 4x RTT of last epoch, double the target + - Pass 2: GrowTo(new_target, force=True) — grow to new target + - Only the final delta is sent to the peer (matches go-yamux behavior) + """ + total_delta: int + async with self.window_lock: + # Match go-yamux GrowTo: currentWindow = cap + len + buffered = len(self.conn.stream_buffers.get(self.stream_id, b"")) + current_window = self.recv_window + buffered + + # Pass 1: GrowTo(target_recv_window) — like go's first GrowTo call + delta = self.target_recv_window - current_window + if delta <= 0: + return + # Hysteresis: skip if delta < 50% of target (matches go-yamux GrowTo) + if delta < self.target_recv_window // 2: + return + # Apply first pass growth to recv_window (like go's cap += delta) + self.recv_window += delta + + # Auto-tune: if within 4x RTT of last epoch, double the target + now = trio.current_time() + rtt = self.conn.rtt() + if rtt > 0 and self.epoch_start > 0 and (now - self.epoch_start) < rtt * 4: + new_target = min(self.target_recv_window * 2, MAX_WINDOW_SIZE) + if new_target > self.target_recv_window: + self.target_recv_window = new_target + # Pass 2: GrowTo(new_target, force=True) — incremental + # Recompute current_window after pass 1 growth + new_current = self.recv_window + buffered + extra_delta = self.target_recv_window - new_current + if extra_delta > 0: + self.recv_window += extra_delta + delta += extra_delta # Send total delta (pass 1 + pass 2) + + self.epoch_start = now + logger.debug( + f"Stream {self.stream_id}: Auto-tune window update " + f"delta={delta}, target={self.target_recv_window}" + ) + total_delta = delta + + await self.send_window_update(total_delta, skip_lock=True) async def read(self, n: int | None = -1) -> bytes: """ @@ -288,11 +350,8 @@ async def read(self, n: int | None = -1) -> bytes: buffer.clear() data += chunk - # Send window update for the chunk we just read - async with self.window_lock: - self.recv_window += len(chunk) - logger.debug(f"Stream {self.stream_id}: Update {len(chunk)}") - await self.send_window_update(len(chunk), skip_lock=True) + # Auto-tune and send window update for the chunk we just read + await self._auto_tune_and_send_window_update() # Check for reset if self.reset_received: @@ -337,13 +396,7 @@ async def read(self, n: int | None = -1) -> bytes: return b"" else: data = await self.conn.read_stream(self.stream_id, n) - async with self.window_lock: - self.recv_window += len(data) - logger.debug( - f"Stream {self.stream_id}: Sending window update after read, " - f"increment={len(data)}" - ) - await self.send_window_update(len(data), skip_lock=True) + await self._auto_tune_and_send_window_update() return data async def close(self) -> None: @@ -352,9 +405,14 @@ async def close(self) -> None: logger.debug(f"Half-closing stream {self.stream_id} (local end)") try: header = struct.pack( - YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_FIN, self.stream_id, 0 + YAMUX_HEADER_FORMAT, + 0, + TYPE_WINDOW_UPDATE, + FLAG_FIN, + self.stream_id, + 0, ) - await self.conn.secured_conn.write(header) + await self.conn._write_frame(header) except (RawConnError, ConnectionClosedError) as e: logger.debug(f"Error sending FIN, connection likely closed: {e}") finally: @@ -373,9 +431,14 @@ async def reset(self) -> None: logger.debug(f"Resetting stream {self.stream_id}") try: header = struct.pack( - YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_RST, self.stream_id, 0 + YAMUX_HEADER_FORMAT, + 0, + TYPE_WINDOW_UPDATE, + FLAG_RST, + self.stream_id, + 0, ) - await self.conn.secured_conn.write(header) + await self.conn._write_frame(header) except (RawConnError, ConnectionClosedError) as e: logger.debug(f"Error sending RST, connection likely closed: {e}") finally: @@ -432,8 +495,45 @@ def __init__( self.event_started = trio.Event() self.stream_buffers: dict[int, bytearray] = {} self.stream_events: dict[int, trio.Event] = {} + self._write_lock = trio.Lock() self._nursery: Nursery | None = None self._established: bool = False + self._rtt: float = 0.0 # smoothed RTT in seconds + self._ping_id: int = 0 # incrementing ping nonce + self._ping_sent_time: float = 0.0 # trio.current_time() when ping sent + self._ping_event: trio.Event = trio.Event() + + def rtt(self) -> float: + """Return the current smoothed RTT estimate in seconds.""" + return self._rtt + + async def _measure_rtt_loop(self) -> None: + """Background task that periodically measures RTT via ping/pong.""" + # Initial delay to let the connection establish + await trio.sleep(0.5) + while not self.event_shutting_down.is_set(): + try: + self._ping_id += 1 + self._ping_event = trio.Event() + header = struct.pack( + YAMUX_HEADER_FORMAT, 0, TYPE_PING, FLAG_SYN, 0, self._ping_id + ) + await self._write_frame(header) + # Record time AFTER write completes, matching go-yamux which + # times after dispatch to avoid including write-lock wait time. + self._ping_sent_time = trio.current_time() + # Wait for pong with timeout + with trio.move_on_after(10.0): + await self._ping_event.wait() + except Exception: + # Connection likely closed, exit the loop + break + if self.event_shutting_down.is_set(): + break + # Sleep between measurements, checking shutdown periodically + with trio.move_on_after(RTT_MEASURE_INTERVAL): + while not self.event_shutting_down.is_set(): + await trio.sleep(1.0) @property def is_established(self) -> bool: @@ -462,10 +562,14 @@ async def start(self) -> None: logger.debug( f"Yamux.start() starting handle_incoming task for {self.peer_id}" ) + + nursery.start_soon(self._measure_rtt_loop) # Use nursery.start() to ensure handle_incoming has started # before we set event_started. This prevents race conditions # where streams are opened before the muxer is ready. + # When handle_incoming exits, the finally block cancels the nursery. await nursery.start(self._handle_incoming_with_ready_signal) + logger.debug(f"Yamux.start() setting event_started for {self.peer_id}") self._established = True self.event_started.set() @@ -494,7 +598,7 @@ async def close(self, error_code: int = GO_AWAY_NORMAL) -> None: header = struct.pack( YAMUX_HEADER_FORMAT, 0, TYPE_GO_AWAY, 0, 0, error_code ) - await self.secured_conn.write(header) + await self._write_frame(header) except Exception as e: logger.debug(f"Failed to send GO_AWAY: {e}") self.event_shutting_down.set() @@ -551,6 +655,33 @@ def get_remote_address(self) -> tuple[str, int] | None: """ return self.secured_conn.get_remote_address() + async def _write_frame(self, data: bytes) -> None: + """Write a frame to the connection, serializing all writes.""" + if len(data) >= HEADER_SIZE: + _, typ, flags, sid, length = struct.unpack( + YAMUX_HEADER_FORMAT, data[:HEADER_SIZE] + ) + flag_names = [] + if flags & FLAG_SYN: + flag_names.append("SYN") + if flags & FLAG_ACK: + flag_names.append("ACK") + if flags & FLAG_FIN: + flag_names.append("FIN") + if flags & FLAG_RST: + flag_names.append("RST") + type_names = {0: "DATA", 1: "WINDOW_UPDATE", 2: "PING", 3: "GO_AWAY"} + logger.debug( + f"YAMUX TX: type={type_names.get(typ, typ)} " + f"flags={'+'.join(flag_names) or '0'} " + f"stream={sid} length={length} " + f"is_initiator={self.is_initiator_value} " + f"payload_bytes={len(data) - HEADER_SIZE}" + ) + async with self._write_lock: + await self.secured_conn.write(data) + await trio.lowlevel.checkpoint() + async def open_stream(self) -> YamuxStream: # Wait for backlog slot await self.stream_backlog_semaphore.acquire() @@ -568,10 +699,15 @@ async def open_stream(self) -> YamuxStream: # If stream is rejected or errors, release the semaphore try: header = struct.pack( - YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_SYN, stream_id, 0 + YAMUX_HEADER_FORMAT, + 0, + TYPE_WINDOW_UPDATE, + FLAG_SYN, + stream_id, + 0, ) logger.debug(f"Sending SYN header for stream {stream_id}") - await self.secured_conn.write(header) + await self._write_frame(header) return stream except Exception as e: self.stream_backlog_semaphore.release() @@ -713,13 +849,18 @@ async def _handle_incoming_with_ready_signal( This method uses trio's task_status to signal that the handle_incoming loop is ready to process frames. This prevents race conditions where streams are opened before the muxer is ready to handle them. + When handle_incoming exits, this cancels the nursery scope. """ logger.debug( f"Yamux _handle_incoming_with_ready_signal() starting for " f"peer {self.peer_id}" ) task_status.started() - await self.handle_incoming() + try: + await self.handle_incoming() + finally: + if self._nursery is not None: + self._nursery.cancel_scope.cancel() async def handle_incoming(self) -> None: logger.debug(f"Yamux handle_incoming() started for peer {self.peer_id}") @@ -787,12 +928,38 @@ async def handle_incoming(self) -> None: version, typ, flags, stream_id, length = struct.unpack( YAMUX_HEADER_FORMAT, header ) + type_names = {0: "DATA", 1: "WINDOW_UPDATE", 2: "PING", 3: "GO_AWAY"} + flag_names = [] + if flags & FLAG_SYN: + flag_names.append("SYN") + if flags & FLAG_ACK: + flag_names.append("ACK") + if flags & FLAG_FIN: + flag_names.append("FIN") + if flags & FLAG_RST: + flag_names.append("RST") logger.debug( - f"Received header for peer {self.peer_id}:" - f"type={typ}, flags={flags}, stream_id={stream_id}," - f"length={length}" + f"YAMUX RX: type={type_names.get(typ, typ)} " + f"flags={'+'.join(flag_names) or '0'} " + f"stream={stream_id} length={length} " + f"is_initiator={self.is_initiator_value}" ) if (typ == TYPE_DATA or typ == TYPE_WINDOW_UPDATE) and flags & FLAG_SYN: + syn_payload: bytes = b"" + syn_payload_err: IncompleteReadError | None = None + if typ == TYPE_DATA and length > 0: + try: + syn_payload = await read_exactly(self.secured_conn, length) + except IncompleteReadError as e: + syn_payload_err = e + logger.error( + "Incomplete read for SYN data on " + f"stream {stream_id}: {e}" + ) + + rst_header: bytes | None = None + ack_header: bytes | None = None + new_stream_notify: YamuxStream | None = None async with self.streams_lock: if stream_id not in self.streams: stream = YamuxStream(stream_id, self, False) @@ -800,79 +967,120 @@ async def handle_incoming(self) -> None: self.stream_buffers[stream_id] = bytearray() self.stream_events[stream_id] = trio.Event() - # Read any data that came with the SYN frame - if length > 0: - try: - data = await read_exactly(self.secured_conn, length) - self.stream_buffers[stream_id].extend(data) - self.stream_events[stream_id].set() - logger.debug( - f"Read {length} bytes with SYN " - f"for stream {stream_id}" - ) - except IncompleteReadError as e: - logger.error( - "Incomplete read for SYN data on stream " - f"{stream_id}: {e}" + if syn_payload_err is not None: + stream.recv_closed = True + stream.closed = True + self.stream_events[stream_id].set() + elif typ == TYPE_WINDOW_UPDATE and length > 0: + # Window update SYN: length is a delta + async with stream.window_lock: + stream.send_window += length + logger.debug( + f"SYN window update for stream " + f"{stream_id}: window={length}" + ) + elif typ == TYPE_DATA and length > 0: + self.stream_buffers[stream_id].extend(syn_payload) + stream.recv_window -= len(syn_payload) + if stream.recv_window < 0: + logger.warning( + f"Stream {stream_id}: peer exceeded " + f"receive window by " + f"{-stream.recv_window} bytes" ) - # Mark stream as closed - stream.recv_closed = True - stream.closed = True - if stream_id in self.stream_events: - self.stream_events[stream_id].set() + stream.recv_window = 0 + self.stream_events[stream_id].set() + logger.debug( + f"Read {length} bytes with SYN " + f"for stream {stream_id}" + ) ack_header = struct.pack( YAMUX_HEADER_FORMAT, 0, - TYPE_DATA, + TYPE_WINDOW_UPDATE, FLAG_ACK, stream_id, 0, ) - await self.secured_conn.write(ack_header) - logger.debug( - f"Sending stream {stream_id}" - f"to channel for peer {self.peer_id}" - ) - await self.new_stream_send_channel.send(stream) + new_stream_notify = stream else: rst_header = struct.pack( YAMUX_HEADER_FORMAT, 0, - TYPE_DATA, + TYPE_WINDOW_UPDATE, FLAG_RST, stream_id, 0, ) - await self.secured_conn.write(rst_header) - elif typ == TYPE_DATA and flags & FLAG_ACK: + + if rst_header is not None: + await self._write_frame(rst_header) + elif ack_header is not None: + await self._write_frame(ack_header) + logger.debug( + f"Sending stream {stream_id}" + f"to channel for peer {self.peer_id}" + ) + if new_stream_notify is not None: + await self.new_stream_send_channel.send(new_stream_notify) + elif ( + typ == TYPE_DATA or typ == TYPE_WINDOW_UPDATE + ) and flags & FLAG_ACK: + ack_payload: bytes = b"" + ack_payload_err: IncompleteReadError | None = None + if typ == TYPE_DATA and length > 0: + try: + ack_payload = await read_exactly(self.secured_conn, length) + except IncompleteReadError as e: + ack_payload_err = e + logger.error( + "Incomplete read for ACK data on " + f"stream {stream_id}: {e}" + ) async with self.streams_lock: if stream_id in self.streams: - # Read any data that came with the ACK - if length > 0: - try: - data = await read_exactly(self.secured_conn, length) - self.stream_buffers[stream_id].extend(data) - self.stream_events[stream_id].set() - logger.debug( - f"Received ACK with {length} bytes for stream " - f"{stream_id} for peer {self.peer_id}" - ) - except IncompleteReadError as e: - logger.error( - "Incomplete read for ACK data on stream " - f"{stream_id}: {e}" - ) - # Mark stream as closed - stream = self.streams[stream_id] + stream = self.streams[stream_id] + if typ == TYPE_WINDOW_UPDATE: + # Window update ACK: length is a delta + # (matches go-yamux incrSendWindow). + if length > 0: + async with stream.window_lock: + stream.send_window += length + logger.debug( + f"Received WINDOW_UPDATE ACK for stream " + f"{stream_id}, send_window={length} " + f"for peer {self.peer_id}" + ) + elif typ == TYPE_DATA and length > 0: + if ack_payload_err is not None: stream.recv_closed = True stream.closed = True if stream_id in self.stream_events: self.stream_events[stream_id].set() + else: + self.stream_buffers[stream_id].extend(ack_payload) + self.streams[stream_id].recv_window -= len( + ack_payload + ) + if self.streams[stream_id].recv_window < 0: + logger.warning( + f"Stream {stream_id}: peer exceeded " + f"receive window by " + f"{-self.streams[stream_id].recv_window}" + f" bytes" + ) + self.streams[stream_id].recv_window = 0 + self.stream_events[stream_id].set() + logger.debug( + f"Received ACK with {length} bytes " + f"for stream {stream_id} " + f"for peer {self.peer_id}" + ) else: logger.debug( - f"Received ACK (no data) for stream {stream_id} " - f"for peer {self.peer_id}" + f"Received ACK (no data) for stream " + f"{stream_id} for peer {self.peer_id}" ) elif typ == TYPE_GO_AWAY: error_code = length @@ -906,11 +1114,19 @@ async def handle_incoming(self) -> None: ping_header = struct.pack( YAMUX_HEADER_FORMAT, 0, TYPE_PING, FLAG_ACK, 0, length ) - await self.secured_conn.write(ping_header) + await self._write_frame(ping_header) elif flags & FLAG_ACK: + # Compute RTT with exponential smoothing + now = trio.current_time() + new_rtt = now - self._ping_sent_time + if self._rtt == 0.0: + self._rtt = new_rtt + else: + self._rtt = (self._rtt + new_rtt) / 2 + self._ping_event.set() logger.debug( f"Received ping response with value" - f"{length} for peer {self.peer_id}" + f"{length} for peer {self.peer_id}, rtt={self._rtt:.4f}s" ) elif typ == TYPE_DATA: try: @@ -954,6 +1170,15 @@ async def handle_incoming(self) -> None: async with self.streams_lock: if stream_id in self.streams: self.stream_buffers[stream_id].extend(data) + self.streams[stream_id].recv_window -= len(data) + if self.streams[stream_id].recv_window < 0: + logger.warning( + f"Stream {stream_id}: peer exceeded " + f"receive window by " + f"{-self.streams[stream_id].recv_window}" + f" bytes" + ) + self.streams[stream_id].recv_window = 0 # Always set event, even if no data # in case FIN/RST is set self.stream_events[stream_id].set() diff --git a/newsfragments/1270.feature.rst b/newsfragments/1270.feature.rst new file mode 100644 index 000000000..7b1cd8137 --- /dev/null +++ b/newsfragments/1270.feature.rst @@ -0,0 +1 @@ +Added yamux receive window auto-tuning: the per-stream receive window starts at 256 KB and doubles each RTT epoch up to 16 MB, matching go-yamux behavior for improved throughput on high-bandwidth connections. diff --git a/newsfragments/1271.bugfix.rst b/newsfragments/1271.bugfix.rst new file mode 100644 index 000000000..b26f965db --- /dev/null +++ b/newsfragments/1271.bugfix.rst @@ -0,0 +1 @@ +Fixed yamux interoperability with go-yamux: SYN/ACK/FIN/RST frames are now sent as TYPE_WINDOW_UPDATE (not TYPE_DATA), writes are serialized with a lock to prevent frame interleaving, and SYN/ACK window values match go-yamux conventions so peers no longer get an inflated send window. diff --git a/tests/core/security/noise/test_buffer_management.py b/tests/core/security/noise/test_buffer_management.py index c76ae18f5..71170761c 100644 --- a/tests/core/security/noise/test_buffer_management.py +++ b/tests/core/security/noise/test_buffer_management.py @@ -85,13 +85,14 @@ async def test_partial_read_handling(self, nursery): # Read in chunks chunk_size = 5 - received_data = b"" + received_buf = bytearray() - while len(received_data) < len(test_data): - chunk = await remote_conn.read(chunk_size) - received_data += chunk + while len(received_buf) < len(test_data): + remain = len(test_data) - len(received_buf) + chunk = await remote_conn.read(min(chunk_size, remain)) + received_buf.extend(chunk) - assert received_data == test_data + assert bytes(received_buf) == test_data @pytest.mark.trio async def test_empty_buffer_handling(self, nursery): diff --git a/tests/core/security/noise/test_large_payloads.py b/tests/core/security/noise/test_large_payloads.py index deb2985c1..107f418a1 100644 --- a/tests/core/security/noise/test_large_payloads.py +++ b/tests/core/security/noise/test_large_payloads.py @@ -16,6 +16,23 @@ class TestLargePayloads: """Test large payload handling in Noise transport.""" + @pytest.mark.trio + async def test_go_large_payload_roundtrip(self, nursery): + """Match go-libp2p's large-payload transport test.""" + async with noise_conn_factory(nursery) as conns: + local_conn, remote_conn = conns + + random.seed(1234) + size = 100000 + test_data = bytes(random.getrandbits(8) for _ in range(size)) + + await local_conn.write(test_data) + + received_data = await remote_conn.read(len(test_data)) + + assert len(received_data) == len(test_data) + assert received_data == test_data + @pytest.mark.trio async def test_large_payload_roundtrip(self, nursery): """Test large payload requiring multiple Noise messages.""" diff --git a/tests/core/security/test_secio.py b/tests/core/security/test_secio.py index 55035bbf1..4bf18c8dd 100644 --- a/tests/core/security/test_secio.py +++ b/tests/core/security/test_secio.py @@ -12,9 +12,6 @@ NONCE_SIZE, create_secure_session, ) -from libp2p.tools.constants import ( - MAX_READ_LEN, -) from tests.utils.factories import ( raw_conn_factory, ) @@ -61,5 +58,7 @@ async def remote_create_secure_session(): msg = b"abc" await local_secure_conn.write(msg) - received_msg = await remote_secure_conn.read(MAX_READ_LEN) + # SecureSession.read(n) aggregates until it has exactly n decrypted bytes; + # reading a large fixed cap would block forever after a tiny write. + received_msg = await remote_secure_conn.read(len(msg)) assert received_msg == msg diff --git a/tests/core/stream_muxer/test_yamux.py b/tests/core/stream_muxer/test_yamux.py index 8e0befc89..fc1cc02af 100644 --- a/tests/core/stream_muxer/test_yamux.py +++ b/tests/core/stream_muxer/test_yamux.py @@ -323,9 +323,10 @@ async def test_yamux_flow_control(yamux_pair): # Send the data await client_stream.write(large_data) - # Check that window was reduced - assert client_stream.send_window < initial_window, ( - "Window should be reduced after sending" + # Window was reduced by the send; ACK may have already restored some, + # but it should differ from the initial value. + assert client_stream.send_window != initial_window, ( + "Window should have changed after sending data and receiving ACK" ) # Read the data on the server side diff --git a/tests/core/stream_muxer/test_yamux_interleaving.py b/tests/core/stream_muxer/test_yamux_interleaving.py index facbee7ae..45f091a85 100644 --- a/tests/core/stream_muxer/test_yamux_interleaving.py +++ b/tests/core/stream_muxer/test_yamux_interleaving.py @@ -24,7 +24,16 @@ class TrioStreamAdapter(IRawConnection): - """Adapter to make trio memory streams work with libp2p.""" + """ + Adapter to make trio memory streams work with libp2p. + + Read/write wrap each syscall with :func:`trio.move_on_after(2)`. Checkpoints + alone are not enough on Trio memory streams under this load: the scope still + installs cancel machinery that improves scheduling fairness versus the peer + ``handle_incoming`` loop. Deadline should not elapse on a passing run (~2–3s). + The race test accumulates reads to match partial-read semantics of + :class:`~libp2p.stream_muxer.yamux.yamux.YamuxStream`. + """ def __init__(self, send_stream, receive_stream, is_initiator=False): self.send_stream = send_stream @@ -124,7 +133,8 @@ async def yamux_pair(secure_conn_pair, peer_id): with trio.move_on_after(5): nursery.start_soon(client_yamux.start) nursery.start_soon(server_yamux.start) - await trio.sleep(0.1) + await client_yamux.event_started.wait() + await server_yamux.event_started.wait() logging.debug("yamux_pair started") yield client_yamux, server_yamux logging.debug("yamux_pair cleanup") @@ -170,8 +180,11 @@ async def writer(stream, msgs, name): async def reader(stream, received, name): """Read messages and store them for verification.""" for i in range(MSG_COUNT): - data = await stream.read(MSG_SIZE) - received.append(data) + buf = bytearray() + while len(buf) < MSG_SIZE: + chunk = await stream.read(MSG_SIZE - len(buf)) + buf.extend(chunk) + received.append(bytes(buf)) if i % 3 == 0: await trio.sleep(0.001) diff --git a/tests/core/stream_muxer/yamux/test_yamux_window_update_error_handling.py b/tests/core/stream_muxer/yamux/test_yamux_window_update_error_handling.py index 92715e6b9..617a91f05 100644 --- a/tests/core/stream_muxer/yamux/test_yamux_window_update_error_handling.py +++ b/tests/core/stream_muxer/yamux/test_yamux_window_update_error_handling.py @@ -41,8 +41,7 @@ async def test_send_window_update_handles_connection_closed_error(): by type — no string matching required. """ mock_conn = Mock() - mock_conn.secured_conn = AsyncMock() - mock_conn.secured_conn.write = AsyncMock( + mock_conn._write_frame = AsyncMock( side_effect=ConnectionClosedError( "WebSocket connection closed by peer during write operation", close_code=1000, @@ -57,7 +56,7 @@ async def test_send_window_update_handles_connection_closed_error(): # Should not raise — ConnectionClosedError is handled gracefully await stream.send_window_update(32) - assert mock_conn.secured_conn.write.called + assert mock_conn._write_frame.called @pytest.mark.trio @@ -75,14 +74,13 @@ async def test_send_window_update_handles_connection_closed_error_any_message(): for msg in unusual_messages: mock_conn = Mock() - mock_conn.secured_conn = AsyncMock() - mock_conn.secured_conn.write = AsyncMock( + mock_conn._write_frame = AsyncMock( side_effect=ConnectionClosedError(msg, close_code=1000) ) stream = YamuxStream(1, mock_conn, is_initiator=True) await stream.send_window_update(32) # Should not raise - assert mock_conn.secured_conn.write.called + assert mock_conn._write_frame.called # --------------------------------------------------------------------------- @@ -97,10 +95,7 @@ async def test_send_window_update_handles_raw_conn_error(): gracefully (string-matching fallback for TCP transport). """ mock_conn = Mock() - mock_conn.secured_conn = AsyncMock() - mock_conn.secured_conn.write = AsyncMock( - side_effect=RawConnError("Connection closed") - ) + mock_conn._write_frame = AsyncMock(side_effect=RawConnError("Connection closed")) stream_id = 1 stream = YamuxStream(stream_id, mock_conn, is_initiator=True) @@ -108,7 +103,7 @@ async def test_send_window_update_handles_raw_conn_error(): # Should not raise — falls through to string-matching fallback await stream.send_window_update(32) - assert mock_conn.secured_conn.write.called + assert mock_conn._write_frame.called @pytest.mark.trio @@ -126,14 +121,13 @@ async def test_send_window_update_handles_various_closure_messages(): for error_msg in closure_messages: mock_conn = Mock() - mock_conn.secured_conn = AsyncMock() - mock_conn.secured_conn.write = AsyncMock(side_effect=IOException(error_msg)) + mock_conn._write_frame = AsyncMock(side_effect=IOException(error_msg)) stream = YamuxStream(1, mock_conn, is_initiator=True) # Should not raise for any of these messages await stream.send_window_update(32) - assert mock_conn.secured_conn.write.called + assert mock_conn._write_frame.called # --------------------------------------------------------------------------- @@ -147,8 +141,7 @@ async def test_send_window_update_raises_unexpected_errors(): Test that unexpected errors (not connection closure) are still raised. """ mock_conn = Mock() - mock_conn.secured_conn = AsyncMock() - mock_conn.secured_conn.write = AsyncMock(side_effect=ValueError("Unexpected error")) + mock_conn._write_frame = AsyncMock(side_effect=ValueError("Unexpected error")) stream_id = 1 stream = YamuxStream(stream_id, mock_conn, is_initiator=True) @@ -163,8 +156,7 @@ async def test_send_window_update_raises_non_closure_io_exception(): Test that plain IOException with non-closure message is still raised. """ mock_conn = Mock() - mock_conn.secured_conn = AsyncMock() - mock_conn.secured_conn.write = AsyncMock(side_effect=IOException("Disk full error")) + mock_conn._write_frame = AsyncMock(side_effect=IOException("Disk full error")) stream_id = 1 stream = YamuxStream(stream_id, mock_conn, is_initiator=True) @@ -184,16 +176,15 @@ async def test_send_window_update_succeeds_when_connection_open(): Test that send_window_update succeeds normally when connection is open. """ mock_conn = Mock() - mock_conn.secured_conn = AsyncMock() - mock_conn.secured_conn.write = AsyncMock() # No error + mock_conn._write_frame = AsyncMock() # No error stream_id = 1 stream = YamuxStream(stream_id, mock_conn, is_initiator=True) await stream.send_window_update(32) - assert mock_conn.secured_conn.write.called - call_args = mock_conn.secured_conn.write.call_args[0][0] + assert mock_conn._write_frame.called + call_args = mock_conn._write_frame.call_args[0][0] assert len(call_args) == 12 # Yamux header is 12 bytes assert call_args[1] == 0x1 # Window update type @@ -204,13 +195,13 @@ async def test_send_window_update_skips_zero_increment(): Test that send_window_update skips sending when increment is zero or negative. """ mock_conn = Mock() - mock_conn.secured_conn = AsyncMock() + mock_conn._write_frame = AsyncMock() stream_id = 1 stream = YamuxStream(stream_id, mock_conn, is_initiator=True) await stream.send_window_update(0) - assert not mock_conn.secured_conn.write.called + assert not mock_conn._write_frame.called await stream.send_window_update(-1) - assert not mock_conn.secured_conn.write.called + assert not mock_conn._write_frame.called