@@ -92,21 +92,42 @@ def __init__(self, email: str, password: str, clientsession: aiohttp.ClientSessi
9292 self .reauth_task : asyncio .Task | None = None
9393 self ._client_map : dict [str , aiomqtt .Client ] = {}
9494 self ._mqtt_tasks : dict [str , asyncio .Task ] = {}
95+ self ._mqtt_callbacks : dict [str , tuple [SnooDevice , Callable ]] = {}
9596 self ._client_cond = asyncio .Condition ()
9697
97- async def refresh_tokens (self ):
98- # TODO: Figure out hwo to get this to work and not do a serializaiton exception
98+ async def refresh_tokens (self ) -> int :
99+ """Refreshes AWS Cognito tokens and returns the new expiration time in seconds."""
100+ _LOGGER .info ("AWS Cognito tokens expired, refreshing..." )
101+ if not self .tokens or not self .tokens .aws_refresh :
102+ _LOGGER .error ("No refresh token available. A full re-authentication is required." )
103+ raise SnooAuthException ("Missing refresh token." )
104+
99105 data = {
100106 "AuthParameters" : {"REFRESH_TOKEN" : self .tokens .aws_refresh },
101107 "AuthFlow" : "REFRESH_TOKEN_AUTH" ,
102108 "ClientId" : "6kqofhc8hm394ielqdkvli0oea" ,
103109 }
104- r = await self .session .post (self .aws_auth_url , data = data , headers = self .aws_auth_hdr )
110+ r = await self .session .post (self .aws_auth_url , json = data , headers = self .aws_auth_hdr )
105111 resp = await r .json (content_type = None )
106- if "__type" in resp and resp ["__type" ] == "NotAuthorizedException" :
107- raise InvalidSnooAuth ()
108- result = resp ["AuthenticationResult" ]
109- self .tokens = AuthorizationInfo (aws_id = result ["" ])
112+
113+ if r .status >= 400 :
114+ _LOGGER .error (f"Failed to refresh tokens. Status: { r .status } , Response: { resp } " )
115+ raise InvalidSnooAuth (f"Token refresh failed: { resp .get ('message' , 'Unknown error' )} " )
116+
117+ result = resp .get ("AuthenticationResult" )
118+ if not result :
119+ _LOGGER .error (f"Invalid response during token refresh: { resp } " )
120+ raise SnooAuthException ("Token refresh response missing 'AuthenticationResult'." )
121+
122+ # Update tokens with the new ones from the response
123+ self .tokens = AuthorizationInfo (
124+ snoo = self .tokens .snoo ,
125+ aws_access = result ["AccessToken" ],
126+ aws_id = result ["IdToken" ],
127+ aws_refresh = result .get ("RefreshToken" , self .tokens .aws_refresh ),
128+ )
129+ _LOGGER .info ("✅ Successfully refreshed AWS Cognito tokens." )
130+ return result .get ("ExpiresIn" , 3600 )
110131
111132 def check_tokens (self ):
112133 if self .tokens is None :
@@ -165,6 +186,7 @@ async def disconnect(self):
165186 task .cancel ()
166187 await asyncio .gather (* self ._mqtt_tasks .values (), return_exceptions = True )
167188 self ._mqtt_tasks = {}
189+ self ._mqtt_callbacks = {}
168190
169191 if self .reauth_task :
170192 self .reauth_task .cancel ()
@@ -181,8 +203,6 @@ async def send_command(self, command: str, device: SnooDevice, **kwargs):
181203 async with self ._client_cond :
182204 try :
183205 # Wait up to 30 seconds for the client to connect.
184- # The wait_for() method will wait until the lambda returns True.
185- # It releases the lock while waiting and re-acquires it before returning.
186206 await asyncio .wait_for (
187207 self ._client_cond .wait_for (lambda : device .serialNumber in self ._client_map ), timeout = 30.0
188208 )
@@ -241,28 +261,60 @@ async def authorize(self) -> AuthorizationInfo:
241261 access = amz ["AccessToken" ]
242262 _id = amz ["IdToken" ]
243263 ref = amz ["RefreshToken" ]
244- snoo_token = await self .auth_snoo (_id )
245- snoo_expiry = snoo_token ["expiresIn" ] / 1.5
246- snoo_token = snoo_token ["snoo" ]["token" ]
264+ expires_in = amz ["ExpiresIn" ]
265+
266+ snoo_token_data = await self .auth_snoo (_id )
267+ snoo_token = snoo_token_data ["snoo" ]["token" ]
268+
247269 self .tokens = AuthorizationInfo (snoo = snoo_token , aws_access = access , aws_id = _id , aws_refresh = ref )
248- self .reauth_task = asyncio .create_task (self .schedule_reauthorization (snoo_expiry ))
270+
271+ if self .reauth_task :
272+ self .reauth_task .cancel ()
273+
274+ # Schedule reauthorization with a 5-minute buffer before expiry
275+ reauth_delay = max (expires_in - 300 , 0 )
276+ self .reauth_task = asyncio .create_task (self .schedule_reauthorization (reauth_delay ))
277+ _LOGGER .info (f"Authorization successful. Next token refresh scheduled in { reauth_delay } seconds." )
278+
249279 except InvalidSnooAuth as ex :
250280 raise ex
251281 except Exception as ex :
252282 raise SnooAuthException from ex
253283 return self .tokens
254284
255- async def schedule_reauthorization (self , snoo_expiry : float ):
256- _LOGGER .info ("Snoo token has expired - reauthorizing..." )
285+ async def schedule_reauthorization (self , expiry_seconds : float ):
257286 try :
258- await asyncio .sleep (snoo_expiry )
259- await self .authorize ()
260- for instance in self .pubnub_instances .values ():
261- instance .update_token (self .tokens .snoo )
262- self .pubnub .config .auth_token = self .tokens .snoo
287+ await asyncio .sleep (expiry_seconds )
288+ _LOGGER .info ("Executing scheduled token refresh..." )
263289
264- except Exception as ex :
265- _LOGGER .exception (f"Error during reauthorization: { ex } " )
290+ new_expires_in = await self .refresh_tokens ()
291+
292+ _LOGGER .info ("Restarting MQTT subscriptions with new token..." )
293+ # Cancel all existing MQTT tasks
294+ for task in self ._mqtt_tasks .values ():
295+ task .cancel ()
296+ await asyncio .gather (* self ._mqtt_tasks .values (), return_exceptions = True )
297+ self ._mqtt_tasks .clear ()
298+
299+ # The `finally` block in `subscribe_mqtt` should clear the `_client_map`
300+ # as connections close.
301+
302+ # Re-subscribe for all previously active subscriptions
303+ for device_sn , (device , function ) in self ._mqtt_callbacks .items ():
304+ _LOGGER .info (f"Re-establishing MQTT subscription for device { device_sn } " )
305+ self .start_subscribe (device , function )
306+
307+ _LOGGER .info ("✅ MQTT subscriptions restarted successfully." )
308+
309+ # Schedule the *next* reauthorization
310+ reauth_delay = max (new_expires_in - 300 , 0 )
311+ self .reauth_task = asyncio .create_task (self .schedule_reauthorization (reauth_delay ))
312+ _LOGGER .info (f"Next token refresh scheduled in { reauth_delay } seconds." )
313+
314+ except asyncio .CancelledError :
315+ _LOGGER .info ("Reauthorization task was cancelled." )
316+ except Exception :
317+ _LOGGER .exception ("An unexpected error occurred during reauthorization." )
266318
267319 async def get_devices (self ) -> list [SnooDevice ]:
268320 hdrs = self .generate_snoo_auth_headers (self .tokens .aws_id )
@@ -279,6 +331,9 @@ def start_subscribe(self, device: SnooDevice, function: Callable):
279331 _LOGGER .warning (f"Subscription task for device { device .serialNumber } is already running." )
280332 return
281333
334+ # Store the device and callback function for re-subscription after re-auth
335+ self ._mqtt_callbacks [device .serialNumber ] = (device , function )
336+
282337 self ._mqtt_tasks [device .serialNumber ] = asyncio .create_task (self .subscribe_mqtt (device , function ))
283338
284339 async def subscribe_mqtt (self , device : SnooDevice , function : Callable ):
0 commit comments