Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions python/packages/jumpstarter/jumpstarter/client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from pydantic.dataclasses import dataclass

from .core import AsyncDriverClient
from jumpstarter.common.importlib import _format_missing_driver_message
from jumpstarter.streams.blocking import BlockingStream


Expand Down Expand Up @@ -103,3 +104,54 @@ def close(self):

def __del__(self):
self.close()


@dataclass(kw_only=True, config=ConfigDict(arbitrary_types_allowed=True))
class StubDriverClient(DriverClient):
"""Stub client for drivers that are not installed.

This client is created when a driver client class cannot be imported.
It provides a placeholder that raises a clear error when the driver
is actually used.
"""

def _get_missing_class_path(self) -> str:
"""Get the missing class path from labels."""
return self.labels["jumpstarter.dev/client"]

def _raise_missing_error(self):
"""Raise ImportError with installation instructions."""
class_path = self._get_missing_class_path()
message = _format_missing_driver_message(class_path)
raise ImportError(message)

def call(self, method, *args):
"""Invoke driver call - raises ImportError since driver is not installed."""
self._raise_missing_error()

def streamingcall(self, method, *args):
"""Invoke streaming driver call - raises ImportError since driver is not installed."""
self._raise_missing_error()
# Unreachable yield to make this a generator function for type checking
while False: # noqa: SIM114
yield

@contextmanager
def stream(self, method="connect"):
"""Open a stream - raises ImportError since driver is not installed."""
self._raise_missing_error()
yield

@contextmanager
def log_stream(self):
"""Open a log stream - raises ImportError since driver is not installed."""
self._raise_missing_error()
yield

def __getattr__(self, name):
"""Catch any attribute access and raise the missing driver error.

This ensures that calls like .on(), .off(), .write() etc. on stub clients
raise a helpful ImportError instead of AttributeError.
"""
self._raise_missing_error()
101 changes: 101 additions & 0 deletions python/packages/jumpstarter/jumpstarter/client/base_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
"""Tests for StubDriverClient."""

import logging
from contextlib import ExitStack
from unittest.mock import MagicMock, create_autospec
from uuid import uuid4

import pytest
from anyio.from_thread import BlockingPortal

from .base import StubDriverClient
from jumpstarter.common.utils import serve
from jumpstarter.driver import Driver


class MissingClientDriver(Driver):
"""Test driver that returns a non-existent client class path."""

@classmethod
def client(cls) -> str:
return "nonexistent_driver_package.client.NonExistentClient"


def create_stub_client(class_path: str) -> StubDriverClient:
"""Create a StubDriverClient with minimal mocking for testing."""
return StubDriverClient(
uuid=uuid4(),
labels={"jumpstarter.dev/client": class_path},
stub=MagicMock(),
portal=create_autospec(BlockingPortal, instance=True),
stack=ExitStack(),
)


def test_missing_driver_logs_warning_and_creates_stub(caplog):
"""Test that a missing driver logs a warning and creates a StubDriverClient."""
expected_class_path = "nonexistent_driver_package.client.NonExistentClient"
with caplog.at_level(logging.WARNING):
with serve(MissingClientDriver()) as client:
# Should have logged a warning with the exact class path from MissingDriverError
assert f"Driver client '{expected_class_path}' is not available." in caplog.text

# Should have created a StubDriverClient
assert isinstance(client, StubDriverClient)

# Using the stub should raise an error
with pytest.raises(ImportError):
client.call("some_method")


def test_stub_driver_client_streamingcall_raises():
"""Test that streamingcall() raises ImportError with driver info."""
stub = create_stub_client("missing_driver.client.Client")
with pytest.raises(ImportError) as exc_info:
# Need to consume the generator to trigger the error
list(stub.streamingcall("some_method"))
assert "missing_driver" in str(exc_info.value)


def test_stub_driver_client_stream_raises():
"""Test that stream() raises ImportError with driver info."""
stub = create_stub_client("missing_driver.client.Client")
with pytest.raises(ImportError) as exc_info:
with stub.stream():
pass
assert "missing_driver" in str(exc_info.value)


