Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 33 additions & 4 deletions serialx/platforms/serial_rfc2217/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,6 +723,8 @@ def __init__(
self._rfc2217_waiters: dict[Rfc2217CmdId, asyncio.Future[Rfc2217Command]] = {}
self._tcp_transport: asyncio.Transport | None = None
self._tcp_connection_lost_waiter: asyncio.Future[None] | None = None
self._close_task: asyncio.Task[None] | None = None
self._needs_flush = False

# -- connection lifecycle -----------------------------------------------

Expand Down Expand Up @@ -908,10 +910,13 @@ async def _send_and_wait(self, cmd: Rfc2217Command) -> Rfc2217Command:
def write(self, data: bytes | bytearray | memoryview) -> None:
"""Write data to the serial port, escaping IAC bytes."""
self._check_broken()
if self._closing or self._connection_lost_called:
return
assert self._tcp_transport is not None
escaped = iac_escape(bytes(data))
LOGGER.debug("TX data: %d bytes (%d on wire)", len(data), len(escaped))
self._tcp_transport.write(escaped)
self._needs_flush = True

async def _get_modem_pins(self) -> ModemPins:
"""Return modem pin state from the last NOTIFY-MODEMSTATE."""
Expand Down Expand Up @@ -1007,16 +1012,32 @@ def resume_reading(self) -> None:
self._tcp_transport.resume_reading()

def close(self) -> None:
"""Close the transport."""
if self._connection_lost_called:
"""Close the transport, draining buffered writes to the server first."""
if self._connection_lost_called or self._close_task is not None:
return
self._closing = True
self._mark_user_closed()

if self._tcp_transport is None:
self._tcp_connection_lost(None)
return

if not self._needs_flush:
self._tcp_transport.close()
return

# Closing the TCP socket only flushes our own send buffer; the server may
# still drop bytes it has buffered toward the device on disconnect. Drain
# them first with a req/rsp round-trip, then tear down the connection.
self._close_task = self._loop.create_task(self._flush_and_close())

async def _flush_and_close(self) -> None:
# If the connection drops (or the flush fails) mid-drain, close anyway
with suppress(Exception):
await self._flush()

if self._tcp_transport is not None:
self._tcp_transport.close()
else:
self._tcp_connection_lost(None)

def abort(self) -> None:
"""Abort the transport immediately."""
Expand All @@ -1025,6 +1046,10 @@ def abort(self) -> None:
self._closing = True
self._mark_user_closed()

if self._close_task is not None:
self._close_task.cancel()
self._close_task = None

if self._tcp_transport is not None:
self._tcp_transport.abort()
else:
Expand All @@ -1039,6 +1064,10 @@ async def _flush(self) -> None:
# RFC2217 has no flush. Instead, we "flush" the pipe with a req/rsp sequence.
await self._send_and_wait(SetBaudrateCmd(baudrate=self._serial._baudrate))

# The round-trip confirms the server drained everything we sent, so a
# subsequent close() needs no further flush unless we write again.
self._needs_flush = False

def get_write_buffer_size(self) -> int:
"""Get the number of bytes currently in the write buffer."""
if self._tcp_transport is None:
Expand Down
3 changes: 2 additions & 1 deletion tests/test_async_lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,8 @@ async def test_lifecycle_abort_during_drain_escalates(
await sender.wait_closed()
assert sender_proto.state is ProtocolState.LOST
finally:
sender.close()
# `abort()` so we do not have to wait for the 4MB of data to actually be sent
sender.abort()
receiver.close()
await sender.wait_closed()
await receiver.wait_closed()
Expand Down
Loading