Skip to content

Commit d117b79

Browse files
authored
Merge pull request #643 from ably/AIT-96/realtime-publish
[AIT-96] feat: RealtimeChannel publish over WebSocket implementation
2 parents 70dd615 + b9ab475 commit d117b79

File tree

5 files changed

+1390
-26
lines changed

5 files changed

+1390
-26
lines changed

ably/realtime/connectionmanager.py

Lines changed: 229 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
import asyncio
44
import logging
5+
from collections import deque
56
from datetime import datetime
6-
from queue import Queue
77
from typing import TYPE_CHECKING
88

99
import httpx
@@ -24,6 +24,88 @@
2424
log = logging.getLogger(__name__)
2525

2626

27+
class PendingMessage:
28+
"""Represents a message awaiting acknowledgment from the server"""
29+
30+
def __init__(self, message: dict):
31+
self.message = message
32+
self.future: asyncio.Future | None = None
33+
action = message.get('action')
34+
35+
# Messages that require acknowledgment: MESSAGE, PRESENCE, ANNOTATION, OBJECT
36+
self.ack_required = action in (
37+
ProtocolMessageAction.MESSAGE,
38+
ProtocolMessageAction.PRESENCE,
39+
ProtocolMessageAction.ANNOTATION,
40+
ProtocolMessageAction.OBJECT,
41+
)
42+
43+
if self.ack_required:
44+
self.future = asyncio.Future()
45+
46+
47+
class PendingMessageQueue:
48+
"""Queue for tracking messages awaiting acknowledgment"""
49+
50+
def __init__(self):
51+
self.messages: list[PendingMessage] = []
52+
53+
def push(self, pending_message: PendingMessage) -> None:
54+
"""Add a message to the queue"""
55+
self.messages.append(pending_message)
56+
57+
def count(self) -> int:
58+
"""Return the number of pending messages"""
59+
return len(self.messages)
60+
61+
def complete_messages(self, serial: int, count: int, err: AblyException | None = None) -> None:
62+
"""Complete messages based on serial and count from ACK/NACK
63+
64+
Args:
65+
serial: The msgSerial of the first message being acknowledged
66+
count: The number of messages being acknowledged
67+
err: Error from NACK, or None for successful ACK
68+
"""
69+
log.debug(f'MessageQueue.complete_messages(): serial={serial}, count={count}, err={err}')
70+
71+
if not self.messages:
72+
log.warning('MessageQueue.complete_messages(): called on empty queue')
73+
return
74+
75+
first = self.messages[0]
76+
if first:
77+
start_serial = first.message.get('msgSerial')
78+
if start_serial is None:
79+
log.warning('MessageQueue.complete_messages(): first message has no msgSerial')
80+
return
81+
82+
end_serial = serial + count
83+
84+
if end_serial > start_serial:
85+
# Remove and complete the acknowledged messages
86+
num_to_complete = min(end_serial - start_serial, len(self.messages))
87+
completed_messages = self.messages[:num_to_complete]
88+
self.messages = self.messages[num_to_complete:]
89+
90+
for msg in completed_messages:
91+
if msg.future and not msg.future.done():
92+
if err:
93+
msg.future.set_exception(err)
94+
else:
95+
msg.future.set_result(None)
96+
97+
def complete_all_messages(self, err: AblyException) -> None:
98+
"""Complete all pending messages with an error"""
99+
while self.messages:
100+
msg = self.messages.pop(0)
101+
if msg.future and not msg.future.done():
102+
msg.future.set_exception(err)
103+
104+
def clear(self) -> None:
105+
"""Clear all messages from the queue"""
106+
self.messages.clear()
107+
108+
27109
class ConnectionManager(EventEmitter):
28110
def __init__(self, realtime: AblyRealtime, initial_state):
29111
self.options = realtime.options
@@ -41,8 +123,10 @@ def __init__(self, realtime: AblyRealtime, initial_state):
41123
self.connect_base_task: asyncio.Task | None = None
42124
self.disconnect_transport_task: asyncio.Task | None = None
43125
self.__fallback_hosts: list[str] = self.options.get_fallback_realtime_hosts()
44-
self.queued_messages: Queue = Queue()
126+
self.queued_messages: deque[PendingMessage] = deque()
45127
self.__error_reason: AblyException | None = None
128+
self.msg_serial: int = 0
129+
self.pending_message_queue: PendingMessageQueue = PendingMessageQueue()
46130
super().__init__()
47131

48132
def enact_state_change(self, state: ConnectionState, reason: AblyException | None = None) -> None:
@@ -88,37 +172,109 @@ async def close_impl(self) -> None:
88172
self.notify_state(ConnectionState.CLOSED)
89173

