Skip to content

Commit 7c1c4eb

Browse files
committed
fix: reauth
1 parent 3d9a30e commit 7c1c4eb

1 file changed

Lines changed: 77 additions & 22 deletions

File tree

python_snoo/snoo.py

Lines changed: 77 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -92,21 +92,42 @@ def __init__(self, email: str, password: str, clientsession: aiohttp.ClientSessi
9292
self.reauth_task: asyncio.Task | None = None
9393
self._client_map: dict[str, aiomqtt.Client] = {}
9494
self._mqtt_tasks: dict[str, asyncio.Task] = {}
95+
self._mqtt_callbacks: dict[str, tuple[SnooDevice, Callable]] = {}
9596
self._client_cond = asyncio.Condition()
9697

97-
async def refresh_tokens(self):
98-
# TODO: Figure out hwo to get this to work and not do a serializaiton exception
98+
async def refresh_tokens(self) -> int:
99+
"""Refreshes AWS Cognito tokens and returns the new expiration time in seconds."""
100+
_LOGGER.info("AWS Cognito tokens expired, refreshing...")
101+
if not self.tokens or not self.tokens.aws_refresh:
102+
_LOGGER.error("No refresh token available. A full re-authentication is required.")
103+
raise SnooAuthException("Missing refresh token.")
104+
99105
data = {
100106
"AuthParameters": {"REFRESH_TOKEN": self.tokens.aws_refresh},
101107
"AuthFlow": "REFRESH_TOKEN_AUTH",
102108
"ClientId": "6kqofhc8hm394ielqdkvli0oea",
103109
}
104-
r = await self.session.post(self.aws_auth_url, data=data, headers=self.aws_auth_hdr)
110+
r = await self.session.post(self.aws_auth_url, json=data, headers=self.aws_auth_hdr)
105111
resp = await r.json(content_type=None)
106-
if "__type" in resp and resp["__type"] == "NotAuthorizedException":
107-
raise InvalidSnooAuth()
108-
result = resp["AuthenticationResult"]
109-
self.tokens = AuthorizationInfo(aws_id=result[""])
112+
113+
if r.status >= 400:
114+
_LOGGER.error(f"Failed to refresh tokens. Status: {r.status}, Response: {resp}")
115+
raise InvalidSnooAuth(f"Token refresh failed: {resp.get('message', 'Unknown error')}")
116+
117+
result = resp.get("AuthenticationResult")
118+
if not result:
119+
_LOGGER.error(f"Invalid response during token refresh: {resp}")
120+
raise SnooAuthException("Token refresh response missing 'AuthenticationResult'.")
121+
122+
# Update tokens with the new ones from the response
123+
self.tokens = AuthorizationInfo(
124+
snoo=self.tokens.snoo,
125+
aws_access=result["AccessToken"],
126+
aws_id=result["IdToken"],
127+
aws_refresh=result.get("RefreshToken", self.tokens.aws_refresh),
128+
)
129+
_LOGGER.info("✅ Successfully refreshed AWS Cognito tokens.")
130+
return result.get("ExpiresIn", 3600)
110131

111132
def check_tokens(self):
112133
if self.tokens is None:
@@ -165,6 +186,7 @@ async def disconnect(self):
165186
task.cancel()
166187
await asyncio.gather(*self._mqtt_tasks.values(), return_exceptions=True)
167188
self._mqtt_tasks = {}
189+
self._mqtt_callbacks = {}
168190

169191
if self.reauth_task:
170192
self.reauth_task.cancel()
@@ -181,8 +203,6 @@ async def send_command(self, command: str, device: SnooDevice, **kwargs):
181203
async with self._client_cond:
182204
try:
183205
# Wait up to 30 seconds for the client to connect.
184-
# The wait_for() method will wait until the lambda returns True.
185-
# It releases the lock while waiting and re-acquires it before returning.
186206
await asyncio.wait_for(
187207
self._client_cond.wait_for(lambda: device.serialNumber in self._client_map), timeout=30.0
188208
)
@@ -241,28 +261,60 @@ async def authorize(self) -> AuthorizationInfo:
241261
access = amz["AccessToken"]
242262
_id = amz["IdToken"]
243263
ref = amz["RefreshToken"]
244-
snoo_token = await self.auth_snoo(_id)
245-
snoo_expiry = snoo_token["expiresIn"] / 1.5
246-
snoo_token = snoo_token["snoo"]["token"]
264+
expires_in = amz["ExpiresIn"]
265+
266+
snoo_token_data = await self.auth_snoo(_id)
267+
snoo_token = snoo_token_data["snoo"]["token"]
268+
247269
self.tokens = AuthorizationInfo(snoo=snoo_token, aws_access=access, aws_id=_id, aws_refresh=ref)
248-
self.reauth_task = asyncio.create_task(self.schedule_reauthorization(snoo_expiry))
270+
271+
if self.reauth_task:
272+
self.reauth_task.cancel()
273+
274+
# Schedule reauthorization with a 5-minute buffer before expiry
275+
reauth_delay = max(expires_in - 300, 0)
276+
self.reauth_task = asyncio.create_task(self.schedule_reauthorization(reauth_delay))
277+
_LOGGER.info(f"Authorization successful. Next token refresh scheduled in {reauth_delay} seconds.")
278+
249279
except InvalidSnooAuth as ex:
250280
raise ex
251281
except Exception as ex:
252282
raise SnooAuthException from ex
253283
return self.tokens
254284

255-
async def schedule_reauthorization(self, snoo_expiry: float):
256-
_LOGGER.info("Snoo token has expired - reauthorizing...")
285+
async def schedule_reauthorization(self, expiry_seconds: float):
257286
try:
258-
await asyncio.sleep(snoo_expiry)
259-
await self.authorize()
260-
for instance in self.pubnub_instances.values():
261-
instance.update_token(self.tokens.snoo)
262-
self.pubnub.config.auth_token = self.tokens.snoo
287+
await asyncio.sleep(expiry_seconds)
288+
_LOGGER.info("Executing scheduled token refresh...")
263289

264-
except Exception as ex:
265-
_LOGGER.exception(f"Error during reauthorization: {ex}")
290+
new_expires_in = await self.refresh_tokens()
291+
292+
_LOGGER.info("Restarting MQTT subscriptions with new token...")
293+
# Cancel all existing MQTT tasks
294+
for task in self._mqtt_tasks.values():
295+
task.cancel()
296+
await asyncio.gather(*self._mqtt_tasks.values(), return_exceptions=True)
297+
self._mqtt_tasks.clear()
298+
299+
# The `finally` block in `subscribe_mqtt` should clear the `_client_map`
300+
# as connections close.
301+
302+
# Re-subscribe for all previously active subscriptions
303+
for device_sn, (device, function) in self._mqtt_callbacks.items():
304+
_LOGGER.info(f"Re-establishing MQTT subscription for device {device_sn}")
305+
self.start_subscribe(device, function)
306+
307+
_LOGGER.info("✅ MQTT subscriptions restarted successfully.")
308+
309+
# Schedule the *next* reauthorization
310+
reauth_delay = max(new_expires_in - 300, 0)
311+
self.reauth_task = asyncio.create_task(self.schedule_reauthorization(reauth_delay))
312+
_LOGGER.info(f"Next token refresh scheduled in {reauth_delay} seconds.")
313+
314+
except asyncio.CancelledError:
315+
_LOGGER.info("Reauthorization task was cancelled.")
316+
except Exception:
317+
_LOGGER.exception("An unexpected error occurred during reauthorization.")
266318

267319
async def get_devices(self) -> list[SnooDevice]:
268320
hdrs = self.generate_snoo_auth_headers(self.tokens.aws_id)
@@ -279,6 +331,9 @@ def start_subscribe(self, device: SnooDevice, function: Callable):
279331
_LOGGER.warning(f"Subscription task for device {device.serialNumber} is already running.")
280332
return
281333

334+
# Store the device and callback function for re-subscription after re-auth
335+
self._mqtt_callbacks[device.serialNumber] = (device, function)
336+
282337
self._mqtt_tasks[device.serialNumber] = asyncio.create_task(self.subscribe_mqtt(device, function))
283338

284339
async def subscribe_mqtt(self, device: SnooDevice, function: Callable):

0 commit comments

Comments
 (0)