def test_stub_driver_client_log_stream_raises():
"""Test that log_stream() raises ImportError with driver info."""
stub = create_stub_client("missing_driver.client.Client")
with pytest.raises(ImportError) as exc_info:
with stub.log_stream():
pass
assert "missing_driver" in str(exc_info.value)


def test_stub_driver_client_error_message_jumpstarter_driver():
"""Test that error message mentions version mismatch for Jumpstarter drivers."""
stub = create_stub_client("jumpstarter_driver_xyz.client.XyzClient")
with pytest.raises(ImportError) as exc_info:
stub.call("some_method")
assert "version mismatch" in str(exc_info.value)


def test_stub_driver_client_error_message_third_party():
"""Test that error message includes install instructions for third-party drivers."""
stub = create_stub_client("custom_driver.client.CustomClient")
with pytest.raises(ImportError) as exc_info:
stub.call("some_method")
assert "pip install custom_driver" in str(exc_info.value)


def test_stub_driver_client_arbitrary_method_raises():
"""Test that accessing arbitrary methods like .on() raises ImportError, not AttributeError."""
stub = create_stub_client("jumpstarter_driver_power.client.PowerClient")
# Accessing .on() should raise ImportError with helpful message, not AttributeError
with pytest.raises(ImportError) as exc_info:
stub.on()
assert "jumpstarter_driver_power" in str(exc_info.value)
assert "version mismatch" in str(exc_info.value)
15 changes: 14 additions & 1 deletion python/packages/jumpstarter/jumpstarter/client/client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import logging
import os
from collections import OrderedDict, defaultdict
from contextlib import ExitStack, asynccontextmanager
from graphlib import TopologicalSorter
Expand All @@ -9,8 +11,12 @@

from .grpc import MultipathExporterStub
from jumpstarter.client import DriverClient
from jumpstarter.client.base import StubDriverClient
from jumpstarter.common.exceptions import MissingDriverError
from jumpstarter.common.importlib import import_class

logger = logging.getLogger(__name__)


