22from base64 import b64encode
33from contextlib import asynccontextmanager
44from dataclasses import dataclass , field
5+ from functools import wraps
56from io import BytesIO
67
78import anyio
8- from aiohttp import ClientSession
9+ from aiohttp import ClientResponseError , ClientSession
910from jumpstarter_driver_composite .driver import Composite
1011from jumpstarter_driver_pyserial .driver import PySerial
1112from nanokvm .client import NanoKVMClient as NanoKVMAPIClient
1213
1314from jumpstarter .driver import Driver , export , exportstream
1415
1516
17+ def _is_unauthorized_error (error : Exception ) -> bool :
18+ """Check if an error is a 401 Unauthorized error"""
19+ if isinstance (error , ClientResponseError ):
20+ return error .status == 401
21+ # Also check for string representation in case error is wrapped
22+ error_str = str (error )
23+ return "401" in error_str and ("Unauthorized" in error_str or "unauthorized" in error_str .lower ())
24+
25+
26+ def with_reauth (func ):
27+ """Decorator to automatically re-authenticate on 401 errors"""
28+ @wraps (func )
29+ async def wrapper (self , * args , ** kwargs ):
30+ try :
31+ return await func (self , * args , ** kwargs )
32+ except Exception as e :
33+ if _is_unauthorized_error (e ):
34+ self .logger .warning ("Received 401 Unauthorized, re-authenticating..." )
35+ await self ._reset_client ()
36+ # Retry once after re-authentication
37+ return await func (self , * args , ** kwargs )
38+ raise
39+ return wrapper
40+
41+
1642@dataclass (kw_only = True )
1743class NanoKVMVideo (Driver ):
1844 """NanoKVM Video Streaming driver"""
@@ -32,6 +58,16 @@ def __post_init__(self):
3258 def client (cls ) -> str :
3359 return "jumpstarter_driver_nanokvm.client.NanoKVMVideoClient"
3460
61+ async def _reset_client (self ):
62+ """Reset the client and session, forcing re-authentication"""
63+ if self ._session is not None and not self ._session .closed :
64+ try :
65+ await self ._session .close ()
66+ except Exception as e :
67+ self .logger .debug (f"Error closing session during reset: { e } " )
68+ self ._client = None
69+ self ._session = None
70+
3571 async def _get_client (self ) -> NanoKVMAPIClient :
3672 """Get or create the NanoKVM API client"""
3773 if self ._client is None :
@@ -55,6 +91,7 @@ def close(self):
5591 self .logger .debug (f"Error closing session: { e } " )
5692
5793 @export
94+ @with_reauth
5895 async def snapshot (self ) -> str :
5996 """
6097 Take a snapshot from the video stream
@@ -96,7 +133,19 @@ async def stream_video():
96133 # TODO(mangelajo): this needs to be tested
97134 await send_stream .send (data )
98135 except Exception as e :
99- self .logger .error (f"Error streaming video: { e } " )
136+ if _is_unauthorized_error (e ):
137+ self .logger .warning ("Received 401 Unauthorized during stream, re-authenticating..." )
138+ await self ._reset_client ()
139+ # Retry with new client
140+ new_client = await self ._get_client ()
141+ async for frame in new_client .mjpeg_stream ():
142+ buffer = BytesIO ()
143+ frame .save (buffer , format = "JPEG" )
144+ data = buffer .getvalue ()
145+ await send_stream .send (data )
146+ else :
147+ self .logger .error (f"Error streaming video: { e } " )
148+ raise
100149
101150 # Start the video streaming task
102151 task = asyncio .create_task (stream_video ())
@@ -131,6 +180,22 @@ def __post_init__(self):
131180 def client (cls ) -> str :
132181 return "jumpstarter_driver_nanokvm.client.NanoKVMHIDClient"
133182
183+ async def _reset_client (self ):
184+ """Reset the client, session, and websocket, forcing re-authentication"""
185+ if self ._ws is not None :
186+ try :
187+ await self ._ws .close ()
188+ except Exception as e :
189+ self .logger .debug (f"Error closing websocket during reset: { e } " )
190+ if self ._session is not None and not self ._session .closed :
191+ try :
192+ await self ._session .close ()
193+ except Exception as e :
194+ self .logger .debug (f"Error closing session during reset: { e } " )
195+ self ._client = None
196+ self ._session = None
197+ self ._ws = None
198+
134199 async def _get_client (self ) -> NanoKVMAPIClient :
135200 """Get or create the NanoKVM API client"""
136201 if self ._client is None :
@@ -151,6 +216,7 @@ async def _get_ws(self):
151216 )
152217 return self ._ws
153218
219+ @with_reauth
154220 async def _send_mouse_event (self , event_type : int , x : int , y : int ):
155221 """
156222 Send a mouse event via WebSocket
@@ -189,6 +255,7 @@ def close(self):
189255 self .logger .debug (f"Error closing session: { e } " )
190256
191257 @export
258+ @with_reauth
192259 async def paste_text (self , text : str ):
193260 """
194261 Paste text via keyboard HID simulation
@@ -201,6 +268,7 @@ async def paste_text(self, text: str):
201268 self .logger .info (f"Pasted text: { text } " )
202269
203270 @export
271+ @with_reauth
204272 async def press_key (self , key : str ):
205273 """
206274 Press a key by pasting a single character
@@ -220,6 +288,7 @@ async def press_key(self, key: str):
220288 self .logger .debug (f"Pressed key: { repr (key )} " )
221289
222290 @export
291+ @with_reauth
223292 async def reset_hid (self ):
224293 """Reset the HID subsystem"""
225294 client = await self ._get_client ()
@@ -372,15 +441,20 @@ async def get_info(self):
372441 """Get device information"""
373442 # Get info from the video driver's client
374443 video_driver = self .children ["video" ]
375- client = await video_driver ._get_client ()
376- info = await client .get_info ()
377- return {
378- "ips" : [{"name" : ip .name , "addr" : ip .addr , "version" : ip .version , "type" : ip .type } for ip in info .ips ],
379- "mdns" : info .mdns ,
380- "image" : info .image ,
381- "application" : info .application ,
382- "device_key" : info .device_key ,
383- }
444+
445+ @with_reauth
446+ async def _get_info_impl (driver ):
447+ client = await driver ._get_client ()
448+ info = await client .get_info ()
449+ return {
450+ "ips" : [{"name" : ip .name , "addr" : ip .addr , "version" : ip .version , "type" : ip .type } for ip in info .ips ],
451+ "mdns" : info .mdns ,
452+ "image" : info .image ,
453+ "application" : info .application ,
454+ "device_key" : info .device_key ,
455+ }
456+
457+ return await _get_info_impl (video_driver )
384458
385459 @export
386460 async def reboot (self ):
0 commit comments