Skip to content
This repository was archived by the owner on Jan 23, 2026. It is now read-only.

Commit 96632cf

Browse files
committed
nanokvm: reauthentication support
1 parent 43b9e13 commit 96632cf

1 file changed

Lines changed: 85 additions & 11 deletions

File tree

  • packages/jumpstarter-driver-nanokvm/jumpstarter_driver_nanokvm

packages/jumpstarter-driver-nanokvm/jumpstarter_driver_nanokvm/driver.py

Lines changed: 85 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,43 @@
22
from base64 import b64encode
33
from contextlib import asynccontextmanager
44
from dataclasses import dataclass, field
5+
from functools import wraps
56
from io import BytesIO
67

78
import anyio
8-
from aiohttp import ClientSession
9+
from aiohttp import ClientResponseError, ClientSession
910
from jumpstarter_driver_composite.driver import Composite
1011
from jumpstarter_driver_pyserial.driver import PySerial
1112
from nanokvm.client import NanoKVMClient as NanoKVMAPIClient
1213

1314
from jumpstarter.driver import Driver, export, exportstream
1415

1516

17+
def _is_unauthorized_error(error: Exception) -> bool:
18+
"""Check if an error is a 401 Unauthorized error"""
19+
if isinstance(error, ClientResponseError):
20+
return error.status == 401
21+
# Also check for string representation in case error is wrapped
22+
error_str = str(error)
23+
return "401" in error_str and ("Unauthorized" in error_str or "unauthorized" in error_str.lower())
24+
25+
26+
def with_reauth(func):
27+
"""Decorator to automatically re-authenticate on 401 errors"""
28+
@wraps(func)
29+
async def wrapper(self, *args, **kwargs):
30+
try:
31+
return await func(self, *args, **kwargs)
32+
except Exception as e:
33+
if _is_unauthorized_error(e):
34+
self.logger.warning("Received 401 Unauthorized, re-authenticating...")
35+
await self._reset_client()
36+
# Retry once after re-authentication
37+
return await func(self, *args, **kwargs)
38+
raise
39+
return wrapper
40+
41+
1642
@dataclass(kw_only=True)
1743
class NanoKVMVideo(Driver):
1844
"""NanoKVM Video Streaming driver"""
@@ -32,6 +58,16 @@ def __post_init__(self):
3258
def client(cls) -> str:
3359
return "jumpstarter_driver_nanokvm.client.NanoKVMVideoClient"
3460

61+
async def _reset_client(self):
62+
"""Reset the client and session, forcing re-authentication"""
63+
if self._session is not None and not self._session.closed:
64+
try:
65+
await self._session.close()
66+
except Exception as e:
67+
self.logger.debug(f"Error closing session during reset: {e}")
68+
self._client = None
69+
self._session = None
70+
3571
async def _get_client(self) -> NanoKVMAPIClient:
3672
"""Get or create the NanoKVM API client"""
3773
if self._client is None:
@@ -55,6 +91,7 @@ def close(self):
5591
self.logger.debug(f"Error closing session: {e}")
5692

5793
@export
94+
@with_reauth
5895
async def snapshot(self) -> str:
5996
"""
6097
Take a snapshot from the video stream
@@ -96,7 +133,19 @@ async def stream_video():
96133
# TODO(mangelajo): this needs to be tested
97134
await send_stream.send(data)
98135
except Exception as e:
99-
self.logger.error(f"Error streaming video: {e}")
136+
if _is_unauthorized_error(e):
137+
self.logger.warning("Received 401 Unauthorized during stream, re-authenticating...")
138+
await self._reset_client()
139+
# Retry with new client
140+
new_client = await self._get_client()
141+
async for frame in new_client.mjpeg_stream():
142+
buffer = BytesIO()
143+
frame.save(buffer, format="JPEG")
144+
data = buffer.getvalue()
145+
await send_stream.send(data)
146+
else:
147+
self.logger.error(f"Error streaming video: {e}")
148+
raise
100149

101150
# Start the video streaming task
102151
task = asyncio.create_task(stream_video())
@@ -131,6 +180,22 @@ def __post_init__(self):
131180
def client(cls) -> str:
132181
return "jumpstarter_driver_nanokvm.client.NanoKVMHIDClient"
133182

183+
async def _reset_client(self):
184+
"""Reset the client, session, and websocket, forcing re-authentication"""
185+
if self._ws is not None:
186+
try:
187+
await self._ws.close()
188+
except Exception as e:
189+
self.logger.debug(f"Error closing websocket during reset: {e}")
190+
if self._session is not None and not self._session.closed:
191+
try:
192+
await self._session.close()
193+
except Exception as e:
194+
self.logger.debug(f"Error closing session during reset: {e}")
195+
self._client = None
196+
self._session = None
197+
self._ws = None
198+
134199
async def _get_client(self) -> NanoKVMAPIClient:
135200
"""Get or create the NanoKVM API client"""
136201
if self._client is None:
@@ -151,6 +216,7 @@ async def _get_ws(self):
151216
)
152217
return self._ws
153218

219+
@with_reauth
154220
async def _send_mouse_event(self, event_type: int, x: int, y: int):
155221
"""
156222
Send a mouse event via WebSocket
@@ -189,6 +255,7 @@ def close(self):
189255
self.logger.debug(f"Error closing session: {e}")
190256

191257
@export
258+
@with_reauth
192259
async def paste_text(self, text: str):
193260
"""
194261
Paste text via keyboard HID simulation
@@ -201,6 +268,7 @@ async def paste_text(self, text: str):
201268
self.logger.info(f"Pasted text: {text}")
202269

203270
@export
271+
@with_reauth
204272
async def press_key(self, key: str):
205273
"""
206274
Press a key by pasting a single character
@@ -220,6 +288,7 @@ async def press_key(self, key: str):
220288
self.logger.debug(f"Pressed key: {repr(key)}")
221289

222290
@export
291+
@with_reauth
223292
async def reset_hid(self):
224293
"""Reset the HID subsystem"""
225294
client = await self._get_client()
@@ -372,15 +441,20 @@ async def get_info(self):
372441
"""Get device information"""
373442
# Get info from the video driver's client
374443
video_driver = self.children["video"]
375-
client = await video_driver._get_client()
376-
info = await client.get_info()
377-
return {
378-
"ips": [{"name": ip.name, "addr": ip.addr, "version": ip.version, "type": ip.type} for ip in info.ips],
379-
"mdns": info.mdns,
380-
"image": info.image,
381-
"application": info.application,
382-
"device_key": info.device_key,
383-
}
444+
445+
@with_reauth
446+
async def _get_info_impl(driver):
447+
client = await driver._get_client()
448+
info = await client.get_info()
449+
return {
450+
"ips": [{"name": ip.name, "addr": ip.addr, "version": ip.version, "type": ip.type} for ip in info.ips],
451+
"mdns": info.mdns,
452+
"image": info.image,
453+
"application": info.application,
454+
"device_key": info.device_key,
455+
}
456+
457+
return await _get_info_impl(video_driver)
384458

385459
@export
386460
async def reboot(self):

0 commit comments

Comments
 (0)