@@ -45,17 +45,25 @@ class SubscribeRequest:
4545
4646
4747class Subscription :
48- def __init__ (self , subscription_id : int , session : AsyncSession ):
48+ def __init__ (
49+ self , subscription_id : int , session : AsyncSession , event_handler : Callable [[types .Event ], Awaitable [None ]]
50+ ):
4951 self .subscription_id = subscription_id
5052 self ._session = session
53+ self ._event_handler = event_handler
5154
5255 async def unsubscribe (self ) -> None :
5356 if not await self ._session ._base_session .transport .is_connected ():
5457 raise Exception ("cannot unsubscribe topic: session not established" )
5558
56- unsubscribe = messages .Unsubscribe (
57- messages .UnsubscribeFields (self ._session ._idgen .next (), self .subscription_id )
58- )
59+ subscriptions = self ._session .subscriptions .get (self .subscription_id , None )
60+ if subscriptions is not None :
61+ subscriptions .pop (self , None )
62+ if len (subscriptions ) != 0 :
63+ self ._session ._subscriptions [self .subscription_id ] = subscriptions
64+ return None
65+
66+ unsubscribe = messages .Unsubscribe (messages .UnsubscribeFields (self ._session .idgen .next (), self .subscription_id ))
5967 data = self ._session ._session .send_message (unsubscribe )
6068
6169 f : Future = Future ()
@@ -79,7 +87,7 @@ def __init__(self, base_session: types.IAsyncBaseSession):
7987 # PubSub data structures
8088 self ._publish_requests : dict [int , Future [None ]] = {}
8189 self ._subscribe_requests : dict [int , SubscribeRequest ] = {}
82- self ._subscriptions : dict [int , Callable [[ types . Event ], Awaitable [ None ] ]] = {}
90+ self ._subscriptions : dict [int , dict [ Subscription , Subscription ]] = {}
8391 self ._unsubscribe_requests : dict [int , types .UnsubscribeRequest ] = {}
8492
8593 self ._goodbye_request = Future ()
@@ -155,8 +163,14 @@ async def _process_incoming_message(self, msg: messages.Message):
155163 await self ._base_session .send (data )
156164 elif isinstance (msg , messages .Subscribed ):
157165 request = self ._subscribe_requests .pop (msg .request_id )
158- self ._subscriptions [msg .subscription_id ] = request .endpoint
159- request .future .set_result (Subscription (msg .subscription_id , self ))
166+ sub = Subscription (msg .subscription_id , self , request .endpoint )
167+ subscriptions = self ._subscriptions .get (msg .subscription_id , None )
168+ if subscriptions is None :
169+ self ._subscriptions [msg .subscription_id ] = {sub : sub }
170+ else :
171+ subscriptions [sub ] = sub
172+
173+ request .future .set_result (sub )
160174 elif isinstance (msg , messages .Unsubscribed ):
161175 request = self ._unsubscribe_requests .pop (msg .request_id )
162176 del self ._subscriptions [request .subscription_id ]
@@ -165,9 +179,11 @@ async def _process_incoming_message(self, msg: messages.Message):
165179 request = self ._publish_requests .pop (msg .request_id )
166180 request .set_result (None )
167181 elif isinstance (msg , messages .Event ):
168- endpoint = self ._subscriptions [msg .subscription_id ]
169182 try :
170- await endpoint (types .Event (msg .args , msg .kwargs , msg .details ))
183+ subscriptions = self ._subscriptions [msg .subscription_id ]
184+ event = types .Event (msg .args , msg .kwargs , msg .details )
185+ for subscription in subscriptions .keys ():
186+ await subscription ._event_handler (event )
171187 except Exception as e :
172188 print (e )
173189 elif isinstance (msg , messages .Error ):
0 commit comments