90174
async def send_protocol_message(self, protocol_message: dict) -> None:
91-
if self.state in (
92-
ConnectionState.DISCONNECTED,
93-
ConnectionState.CONNECTING,
94-
):
95-
self.queued_messages.put(protocol_message)
96-
return
97-
98-
if self.state == ConnectionState.CONNECTED:
99-
if self.transport:
100-
await self.transport.send(protocol_message)
101-
else:
102-
log.exception(
103-
"ConnectionManager.send_protocol_message(): can not send message with no active transport"
175+
"""Send a protocol message and optionally track it for acknowledgment
176+
177+
Args:
178+
protocol_message: protocol message dict (new message)
179+
Returns:
180+
None
181+
"""
182+
if self.state not in (ConnectionState.DISCONNECTED, ConnectionState.CONNECTING, ConnectionState.CONNECTED):
183+
raise AblyException(f"ConnectionManager.send_protocol_message(): called in {self.state}", 500, 50000)
184+
185+
pending_message = PendingMessage(protocol_message)
186+
187+
# Assign msgSerial to messages that need acknowledgment
188+
if pending_message.ack_required:
189+
# New message - assign fresh serial
190+
protocol_message['msgSerial'] = self.msg_serial
191+
self.pending_message_queue.push(pending_message)
192+
self.msg_serial += 1
193+
194+
if self.state in (ConnectionState.DISCONNECTED, ConnectionState.CONNECTING):
195+
self.queued_messages.appendleft(pending_message)
196+
if pending_message.ack_required:
197+
await pending_message.future
198+
return None
199+
200+
return await self._send_protocol_message_on_connected_state(pending_message)
201+
202+
async def _send_protocol_message_on_connected_state(self, pending_message: PendingMessage) -> None:
203+
if self.state == ConnectionState.CONNECTED and self.transport:
204+
# Add to pending queue before sending (for messages being resent from queue)
205+
if pending_message.ack_required and pending_message not in self.pending_message_queue.messages:
206+
self.pending_message_queue.push(pending_message)
207+
await self.transport.send(pending_message.message)
208+
else:
209+
log.exception(
210+
"ConnectionManager.send_protocol_message(): can not send message with no active transport"
211+
)
212+
if pending_message.future:
213+
pending_message.future.set_exception(
214+
AblyException("No active transport", 500, 50000)
104215
)
216+
if pending_message.ack_required:
217+
await pending_message.future
218+
return None
219+
220+
def send_queued_messages(self) -> None:
221+
log.info(f'ConnectionManager.send_queued_messages(): sending {len(self.queued_messages)} message(s)')
222+
while len(self.queued_messages) > 0:
223+
pending_message = self.queued_messages.pop()
224+
asyncio.create_task(self._send_protocol_message_on_connected_state(pending_message))
225+
226+
def requeue_pending_messages(self) -> None:
227+
"""RTN19a: Requeue messages awaiting ACK/NACK when transport disconnects
228+
229+
These messages will be resent when connection becomes CONNECTED again.
230+
RTN19a2: msgSerial is preserved for resume, reset for new connection.
231+
"""
232+
pending_count = self.pending_message_queue.count()
233+
if pending_count == 0:
105234
return
106235

107-
raise AblyException(f"ConnectionManager.send_protocol_message(): called in {self.state}", 500, 50000)
236+
log.info(
237+
f'ConnectionManager.requeue_pending_messages(): '
238+
f'requeuing {pending_count} pending message(s) for resend'
239+
)
108240

109-
def send_queued_messages(self) -> None:
110-
log.info(f'ConnectionManager.send_queued_messages(): sending {self.queued_messages.qsize()} message(s)')
111-
while not self.queued_messages.empty():
112-
asyncio.create_task(self.send_protocol_message(self.queued_messages.get()))
241+
# Get all pending messages and add them back to the queue
242+
# They'll be sent again when we reconnect
243+
pending_messages = list(self.pending_message_queue.messages)
244+
245+
# Add back to front of queue (FIFO but priority over new messages)
246+
# Store the entire PendingMessage object to preserve Future
247+
for pending_msg in reversed(pending_messages):
248+
# PendingMessage object retains its Future, msgSerial
249+
self.queued_messages.append(pending_msg)
250+
251+
# Clear the message queue since we're requeueing them all
252+
# When they're resent, the existing Future will be resolved
253+
self.pending_message_queue.clear()
113254

114255
def fail_queued_messages(self, err) -> None:
115256
log.info(
116-
f"ConnectionManager.fail_queued_messages(): discarding {self.queued_messages.qsize()} messages;" +
257+
f"ConnectionManager.fail_queued_messages(): discarding {len(self.queued_messages)} messages;" +
117258
f" reason = {err}"
118259
)
119-
while not self.queued_messages.empty():
120-
msg = self.queued_messages.get()
121-
log.exception(f"ConnectionManager.fail_queued_messages(): Failed to send protocol message: {msg}")
260+
error = err or AblyException("Connection failed", 80000, 500)
261+
while len(self.queued_messages) > 0:
262+
pending_msg = self.queued_messages.pop()
263+
log.exception(
264+
f"ConnectionManager.fail_queued_messages(): Failed to send protocol message: "
265+
f"{pending_msg.message}"
266+
)
267+
# Fail the Future if it exists
268+
if pending_msg.future and not pending_msg.future.done():
269+
pending_msg.future.set_exception(error)
270+
271+
# Also fail all pending messages awaiting acknowledgment
272+
if self.pending_message_queue.count() > 0:
273+
count = self.pending_message_queue.count()
274+
log.info(
275+
f"ConnectionManager.fail_queued_messages(): failing {count} pending messages"
276+
)
277+
self.pending_message_queue.complete_all_messages(error)
122278

123279
async def ping(self) -> float:
124280
if self.__ping_future:
@@ -149,6 +305,16 @@ def on_connected(self, connection_details: ConnectionDetails, connection_id: str
149305
reason: AblyException | None = None) -> None:
150306
self.__fail_state = ConnectionState.DISCONNECTED
151307

308+
# RTN19a2: Reset msgSerial if connectionId changed (new connection)
309+
prev_connection_id = self.connection_id
310+
connection_id_changed = prev_connection_id is not None and prev_connection_id != connection_id
311+
312+
if connection_id_changed:
313+
log.info('ConnectionManager.on_connected(): New connectionId; resetting msgSerial')
314+
self.msg_serial = 0
315+
# Note: In JS they call resetSendAttempted() here, but we don't need it
316+
# because we fail all pending messages on disconnect per RTN7e
317+
152318
self.__connection_details = connection_details
153319
self.connection_id = connection_id
154320

@@ -244,7 +410,36 @@ def on_heartbeat(self, id: str | None) -> None:
244410
self.__ping_future.set_result(None)
245411
self.__ping_future = None
246412

413+
def on_ack(self, serial: int, count: int) -> None:
414+
"""Handle ACK protocol message from server
415+
416+
Args:
417+
serial: The msgSerial of the first message being acknowledged
418+
count: The number of messages being acknowledged
419+
"""
420+
log.debug(f'ConnectionManager.on_ack(): serial={serial}, count={count}')
421+
self.pending_message_queue.complete_messages(serial, count)
422+
423+
def on_nack(self, serial: int, count: int, err: AblyException | None) -> None:
424+
"""Handle NACK protocol message from server
425+
426+
Args:
427+
serial: The msgSerial of the first message being rejected
428+
count: The number of messages being rejected
429+
err: Error information from the server
430+
"""
431+
if not err:
432+
err = AblyException('Unable to send message; channel not responding', 50001, 500)
433+
434+
log.error(f'ConnectionManager.on_nack(): serial={serial}, count={count}, err={err}')
435+
self.pending_message_queue.complete_messages(serial, count, err)
436+
247437
def deactivate_transport(self, reason: AblyException | None = None):
438+
# RTN19a: Before disconnecting, requeue any pending messages
439+
# so they'll be resent on reconnection
440+
if self.transport:
441+
log.info('ConnectionManager.deactivate_transport(): requeuing pending messages')
442+
self.requeue_pending_messages()
248443
self.transport = None
249444
self.notify_state(ConnectionState.DISCONNECTED, reason)
250445

@@ -383,8 +578,16 @@ def notify_state(self, state: ConnectionState, reason: AblyException | None = No
383578
ConnectionState.SUSPENDED,
384579
ConnectionState.FAILED,
385580
):
581+
# RTN7e: Fail pending messages on SUSPENDED, CLOSED, FAILED
386582
self.fail_queued_messages(reason)
387583
self.ably.channels._propagate_connection_interruption(state, reason)
584+
elif state == ConnectionState.DISCONNECTED and not self.options.queue_messages:
585+
# RTN7d: If queueMessages is false, fail pending messages on DISCONNECTED
586+
log.info(
587+
'ConnectionManager.notify_state(): queueMessages is false; '
588+
'failing pending messages on DISCONNECTED'
589+
)
590+
self.fail_queued_messages(reason)
388591

389592
def start_transition_timer(self, state: ConnectionState, fail_state: ConnectionState | None = None) -> None:
390593
log.debug(f'ConnectionManager.start_transition_timer(): transition state = {state}')
@@ -466,6 +669,8 @@ def cancel_retry_timer(self) -> None:
466669
def disconnect_transport(self) -> None:
467670
log.info('ConnectionManager.disconnect_transport()')
468671
if self.transport:
672+
# RTN19a: Requeue pending messages before disposing transport
673+
self.requeue_pending_messages()
469674
self.disconnect_transport_task = asyncio.create_task(self.transport.dispose())
470675

471676
async def on_auth_updated(self, token_details: TokenDetails):

0 commit comments

Comments
 (0)