@asynccontextmanager
async def client_from_path(path: str, portal: BlockingPortal, stack: ExitStack, allow: list[str], unsafe: bool):
Expand Down Expand Up @@ -50,7 +56,14 @@ async def client_from_channel(
for index in TopologicalSorter(topo).static_order():
report = reports[index]

client_class = import_class(report.labels["jumpstarter.dev/client"], allow, unsafe)
try:
client_class = import_class(report.labels["jumpstarter.dev/client"], allow, unsafe)
except MissingDriverError as e:
# Create stub client instead of failing
# Suppress duplicate warnings
if not os.environ.get("_JMP_SUPPRESS_DRIVER_WARNINGS"):
logger.warning("Driver client '%s' is not available.", e.class_path)
client_class = StubDriverClient

client = client_class(
uuid=UUID(report.uuid),
Expand Down
12 changes: 12 additions & 0 deletions python/packages/jumpstarter/jumpstarter/common/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,15 @@ class EnvironmentVariableNotSetError(JumpstarterException):
"""Raised when a environment variable is not set."""

pass


class MissingDriverError(JumpstarterException):
"""Raised when a driver module is not found but should be handled gracefully.

This exception is raised when a driver client class cannot be imported,
but the connection should continue with a stub client instead of failing.
"""

def __init__(self, message: str, class_path: str):
super().__init__(message)
self.class_path = class_path
65 changes: 33 additions & 32 deletions python/packages/jumpstarter/jumpstarter/common/importlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,33 @@
from fnmatch import fnmatchcase
from importlib import import_module

from jumpstarter.common.exceptions import MissingDriverError

logger = logging.getLogger(__name__)


def _format_missing_driver_message(class_path: str) -> str:
"""Format error message depending on whether the class path is a Jumpstarter driver."""
# Extract package name from class path (first component)
package_name = class_path.split(".")[0]

if class_path.startswith("jumpstarter_driver_"):
return (
f"Driver '{class_path}' is not installed.\n\n"
"This usually indicates a version mismatch between your client and the exporter.\n"
"Please try to update your client to the latest version and ensure the exporter "
"has the correct version installed.\n"
)
else:
return (
f"Driver '{class_path}' is not installed.\n\n"
"Please install the missing module:\n"
f" pip install {package_name}\n\n"
"or if using uv:\n"
f" uv pip install {package_name}"
)


def cached_import(module_path, class_name):
# Check whether module is loaded and fully initialized.
if not (
Expand Down Expand Up @@ -40,36 +64,13 @@ def import_class(class_path: str, allow: list[str], unsafe: bool):
try:
return cached_import(module_path, class_name)
except ModuleNotFoundError as e:
module_name = str(e).split("'")[1] if "'" in str(e) else str(e).split()[-1]

is_jumpstarter_driver = unsafe or any(fnmatchcase(class_path, pattern) for pattern in allow)

if is_jumpstarter_driver:
logger.error(
"Missing Jumpstarter driver module '%s' for class '%s'. "
"This usually indicates a version mismatch between your client and the exporter.",
module_name,
class_path,
)
raise ConnectionError(
f"Missing Jumpstarter driver module '{module_name}'.\n\n"
"This usually indicates a version mismatch between your client and the exporter.\n"
"Please try to update your client to the latest version and ensure the exporter "
"has the correct version installed.\n"
) from e
else:
logger.error(
"Missing Python module '%s' while importing '%s'. "
"This module needs to be installed in your environment.",
module_name,
class_path,
)
raise ConnectionError(
f"Missing Python module '{module_name}'.\n\n"
"Please install the missing module:\n"
f" pip install {module_name}\n\n"
"or if using uv:\n"
f" uv pip install {module_name}"
) from e
raise MissingDriverError(
message=_format_missing_driver_message(class_path),
class_path=class_path,
) from e
except AttributeError as e:
raise ImportError(f"{module_path} doesn't have specified class {class_name}") from e
# Module exists but class doesn't - treat in a similar way to missing module
raise MissingDriverError(
message=_format_missing_driver_message(class_path),
class_path=class_path,
) from e
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import pytest

from .exceptions import MissingDriverError
from .importlib import import_class


def test_import_class():
import_class("os.open", [], True)

with pytest.raises(ImportError):
with pytest.raises(MissingDriverError):
import_class("os.invalid", [], True)

with pytest.raises(ImportError):
Expand Down
1 change: 1 addition & 0 deletions python/packages/jumpstarter/jumpstarter/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def launch_shell(
common_env = os.environ | {
JUMPSTARTER_HOST: host,
JMP_DRIVERS_ALLOW: "UNSAFE" if unsafe else ",".join(allow),
"_JMP_SUPPRESS_DRIVER_WARNINGS": "1", # Already warned during client initialization
}

if command:
Expand Down
9 changes: 7 additions & 2 deletions python/packages/jumpstarter/jumpstarter/config/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from .common import ObjectMeta
from .grpc import call_credentials
from .tls import TLSConfigV1Alpha1
from jumpstarter.common.exceptions import ConfigurationError
from jumpstarter.common.exceptions import ConfigurationError, MissingDriverError
from jumpstarter.common.grpc import aio_secure_channel, ssl_channel_credentials
from jumpstarter.common.importlib import import_class
from jumpstarter.driver import Driver
Expand Down Expand Up @@ -44,7 +44,12 @@ class ExporterConfigV1Alpha1DriverInstance(RootModel):
def instantiate(self) -> Driver:
match self.root:
case ExporterConfigV1Alpha1DriverInstanceBase():
driver_class = import_class(self.root.type, [], True)
try:
driver_class = import_class(self.root.type, [], True)
except MissingDriverError:
raise ConfigurationError(
f"Driver '{self.root.type}' is not installed. Please check exporter configuration."
) from None

children = {name: child.instantiate() for name, child in self.root.children.items()}

Expand Down
Loading