22
33import asyncio
44import logging
5+ from collections import deque
56from datetime import datetime
6- from queue import Queue
77from typing import TYPE_CHECKING
88
99import httpx
2424log = logging .getLogger (__name__ )
2525
2626
27+ class PendingMessage :
28+ """Represents a message awaiting acknowledgment from the server"""
29+
30+ def __init__ (self , message : dict ):
31+ self .message = message
32+ self .future : asyncio .Future | None = None
33+ action = message .get ('action' )
34+
35+ # Messages that require acknowledgment: MESSAGE, PRESENCE, ANNOTATION, OBJECT
36+ self .ack_required = action in (
37+ ProtocolMessageAction .MESSAGE ,
38+ ProtocolMessageAction .PRESENCE ,
39+ ProtocolMessageAction .ANNOTATION ,
40+ ProtocolMessageAction .OBJECT ,
41+ )
42+
43+ if self .ack_required :
44+ self .future = asyncio .Future ()
45+
46+
47+ class PendingMessageQueue :
48+ """Queue for tracking messages awaiting acknowledgment"""
49+
50+ def __init__ (self ):
51+ self .messages : list [PendingMessage ] = []
52+
53+ def push (self , pending_message : PendingMessage ) -> None :
54+ """Add a message to the queue"""
55+ self .messages .append (pending_message )
56+
57+ def count (self ) -> int :
58+ """Return the number of pending messages"""
59+ return len (self .messages )
60+
61+ def complete_messages (self , serial : int , count : int , err : AblyException | None = None ) -> None :
62+ """Complete messages based on serial and count from ACK/NACK
63+
64+ Args:
65+ serial: The msgSerial of the first message being acknowledged
66+ count: The number of messages being acknowledged
67+ err: Error from NACK, or None for successful ACK
68+ """
69+ log .debug (f'MessageQueue.complete_messages(): serial={ serial } , count={ count } , err={ err } ' )
70+
71+ if not self .messages :
72+ log .warning ('MessageQueue.complete_messages(): called on empty queue' )
73+ return
74+
75+ first = self .messages [0 ]
76+ if first :
77+ start_serial = first .message .get ('msgSerial' )
78+ if start_serial is None :
79+ log .warning ('MessageQueue.complete_messages(): first message has no msgSerial' )
80+ return
81+
82+ end_serial = serial + count
83+
84+ if end_serial > start_serial :
85+ # Remove and complete the acknowledged messages
86+ num_to_complete = min (end_serial - start_serial , len (self .messages ))
87+ completed_messages = self .messages [:num_to_complete ]
88+ self .messages = self .messages [num_to_complete :]
89+
90+ for msg in completed_messages :
91+ if msg .future and not msg .future .done ():
92+ if err :
93+ msg .future .set_exception (err )
94+ else :
95+ msg .future .set_result (None )
96+
97+ def complete_all_messages (self , err : AblyException ) -> None :
98+ """Complete all pending messages with an error"""
99+ while self .messages :
100+ msg = self .messages .pop (0 )
101+ if msg .future and not msg .future .done ():
102+ msg .future .set_exception (err )
103+
104+ def clear (self ) -> None :
105+ """Clear all messages from the queue"""
106+ self .messages .clear ()
107+
108+
27109class ConnectionManager (EventEmitter ):
28110 def __init__ (self , realtime : AblyRealtime , initial_state ):
29111 self .options = realtime .options
@@ -41,8 +123,10 @@ def __init__(self, realtime: AblyRealtime, initial_state):
41123 self .connect_base_task : asyncio .Task | None = None
42124 self .disconnect_transport_task : asyncio .Task | None = None
43125 self .__fallback_hosts : list [str ] = self .options .get_fallback_realtime_hosts ()
44- self .queued_messages : Queue = Queue ()
126+ self .queued_messages : deque [ PendingMessage ] = deque ()
45127 self .__error_reason : AblyException | None = None
128+ self .msg_serial : int = 0
129+ self .pending_message_queue : PendingMessageQueue = PendingMessageQueue ()
46130 super ().__init__ ()
47131
48132 def enact_state_change (self , state : ConnectionState , reason : AblyException | None = None ) -> None :
@@ -88,37 +172,109 @@ async def close_impl(self) -> None:
88172 self .notify_state (ConnectionState .CLOSED )
89173
90174 async def send_protocol_message (self , protocol_message : dict ) -> None :
91- if self .state in (
92- ConnectionState .DISCONNECTED ,
93- ConnectionState .CONNECTING ,
94- ):
95- self .queued_messages .put (protocol_message )
96- return
97-
98- if self .state == ConnectionState .CONNECTED :
99- if self .transport :
100- await self .transport .send (protocol_message )
101- else :
102- log .exception (
103- "ConnectionManager.send_protocol_message(): can not send message with no active transport"
175+ """Send a protocol message and optionally track it for acknowledgment
176+
177+ Args:
178+ protocol_message: protocol message dict (new message)
179+ Returns:
180+ None
181+ """
182+ if self .state not in (ConnectionState .DISCONNECTED , ConnectionState .CONNECTING , ConnectionState .CONNECTED ):
183+ raise AblyException (f"ConnectionManager.send_protocol_message(): called in { self .state } " , 500 , 50000 )
184+
185+ pending_message = PendingMessage (protocol_message )
186+
187+ # Assign msgSerial to messages that need acknowledgment
188+ if pending_message .ack_required :
189+ # New message - assign fresh serial
190+ protocol_message ['msgSerial' ] = self .msg_serial
191+ self .pending_message_queue .push (pending_message )
192+ self .msg_serial += 1
193+
194+ if self .state in (ConnectionState .DISCONNECTED , ConnectionState .CONNECTING ):
195+ self .queued_messages .appendleft (pending_message )
196+ if pending_message .ack_required :
197+ await pending_message .future
198+ return None
199+
200+ return await self ._send_protocol_message_on_connected_state (pending_message )
201+
202+ async def _send_protocol_message_on_connected_state (self , pending_message : PendingMessage ) -> None :
203+ if self .state == ConnectionState .CONNECTED and self .transport :
204+ # Add to pending queue before sending (for messages being resent from queue)
205+ if pending_message .ack_required and pending_message not in self .pending_message_queue .messages :
206+ self .pending_message_queue .push (pending_message )
207+ await self .transport .send (pending_message .message )
208+ else :
209+ log .exception (
210+ "ConnectionManager.send_protocol_message(): can not send message with no active transport"
211+ )
212+ if pending_message .future :
213+ pending_message .future .set_exception (
214+ AblyException ("No active transport" , 500 , 50000 )
104215 )
216+ if pending_message .ack_required :
217+ await pending_message .future
218+ return None
219+
220+ def send_queued_messages (self ) -> None :
221+ log .info (f'ConnectionManager.send_queued_messages(): sending { len (self .queued_messages )} message(s)' )
222+ while len (self .queued_messages ) > 0 :
223+ pending_message = self .queued_messages .pop ()
224+ asyncio .create_task (self ._send_protocol_message_on_connected_state (pending_message ))
225+
226+ def requeue_pending_messages (self ) -> None :
227+ """RTN19a: Requeue messages awaiting ACK/NACK when transport disconnects
228+
229+ These messages will be resent when connection becomes CONNECTED again.
230+ RTN19a2: msgSerial is preserved for resume, reset for new connection.
231+ """
232+ pending_count = self .pending_message_queue .count ()
233+ if pending_count == 0 :
105234 return
106235
107- raise AblyException (f"ConnectionManager.send_protocol_message(): called in { self .state } " , 500 , 50000 )
236+ log .info (
237+ f'ConnectionManager.requeue_pending_messages(): '
238+ f'requeuing { pending_count } pending message(s) for resend'
239+ )
108240
109- def send_queued_messages (self ) -> None :
110- log .info (f'ConnectionManager.send_queued_messages(): sending { self .queued_messages .qsize ()} message(s)' )
111- while not self .queued_messages .empty ():
112- asyncio .create_task (self .send_protocol_message (self .queued_messages .get ()))
241+ # Get all pending messages and add them back to the queue
242+ # They'll be sent again when we reconnect
243+ pending_messages = list (self .pending_message_queue .messages )
244+
245+ # Add back to front of queue (FIFO but priority over new messages)
246+ # Store the entire PendingMessage object to preserve Future
247+ for pending_msg in reversed (pending_messages ):
248+ # PendingMessage object retains its Future, msgSerial
249+ self .queued_messages .append (pending_msg )
250+
251+ # Clear the message queue since we're requeueing them all
252+ # When they're resent, the existing Future will be resolved
253+ self .pending_message_queue .clear ()
113254
114255 def fail_queued_messages (self , err ) -> None :
115256 log .info (
116- f"ConnectionManager.fail_queued_messages(): discarding { self .queued_messages . qsize ( )} messages;" +
257+ f"ConnectionManager.fail_queued_messages(): discarding { len ( self .queued_messages )} messages;" +
117258 f" reason = { err } "
118259 )
119- while not self .queued_messages .empty ():
120- msg = self .queued_messages .get ()
121- log .exception (f"ConnectionManager.fail_queued_messages(): Failed to send protocol message: { msg } " )
260+ error = err or AblyException ("Connection failed" , 80000 , 500 )
261+ while len (self .queued_messages ) > 0 :
262+ pending_msg = self .queued_messages .pop ()
263+ log .exception (
264+ f"ConnectionManager.fail_queued_messages(): Failed to send protocol message: "
265+ f"{ pending_msg .message } "
266+ )
267+ # Fail the Future if it exists
268+ if pending_msg .future and not pending_msg .future .done ():
269+ pending_msg .future .set_exception (error )
270+
271+ # Also fail all pending messages awaiting acknowledgment
272+ if self .pending_message_queue .count () > 0 :
273+ count = self .pending_message_queue .count ()
274+ log .info (
275+ f"ConnectionManager.fail_queued_messages(): failing { count } pending messages"
276+ )
277+ self .pending_message_queue .complete_all_messages (error )
122278
123279 async def ping (self ) -> float :
124280 if self .__ping_future :
@@ -149,6 +305,16 @@ def on_connected(self, connection_details: ConnectionDetails, connection_id: str
149305 reason : AblyException | None = None ) -> None :
150306 self .__fail_state = ConnectionState .DISCONNECTED
151307
308+ # RTN19a2: Reset msgSerial if connectionId changed (new connection)
309+ prev_connection_id = self .connection_id
310+ connection_id_changed = prev_connection_id is not None and prev_connection_id != connection_id
311+
312+ if connection_id_changed :
313+ log .info ('ConnectionManager.on_connected(): New connectionId; resetting msgSerial' )
314+ self .msg_serial = 0
315+ # Note: In JS they call resetSendAttempted() here, but we don't need it
316+ # because we fail all pending messages on disconnect per RTN7e
317+
152318 self .__connection_details = connection_details
153319 self .connection_id = connection_id
154320
@@ -244,7 +410,36 @@ def on_heartbeat(self, id: str | None) -> None:
244410 self .__ping_future .set_result (None )
245411 self .__ping_future = None
246412
413+ def on_ack (self , serial : int , count : int ) -> None :
414+ """Handle ACK protocol message from server
415+
416+ Args:
417+ serial: The msgSerial of the first message being acknowledged
418+ count: The number of messages being acknowledged
419+ """
420+ log .debug (f'ConnectionManager.on_ack(): serial={ serial } , count={ count } ' )
421+ self .pending_message_queue .complete_messages (serial , count )
422+
423+ def on_nack (self , serial : int , count : int , err : AblyException | None ) -> None :
424+ """Handle NACK protocol message from server
425+
426+ Args:
427+ serial: The msgSerial of the first message being rejected
428+ count: The number of messages being rejected
429+ err: Error information from the server
430+ """
431+ if not err :
432+ err = AblyException ('Unable to send message; channel not responding' , 50001 , 500 )
433+
434+ log .error (f'ConnectionManager.on_nack(): serial={ serial } , count={ count } , err={ err } ' )
435+ self .pending_message_queue .complete_messages (serial , count , err )
436+
247437 def deactivate_transport (self , reason : AblyException | None = None ):
438+ # RTN19a: Before disconnecting, requeue any pending messages
439+ # so they'll be resent on reconnection
440+ if self .transport :
441+ log .info ('ConnectionManager.deactivate_transport(): requeuing pending messages' )
442+ self .requeue_pending_messages ()
248443 self .transport = None
249444 self .notify_state (ConnectionState .DISCONNECTED , reason )
250445
@@ -383,8 +578,16 @@ def notify_state(self, state: ConnectionState, reason: AblyException | None = No
383578 ConnectionState .SUSPENDED ,
384579 ConnectionState .FAILED ,
385580 ):
581+ # RTN7e: Fail pending messages on SUSPENDED, CLOSED, FAILED
386582 self .fail_queued_messages (reason )
387583 self .ably .channels ._propagate_connection_interruption (state , reason )
584+ elif state == ConnectionState .DISCONNECTED and not self .options .queue_messages :
585+ # RTN7d: If queueMessages is false, fail pending messages on DISCONNECTED
586+ log .info (
587+ 'ConnectionManager.notify_state(): queueMessages is false; '
588+ 'failing pending messages on DISCONNECTED'
589+ )
590+ self .fail_queued_messages (reason )
388591
389592 def start_transition_timer (self , state : ConnectionState , fail_state : ConnectionState | None = None ) -> None :
390593 log .debug (f'ConnectionManager.start_transition_timer(): transition state = { state } ' )
@@ -466,6 +669,8 @@ def cancel_retry_timer(self) -> None:
466669 def disconnect_transport (self ) -> None :
467670 log .info ('ConnectionManager.disconnect_transport()' )
468671 if self .transport :
672+ # RTN19a: Requeue pending messages before disposing transport
673+ self .requeue_pending_messages ()
469674 self .disconnect_transport_task = asyncio .create_task (self .transport .dispose ())
470675
471676 async def on_auth_updated (self , token_details : TokenDetails ):
0 commit comments