Skip to content

Commit 18ffca9

Browse files
committed
fix: allow multiple subscriptions to the same URI in a single session
1 parent 5c51923 commit 18ffca9

2 files changed

Lines changed: 47 additions & 15 deletions

File tree

xconn/async_session.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -45,17 +45,25 @@ class SubscribeRequest:
4545

4646

4747
class Subscription:
48-
def __init__(self, subscription_id: int, session: AsyncSession):
48+
def __init__(
49+
self, subscription_id: int, session: AsyncSession, event_handler: Callable[[types.Event], Awaitable[None]]
50+
):
4951
self.subscription_id = subscription_id
5052
self._session = session
53+
self._event_handler = event_handler
5154

5255
async def unsubscribe(self) -> None:
5356
if not await self._session._base_session.transport.is_connected():
5457
raise Exception("cannot unsubscribe topic: session not established")
5558

56-
unsubscribe = messages.Unsubscribe(
57-
messages.UnsubscribeFields(self._session._idgen.next(), self.subscription_id)
58-
)
59+
subscriptions = self._session.subscriptions.get(self.subscription_id, None)
60+
if subscriptions is not None:
61+
subscriptions.pop(self, None)
62+
if len(subscriptions) != 0:
63+
self._session._subscriptions[self.subscription_id] = subscriptions
64+
return None
65+
66+
unsubscribe = messages.Unsubscribe(messages.UnsubscribeFields(self._session.idgen.next(), self.subscription_id))
5967
data = self._session._session.send_message(unsubscribe)
6068

6169
f: Future = Future()
@@ -79,7 +87,7 @@ def __init__(self, base_session: types.IAsyncBaseSession):
7987
# PubSub data structures
8088
self._publish_requests: dict[int, Future[None]] = {}
8189
self._subscribe_requests: dict[int, SubscribeRequest] = {}
82-
self._subscriptions: dict[int, Callable[[types.Event], Awaitable[None]]] = {}
90+
self._subscriptions: dict[int, dict[Subscription, Subscription]] = {}
8391
self._unsubscribe_requests: dict[int, types.UnsubscribeRequest] = {}
8492

8593
self._goodbye_request = Future()
@@ -155,8 +163,14 @@ async def _process_incoming_message(self, msg: messages.Message):
155163
await self._base_session.send(data)
156164
elif isinstance(msg, messages.Subscribed):
157165
request = self._subscribe_requests.pop(msg.request_id)
158-
self._subscriptions[msg.subscription_id] = request.endpoint
159-
request.future.set_result(Subscription(msg.subscription_id, self))
166+
sub = Subscription(msg.subscription_id, self, request.endpoint)
167+
subscriptions = self._subscriptions.get(msg.subscription_id, None)
168+
if subscriptions is None:
169+
self._subscriptions[msg.subscription_id] = {sub: sub}
170+
else:
171+
subscriptions[sub] = sub
172+
173+
request.future.set_result(sub)
160174
elif isinstance(msg, messages.Unsubscribed):
161175
request = self._unsubscribe_requests.pop(msg.request_id)
162176
del self._subscriptions[request.subscription_id]
@@ -165,9 +179,11 @@ async def _process_incoming_message(self, msg: messages.Message):
165179
request = self._publish_requests.pop(msg.request_id)
166180
request.set_result(None)
167181
elif isinstance(msg, messages.Event):
168-
endpoint = self._subscriptions[msg.subscription_id]
169182
try:
170-
await endpoint(types.Event(msg.args, msg.kwargs, msg.details))
183+
subscriptions = self._subscriptions[msg.subscription_id]
184+
event = types.Event(msg.args, msg.kwargs, msg.details)
185+
for subscription in subscriptions.keys():
186+
await subscription._event_handler(event)
171187
except Exception as e:
172188
print(e)
173189
elif isinstance(msg, messages.Error):

xconn/session.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,22 @@ class SubscribeRequest:
4444

4545

4646
class Subscription:
47-
def __init__(self, subscription_id: int, session: Session):
47+
def __init__(self, subscription_id: int, session: Session, event_handler: Callable[[types.Event], None]):
4848
self.subscription_id = subscription_id
4949
self._session = session
50+
self._event_handler = event_handler
5051

5152
def unsubscribe(self) -> None:
5253
if not self._session._base_session.transport.is_connected():
5354
raise Exception("cannot unsubscribe topic: session not established")
5455

56+
subscriptions = self._session._subscriptions.get(self.subscription_id, None)
57+
if subscriptions is not None:
58+
subscriptions.pop(self, None)
59+
if len(subscriptions) != 0:
60+
self._session._subscriptions[self.subscription_id] = subscriptions
61+
return None
62+
5563
unsubscribe = messages.Unsubscribe(
5664
messages.UnsubscribeFields(self._session._idgen.next(), self.subscription_id)
5765
)
@@ -75,7 +83,7 @@ def __init__(self, base_session: types.BaseSession):
7583
# PubSub data structures
7684
self._publish_requests: dict[int, Future[None]] = {}
7785
self._subscribe_requests: dict[int, SubscribeRequest] = {}
78-
self._subscriptions: dict[int, Callable[[types.Event], None]] = {}
86+
self._subscriptions: dict[int, dict[Subscription, Subscription]] = {}
7987
self._unsubscribe_requests: dict[int, types.UnsubscribeRequest] = {}
8088

8189
self._goodbye_request = Future()
@@ -150,8 +158,14 @@ def _process_incoming_message(self, msg: messages.Message):
150158
self._base_session.send(data)
151159
elif isinstance(msg, messages.Subscribed):
152160
request = self._subscribe_requests.pop(msg.request_id)
153-
self._subscriptions[msg.subscription_id] = request.endpoint
154-
request.future.set_result(Subscription(msg.subscription_id, self))
161+
sub = Subscription(msg.subscription_id, self, request.endpoint)
162+
subscriptions = self._subscriptions.get(msg.subscription_id, None)
163+
if subscriptions is None:
164+
self._subscriptions[msg.subscription_id] = {sub: sub}
165+
else:
166+
subscriptions[sub] = sub
167+
168+
request.future.set_result(sub)
155169
elif isinstance(msg, messages.Unsubscribed):
156170
request = self._unsubscribe_requests.pop(msg.request_id)
157171
del self._subscriptions[request.subscription_id]
@@ -161,8 +175,10 @@ def _process_incoming_message(self, msg: messages.Message):
161175
request.set_result(None)
162176
elif isinstance(msg, messages.Event):
163177
try:
164-
endpoint = self._subscriptions[msg.subscription_id]
165-
endpoint(types.Event(msg.args, msg.kwargs, msg.details))
178+
subscriptions = self._subscriptions[msg.subscription_id]
179+
event = types.Event(msg.args, msg.kwargs, msg.details)
180+
for subscription in subscriptions.keys():
181+
subscription._event_handler(event)
166182
except Exception as e:
167183
print(e)
168184
elif isinstance(msg, messages.Error):

0 commit comments

Comments
 (0)