diff --git a/serialx/platforms/serial_rfc2217/__init__.py b/serialx/platforms/serial_rfc2217/__init__.py index c05f04d..365c33a 100644 --- a/serialx/platforms/serial_rfc2217/__init__.py +++ b/serialx/platforms/serial_rfc2217/__init__.py @@ -723,6 +723,8 @@ def __init__( self._rfc2217_waiters: dict[Rfc2217CmdId, asyncio.Future[Rfc2217Command]] = {} self._tcp_transport: asyncio.Transport | None = None self._tcp_connection_lost_waiter: asyncio.Future[None] | None = None + self._close_task: asyncio.Task[None] | None = None + self._needs_flush = False # -- connection lifecycle ----------------------------------------------- @@ -908,10 +910,13 @@ async def _send_and_wait(self, cmd: Rfc2217Command) -> Rfc2217Command: def write(self, data: bytes | bytearray | memoryview) -> None: """Write data to the serial port, escaping IAC bytes.""" self._check_broken() + if self._closing or self._connection_lost_called: + return assert self._tcp_transport is not None escaped = iac_escape(bytes(data)) LOGGER.debug("TX data: %d bytes (%d on wire)", len(data), len(escaped)) self._tcp_transport.write(escaped) + self._needs_flush = True async def _get_modem_pins(self) -> ModemPins: """Return modem pin state from the last NOTIFY-MODEMSTATE.""" @@ -1007,16 +1012,32 @@ def resume_reading(self) -> None: self._tcp_transport.resume_reading() def close(self) -> None: - """Close the transport.""" - if self._connection_lost_called: + """Close the transport, draining buffered writes to the server first.""" + if self._connection_lost_called or self._close_task is not None: return self._closing = True self._mark_user_closed() + if self._tcp_transport is None: + self._tcp_connection_lost(None) + return + + if not self._needs_flush: + self._tcp_transport.close() + return + + # Closing the TCP socket only flushes our own send buffer; the server may + # still drop bytes it has buffered toward the device on disconnect. Drain + # them first with a req/rsp round-trip, then tear down the connection. + self._close_task = self._loop.create_task(self._flush_and_close()) + + async def _flush_and_close(self) -> None: + # If the connection drops (or the flush fails) mid-drain, close anyway + with suppress(Exception): + await self._flush() + if self._tcp_transport is not None: self._tcp_transport.close() - else: - self._tcp_connection_lost(None) def abort(self) -> None: """Abort the transport immediately.""" @@ -1025,6 +1046,10 @@ def abort(self) -> None: self._closing = True self._mark_user_closed() + if self._close_task is not None: + self._close_task.cancel() + self._close_task = None + if self._tcp_transport is not None: self._tcp_transport.abort() else: @@ -1039,6 +1064,10 @@ async def _flush(self) -> None: # RFC2217 has no flush. Instead, we "flush" the pipe with a req/rsp sequence. await self._send_and_wait(SetBaudrateCmd(baudrate=self._serial._baudrate)) + # The round-trip confirms the server drained everything we sent, so a + # subsequent close() needs no further flush unless we write again. + self._needs_flush = False + def get_write_buffer_size(self) -> int: """Get the number of bytes currently in the write buffer.""" if self._tcp_transport is None: diff --git a/tests/test_async_lifecycle.py b/tests/test_async_lifecycle.py index de22ee6..71dd3b8 100644 --- a/tests/test_async_lifecycle.py +++ b/tests/test_async_lifecycle.py @@ -313,7 +313,8 @@ async def test_lifecycle_abort_during_drain_escalates( await sender.wait_closed() assert sender_proto.state is ProtocolState.LOST finally: - sender.close() + # `abort()` so we do not have to wait for the 4MB of data to actually be sent + sender.abort() receiver.close() await sender.wait_closed() await receiver.wait_closed()