From 2c9878d77d8d18e26353efb9f9184fb12cec3ba1 Mon Sep 17 00:00:00 2001 From: John Pollock Date: Sun, 29 Mar 2026 22:23:18 +0000 Subject: [PATCH] =?UTF-8?q?feat:=20pyisolate=200.10.0=20=E2=80=94=20conda?= =?UTF-8?q?=20backend,=20sealed=20workers,=20CUDA=20wheels,=20event=20chan?= =?UTF-8?q?nel,=20sandbox=20hardening?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: pyisolate 0.10.0 β€” conda backend, sealed workers, CUDA wheels, event channel, sandbox hardening Squash of 93 commits from conda_support since v0.9.2. Major features: - Conda/pixi backend: environment creation, config, routing, PEP 508 markers - CUDA wheel resolver: target_python, multi-index, find_links, dotted torch versions - Sealed worker execution model: uv and conda backends, proxy handle round-trip - RPC serialization: dynamic serializer resolution, perf tracing - Generic event channel: child-to-host dispatch, setup_child_event_hooks - Bwrap sandbox hardening: adapter paths, sealed workers, sys.base_prefix - Remove ComfyUI-specific references from pyisolate core - Platform fixes: Windows uv_exe fallback, test expectations, TMPDIR cleanup - Formatting fixes: ruff format compliance for CI * chore: bump version to 0.10.0 * fix: use tomli fallback for Python 3.10 in integration tests * fix: inline fixture config in integration tests, remove tomllib dependency * fix: skip network-dependent and bwrap-dependent tests in CI * fix: skip uv sealed integration test when bwrap sandbox rejects at runtime --- .gitignore | 5 +- README.md | 63 +- pyisolate/__init__.py | 4 +- pyisolate/_internal/bootstrap.py | 84 ++- pyisolate/_internal/client.py | 4 + pyisolate/_internal/cuda_wheels.py | 140 +++-- pyisolate/_internal/environment.py | 82 ++- pyisolate/_internal/environment_conda.py | 464 ++++++++++++++ pyisolate/_internal/event_bridge.py | 30 + pyisolate/_internal/host.py | 186 ++++-- pyisolate/_internal/model_serialization.py | 96 +-- pyisolate/_internal/perf_trace.py | 53 ++ pyisolate/_internal/rpc_serialization.py | 132 ++-- pyisolate/_internal/rpc_transports.py | 70 ++- pyisolate/_internal/sandbox.py | 241 +++++--- pyisolate/_internal/sandbox_detect.py | 83 +-- pyisolate/_internal/tensor_serializer.py | 119 +++- pyisolate/_internal/uds_client.py | 63 +- pyisolate/config.py | 26 +- pyisolate/host.py | 10 + pyisolate/interfaces.py | 11 + pyisolate/path_helpers.py | 19 +- pyisolate/sealed.py | 195 ++++++ pyisolate/shared.py | 31 + pyproject.toml | 7 +- tests/conftest.py | 11 - tests/fixtures/conda_sealed_node/__init__.py | 81 +++ .../fixtures/conda_sealed_node/pyproject.toml | 11 + tests/fixtures/uv_sealed_worker/__init__.py | 161 +++++ .../fixtures/uv_sealed_worker/pyproject.toml | 10 + tests/harness/host.py | 15 +- tests/harness/test_package/__init__.py | 10 + tests/integration_v2/test_isolation.py | 24 + tests/integration_v2/test_tensors.py | 11 +- tests/path_unification/test_path_helpers.py | 84 +-- tests/test_bootstrap.py | 146 +++++ tests/test_bwrap_command.py | 232 +++++++ tests/test_client_entrypoint_extra.py | 25 +- tests/test_conda_integration.py | 131 ++++ tests/test_conda_sealed_worker_contract.py | 226 +++++++ tests/test_config_conda.py | 117 ++++ tests/test_config_sealed_worker.py | 134 ++++ tests/test_cuda_wheels.py | 264 +++++++- tests/test_environment_conda.py | 582 ++++++++++++++++++ tests/test_environment_sealed_worker.py | 72 +++ tests/test_event_channel.py | 99 +++ tests/test_exact_proxy_bootstrap.py | 91 +++ tests/test_harness_host_env.py | 25 + tests/test_host_conda_dispatch.py | 423 +++++++++++++ tests/test_host_internal_ext.py | 3 + tests/test_host_sealed_worker_dispatch.py | 163 +++++ tests/test_model_serialization.py | 31 + tests/test_rpc_contract.py | 88 +-- tests/test_rpc_transports.py | 133 ++-- tests/test_sandbox_detect.py | 24 + tests/test_sealed_proxy_handle.py | 71 +++ tests/test_security_conda.py | 198 ++++++ tests/test_security_sealed_worker.py | 157 +++++ tests/test_shared_additional.py | 40 +- .../test_tensor_serializer_signal_cleanup.py | 22 + tests/test_uv_sealed_integration.py | 175 ++++++ 61 files changed, 5735 insertions(+), 573 deletions(-) create mode 100644 pyisolate/_internal/environment_conda.py create mode 100644 pyisolate/_internal/event_bridge.py create mode 100644 pyisolate/_internal/perf_trace.py create mode 100644 pyisolate/sealed.py create mode 100644 tests/fixtures/conda_sealed_node/__init__.py create mode 100644 tests/fixtures/conda_sealed_node/pyproject.toml create mode 100644 tests/fixtures/uv_sealed_worker/__init__.py create mode 100644 tests/fixtures/uv_sealed_worker/pyproject.toml create mode 100644 tests/test_conda_integration.py create mode 100644 tests/test_conda_sealed_worker_contract.py create mode 100644 tests/test_config_conda.py create mode 100644 tests/test_config_sealed_worker.py create mode 100644 tests/test_environment_conda.py create mode 100644 tests/test_environment_sealed_worker.py create mode 100644 tests/test_event_channel.py create mode 100644 tests/test_exact_proxy_bootstrap.py create mode 100644 tests/test_harness_host_env.py create mode 100644 tests/test_host_conda_dispatch.py create mode 100644 tests/test_host_sealed_worker_dispatch.py create mode 100644 tests/test_sealed_proxy_handle.py create mode 100644 tests/test_security_conda.py create mode 100644 tests/test_security_sealed_worker.py create mode 100644 tests/test_tensor_serializer_signal_cleanup.py create mode 100644 tests/test_uv_sealed_integration.py diff --git a/.gitignore b/.gitignore index 9b58e30..eb02e64 100644 --- a/.gitignore +++ b/.gitignore @@ -47,6 +47,7 @@ coverage.xml *.py,cover .hypothesis/ .pytest_cache/ +.pytest_artifacts/ cover/ # Translations @@ -155,5 +156,5 @@ cython_debug/ # UV cache directory (for hardlinking optimization) .uv_cache/ -# Generated demo venvs -comfy_hello_world/node-venvs/ +# Generated test venvs +.smoke_venv/ diff --git a/README.md b/README.md index 2c2b60f..8bbd39b 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ > 🚨 **Fail Loud Policy**: pyisolate assumes the rest of ComfyUI core is correct. Missing prerequisites or runtime failures immediately raise descriptive exceptions instead of being silently ignored. -pyisolate enables you to run Python extensions with conflicting dependencies in the same application by automatically creating isolated virtual environments for each extension using `uv`. Extensions communicate with the host process through a transparent RPC system, making the isolation invisible to your code while keeping the host environment dependency-free. +pyisolate enables you to run Python extensions with conflicting dependencies in the same application by automatically creating isolated environments for each extension. The default provisioner uses `uv`, and ComfyUI integrations can also provision a conda environment through `pixi` when an extension needs conda-first packages. Extensions communicate with the host process through a transparent RPC system, making the isolation invisible to your code while keeping the host environment dependency-free. ## Requirements @@ -86,6 +86,7 @@ The script installs `uv`, creates the dev venv, installs pyisolate in editable m - πŸ”’ **Dependency Isolation**: Run extensions with incompatible dependencies (e.g., numpy 1.x and 2.x) in the same application - πŸš€ **Zero-Copy PyTorch Tensor Sharing**: Share PyTorch tensors between processes without serialization overhead +- πŸ“¦ **Multiple Environment Backends**: Use `uv` by default or a conda/pixi environment when the extension needs conda-native dependencies - πŸ”„ **Transparent Communication**: Call async methods across process boundaries as if they were local - 🎯 **Simple API**: Clean, intuitive interface with minimal boilerplate - ⚑ **Fast**: Uses `uv` for blazing-fast virtual environment creation @@ -185,6 +186,66 @@ large_tensor = torch.randn(1000, 1000) mean = await extension.process_tensor(large_tensor) ``` +### Execution Model Axis + +ComfyUI integrations now treat environment provisioning and runtime boundary as separate choices: + +- `package_manager = "uv"` or `package_manager = "conda"` chooses how the child environment is built +- `execution_model = "host-coupled"` or `execution_model = "sealed_worker"` chooses how much host runtime state the child may inherit + +`host-coupled` remains the default for the classic `uv` path. `sealed_worker` is the foreign-interpreter path: no host `sys.path` reconstruction, no host framework runtime imports as a crutch, JSON-RPC tensor transport, and no sandbox in this phase. + +### UV Backend for Sealed Workers + +ComfyUI extensions can also request a sealed `uv` worker explicitly: + +```toml +[project] +name = "uv-sealed-node" +version = "0.1.0" +dependencies = ["boltons"] + +[tool.comfy.isolation] +can_isolate = true +package_manager = "uv" +execution_model = "sealed_worker" +share_torch = false +``` + +Trade-offs for `package_manager = "uv"` with `execution_model = "sealed_worker"`: + +- `share_torch` must be `False` +- tensors cross the boundary through JSON-compatible RPC values instead of shared-memory tensor handles +- host `sys.path` reconstruction is disabled +- host framework runtime imports such as `comfy.isolation.extension_wrapper` must not be required in the child +- `bwrap` sandboxing is intentionally disabled in this phase + +### Conda Backend for Sealed Workers + +ComfyUI extensions can declare a conda-backed isolated environment in `pyproject.toml`: + +```toml +[project] +name = "weather-node" +version = "0.1.0" +dependencies = ["xarray", "cfgrib"] + +[tool.comfy.isolation] +can_isolate = true +package_manager = "conda" +share_torch = false +conda_channels = ["conda-forge"] +conda_dependencies = ["eccodes", "cfgrib"] +``` + +Trade-offs for `package_manager = "conda"`: + +- `share_torch` is forced `False` +- `bwrap` sandboxing is skipped +- the child uses its own interpreter instead of the host Python +- the child is treated as a sealed foreign runtime and must not import host framework runtime code through leaked `sys.path` +- tensor transfer crosses the RPC boundary as JSON-compatible values instead of shared-memory tensor handles + ### Shared State with Singletons Share state across all extensions using ProxiedSingleton: diff --git a/pyisolate/__init__.py b/pyisolate/__init__.py index 819e4d7..288299c 100644 --- a/pyisolate/__init__.py +++ b/pyisolate/__init__.py @@ -39,11 +39,12 @@ from ._internal.tensor_serializer import flush_tensor_keeper, purge_orphan_sender_shm_files from .config import ExtensionConfig, ExtensionManagerConfig, SandboxMode from .host import ExtensionBase, ExtensionManager +from .sealed import SealedNodeExtension if TYPE_CHECKING: from .interfaces import IsolationAdapter -__version__ = "0.9.1" +__version__ = "0.10.0" __all__ = [ "ExtensionBase", @@ -51,6 +52,7 @@ "ExtensionManagerConfig", "ExtensionConfig", "SandboxMode", + "SealedNodeExtension", "ProxiedSingleton", "local_execution", "singleton_scope", diff --git a/pyisolate/_internal/bootstrap.py b/pyisolate/_internal/bootstrap.py index a215148..cda4922 100644 --- a/pyisolate/_internal/bootstrap.py +++ b/pyisolate/_internal/bootstrap.py @@ -21,24 +21,12 @@ logger = logging.getLogger(__name__) -def _apply_sys_path(snapshot: dict[str, Any]) -> None: - host_paths = snapshot.get("sys_path", []) - extra_paths = snapshot.get("additional_paths", []) +def _should_apply_host_sys_path(snapshot: dict[str, Any]) -> bool: + return bool(snapshot.get("apply_host_sys_path", True)) - preferred_root: str | None = snapshot.get("preferred_root") - if not preferred_root: - context_data = snapshot.get("context_data", {}) - module_path = context_data.get("module_path") or os.environ.get("PYISOLATE_MODULE_PATH") - if module_path: - preferred_root = str(Path(module_path).parent.parent) - child_paths = build_child_sys_path(host_paths, extra_paths, preferred_root) - - if not child_paths: - return - - # Rebuild sys.path with child paths first while preserving any existing entries - # that are not already in the computed set. +def _merge_sys_path_front(paths: list[str]) -> None: + """Prepend paths to sys.path while preserving order and removing duplicates.""" seen = set() merged: list[str] = [] @@ -49,13 +37,62 @@ def add_path(p: str) -> None: seen.add(norm) merged.append(p) - for p in child_paths: + for p in paths: add_path(p) for p in sys.path: add_path(p) sys.path[:] = merged + + +def _apply_sealed_opt_in_paths(snapshot: dict[str, Any]) -> None: + raw_paths = snapshot.get("sealed_host_ro_paths", []) + if not isinstance(raw_paths, list): + return + + opt_in_paths: list[str] = [] + for path in raw_paths: + if not isinstance(path, str) or not path.strip(): + continue + if not os.path.isabs(path): + continue + if not os.path.exists(path): + continue + opt_in_paths.append(path) + + if not opt_in_paths: + return + + _merge_sys_path_front(opt_in_paths) + logger.debug("Applied %d sealed opt-in import paths", len(opt_in_paths)) + + +def _apply_sys_path(snapshot: dict[str, Any]) -> None: + if not _should_apply_host_sys_path(snapshot): + _apply_sealed_opt_in_paths(snapshot) + logger.debug("Skipping host sys.path reconstruction for sealed child") + return + + host_paths = snapshot.get("sys_path", []) + extra_paths = snapshot.get("additional_paths", []) + + preferred_root: str | None = snapshot.get("preferred_root") + if not preferred_root: + context_data = snapshot.get("context_data", {}) + module_path = context_data.get("module_path") or os.environ.get("PYISOLATE_MODULE_PATH") + if module_path: + preferred_root = str(Path(module_path).parent.parent) + + filtered_subdirs = snapshot.get("filtered_subdirs") + child_paths = build_child_sys_path(host_paths, extra_paths, preferred_root, filtered_subdirs) + + if not child_paths: + return + + # Rebuild sys.path with child paths first while preserving any existing entries + # that are not already in the computed set. + _merge_sys_path_front(child_paths) logger.debug("Applied %d paths from snapshot (preferred_root=%s)", len(child_paths), preferred_root) @@ -125,20 +162,25 @@ def bootstrap_child() -> IsolationAdapter | None: _apply_sys_path(snapshot) adapter: IsolationAdapter | None = None + is_sealed = not _should_apply_host_sys_path(snapshot) adapter_ref = snapshot.get("adapter_ref") if adapter_ref: try: adapter = _rehydrate_adapter(adapter_ref) except Exception as exc: - logger.warning("Failed to rehydrate adapter from ref %s: %s", adapter_ref, exc) + logger.warning( + "Failed to rehydrate adapter from ref %s: %s", + adapter_ref, + exc, + ) - if not adapter and adapter_ref: - # If we had info but failed to load, that's an error + if not adapter and adapter_ref and not is_sealed: raise ValueError("Snapshot contained adapter info but adapter could not be loaded") if adapter: - adapter.setup_child_environment(snapshot) + if not is_sealed: + adapter.setup_child_environment(snapshot) registry = SerializerRegistry.get_instance() adapter.register_serializers(registry) diff --git a/pyisolate/_internal/client.py b/pyisolate/_internal/client.py index 707d8de..4dee121 100644 --- a/pyisolate/_internal/client.py +++ b/pyisolate/_internal/client.py @@ -107,6 +107,10 @@ async def async_entrypoint( api_instance = cast(ProxiedSingleton, getattr(api, "instance", api)) _adapter.handle_api_registration(api_instance, rpc) + # Let the adapter wire child-side event hooks (e.g., progress bar) + if _adapter and hasattr(_adapter, "setup_child_event_hooks"): + _adapter.setup_child_event_hooks(extension) + # Sanitize module name for use as Python identifier. # Replace '-' and '.' with '_' to prevent import errors when module names contain # non-identifier characters (e.g., "my-node" β†’ "my_node", "my.node" β†’ "my_node"). diff --git a/pyisolate/_internal/cuda_wheels.py b/pyisolate/_internal/cuda_wheels.py index 668bab8..89a5899 100644 --- a/pyisolate/_internal/cuda_wheels.py +++ b/pyisolate/_internal/cuda_wheels.py @@ -9,16 +9,22 @@ from packaging.markers import default_environment from packaging.requirements import InvalidRequirement, Requirement -from packaging.tags import sys_tags +from packaging.tags import Tag, compatible_tags, cpython_tags, sys_tags from packaging.utils import canonicalize_name, parse_wheel_filename from packaging.version import Version from ..config import CUDAWheelConfig _TORCH_VERSION_RE = re.compile(r"^(?P\d+)\.(?P\d+)") +# TODO: Resolve on a single naming convention with Andrea's cuda-wheels index. +# Currently two conventions coexist: +# - dotted: cu130torch2.10 (Andrea's index for torch >= 2.10) +# - nodot: cu130torch210 (older wheels, some other indices) +# Both patterns accept either form; _matches_runtime normalizes the captured +# torch group by stripping dots before comparison. _CUDA_LOCAL_PATTERNS = ( - re.compile(r"(^|[.-])cu(?P\d+)torch(?P\d+)([.-]|$)"), - re.compile(r"(^|[.-])pt(?P\d+)cu(?P\d+)([.-]|$)"), + re.compile(r"(^|[.-])cu(?P\d+)torch(?P\d+(?:\.\d+)?)([.-]|$)"), + re.compile(r"(^|[.-])pt(?P\d+(?:\.\d+)?)cu(?P\d+)([.-]|$)"), ) @@ -55,7 +61,23 @@ def _parse_major_minor(version_text: str, label: str) -> str: return f"{match.group('major')}.{match.group('minor')}" -def get_cuda_wheel_runtime() -> CUDAWheelRuntime: +def _tags_for_python(target_python: tuple[int, int] | None = None) -> list[Tag]: + """Return wheel compatibility tags for a target Python version. + + When ``target_python`` is ``None``, returns tags for the running + interpreter via ``sys_tags()``. Otherwise generates CPython + + compatible tags for the specified ``(major, minor)`` tuple. + """ + if target_python is None: + return list(sys_tags()) + return list(cpython_tags(python_version=target_python)) + list( + compatible_tags(python_version=target_python) + ) + + +def get_cuda_wheel_runtime( + target_python: tuple[int, int] | None = None, +) -> CUDAWheelRuntime: try: import torch except ImportError as exc: @@ -75,7 +97,7 @@ def get_cuda_wheel_runtime() -> CUDAWheelRuntime: "torch_nodot": torch_version.replace(".", ""), "cuda": cuda_major_minor, "cuda_nodot": cuda_major_minor.replace(".", ""), - "python_tags": [str(tag) for tag in sys_tags()], + "python_tags": [str(tag) for tag in _tags_for_python(target_python)], } @@ -91,12 +113,24 @@ def get_cuda_wheel_runtime_descriptor() -> dict[str, object]: def _normalize_cuda_wheel_config(config: CUDAWheelConfig) -> CUDAWheelConfig: + index_urls = config.get("index_urls") index_url = config.get("index_url") packages = config.get("packages") package_map = config.get("package_map", {}) - if not isinstance(index_url, str) or not index_url.strip(): - raise CUDAWheelResolutionError("cuda_wheels.index_url must be a non-empty string") + # Accept index_urls (plural list) or index_url (singular string). + # Normalize to index_urls (list) internally. + if index_urls is not None: + if not isinstance(index_urls, list) or not all(isinstance(u, str) and u.strip() for u in index_urls): + raise CUDAWheelResolutionError("cuda_wheels.index_urls must be a list of non-empty strings") + normalized_urls = [u.rstrip("/") + "/" for u in index_urls] + elif isinstance(index_url, str) and index_url.strip(): + normalized_urls = [index_url.rstrip("/") + "/"] + else: + raise CUDAWheelResolutionError( + "cuda_wheels requires either index_url (string) or index_urls (list of strings)" + ) + if not isinstance(packages, list) or not all( isinstance(package_name, str) and package_name.strip() for package_name in packages ): @@ -113,7 +147,7 @@ def _normalize_cuda_wheel_config(config: CUDAWheelConfig) -> CUDAWheelConfig: normalized_map[canonicalize_name(dependency_name)] = index_package_name.strip() return { - "index_url": index_url.rstrip("/") + "/", + "index_urls": normalized_urls, "packages": [canonicalize_name(package_name) for package_name in packages], "package_map": normalized_map, } @@ -168,58 +202,67 @@ def _matches_runtime(local_version: str | None, runtime: CUDAWheelRuntime) -> bo match = pattern.search(normalized_local) if not match: continue - if match.group("torch") == runtime["torch_nodot"] and match.group("cuda") == runtime["cuda_nodot"]: + torch_match = match.group("torch").replace(".", "") == runtime["torch_nodot"] + cuda_match = match.group("cuda") == runtime["cuda_nodot"] + if torch_match and cuda_match: return True return False def resolve_cuda_wheel_url( - requirement: Requirement, config: CUDAWheelConfig, runtime: CUDAWheelRuntime | None = None + requirement: Requirement, + config: CUDAWheelConfig, + runtime: CUDAWheelRuntime | None = None, + *, + target_python: tuple[int, int] | None = None, ) -> str: normalized_config = _normalize_cuda_wheel_config(config) dependency_name = canonicalize_name(requirement.name) - runtime_info = runtime or get_cuda_wheel_runtime() - supported_tag_list = list(sys_tags()) + runtime_info = runtime or get_cuda_wheel_runtime(target_python=target_python) + supported_tag_list = _tags_for_python(target_python) supported_tags = set(supported_tag_list) tag_rank = {tag: idx for idx, tag in enumerate(supported_tag_list)} fetch_attempted = False candidates: list[tuple[Version, int, str]] = [] - for package_name in _candidate_package_names(dependency_name, normalized_config.get("package_map", {})): - page_url = urljoin(normalized_config["index_url"], package_name.rstrip("/") + "/") - html = _fetch_index_html(page_url) - if html is None: - continue - fetch_attempted = True - for wheel_url in _parse_index_links(page_url, html): - parsed_url = urlparse(wheel_url) - wheel_filename = unquote(parsed_url.path.rsplit("/", 1)[-1]) - if not wheel_filename.endswith(".whl"): - continue - try: - wheel_name, wheel_version, _, wheel_tags = parse_wheel_filename(wheel_filename) - except ValueError: + for base_url in normalized_config["index_urls"]: + for package_name in _candidate_package_names( + dependency_name, normalized_config.get("package_map", {}) + ): + page_url = urljoin(base_url, package_name.rstrip("/") + "/") + html = _fetch_index_html(page_url) + if html is None: continue - if canonicalize_name(wheel_name) != dependency_name: - continue - matching_tags = wheel_tags.intersection(supported_tags) - if not matching_tags: - continue - if not _matches_runtime(getattr(wheel_version, "local", None), runtime_info): - continue - if requirement.specifier and wheel_version not in requirement.specifier: - continue - candidates.append( - ( - wheel_version, - min(tag_rank[tag] for tag in matching_tags), - _normalize_wheel_url(wheel_url), + fetch_attempted = True + for wheel_url in _parse_index_links(page_url, html): + parsed_url = urlparse(wheel_url) + wheel_filename = unquote(parsed_url.path.rsplit("/", 1)[-1]) + if not wheel_filename.endswith(".whl"): + continue + try: + wheel_name, wheel_version, _, wheel_tags = parse_wheel_filename(wheel_filename) + except ValueError: + continue + if canonicalize_name(wheel_name) != dependency_name: + continue + matching_tags = wheel_tags.intersection(supported_tags) + if not matching_tags: + continue + if not _matches_runtime(getattr(wheel_version, "local", None), runtime_info): + continue + if requirement.specifier and wheel_version not in requirement.specifier: + continue + candidates.append( + ( + wheel_version, + min(tag_rank[tag] for tag in matching_tags), + _normalize_wheel_url(wheel_url), + ) ) - ) if not fetch_attempted: raise CUDAWheelResolutionError( - f"No CUDA wheel index page found for '{requirement.name}' under {normalized_config['index_url']}" + f"No CUDA wheel index page found for '{requirement.name}' under {normalized_config['index_urls']}" ) if not candidates: raise CUDAWheelResolutionError( @@ -231,11 +274,16 @@ def resolve_cuda_wheel_url( return candidates[-1][2] -def resolve_cuda_wheel_requirements(requirements: list[str], config: CUDAWheelConfig) -> list[str]: +def resolve_cuda_wheel_requirements( + requirements: list[str], + config: CUDAWheelConfig, + *, + target_python: tuple[int, int] | None = None, +) -> list[str]: normalized_config = _normalize_cuda_wheel_config(config) configured_packages = set(normalized_config["packages"]) environment = cast(dict[str, str], default_environment()) - runtime = get_cuda_wheel_runtime() + runtime = get_cuda_wheel_runtime(target_python=target_python) resolved_requirements: list[str] = [] for dependency in requirements: @@ -267,6 +315,8 @@ def resolve_cuda_wheel_requirements(requirements: list[str], config: CUDAWheelCo resolved_requirements.append(dependency) continue - resolved_requirements.append(resolve_cuda_wheel_url(requirement, normalized_config, runtime)) + resolved_requirements.append( + resolve_cuda_wheel_url(requirement, normalized_config, runtime, target_python=target_python) + ) return resolved_requirements diff --git a/pyisolate/_internal/environment.py b/pyisolate/_internal/environment.py index 634723f..0719f83 100644 --- a/pyisolate/_internal/environment.py +++ b/pyisolate/_internal/environment.py @@ -22,6 +22,82 @@ ) from .torch_utils import get_torch_ecosystem_packages + +def validate_backend_config(config: ExtensionConfig) -> None: + """Validate backend-specific configuration. Fail loud on invalid combos.""" + package_manager = config.get("package_manager", "uv") + execution_model = config.get("execution_model") + + if execution_model is None: + execution_model = "sealed_worker" if package_manager == "conda" else "host-coupled" + + if execution_model not in {"host-coupled", "sealed_worker"}: + raise ValueError( + f"Unknown execution_model '{execution_model}'. Must be 'host-coupled' or 'sealed_worker'." + ) + + if config.get("share_cuda_ipc", False) and not config.get("share_torch", False): + raise ValueError( + "share_cuda_ipc=True requires share_torch=True. " + "CUDA IPC cannot be enabled without host torch sharing." + ) + + if package_manager == "uv" and execution_model == "sealed_worker" and config.get("share_torch", False): + raise ValueError( + "sealed_worker execution_model requires share_torch=False. " + "Sealed workers use explicit RPC serialization rather than host-coupled tensor sharing." + ) + + sealed_host_ro_paths = config.get("sealed_host_ro_paths") + if sealed_host_ro_paths is not None: + if execution_model != "sealed_worker": + raise ValueError("sealed_host_ro_paths requires execution_model='sealed_worker'.") + if not isinstance(sealed_host_ro_paths, list): + raise ValueError("sealed_host_ro_paths must be a list of absolute paths.") + for path in sealed_host_ro_paths: + if not isinstance(path, str) or not path: + raise ValueError("sealed_host_ro_paths entries must be non-empty strings.") + if not os.path.isabs(path): + raise ValueError("sealed_host_ro_paths entries must be absolute paths.") + + if package_manager == "uv": + return + + if package_manager != "conda": + raise ValueError(f"Unknown package_manager '{package_manager}'. Must be 'uv' or 'conda'.") + + if execution_model != "sealed_worker": + raise ValueError( + "conda backend requires execution_model='sealed_worker'. " + "Conda always runs as a sealed foreign interpreter." + ) + + # conda + share_torch is incompatible + if config.get("share_torch", False): + raise ValueError( + "conda backend requires share_torch=False. Conda uses its own Python " + "interpreter, which is incompatible with zero-copy tensor sharing." + ) + + # cuda_wheels for conda: resolved post-pixi-install via pip --no-deps + # (same wheel resolution as uv, just installed into the pixi env after provisioning) + + # conda requires conda_channels + channels = config.get("conda_channels") + if not channels: + raise ValueError( + "conda_channels is required when package_manager='conda'. " + "Specify at least one channel (e.g. ['conda-forge'])." + ) + + # conda requires pixi on PATH + if not shutil.which("pixi"): + raise ValueError( + "pixi is required for conda backend but not found. " + "Install: curl -fsSL https://pixi.sh/install.sh | bash" + ) + + logger = logging.getLogger(__name__) _DANGEROUS_PATTERNS = ("&&", "||", "|", "`", "$", "\n", "\r", "\0") @@ -149,6 +225,7 @@ def build_extension_snapshot(module_path: str) -> dict[str, object]: "adapter_name": adapter.identifier if adapter else None, "preferred_root": path_config.get("preferred_root"), "additional_paths": path_config.get("additional_paths", []), + "filtered_subdirs": path_config.get("filtered_subdirs"), "context_data": {"module_path": module_path}, } ) @@ -301,6 +378,8 @@ def install_dependencies(venv_path: Path, config: ExtensionConfig, name: str) -> ) safe_deps: list[str] = [] + if config.get("execution_model") == "sealed_worker": + safe_deps.append(str(Path(__file__).resolve().parents[2])) for dep in config["dependencies"]: validate_dependency(dep) safe_deps.append(dep) @@ -339,7 +418,8 @@ def install_dependencies(venv_path: Path, config: ExtensionConfig, name: str) -> common_args: list[str] = ["--cache-dir", str(cache_dir)] torch_spec: str | None = None - if not config["share_torch"]: + needs_child_torch = not config["share_torch"] and config.get("execution_model") != "sealed_worker" + if needs_child_torch: import torch torch_version: str = str(torch.__version__) diff --git a/pyisolate/_internal/environment_conda.py b/pyisolate/_internal/environment_conda.py new file mode 100644 index 0000000..b752de2 --- /dev/null +++ b/pyisolate/_internal/environment_conda.py @@ -0,0 +1,464 @@ +"""Conda/pixi environment creation for pyisolate extensions.""" + +from __future__ import annotations + +import hashlib +import json +import logging +import os +import shutil +import subprocess +import sys +from pathlib import Path + +from packaging.utils import canonicalize_name + +from ..config import CUDAWheelConfig, ExtensionConfig + +logger = logging.getLogger(__name__) + + +def _pyisolate_source_path() -> Path: + """Return the local pyisolate source tree for sealed-worker child installs.""" + return Path(__file__).resolve().parents[2] + + +def _toml_path_string(path: Path) -> str: + """Serialize a local path safely for pixi TOML.""" + return path.as_posix() + + +def _detect_glibc_version() -> str | None: + """Detect the host glibc version for pixi system-requirements.""" + try: + import platform as plat + + ver = plat.libc_ver()[1] + if ver: + return ver + except Exception: + pass + return None + + +def _generate_pixi_toml(config: ExtensionConfig) -> str: + """Generate a pixi.toml manifest from an ExtensionConfig. + + Maps conda_dependencies β†’ [dependencies], pip dependencies β†’ [pypi-dependencies]. + """ + lines: list[str] = [] + + # [workspace] section + name = config.get("module", "extension") + lines.append("[workspace]") + lines.append(f'name = "{name}"') + lines.append('version = "0.1.0"') + + # channels + channels = config.get("conda_channels", []) + if channels: + channels_str = ", ".join(f'"{c}"' for c in channels) + lines.append(f"channels = [{channels_str}]") + + # platforms + platforms = config.get("conda_platforms", []) + if not platforms: + # Auto-detect current platform + if sys.platform == "linux": + platforms = ["linux-64"] + elif sys.platform == "darwin": + import platform as plat + + arch = plat.machine() + platforms = ["osx-arm64"] if arch == "arm64" else ["osx-64"] + elif sys.platform == "win32": + platforms = ["win-64"] + else: + platforms = ["linux-64"] + platforms_str = ", ".join(f'"{p}"' for p in platforms) + lines.append(f"platforms = [{platforms_str}]") + lines.append("") + + # [system-requirements] β€” detect host glibc for correct wheel matching + if sys.platform == "linux": + glibc_version = _detect_glibc_version() + if glibc_version: + lines.append("[system-requirements]") + lines.append(f'libc = {{ family = "glibc", version = "{glibc_version}" }}') + lines.append("") + + # [dependencies] β€” conda packages + conda_deps = config.get("conda_dependencies", []) + lines.append("[dependencies]") + python_version = config.get("conda_python", "*") + lines.append(f'python = "{python_version}"') + if conda_deps: + for dep in conda_deps: + # Parse "numpy>=1.20" β†’ name="numpy", spec=">=1.20" + # Parse "numpy" β†’ name="numpy", spec="*" + name_part, sep, version_part, _extras, _marker = _parse_dep(dep) + if version_part: + lines.append(f'{name_part} = "{version_part}"') + else: + lines.append(f'{name_part} = "*"') + lines.append("") + + # [pypi-dependencies] β€” pip packages + pip_deps = list(config.get("dependencies", [])) + + cuda_wheels_config = config.get("cuda_wheels") + cuda_wheel_packages: set[str] = set() + if cuda_wheels_config: + cuda_wheel_packages = { + canonicalize_name(package_name) for package_name in cuda_wheels_config.get("packages", []) + } + + if config.get("package_manager") == "conda": + # [pypi-options] β€” extra index URLs and find-links for local wheels + pypi_options_lines: list[str] = [] + if cuda_wheels_config: + # Support both single index_url and multiple index_urls + index_urls = cuda_wheels_config.get("index_urls", []) + if not index_urls: + single = cuda_wheels_config.get("index_url", "") + if single: + index_urls = [single] + if index_urls: + urls_str = ", ".join(f'"{u}"' for u in index_urls) + pypi_options_lines.append(f"extra-index-urls = [{urls_str}]") + find_links_raw = config.get("find_links", []) + if isinstance(find_links_raw, str): + find_links: list[str] = [find_links_raw] + elif isinstance(find_links_raw, list): + find_links = find_links_raw + else: + find_links = [] + if find_links: + # Resolve relative paths against the extension's module_path + resolved = [] + module_path = config.get("module_path") + for link in find_links: + link_path = Path(link) + if not link_path.is_absolute() and module_path: + link_path = Path(module_path) / link_path + resolved.append(f'{{ path = "{_toml_path_string(link_path)}" }}') + links_str = ", ".join(resolved) + pypi_options_lines.append(f"find-links = [{links_str}]") + # Pixi treats find-links as exclusive; add PyPI so regular deps resolve + has_extra_index = any("extra-index-urls" in line for line in pypi_options_lines) + if not has_extra_index: + pypi_options_lines.insert(0, 'extra-index-urls = ["https://pypi.org/simple/"]') + if pypi_options_lines: + lines.append("[pypi-options]") + lines.extend(pypi_options_lines) + lines.append("") + + lines.append("[pypi-dependencies]") + lines.append(f'pyisolate = {{ path = "{_toml_path_string(_pyisolate_source_path())}" }}') + for dep in pip_deps: + name_part, sep, version_part, extras, marker = _parse_dep(dep) + if cuda_wheel_packages and canonicalize_name(name_part) in cuda_wheel_packages: + continue + # Build pixi inline table fields + fields: list[str] = [] + if sep == "@": + fields.append(f'url = "{version_part}"') + else: + ver = version_part if version_part else "*" + fields.append(f'version = "{ver}"') + if extras: + extras_str = ", ".join(f'"{e}"' for e in extras) + fields.append(f"extras = [{extras_str}]") + if marker: + fields.append(f'markers = "{marker}"') + # Emit: simple string form when only version, inline table otherwise + if len(fields) == 1 and fields[0].startswith("version"): + ver = version_part if version_part else "*" + lines.append(f'{name_part} = "{ver}"') + else: + lines.append(f"{name_part} = {{ {', '.join(fields)} }}") + lines.append("") + + return "\n".join(lines) + "\n" + + +def _build_pixi_install_env(env_path: Path) -> dict[str, str]: + """Build a stable subprocess env for pixi installs. + + Harness-backed tests may leave ambient TMPDIR pointing at a deleted + directory. Give pixi its own guaranteed-writable temp root instead of + inheriting whatever process-global temp state happens to exist. + """ + env = os.environ.copy() + pixi_tmp = env_path / ".tmp" + pixi_tmp.mkdir(parents=True, exist_ok=True) + env["TMPDIR"] = str(pixi_tmp) + env["TMP"] = str(pixi_tmp) + env["TEMP"] = str(pixi_tmp) + return env + + +def _parse_dep(dep: str) -> tuple[str, str, str, list[str], str]: + """Parse a PEP 508 dependency string into (name, separator, version_spec, extras, marker). + + Handles extras (trimesh[easy]>=4.0.0), URL deps (pkg @ https://...), + and PEP 508 environment markers (jax>=0.4.30; sys_platform == 'linux'). + Extras are extracted and returned separately for pixi ``extras = [...]`` syntax. + Markers are split off before version parsing so they don't contaminate the + version spec. + + Examples: + "numpy>=1.20" β†’ ("numpy", ">=", ">=1.20", [], "") + "numpy" β†’ ("numpy", "", "", [], "") + "scipy==1.10.0" β†’ ("scipy", "==", "==1.10.0", [], "") + "trimesh[easy]>=4.0.0" β†’ ("trimesh", ">=", ">=4.0.0", ["easy"], "") + "jax[cuda12]>=0.4.30" β†’ ("jax", ">=", ">=0.4.30", ["cuda12"], "") + "jax[cuda12]>=0.4.30; sys_platform == 'linux'" + β†’ ("jax", ">=", ">=0.4.30", ["cuda12"], "sys_platform == 'linux'") + "pkg @ https://example.com/pkg.whl" + β†’ ("pkg", "@", "https://example.com/pkg.whl", [], "") + "pkg @ https://example.com/pkg.whl ; python_version >= '3.12'" + β†’ ("pkg", "@", "https://example.com/pkg.whl", [], "python_version >= '3.12'") + """ + # Split off PEP 508 marker before any other parsing. + # Markers follow a semicolon: "dep_spec ; marker_expr" + marker = "" + if ";" in dep: + dep, _, marker = dep.partition(";") + dep = dep.strip() + marker = marker.strip() + + # Extract extras if present + extras: list[str] = [] + if "[" in dep: + bracket_start = dep.index("[") + bracket_end = dep.index("]", bracket_start) + extras_str = dep[bracket_start + 1 : bracket_end] + extras = [e.strip() for e in extras_str.split(",") if e.strip()] + + # Handle URL deps: "name @ url" + if " @ " in dep: + name_part, _, url = dep.partition(" @ ") + name_part = name_part.strip() + if "[" in name_part: + name_part = name_part[: name_part.index("[")] + return name_part.strip(), "@", url.strip(), extras, marker + + # Strip extras from dep before parsing version + clean_dep = dep + if "[" in dep: + bracket_start = dep.index("[") + bracket_end = dep.index("]", bracket_start) + clean_dep = dep[:bracket_start] + dep[bracket_end + 1 :] + + for sep in (">=", "<=", "==", "!=", "~=", ">", "<"): + idx = clean_dep.find(sep) + if idx > 0: + return clean_dep[:idx].strip(), sep, clean_dep[idx:].strip(), extras, marker + return clean_dep.strip(), "", "", extras, marker + + +def create_conda_env(env_path: Path, config: ExtensionConfig, name: str) -> None: + """Create a conda/pixi environment for an extension. + + Writes pixi.toml, runs pixi install, and writes a fingerprint lock file. + Skips install if the fingerprint matches a previous run. + """ + env_path.mkdir(parents=True, exist_ok=True) + + pixi_path = shutil.which("pixi") + if not pixi_path: + raise RuntimeError( + "pixi is required for conda backend but not found on PATH. " + "Install: curl -fsSL https://pixi.sh/install.sh | bash" + ) + + cuda_wheels_config = config.get("cuda_wheels") + + # Generate pixi.toml content + toml_content = _generate_pixi_toml(config) + + # Build fingerprint descriptor + descriptor = { + "conda_dependencies": config.get("conda_dependencies", []), + "pip_dependencies": config.get("dependencies", []), + "channels": config.get("conda_channels", []), + "platforms": config.get("conda_platforms", []), + "cuda_wheels": config.get("cuda_wheels"), + "find_links": config.get("find_links", []), + "pixi_toml": toml_content, + } + fingerprint = hashlib.sha256(json.dumps(descriptor, sort_keys=True).encode()).hexdigest() + + # Check fingerprint β€” skip install if unchanged + lock_path = env_path / ".pyisolate_deps.json" + if lock_path.exists(): + try: + cached = json.loads(lock_path.read_text(encoding="utf-8")) + if cached.get("fingerprint") == fingerprint and cached.get("descriptor") == descriptor: + # Verify python exe still exists before skipping + _resolve_pixi_python(env_path) + logger.debug( + "Conda env fingerprint match for %s, skipping pixi install", + name, + ) + return + except Exception as exc: + logger.debug("Conda fingerprint cache read failed: %s", exc) + + # Write pixi.toml + toml_path = env_path / "pixi.toml" + toml_path.write_text(toml_content, encoding="utf-8") + + # Run pixi install + pixi_env = _build_pixi_install_env(env_path) + subprocess.check_call( + [pixi_path, "install", "--manifest-path", str(toml_path)], + env=pixi_env, # noqa: S603 + ) + + # Verify python exists after install + python_exe = _resolve_pixi_python(env_path) + + if cuda_wheels_config: + _install_cuda_wheels_into_pixi( + python_exe, + config, + cuda_wheels_config, + name, + ) + + # Install local wheels from find_links directories (post-pixi, --no-deps) + fl_raw = config.get("find_links", []) + if isinstance(fl_raw, str): + fl_list: list[str] = [fl_raw] + elif isinstance(fl_raw, list): + fl_list = fl_raw + else: + fl_list = [] + if fl_list: + _install_local_wheels(python_exe, config, fl_list, name) + + # Write fingerprint + lock_path.write_text( + json.dumps({"fingerprint": fingerprint, "descriptor": descriptor}), + encoding="utf-8", + ) + + +def _parse_conda_python_target(conda_python: str) -> tuple[int, int] | None: + """Parse a conda_python spec like ``"3.12.*"`` into ``(3, 12)``. + + Returns ``None`` for wildcard (``"*"``) or unparseable values, + which causes the CUDA wheel resolver to fall back to host tags. + """ + if not conda_python or conda_python == "*": + return None + import re + + match = re.match(r"(\d+)\.(\d+)", conda_python) + if not match: + return None + return (int(match.group(1)), int(match.group(2))) + + +def _resolve_uv_exe(python_exe: Path) -> str: + """Resolve the uv executable path, preferring the pixi env's bin dir. + + On Linux, pixi places uv alongside python in `bin/`. On Windows, python + is at `.pixi/envs/default/python.exe` (no bin/ subdirectory), so the + sibling path does not exist. Falls back to system PATH via shutil.which. + """ + local_uv = python_exe.parent / "uv" + if local_uv.exists(): + return str(local_uv) + found = shutil.which("uv") + if found: + return found + raise RuntimeError(f"uv is required but not found. Checked pixi env ({local_uv}) and system PATH.") + + +def _install_cuda_wheels_into_pixi( + python_exe: Path, + config: ExtensionConfig, + cuda_wheels_config: CUDAWheelConfig, + name: str, +) -> None: + """Install CUDA wheels into a pixi environment via pip --no-deps. + + Uses the same resolver as the uv path (cuda_wheels.py) but installs + into the pixi env's Python instead of a uv venv. + """ + from .cuda_wheels import resolve_cuda_wheel_requirements + + target_python = _parse_conda_python_target(str(config.get("conda_python", "*"))) + deps = list(config.get("dependencies", [])) + resolved = resolve_cuda_wheel_requirements(deps, cuda_wheels_config, target_python=target_python) + + wheel_urls = [] + for orig, res in zip(deps, resolved, strict=True): + if orig != res: + wheel_urls.append(res) + logger.info("][ CUDA_WHEEL_CONDA ext=%s dep=%s -> %s", name, orig, res) + + if not wheel_urls: + return + + uv_exe = _resolve_uv_exe(python_exe) + pip_cmd = [uv_exe, "pip", "install", "--no-deps", "--python", str(python_exe)] + pip_cmd.extend(wheel_urls) + + logger.info("][ CUDA_WHEEL_CONDA_INSTALL ext=%s count=%d", name, len(wheel_urls)) + subprocess.check_call(pip_cmd) # noqa: S603 + + +def _install_local_wheels( + python_exe: Path, + config: ExtensionConfig, + find_links: list[str], + name: str, +) -> None: + """Install all .whl files from find_links directories with --no-deps. + + Used for pre-built CUDA/extension wheels that are shipped in-repo + and have internal cross-dependencies that pixi's resolver can't handle. + """ + module_path = config.get("module_path") + wheel_files: list[str] = [] + for link_dir in find_links: + link_path = Path(link_dir) + if not link_path.is_absolute() and module_path: + link_path = Path(module_path) / link_path + if link_path.is_dir(): + for whl in sorted(link_path.glob("*.whl")): + wheel_files.append(str(whl)) + + if not wheel_files: + return + + uv_exe = _resolve_uv_exe(python_exe) + pip_cmd = [uv_exe, "pip", "install", "--no-deps", "--python", str(python_exe)] + pip_cmd.extend(wheel_files) + + logger.info("][ LOCAL_WHEELS ext=%s count=%d files=%s", name, len(wheel_files), wheel_files) + subprocess.check_call(pip_cmd) # noqa: S603 + + +def _resolve_pixi_python(env_path: Path) -> Path: + """Resolve the Python interpreter inside a pixi environment. + + Returns the path to the pixi-managed Python, NEVER the host interpreter. + Raises RuntimeError if the Python executable does not exist. + """ + if os.name == "nt": + python_exe = env_path / ".pixi" / "envs" / "default" / "python.exe" + else: + python_exe = env_path / ".pixi" / "envs" / "default" / "bin" / "python" + + if not python_exe.exists(): + raise RuntimeError( + f"Python executable not found at {python_exe}. " + "pixi install may have failed or the environment is corrupted." + ) + + return python_exe diff --git a/pyisolate/_internal/event_bridge.py b/pyisolate/_internal/event_bridge.py new file mode 100644 index 0000000..b6d38f3 --- /dev/null +++ b/pyisolate/_internal/event_bridge.py @@ -0,0 +1,30 @@ +"""Internal event bridge for child-to-host event dispatch.""" + +import logging +from collections.abc import Callable +from typing import Any + +logger = logging.getLogger(__name__) + + +class _EventBridge: + """RPC callee registered on the host to receive events from the child. + + The child calls ``dispatch(name, payload)`` via RPC. The host looks up + the registered handler for ``name`` and invokes it with ``payload``. + """ + + def __init__(self) -> None: + self._handlers: dict[str, Callable[..., Any]] = {} + + def register_handler(self, name: str, handler: Callable[..., Any]) -> None: + self._handlers[name] = handler + + async def dispatch(self, name: str, payload: Any) -> None: + if name not in self._handlers: + raise ValueError(f"No handler registered for event '{name}'") + handler = self._handlers[name] + result = handler(payload) + # Support both sync and async handlers + if hasattr(result, "__await__"): + await result diff --git a/pyisolate/_internal/host.py b/pyisolate/_internal/host.py index f0fb764..ecce6d6 100644 --- a/pyisolate/_internal/host.py +++ b/pyisolate/_internal/host.py @@ -18,9 +18,14 @@ create_venv, install_dependencies, normalize_extension_name, + validate_backend_config, validate_dependency, validate_path_within_root, ) +from .environment_conda import ( + _resolve_pixi_python, + create_conda_env, +) from .rpc_protocol import AsyncRPC from .rpc_transports import JSONSocketTransport from .sandbox import build_bwrap_command @@ -99,6 +104,7 @@ def __init__( self.config = config self.extension_type = extension_type self._cuda_ipc_enabled = False + self._host_rpc_services: list[type[Any]] = [] # Auto-populate APIs from adapter if not already in config if "apis" not in self.config: @@ -117,6 +123,19 @@ def __init__( logger.warning("[Extension] Could not load adapter RPC services: %s", exc) self.config["apis"] = [] + if self.config.get("execution_model") == "sealed_worker": + try: + from .adapter_registry import AdapterRegistry + + adapter = AdapterRegistry.get() + if adapter: + self._host_rpc_services = list(adapter.provide_rpc_services()) + except Exception as exc: + logger.warning("[Extension] Could not load sealed-worker host RPC services: %s", exc) + self._host_rpc_services = [] + else: + self._host_rpc_services = list(self.config["apis"]) + self.mp: Any if self.config["share_torch"]: torch, _ = get_torch_optional() @@ -141,6 +160,23 @@ def __init__( self.extension_proxy: T | None = None + def _package_manager(self) -> str: + return self.config.get("package_manager", "uv") + + def _execution_model(self) -> str: + execution_model = self.config.get("execution_model") + if execution_model is not None: + return execution_model + return "sealed_worker" if self._package_manager() == "conda" else "host-coupled" + + def _is_sealed_worker(self) -> bool: + return self._execution_model() == "sealed_worker" + + def _tensor_transport_mode(self) -> str: + if self._is_sealed_worker(): + return "json" + return "shared_memory" + def ensure_process_started(self) -> None: """Start the isolated process if it has not been initialized.""" if self._process_initialized: @@ -194,13 +230,15 @@ def _initialize_process(self) -> None: self.log_listener.start() torch, _ = get_torch_optional() + tensor_transport = self._tensor_transport_mode() if torch is not None: # Register tensor serializer for JSON-RPC only when torch is available. from .serialization_registry import SerializerRegistry - register_tensor_serializer(SerializerRegistry.get_instance()) + register_tensor_serializer(SerializerRegistry.get_instance(), mode=tensor_transport) # Ensure file_system strategy for CPU tensors. - torch.multiprocessing.set_sharing_strategy("file_system") + if tensor_transport == "shared_memory": + torch.multiprocessing.set_sharing_strategy("file_system") elif self.config.get("share_torch", False): raise RuntimeError( "share_torch=True requires PyTorch. Install 'torch' to use tensor-sharing features." @@ -208,9 +246,15 @@ def _initialize_process(self) -> None: self.proc = self.__launch() - for api in self.config["apis"]: + for api in self._host_rpc_services: api()._register(self.rpc) + # Register event bridge for childβ†’host event dispatch + from .event_bridge import _EventBridge + + self._event_bridge = _EventBridge() + self.rpc.register_callee(self._event_bridge, "_event_bridge") + self.rpc.run() def get_proxy(self) -> T: @@ -289,8 +333,17 @@ def stop(self) -> None: def __launch(self) -> Any: """Launch the extension in a separate process after venv + deps are ready.""" - create_venv(self.venv_path, self.config) - install_dependencies(self.venv_path, self.config, self.name) + validate_backend_config(self.config) + + if self._package_manager() == "conda": + # Conda backend: force share_cuda_ipc=False (pixi envs don't support IPC) + self.config["share_cuda_ipc"] = False + self._cuda_ipc_enabled = False + create_conda_env(self.venv_path, self.config, self.name) + else: + create_venv(self.venv_path, self.config) + install_dependencies(self.venv_path, self.config, self.name) + return self._launch_with_uds() def _launch_with_uds(self) -> Any: @@ -298,7 +351,9 @@ def _launch_with_uds(self) -> Any: from .socket_utils import ensure_ipc_socket_dir, has_af_unix # Determine Python executable - if os.name == "nt": + if self._package_manager() == "conda": + python_exe = str(_resolve_pixi_python(self.venv_path)) + elif os.name == "nt": python_exe = str(self.venv_path / "Scripts" / "python.exe") else: python_exe = str(self.venv_path / "bin" / "python") @@ -327,6 +382,9 @@ def _launch_with_uds(self) -> Any: # Prepare environment env = os.environ.copy() + pyisolate_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + if "env" in self.config: + env.update(self.config["env"]) # Get sandbox mode (default: REQUIRED) sandbox_mode = self.config.get("sandbox_mode", SandboxMode.REQUIRED) @@ -336,39 +394,37 @@ def _launch_with_uds(self) -> Any: # Check platform for sandbox requirement use_sandbox = False + is_sealed_worker = self._is_sealed_worker() + use_sealed_worker_bwrap = is_sealed_worker if sys.platform == "linux": - cap = detect_sandbox_capability() - if sandbox_mode == SandboxMode.DISABLED: - # User explicitly disabled sandbox - emit LOUD warning - logger.warning("=" * 78) - logger.warning("SECURITY WARNING: Sandbox DISABLED for extension '%s'", self.name) - logger.warning( - "The isolated process will have FULL ACCESS to your filesystem, " - "network, and GPU memory. This is STRONGLY DISCOURAGED for any " - "code you did not write yourself." - ) - logger.warning( - "To enable sandbox protection, remove 'sandbox_mode: disabled' " - "from your extension config." - ) - logger.warning("=" * 78) use_sandbox = False - elif not cap.available: - # REQUIRED mode (default) but bwrap unavailable - fail loud - raise RuntimeError( - f"Process isolation on Linux REQUIRES bubblewrap.\n" - f"Error: {cap.remediation}\n" - f"Details: {cap.restriction_model} - {cap.raw_error}\n\n" - f"If you understand the security risks and want to proceed without " - f"sandbox protection, set sandbox_mode='disabled' in your extension config." - ) - else: + elif use_sealed_worker_bwrap: + cap = detect_sandbox_capability() + if not cap.available: + raise RuntimeError( + f"Process isolation on Linux REQUIRES bubblewrap.\n" + f"Error: {cap.remediation}\n" + f"Details: {cap.restriction_model} - {cap.raw_error}\n\n" + f"If you understand the security risks and want to proceed without " + f"sandbox protection, set sandbox_mode='disabled' in your extension config." + ) use_sandbox = True - - # Apply env overrides BEFORE building cmd or bwrap env - if "env" in self.config: - env.update(self.config["env"]) + elif is_sealed_worker: + use_sandbox = False + else: + cap = detect_sandbox_capability() + if not cap.available: + # REQUIRED mode (default) but bwrap unavailable - fail loud + raise RuntimeError( + f"Process isolation on Linux REQUIRES bubblewrap.\n" + f"Error: {cap.remediation}\n" + f"Details: {cap.restriction_model} - {cap.raw_error}\n\n" + f"If you understand the security risks and want to proceed without " + f"sandbox protection, set sandbox_mode='disabled' in your extension config." + ) + else: + use_sandbox = True if use_sandbox: # Build Bwrap Command @@ -376,21 +432,13 @@ def _launch_with_uds(self) -> Any: if isinstance(sandbox_config, bool): sandbox_config = {} - # Detect host site-packages to allow access to Torch/Comfy dependencies - import site + adapter = None + try: + from .adapter_registry import AdapterRegistry - extra_binds = [] - - # Add standard site-packages - site_packages = site.getsitepackages() - for sp in site_packages: - if os.path.exists(sp): - extra_binds.append(sp) - - # Also add user site-packages just in case - user_site = site.getusersitepackages() - if isinstance(user_site, str) and os.path.exists(user_site): - extra_binds.append(user_site) + adapter = AdapterRegistry.get() + except Exception as exc: + logger.warning("[Extension] Could not load adapter for sandbox paths: %s", exc) cmd = build_bwrap_command( python_exe=python_exe, @@ -401,6 +449,9 @@ def _launch_with_uds(self) -> Any: allow_gpu=True, # Default to allowing GPU for ComfyUI nodes restriction_model=cap.restriction_model, env_overrides=self.config.get("env"), + adapter=adapter, + execution_model=self._execution_model(), + sealed_host_ro_paths=cast(list[str] | None, self.config.get("sealed_host_ro_paths")), ) else: # Linux without sandbox (DISABLED mode) @@ -410,6 +461,9 @@ def _launch_with_uds(self) -> Any: env["PYISOLATE_EXTENSION"] = self.name env["PYISOLATE_MODULE_PATH"] = self.module_path env["PYISOLATE_ENABLE_CUDA_IPC"] = "1" if self._cuda_ipc_enabled else "0" + if is_sealed_worker: + env["PYTHONPATH"] = pyisolate_root + env["PYTHONNOUSERSITE"] = "1" else: # Non-Linux (Windows/Mac) - Fallback to direct launch @@ -420,13 +474,23 @@ def _launch_with_uds(self) -> Any: env["PYISOLATE_EXTENSION"] = self.name env["PYISOLATE_MODULE_PATH"] = self.module_path env["PYISOLATE_ENABLE_CUDA_IPC"] = "1" if self._cuda_ipc_enabled else "0" + if is_sealed_worker: + env["PYTHONPATH"] = pyisolate_root + env["PYTHONNOUSERSITE"] = "1" + + # Tell the child whether to import torch (controls adapter import chain) + if not self.config.get("share_torch", False): + env["PYISOLATE_IMPORT_TORCH"] = "0" + # Also inject into config env for bwrap path which uses env_overrides + if "env" not in self.config: + self.config["env"] = {} + self.config["env"]["PYISOLATE_IMPORT_TORCH"] = "0" # Launch process - # logger.error(f"[BWRAP-DEBUG] Final subprocess.Popen args: {cmd}") - proc = subprocess.Popen( cmd, env=env, + cwd=self.module_path if is_sealed_worker else None, stdout=None, # Inherit stdout/stderr for now so we see logs stderr=None, close_fds=True, @@ -461,11 +525,26 @@ def accept_connection() -> None: raise RuntimeError(f"Child connection is None for {self.name}") # Setup JSON-RPC + tensor_transport = self._tensor_transport_mode() transport = JSONSocketTransport(client_sock) + if hasattr(transport, "set_tensor_transport_mode"): + transport.set_tensor_transport_mode(tensor_transport) logger.debug("Child connected, sending bootstrap data") # Send bootstrap snapshot = build_extension_snapshot(self.module_path) + if is_sealed_worker: + snapshot["apply_host_sys_path"] = False + snapshot["preferred_root"] = None + snapshot["additional_paths"] = [] + sealed_host_ro_paths = self.config.get("sealed_host_ro_paths") or [] + snapshot["sealed_host_ro_paths"] = list(cast(list[str], sealed_host_ro_paths)) + # If RO paths are configured, preserve adapter_ref so the sealed + # child can rehydrate the adapter and register serializers. + # Without RO paths, null out adapter to maintain hermetic boundary. + if not sealed_host_ro_paths: + snapshot["adapter_ref"] = None + snapshot["adapter_name"] = None ext_type_ref = f"{self.extension_type.__module__}.{self.extension_type.__name__}" # Sanitize config for JSON serialization (convert API classes to string refs) @@ -481,6 +560,7 @@ def accept_connection() -> None: "snapshot": snapshot, "config": safe_config, "extension_type_ref": ext_type_ref, + "tensor_transport": tensor_transport, } transport.send(bootstrap_data) @@ -492,3 +572,7 @@ def accept_connection() -> None: def join(self) -> None: """Join the child process, blocking until it exits.""" self.proc.join() + + def register_event_handler(self, name: str, handler: Any) -> None: + """Register a handler for named events emitted by the child process.""" + self._event_bridge.register_handler(name, handler) diff --git a/pyisolate/_internal/model_serialization.py b/pyisolate/_internal/model_serialization.py index 6f7c2ef..0922b1d 100644 --- a/pyisolate/_internal/model_serialization.py +++ b/pyisolate/_internal/model_serialization.py @@ -10,7 +10,6 @@ coupling pyisolate to any specific framework. """ -import contextlib import logging import os import sys @@ -27,34 +26,24 @@ logger = logging.getLogger(__name__) -def serialize_for_isolation(data: Any) -> Any: - """Serialize data for transmission to an isolated process (host side). - - Adapter-registered objects are converted to reference dictionaries so the - isolated process can fetch them lazily. RemoteObjectHandle instances are passed - through to preserve identity without pickling heavyweight objects. - """ +def _serialize_for_isolation_impl( + data: Any, + *, + registry: SerializerRegistry, + torch_module: Any, + remote_handle_type: type[Any], +) -> Any: type_name = type(data).__name__ - # Adapter-registered serializers take precedence over built-in handlers - registry = SerializerRegistry.get_instance() - if registry.has_handler(type_name): - serializer = registry.get_serializer(type_name) - if serializer: - return serializer(data) - - # If this object originated as a RemoteObjectHandle, send the original - # handle only when no adapter serializer is available for this type. - # This avoids cross-extension stale handle reuse for serializer-backed - # objects (e.g. CLIP/ModelPatcher/VAE refs). - from .remote_handle import RemoteObjectHandle - handle = getattr(data, "_pyisolate_remote_handle", None) - if isinstance(handle, RemoteObjectHandle): + if isinstance(handle, remote_handle_type): return handle - torch, _ = get_torch_optional() - if torch is not None and isinstance(data, torch.Tensor): + serializer = registry.get_serializer(type_name) + if serializer is not None: + return serializer(data) + + if torch_module is not None and isinstance(data, torch_module.Tensor): if data.is_cuda: if _cuda_ipc_enabled: return data @@ -62,15 +51,50 @@ def serialize_for_isolation(data: Any) -> Any: return data if isinstance(data, dict): - return {k: serialize_for_isolation(v) for k, v in data.items()} + return { + k: _serialize_for_isolation_impl( + v, + registry=registry, + torch_module=torch_module, + remote_handle_type=remote_handle_type, + ) + for k, v in data.items() + } if isinstance(data, (list, tuple)): - result = [serialize_for_isolation(item) for item in data] + result = [ + _serialize_for_isolation_impl( + item, + registry=registry, + torch_module=torch_module, + remote_handle_type=remote_handle_type, + ) + for item in data + ] return type(data)(result) return data +def serialize_for_isolation(data: Any) -> Any: + """Serialize data for transmission to an isolated process (host side). + + Adapter-registered objects are converted to reference dictionaries so the + isolated process can fetch them lazily. RemoteObjectHandle instances are passed + through to preserve identity without pickling heavyweight objects. + """ + registry = SerializerRegistry.get_instance() + from .remote_handle import RemoteObjectHandle + + torch, _ = get_torch_optional() + return _serialize_for_isolation_impl( + data, + registry=registry, + torch_module=torch, + remote_handle_type=RemoteObjectHandle, + ) + + async def deserialize_from_isolation(data: Any, extension: Any = None, _nested: bool = False) -> Any: """Deserialize data received from an isolated process (host side). @@ -85,19 +109,13 @@ async def deserialize_from_isolation(data: Any, extension: Any = None, _nested: registry = SerializerRegistry.get_instance() if isinstance(data, RemoteObjectHandle): - if _nested or extension is None: - return data - if registry.has_handler(data.type_name): - return data - try: - resolved = await extension.get_remote_object(data.object_id) - with contextlib.suppress(Exception): - resolved._pyisolate_remote_handle = data - return resolved - except Exception: - return data - - # Check for adapter-registered deserializers by type name (e.g., NodeOutput). + # Handles with a registered handler are returned opaque for the caller + # to process. Handles with NO registered handler are pack-local proxy + # handles β€” keep them opaque so they round-trip back to the originating + # child without a wasteful (and doomed) RPC resolution attempt. + return data + + # Check for adapter-registered deserializers by type name. # Only apply to dicts (serialized form). Objects already deserialized by the # JSON transport layer (e.g., PLY reconstructed via _json_object_hook) are # passed through as-is. diff --git a/pyisolate/_internal/perf_trace.py b/pyisolate/_internal/perf_trace.py new file mode 100644 index 0000000..cd5f7a8 --- /dev/null +++ b/pyisolate/_internal/perf_trace.py @@ -0,0 +1,53 @@ +from __future__ import annotations + +import json +import os +import threading +import time +from pathlib import Path +from typing import Any + +_TRACE_ENV = "PYISOLATE_TRACE_FILE" +_LOCK = threading.Lock() + + +def trace_path() -> str | None: + path = os.environ.get(_TRACE_ENV) + if not path: + return None + return path + + +def tracing_enabled() -> bool: + return trace_path() is not None + + +def estimate_payload_bytes(payload: Any) -> int: + try: + encoded = json.dumps( + payload, + separators=(",", ":"), + sort_keys=True, + default=str, + ).encode("utf-8") + return len(encoded) + except Exception: + return 0 + + +def record_event(event: dict[str, Any]) -> None: + path = trace_path() + if not path: + return + + enriched = { + "ts_ns": time.time_ns(), + "pid": os.getpid(), + **event, + } + + output_path = Path(path) + output_path.parent.mkdir(parents=True, exist_ok=True) + with _LOCK, output_path.open("a", encoding="utf-8") as handle: + handle.write(json.dumps(enriched, sort_keys=True)) + handle.write("\n") diff --git a/pyisolate/_internal/rpc_serialization.py b/pyisolate/_internal/rpc_serialization.py index 078affb..d57e49d 100644 --- a/pyisolate/_internal/rpc_serialization.py +++ b/pyisolate/_internal/rpc_serialization.py @@ -20,6 +20,8 @@ TypedDict, ) +from .torch_gate import get_torch_optional + if TYPE_CHECKING: # Avoid circular imports for type checking if possible # But here we just need types that might be used in annotations @@ -222,6 +224,7 @@ class RPCPendingRequest(TypedDict): # Removed static _cuda_ipc_env_enabled to allow runtime updates _cuda_ipc_warned = False _ipc_metrics: dict[str, int] = {"send_cuda_ipc": 0, "send_cuda_fallback": 0} +_SERIALIZER_BY_TYPE: dict[type[Any], Any | None] = {} def debugprint(*args: Any, **kwargs: Any) -> None: @@ -234,6 +237,63 @@ def debugprint(*args: Any, **kwargs: Any) -> None: # --------------------------------------------------------------------------- +def _resolve_serializer_for_type(registry: Any, obj_type: type[Any]) -> Any | None: + cached = _SERIALIZER_BY_TYPE.get(obj_type) + if obj_type in _SERIALIZER_BY_TYPE: + return cached + + serializer = registry.get_serializer(obj_type.__name__) + if serializer is not None: + _SERIALIZER_BY_TYPE[obj_type] = serializer + return serializer + + for base in obj_type.__mro__[1:]: + serializer = registry.get_serializer(base.__name__) + if serializer is not None: + _SERIALIZER_BY_TYPE[obj_type] = serializer + return serializer + + _SERIALIZER_BY_TYPE[obj_type] = None + return None + + +def _prepare_for_rpc_impl( + obj: Any, + *, + registry: Any, + torch_module: Any, +) -> Any: + obj_type = type(obj) + serializer = _resolve_serializer_for_type(registry, obj_type) + if serializer is not None: + return serializer(obj) + + if torch_module is not None and isinstance(obj, torch_module.Tensor): + if obj.is_cuda: + if os.environ.get("PYISOLATE_ENABLE_CUDA_IPC") == "1": + _ipc_metrics["send_cuda_ipc"] += 1 + return obj + _ipc_metrics["send_cuda_fallback"] += 1 + return obj.cpu() + return obj + + if isinstance(obj, dict): + return { + k: _prepare_for_rpc_impl(v, registry=registry, torch_module=torch_module) for k, v in obj.items() + } + + if isinstance(obj, (list, tuple)): + converted = [ + _prepare_for_rpc_impl(item, registry=registry, torch_module=torch_module) for item in obj + ] + return tuple(converted) if isinstance(obj, tuple) else converted + + if isinstance(obj, (str, int, float, bool, type(None), bytes)): + return obj + + return obj + + def _prepare_for_rpc(obj: Any) -> Any: """Recursively prepare objects for RPC transport. @@ -245,53 +305,11 @@ def _prepare_for_rpc(obj: Any) -> Any: Adapter-registered types are serialized via SerializerRegistry. Unpicklable custom containers are downgraded into plain serializable forms. """ - type_name = type(obj).__name__ - - # Check for adapter-registered serializers first from .serialization_registry import SerializerRegistry registry = SerializerRegistry.get_instance() - - # Try exact type name first (fast path) - if registry.has_handler(type_name): - serializer = registry.get_serializer(type_name) - if serializer: - return serializer(obj) - - # Check base classes for inheritance support - for base in type(obj).__mro__[1:]: # Skip obj itself - if registry.has_handler(base.__name__): - serializer = registry.get_serializer(base.__name__) - if serializer: - return serializer(obj) - - try: - import torch - - if isinstance(obj, torch.Tensor): - if obj.is_cuda: - # Dynamic check to respect runtime activation in host.py - if os.environ.get("PYISOLATE_ENABLE_CUDA_IPC") == "1": - _ipc_metrics["send_cuda_ipc"] += 1 - return obj # allow CUDA IPC path - _ipc_metrics["send_cuda_fallback"] += 1 - return obj.cpu() - return obj - except ImportError: - pass - - if isinstance(obj, dict): - return {k: _prepare_for_rpc(v) for k, v in obj.items()} - - if isinstance(obj, (list, tuple)): - converted = [_prepare_for_rpc(item) for item in obj] - return tuple(converted) if isinstance(obj, tuple) else converted - - # Primitives pass through - if isinstance(obj, (str, int, float, bool, type(None), bytes)): - return obj - - return obj + torch_module, _ = get_torch_optional() + return _prepare_for_rpc_impl(obj, registry=registry, torch_module=torch_module) def _tensor_to_cuda(obj: Any, device: Any | None = None) -> Any: @@ -324,10 +342,28 @@ def _tensor_to_cuda(obj: Any, device: Any | None = None) -> Any: if isinstance(obj, dict): ref_type = obj.get("__type__") - if ref_type and registry.has_handler(ref_type): - deserializer = registry.get_deserializer(ref_type) - if deserializer: - return deserializer(obj) + if ref_type: + has = registry.has_handler(ref_type) + if has: + deserializer = registry.get_deserializer(ref_type) + if deserializer: + result = deserializer(obj) + logger.warning( + "][ DIAG: __type__=%s -> %s", + ref_type, + type(result).__name__, + ) + return result + else: + logger.warning( + "][ DIAG: __type__=%s no deserializer", + ref_type, + ) + else: + logger.warning( + "][ DIAG: __type__=%s no handler", + ref_type, + ) # Handle pyisolate internal container types if obj.get("__pyisolate_attribute_container__") and "data" in obj: diff --git a/pyisolate/_internal/rpc_transports.py b/pyisolate/_internal/rpc_transports.py index 31f09fe..eea332f 100644 --- a/pyisolate/_internal/rpc_transports.py +++ b/pyisolate/_internal/rpc_transports.py @@ -107,10 +107,14 @@ class JSONSocketTransport: Used for ALL Linux isolation modes (sandbox and non-sandbox). """ - def __init__(self, sock: socket.socket) -> None: + def __init__(self, sock: socket.socket, tensor_transport: str = "shared_memory") -> None: self._sock = sock self._lock = threading.Lock() self._recv_lock = threading.Lock() + self._tensor_transport = tensor_transport + + def set_tensor_transport_mode(self, tensor_transport: str) -> None: + self._tensor_transport = tensor_transport def send(self, obj: Any) -> None: """Serialize to JSON with length prefix.""" @@ -149,7 +153,7 @@ def recv(self) -> Any: raise ValueError(f"Message too large: {msg_len} bytes") if msg_len > 100 * 1024 * 1024: # 100MB β€” flag large payloads logger.warning( - "Large RPC message: %.1fMB β€” consider SHM-backed transfer for this type", + "Large RPC message: %.1fMB", msg_len / (1024 * 1024), ) data = self._recvall(msg_len) @@ -248,6 +252,21 @@ def _json_default(self, obj: Any) -> Any: "type_name": obj.type_name, } + # Handle numpy types (scalars and arrays) β€” must be before torch check + try: + import numpy as np + + if isinstance(obj, np.integer): + return int(obj) # type: ignore[arg-type] + if isinstance(obj, np.floating): + return float(obj) # type: ignore[arg-type] + if isinstance(obj, np.bool_): + return bool(obj) # type: ignore[arg-type] + if isinstance(obj, np.ndarray): + return obj.tolist() + except ImportError: + pass + # Handle PyTorch tensors BEFORE __dict__ check (tensors have __dict__ but shouldn't use it) try: import torch @@ -255,7 +274,7 @@ def _json_default(self, obj: Any) -> Any: if isinstance(obj, torch.Tensor): from .tensor_serializer import serialize_tensor - return serialize_tensor(obj) + return serialize_tensor(obj, mode=self._tensor_transport) except ImportError: pass @@ -355,39 +374,36 @@ def _json_object_hook(self, dct: dict) -> Any: return RemoteObjectHandle(dct["object_id"], dct["type_name"]) + # Handle Tensor payloads during JSON parsing + if dct.get("__type__") in {"TensorRef", "TensorValue"}: + try: + from .tensor_serializer import deserialize_tensor + + return deserialize_tensor(dct, mode=self._tensor_transport) + except Exception: + return dct + # Generic Registry Lookup for __type__ if "__type__" in dct: type_name = dct["__type__"] - # Skip TensorRef here as it has special handling below (or generic can handle it if registered) - if type_name != "TensorRef": - from .serialization_registry import SerializerRegistry - - registry = SerializerRegistry.get_instance() - deserializer = registry.get_deserializer(type_name) - if deserializer: - try: - return deserializer(dct) - except Exception as e: - # Log error but don't crash - return dict as fallback - logger.warning(f"Failed to deserialize {type_name}: {e}") - - # Handle TensorRef - deserialize tensors during JSON parsing - if dct.get("__type__") == "TensorRef": from .serialization_registry import SerializerRegistry registry = SerializerRegistry.get_instance() - if registry.has_handler("TensorRef"): + deserializer = registry.get_deserializer(type_name) + if deserializer: + try: + return deserializer(dct) + except Exception as e: + logger.error( + "Deserialization failed for __type__=%s: %s", + type_name, + e, + ) + raise + if type_name == "TensorRef": deserializer = registry.get_deserializer("TensorRef") if deserializer: return deserializer(dct) - # Fallback: direct import if registry not yet populated - try: - from .tensor_serializer import deserialize_tensor - - return deserialize_tensor(dct) - except Exception: - pass - return dct # Last resort fallback # Reconstruct Enums if dct.get("__pyisolate_enum__"): diff --git a/pyisolate/_internal/sandbox.py b/pyisolate/_internal/sandbox.py index 8634a19..217116f 100644 --- a/pyisolate/_internal/sandbox.py +++ b/pyisolate/_internal/sandbox.py @@ -55,6 +55,17 @@ "dri", # Direct Rendering Infrastructure ] +FORBIDDEN_WRITABLE_BIND_PATHS: frozenset[str] = frozenset({"/tmp"}) + + +def _linuxbrew_root(path: Path) -> str | None: + """Return the Linuxbrew root for a path under ``.../.linuxbrew/...``.""" + parts = path.parts + if ".linuxbrew" not in parts: + return None + idx = parts.index(".linuxbrew") + return str(Path(*parts[: idx + 1])) + def _validate_adapter_path(path: str) -> bool: """Validate that an adapter-provided path doesn't weaken sandbox security. @@ -90,6 +101,8 @@ def build_bwrap_command( restriction_model: RestrictionModel = RestrictionModel.NONE, env_overrides: dict[str, str] | None = None, adapter: "IsolationAdapter | None" = None, + execution_model: str = "host-coupled", + sealed_host_ro_paths: list[str] | None = None, ) -> list[str]: """Build the bubblewrap command for launching a sandboxed process. @@ -119,6 +132,7 @@ def build_bwrap_command( sandbox_config = {} cmd = ["bwrap"] + is_sealed_worker = execution_model == "sealed_worker" # Namespace Isolation Logic # ------------------------- @@ -152,15 +166,52 @@ def build_bwrap_command( # Start with default paths system_paths = list(SANDBOX_SYSTEM_PATHS) - # Query adapter for additional paths (safe access for structural typing) - get_adapter_paths = getattr(adapter, "get_sandbox_system_paths", lambda: None) - adapter_paths = get_adapter_paths() - if adapter_paths: - for path in adapter_paths: - if _validate_adapter_path(path): - system_paths.append(path) - else: - logger.warning("Adapter path '%s' rejected: would weaken sandbox security", path) + # Always explicitly bind the python base prefix to ensure the interpreter + # and its lib dependencies are accessible if installed in non-standard locations (e.g. linuxbrew) + if sys.base_prefix not in system_paths: + system_paths.append(sys.base_prefix) + + # uv-created venv interpreters can be symlinks into a non-standard + # installation prefix (for example Homebrew/Linuxbrew cellar paths). + # Bind that resolved prefix too, or bwrap will fail to exec the venv's + # python binary even though the venv itself is mounted. + python_path = Path(python_exe) + try: + if python_path.is_symlink(): + raw_link_target = os.readlink(python_exe) + raw_link_path = Path(raw_link_target) + if not raw_link_path.is_absolute(): + raw_link_path = (python_path.parent / raw_link_path).resolve() + raw_python_prefix = str(raw_link_path.parent.parent) + if raw_python_prefix not in system_paths: + system_paths.append(raw_python_prefix) + raw_linuxbrew_root = _linuxbrew_root(raw_link_path) + if raw_linuxbrew_root and raw_linuxbrew_root not in system_paths: + system_paths.append(raw_linuxbrew_root) + except OSError: + pass + + try: + resolved_python_path = python_path.resolve() + resolved_python_prefix = str(resolved_python_path.parent.parent) + if resolved_python_prefix not in system_paths: + system_paths.append(resolved_python_prefix) + resolved_linuxbrew_root = _linuxbrew_root(resolved_python_path) + if resolved_linuxbrew_root and resolved_linuxbrew_root not in system_paths: + system_paths.append(resolved_linuxbrew_root) + except OSError: + pass + + if not is_sealed_worker: + # Query adapter for additional paths only for host-coupled execution. + get_adapter_paths = getattr(adapter, "get_sandbox_system_paths", lambda: None) + adapter_paths = get_adapter_paths() + if adapter_paths: + for path in adapter_paths: + if _validate_adapter_path(path): + system_paths.append(path) + else: + logger.warning("Adapter path '%s' rejected: would weaken sandbox security", path) for sys_path in system_paths: if os.path.exists(sys_path): @@ -172,6 +223,12 @@ def build_bwrap_command( # Module path: READ-ONLY cmd.extend(["--ro-bind", str(module_path), str(module_path)]) + if is_sealed_worker and sealed_host_ro_paths: + for ro_path in sealed_host_ro_paths: + normalized = os.path.normpath(ro_path) + if os.path.exists(normalized): + cmd.extend(["--ro-bind", normalized, normalized]) + # GPU passthrough (if enabled) if allow_gpu: cmd.extend(["--ro-bind", "/sys", "/sys"]) @@ -180,11 +237,12 @@ def build_bwrap_command( # Start with default GPU patterns gpu_patterns = list(GPU_PASSTHROUGH_PATTERNS) - # Query adapter for additional GPU patterns (safe access) - get_gpu_patterns = getattr(adapter, "get_sandbox_gpu_patterns", lambda: None) - adapter_gpu_patterns = get_gpu_patterns() - if adapter_gpu_patterns: - gpu_patterns.extend(adapter_gpu_patterns) + if not is_sealed_worker: + # Query adapter for additional GPU patterns only for host-coupled execution. + get_gpu_patterns = getattr(adapter, "get_sandbox_gpu_patterns", lambda: None) + adapter_gpu_patterns = get_gpu_patterns() + if adapter_gpu_patterns: + gpu_patterns.extend(adapter_gpu_patterns) for pattern in gpu_patterns: for dev in dev_path.glob(pattern): @@ -225,40 +283,28 @@ def build_bwrap_command( # MOVED: path bindings moved to end to prevent masking by RO binds - # 1. Host venv site-packages: READ-ONLY (for share_torch inheritance via .pth file) - # The child venv has a .pth file pointing to host site-packages for torch sharing - # We find where 'torch' is likely installed (host site-packages) - host_site_packages = Path(sys.executable).parent.parent / "lib" - for sp in host_site_packages.glob("python*/site-packages"): - if sp.exists(): - cmd.extend(["--ro-bind", str(sp), str(sp)]) - break + pyisolate_path: Path | None = None + if not is_sealed_worker: + # 1. Host venv site-packages: READ-ONLY (for share_torch inheritance via .pth file) + # The child venv has a .pth file pointing to host site-packages for torch sharing + # We find where 'torch' is likely installed (host site-packages) + host_site_packages = Path(sys.executable).parent.parent / "lib" + for sp in host_site_packages.glob("python*/site-packages"): + if sp.exists(): + cmd.extend(["--ro-bind", str(sp), str(sp)]) + break - # 2. PyIsolate package path: READ-ONLY (needed for sandbox_client/uds_client) - import pyisolate as pyisolate_pkg + # 2. PyIsolate package path: READ-ONLY (needed for sandbox_client/uds_client) + import pyisolate as pyisolate_pkg - pyisolate_path = Path(pyisolate_pkg.__file__).parent.parent.resolve() - cmd.extend(["--ro-bind", str(pyisolate_path), str(pyisolate_path)]) + pyisolate_path = Path(pyisolate_pkg.__file__).parent.parent.resolve() + cmd.extend(["--ro-bind", str(pyisolate_path), str(pyisolate_path)]) - # 3. ComfyUI package path: READ-ONLY (needed for comfy.isolation.adapter) - try: - import comfy # type: ignore[import] - - if hasattr(comfy, "__file__") and comfy.__file__: - comfy_path = Path(comfy.__file__).parent.parent.resolve() - elif hasattr(comfy, "__path__"): - # Namespace package support - comfy_path = Path(list(comfy.__path__)[0]).parent.resolve() - else: - comfy_path = None - - if comfy_path: - cmd.extend(["--ro-bind", str(comfy_path), str(comfy_path)]) - except Exception: - pass + # Application paths (e.g., ComfyUI) are provided by the adapter via + # get_sandbox_system_paths() at lines 161-166. No direct framework imports. - # Shared Memory (REQUIRED for zero-copy tensors via SharedMemory Lease) - if Path("/dev/shm").exists(): + # Shared memory is only needed for host-coupled shared-memory tensor transport. + if not is_sealed_worker and Path("/dev/shm").exists(): cmd.extend(["--bind", "/dev/shm", "/dev/shm"]) # UDS socket directory must be accessible @@ -281,6 +327,10 @@ def build_bwrap_command( # 1. Writable paths from config (user-specified) # Placed here so they can punch holes in RO binds (e.g. ComfyUI/temp inside RO ComfyUI) for path in sandbox_config.get("writable_paths", []): + normalized_path = os.path.normpath(path) + if normalized_path in FORBIDDEN_WRITABLE_BIND_PATHS: + logger.warning("Skipping forbidden writable sandbox path: %s", normalized_path) + continue if os.path.exists(path): cmd.extend(["--bind", path, path]) @@ -295,49 +345,72 @@ def build_bwrap_command( if os.path.exists(src): cmd.extend(["--ro-bind", src, dst]) + if is_sealed_worker: + cmd.append("--clearenv") + # Environment variables cmd.extend(["--setenv", "PYISOLATE_UDS_ADDRESS", uds_address]) cmd.extend(["--setenv", "PYISOLATE_CHILD", "1"]) - - # 4. Set PYTHONPATH to include pyisolate package - # This ensures the child can find 'pyisolate' even if not installed in its venv - pyisolate_parent = str(pyisolate_path) - # Start with our explicitly bound package - new_pythonpath_parts = [pyisolate_parent] - - # Check existing PYTHONPATH - existing_pythonpath = os.environ.get("PYTHONPATH", "") - if existing_pythonpath: - new_pythonpath_parts.append(existing_pythonpath) - - cmd.extend(["--setenv", "PYTHONPATH", ":".join(new_pythonpath_parts)]) - - # Inherit select environment variables - # Standard environment - for env_var in ["PATH", "HOME", "LANG", "LC_ALL"]: - if env_var in os.environ: - cmd.extend(["--setenv", env_var, os.environ[env_var]]) - - # CUDA/GPU environment variables (critical for GPU access) - cuda_env_vars = [ - "CUDA_HOME", - "CUDA_PATH", - "CUDA_VISIBLE_DEVICES", - "NVIDIA_VISIBLE_DEVICES", - "LD_LIBRARY_PATH", - "PYTORCH_CUDA_ALLOC_CONF", - "TORCH_CUDA_ARCH_LIST", - "PYISOLATE_ENABLE_CUDA_IPC", - "PYISOLATE_ENABLE_CUDA_IPC", - ] - for env_var in cuda_env_vars: - if env_var in os.environ: - cmd.extend(["--setenv", env_var, os.environ[env_var]]) - - # Coverage / Profiling forwarding - for key, val in os.environ.items(): - if key.startswith(("COV_", "COVERAGE_")): - cmd.extend(["--setenv", key, val]) + # Propagate PYISOLATE_IMPORT_TORCH from env_overrides if set + if env_overrides and env_overrides.get("PYISOLATE_IMPORT_TORCH") == "0": + cmd.extend(["--setenv", "PYISOLATE_IMPORT_TORCH", "0"]) + + if is_sealed_worker: + # Hermetic sealed workers get an explicit env allowlist. + for env_var in ["PATH", "LANG", "LC_ALL", "CUDA_HOME", "CUDA_PATH", "CUDA_VISIBLE_DEVICES"]: + if env_var in os.environ: + cmd.extend(["--setenv", env_var, os.environ[env_var]]) + if "NVIDIA_VISIBLE_DEVICES" in os.environ: + cmd.extend(["--setenv", "NVIDIA_VISIBLE_DEVICES", os.environ["NVIDIA_VISIBLE_DEVICES"]]) + if "LD_LIBRARY_PATH" in os.environ: + cmd.extend(["--setenv", "LD_LIBRARY_PATH", os.environ["LD_LIBRARY_PATH"]]) + if "PYTORCH_CUDA_ALLOC_CONF" in os.environ: + cmd.extend(["--setenv", "PYTORCH_CUDA_ALLOC_CONF", os.environ["PYTORCH_CUDA_ALLOC_CONF"]]) + if "TORCH_CUDA_ARCH_LIST" in os.environ: + cmd.extend(["--setenv", "TORCH_CUDA_ARCH_LIST", os.environ["TORCH_CUDA_ARCH_LIST"]]) + cmd.extend(["--setenv", "HOME", "/tmp"]) + cmd.extend(["--setenv", "TMPDIR", "/tmp"]) + cmd.extend(["--setenv", "PYTHONNOUSERSITE", "1"]) + else: + # Set PYTHONPATH to include pyisolate package + # This ensures the child can find 'pyisolate' even if not installed in its venv + pyisolate_parent = str(pyisolate_path) + # Start with our explicitly bound package + new_pythonpath_parts = [pyisolate_parent] + + # Check existing PYTHONPATH + existing_pythonpath = os.environ.get("PYTHONPATH", "") + if existing_pythonpath: + new_pythonpath_parts.append(existing_pythonpath) + + cmd.extend(["--setenv", "PYTHONPATH", ":".join(new_pythonpath_parts)]) + + # Inherit select environment variables + # Standard environment + for env_var in ["PATH", "HOME", "LANG", "LC_ALL"]: + if env_var in os.environ: + cmd.extend(["--setenv", env_var, os.environ[env_var]]) + + # CUDA/GPU environment variables (critical for GPU access) + cuda_env_vars = [ + "CUDA_HOME", + "CUDA_PATH", + "CUDA_VISIBLE_DEVICES", + "NVIDIA_VISIBLE_DEVICES", + "LD_LIBRARY_PATH", + "PYTORCH_CUDA_ALLOC_CONF", + "TORCH_CUDA_ARCH_LIST", + "PYISOLATE_ENABLE_CUDA_IPC", + "PYISOLATE_ENABLE_CUDA_IPC", + ] + for env_var in cuda_env_vars: + if env_var in os.environ: + cmd.extend(["--setenv", env_var, os.environ[env_var]]) + + # Coverage / Profiling forwarding + for key, val in os.environ.items(): + if key.startswith(("COV_", "COVERAGE_")): + cmd.extend(["--setenv", key, val]) # Env overrides from config if env_overrides: diff --git a/pyisolate/_internal/sandbox_detect.py b/pyisolate/_internal/sandbox_detect.py index e7a6595..d865bed 100644 --- a/pyisolate/_internal/sandbox_detect.py +++ b/pyisolate/_internal/sandbox_detect.py @@ -14,6 +14,7 @@ from __future__ import annotations import logging +import os import shutil import subprocess import sys @@ -140,28 +141,28 @@ def _test_bwrap(bwrap_path: str) -> tuple[bool, str]: """ try: # S603: bwrap_path comes from shutil.which(), not user input + cmd = [ + bwrap_path, + "--unshare-user-try", + "--dev", + "/dev", + "--proc", + "/proc", + "--ro-bind", + "/usr", + "/usr", + "--ro-bind", + "/bin", + "/bin", + "--ro-bind", + "/lib", + "/lib", + ] + if os.path.exists("/lib64"): + cmd.extend(["--ro-bind", "/lib64", "/lib64"]) + cmd.append("/usr/bin/true") result = subprocess.run( # noqa: S603 - [ - bwrap_path, - "--unshare-user-try", - "--dev", - "/dev", - "--proc", - "/proc", - "--ro-bind", - "/usr", - "/usr", - "--ro-bind", - "/bin", - "/bin", - "--ro-bind", - "/lib", - "/lib", - "--ro-bind", - "/lib64", - "/lib64", - "/usr/bin/true", - ], + cmd, capture_output=True, timeout=10, ) @@ -182,27 +183,27 @@ def _test_bwrap_degraded(bwrap_path: str) -> tuple[bool, str]: """ try: # S603: bwrap_path comes from shutil.which(), not user input + cmd = [ + bwrap_path, + "--dev", + "/dev", + "--proc", + "/proc", + "--ro-bind", + "/usr", + "/usr", + "--ro-bind", + "/bin", + "/bin", + "--ro-bind", + "/lib", + "/lib", + ] + if os.path.exists("/lib64"): + cmd.extend(["--ro-bind", "/lib64", "/lib64"]) + cmd.append("/usr/bin/true") result = subprocess.run( # noqa: S603 - [ - bwrap_path, - "--dev", - "/dev", - "--proc", - "/proc", - "--ro-bind", - "/usr", - "/usr", - "--ro-bind", - "/bin", - "/bin", - "--ro-bind", - "/lib", - "/lib", - "--ro-bind", - "/lib64", - "/lib64", - "/usr/bin/true", - ], + cmd, capture_output=True, timeout=10, ) diff --git a/pyisolate/_internal/tensor_serializer.py b/pyisolate/_internal/tensor_serializer.py index fdf89ee..3fc8b95 100644 --- a/pyisolate/_internal/tensor_serializer.py +++ b/pyisolate/_internal/tensor_serializer.py @@ -10,6 +10,7 @@ from pathlib import Path from typing import Any +from .perf_trace import record_event, tracing_enabled from .torch_gate import require_torch logger = logging.getLogger(__name__) @@ -221,7 +222,13 @@ def _handler(signum: int, _frame: Any) -> None: finally: os._exit(128 + signum) - for sig in (signal.SIGHUP, signal.SIGTERM): + cleanup_signals = [ + getattr(signal, "SIGHUP", None), + getattr(signal, "SIGTERM", None), + ] + for sig in cleanup_signals: + if sig is None: + continue try: signal.signal(sig, _handler) except Exception: @@ -231,12 +238,59 @@ def _handler(signum: int, _frame: Any) -> None: _install_signal_cleanup_handlers() -def serialize_tensor(t: Any) -> dict[str, Any]: - """Serialize a tensor to JSON-compatible format using shared memory.""" +def serialize_tensor(t: Any, mode: str = "shared_memory") -> dict[str, Any]: + """Serialize a tensor for the configured transport mode.""" + started_at = time.perf_counter() + if mode == "json": + payload = _serialize_tensor_json(t) + _record_tensor_trace(t, mode, started_at) + return payload + torch, _ = require_torch("serialize_tensor") if t.is_cuda: - return _serialize_cuda_tensor(t) - return _serialize_cpu_tensor(t) + payload = _serialize_cuda_tensor(t) + _record_tensor_trace(t, mode, started_at) + return payload + payload = _serialize_cpu_tensor(t) + _record_tensor_trace(t, mode, started_at) + return payload + + +def _record_tensor_trace(t: Any, mode: str, started_at: float) -> None: + if not tracing_enabled(): + return + try: + payload_bytes = int(t.numel()) * int(t.element_size()) + except Exception: + payload_bytes = 0 + try: + device = str(t.device) + except Exception: + device = "unknown" + record_event( + { + "event_kind": "tensor_transport", + "type_name": "Tensor", + "serialize_ms": (time.perf_counter() - started_at) * 1000.0, + "payload_bytes": payload_bytes, + "tensor_transport_mode": mode, + "device": device, + "process_role": "child" if os.environ.get("PYISOLATE_CHILD") == "1" else "host", + } + ) + + +def _serialize_tensor_json(t: Any) -> dict[str, Any]: + """Serialize a tensor into plain JSON data with no shared-memory side effects.""" + require_torch("JSON tensor serialization") + cpu_tensor = t.detach().cpu() + return { + "__type__": "TensorValue", + "dtype": str(cpu_tensor.dtype), + "tensor_size": list(cpu_tensor.size()), + "requires_grad": bool(getattr(t, "requires_grad", False)), + "data": cpu_tensor.tolist(), + } def _serialize_cpu_tensor(t: Any) -> dict[str, Any]: @@ -367,16 +421,43 @@ def _serialize_cuda_tensor(t: Any) -> dict[str, Any]: } -def deserialize_tensor(data: dict[str, Any]) -> Any: - """Deserialize a tensor from TensorRef format.""" - torch, _ = require_torch("deserialize_tensor") +def deserialize_tensor(data: dict[str, Any], mode: str = "shared_memory") -> Any: + """Deserialize a tensor payload for the configured transport mode.""" + try: + torch, _ = require_torch("deserialize_tensor") + except Exception: + torch = None + # If this is already a tensor (e.g., passed through by shared memory), return as-is - if isinstance(data, torch.Tensor): + if torch is not None and isinstance(data, torch.Tensor): return data - # All formats now use TensorRef + + if data.get("__type__") == "TensorValue" or mode == "json": + return _deserialize_json_tensor(data) + return _deserialize_legacy_tensor(data) +def _deserialize_json_tensor(data: dict[str, Any]) -> Any: + """Deserialize a JSON tensor payload. + + If torch is unavailable in the receiving process, leave the JSON payload intact + so sealed workers can still inspect or echo it without importing torch. + """ + try: + torch, _ = require_torch("JSON tensor deserialization") + except Exception: + return data + + dtype_str = data["dtype"] + dtype = getattr(torch, dtype_str.split(".")[-1]) + tensor = torch.tensor(data["data"], dtype=dtype) + tensor = tensor.reshape(tuple(data["tensor_size"])) + if data.get("requires_grad"): + tensor.requires_grad_(True) + return tensor + + def _convert_lists_to_tuples(obj: Any) -> Any: """Recursively convert lists to tuples (PyTorch requires tuples for size/stride).""" if isinstance(obj, list): @@ -465,15 +546,23 @@ def _deserialize_legacy_tensor(data: dict[str, Any]) -> Any: raise RuntimeError(f"Unsupported device: {device}") -def register_tensor_serializer(registry: Any) -> None: +def register_tensor_serializer(registry: Any, mode: str = "shared_memory") -> None: + def serializer(obj: Any) -> dict[str, Any]: + return serialize_tensor(obj, mode=mode) + + def deserializer(data: dict[str, Any]) -> Any: + return deserialize_tensor(data, mode=mode) + require_torch("register_tensor_serializer") + # Register both "Tensor" (type name) and "torch.Tensor" (full name) just in case - registry.register("Tensor", serialize_tensor, deserialize_tensor) - registry.register("torch.Tensor", serialize_tensor, deserialize_tensor) + registry.register("Tensor", serializer, deserializer) + registry.register("torch.Tensor", serializer, deserializer) + registry.register("TensorValue", None, deserializer) # Also register TensorRef for deserialization - registry.register("TensorRef", None, deserialize_tensor) + registry.register("TensorRef", None, deserializer) # Register TorchReduction for recursive deserialization - registry.register("TorchReduction", None, deserialize_tensor) + registry.register("TorchReduction", None, deserializer) # Register PyTorch atom types for recursive serialization def serialize_dtype(obj: Any) -> str: diff --git a/pyisolate/_internal/uds_client.py b/pyisolate/_internal/uds_client.py index e3446b3..4da9d20 100644 --- a/pyisolate/_internal/uds_client.py +++ b/pyisolate/_internal/uds_client.py @@ -33,6 +33,32 @@ logger = logging.getLogger(__name__) +def _resolve_api_classes_from_config(config: ExtensionConfig) -> list[Any]: + if config.get("execution_model") == "sealed_worker": + return [] + + apis = config.get("apis", []) + resolved_apis = [] + + for api_item in apis: + if isinstance(api_item, str): + try: + import importlib + + parts = api_item.rsplit(".", 1) + if len(parts) == 2: + mod = importlib.import_module(parts[0]) + resolved_apis.append(getattr(mod, parts[1])) + else: + logger.warning("Invalid API reference format: %s", api_item) + except Exception as e: + logger.warning("Failed to resolve API %s: %s", api_item, e) + else: + resolved_apis.append(api_item) + + return resolved_apis + + def main() -> None: """Main entry point for isolated child processes.""" @@ -78,6 +104,9 @@ def handle_signal(signum: int, frame: Any) -> None: # Receive bootstrap data from host via JSON bootstrap_data = transport.recv() logger.debug("Received bootstrap data") + tensor_transport = bootstrap_data.get("tensor_transport", "shared_memory") + if hasattr(transport, "set_tensor_transport_mode"): + transport.set_tensor_transport_mode(tensor_transport) # Apply host snapshot to environment snapshot = bootstrap_data.get("snapshot", {}) @@ -120,6 +149,7 @@ def handle_signal(signum: int, frame: Any) -> None: module_path=module_path, extension_type=extension_type, config=config, + tensor_transport=tensor_transport, ) ) @@ -129,6 +159,7 @@ async def _async_uds_entrypoint( module_path: str, extension_type: type[Any], config: ExtensionConfig, + tensor_transport: str, ) -> None: """Async entrypoint for isolated processes using JSON-RPC transport.""" from ..interfaces import IsolationAdapter @@ -147,9 +178,10 @@ async def _async_uds_entrypoint( # Register tensor serializer only when torch is available. from .serialization_registry import SerializerRegistry - register_tensor_serializer(SerializerRegistry.get_instance()) + register_tensor_serializer(SerializerRegistry.get_instance(), mode=tensor_transport) # Ensure file_system strategy for CPU tensors. - torch.multiprocessing.set_sharing_strategy("file_system") + if tensor_transport == "shared_memory": + torch.multiprocessing.set_sharing_strategy("file_system") elif config.get("share_torch", False): raise RuntimeError( "share_torch=True requires PyTorch. Install 'torch' to use tensor-sharing features." @@ -189,34 +221,17 @@ async def _async_uds_entrypoint( with context: rpc.register_callee(extension, "extension") - # Register APIs from config - apis = config.get("apis", []) - resolved_apis = [] - - # Resolve string references back to classes if needed - for api_item in apis: - if isinstance(api_item, str): - try: - import importlib - - parts = api_item.rsplit(".", 1) - if len(parts) == 2: - mod = importlib.import_module(parts[0]) - resolved_apis.append(getattr(mod, parts[1])) - else: - logger.warning("Invalid API reference format: %s", api_item) - except Exception as e: - logger.warning("Failed to resolve API %s: %s", api_item, e) - else: - resolved_apis.append(api_item) - - for api in resolved_apis: + for api in _resolve_api_classes_from_config(config): api.use_remote(rpc) if adapter: api_instance = cast(ProxiedSingleton, getattr(api, "instance", api)) logger.debug("Calling handle_api_registration for %s", api_instance.__class__.__name__) adapter.handle_api_registration(api_instance, rpc) + # Let the adapter wire child-side event hooks (e.g., progress bar) + if adapter and hasattr(adapter, "setup_child_event_hooks"): + adapter.setup_child_event_hooks(extension) + # Import and load the extension module import importlib.util diff --git a/pyisolate/config.py b/pyisolate/config.py index 0005dd2..b2505a5 100644 --- a/pyisolate/config.py +++ b/pyisolate/config.py @@ -45,8 +45,12 @@ class SandboxConfig(TypedDict, total=False): class CUDAWheelConfig(TypedDict): """Configuration for custom CUDA wheel resolution.""" - index_url: str - """Base URL containing per-package simple index directories.""" + index_url: NotRequired[str] + """Base URL containing per-package simple index directories (single index).""" + + index_urls: NotRequired[list[str]] + """Multiple index URLs for CUDA wheel resolution (used by conda backend + to emit pixi [pypi-options] extra-index-urls).""" packages: list[str] """Canonicalized dependency names that must resolve via the custom index.""" @@ -89,5 +93,23 @@ class ExtensionConfig(TypedDict): env: dict[str, str] """Environment variable overrides for the child process.""" + package_manager: NotRequired[str] + """Backend package manager: 'uv' (default) or 'conda'.""" + + execution_model: NotRequired[str] + """Runtime boundary: 'host-coupled' (default for uv) or 'sealed_worker'.""" + + sealed_host_ro_paths: NotRequired[list[str]] + """Optional sealed-worker-only absolute host paths to mount read-only for imports.""" + + conda_channels: NotRequired[list[str]] + """Conda channels to use (required when package_manager='conda').""" + + conda_dependencies: NotRequired[list[str]] + """Conda-forge dependency specifications.""" + + conda_platforms: NotRequired[list[str]] + """Target platforms for conda environment (defaults to current platform).""" + cuda_wheels: NotRequired[CUDAWheelConfig] """Optional custom CUDA wheel resolution configuration for selected dependencies.""" diff --git a/pyisolate/host.py b/pyisolate/host.py index 8736667..6263e20 100644 --- a/pyisolate/host.py +++ b/pyisolate/host.py @@ -56,6 +56,7 @@ def __init__(self, extension_instance: Extension[T]) -> None: super().__init__() self._extension = extension_instance self._proxy: Any = None + self._pending_event_handlers: list[tuple[str, Any]] = [] @property def proxy(self) -> Any: @@ -66,6 +67,9 @@ def proxy(self) -> Any: if self._proxy is None: if hasattr(self._extension, "ensure_process_started"): self._extension.ensure_process_started() + # Re-register event handlers after process restart + for name, handler in self._pending_event_handlers: + self._extension.register_event_handler(name, handler) self._proxy = self._extension.get_proxy() self._initialize_rpc(self._extension.rpc) return self._proxy @@ -75,6 +79,12 @@ def __getattr__(self, item: str) -> Any: return getattr(self._extension, item) return getattr(self.proxy, item) + def register_event_handler(self, name: str, handler: Any) -> None: + """Register a handler for named events from the child process.""" + self._pending_event_handlers.append((name, handler)) + if self._extension._process_initialized: + self._extension.register_event_handler(name, handler) + return cast(T, HostExtension(extension)) def stop_all_extensions(self) -> None: diff --git a/pyisolate/interfaces.py b/pyisolate/interfaces.py index e256c19..f617042 100644 --- a/pyisolate/interfaces.py +++ b/pyisolate/interfaces.py @@ -62,12 +62,23 @@ def setup_child_environment(self, snapshot: dict[str, Any]) -> None: def register_serializers(self, registry: SerializerRegistryProtocol) -> None: """Register custom type serializers for RPC transport.""" + def setup_web_directory(self, module: Any) -> None: + """Detect and populate web directory for a loaded module.""" + def provide_rpc_services(self) -> list[type[ProxiedSingleton]]: """Return ProxiedSingleton classes to expose via RPC.""" def handle_api_registration(self, api: ProxiedSingleton, rpc: AsyncRPC) -> None: """Optional post-registration hook for API-specific setup.""" + def setup_child_event_hooks(self, extension: Any) -> None: + """Wire child-side event hooks (e.g., progress bar) using the extension's emit_event. + + Called once in the child process after the extension is initialized and RPC + is available. The extension's ``emit_event(name, payload)`` method should be + used to forward UI events to host-registered handlers. + """ + def get_sandbox_system_paths(self) -> list[str] | None: """Return additional system paths for sandbox. diff --git a/pyisolate/path_helpers.py b/pyisolate/path_helpers.py index 87fb80e..8cfab64 100644 --- a/pyisolate/path_helpers.py +++ b/pyisolate/path_helpers.py @@ -57,12 +57,18 @@ def build_child_sys_path( host_paths: Sequence[str], extra_paths: Sequence[str], preferred_root: str | None = None, + filtered_subdirs: Sequence[str] | None = None, ) -> list[str]: """Construct ``sys.path`` for an isolated child interpreter. Host paths retain order, an optional preferred root is prepended, and child - venv site-packages are appended while avoiding duplicates and code subdirs - that would shadow imports (e.g., package subfolders like ``utils``). + venv site-packages are appended while avoiding duplicates. When + ``filtered_subdirs`` is provided, those named subdirectories of + ``preferred_root`` are excluded from the reconstructed path. When + ``filtered_subdirs`` is ``None``, no subdirectory filtering is applied. + The caller (typically an :class:`~pyisolate.interfaces.IsolationAdapter`) + is responsible for supplying the appropriate list via + ``get_path_config()["filtered_subdirs"]``. """ def _norm(path: str) -> str: @@ -74,12 +80,9 @@ def _norm(path: str) -> str: ordered_host = list(host_paths) if preferred_root: root_norm = _norm(preferred_root) - code_subdirs = { - os.path.join(root_norm, "comfy"), - os.path.join(root_norm, "app"), - os.path.join(root_norm, "comfy_execution"), - os.path.join(root_norm, "utils"), - } + code_subdirs: set[str] = set() + if filtered_subdirs is not None: + code_subdirs = {os.path.join(root_norm, name) for name in filtered_subdirs} filtered_host = [] for p in ordered_host: p_norm = _norm(p) diff --git a/pyisolate/sealed.py b/pyisolate/sealed.py new file mode 100644 index 0000000..b0a72f2 --- /dev/null +++ b/pyisolate/sealed.py @@ -0,0 +1,195 @@ +"""Child-safe extension wrappers for sealed worker runtimes. + +These wrappers avoid importing host application runtime modules at import time. +They are intended for foreign-interpreter workers such as the conda backend. +""" + +from __future__ import annotations + +import asyncio +import contextlib +import inspect +import logging +import uuid +from types import ModuleType +from typing import Any, cast + +from ._internal.remote_handle import RemoteObjectHandle +from .shared import ExtensionBase + +logger = logging.getLogger(__name__) + + +def _sanitize_for_transport(value: Any) -> Any: + primitives = (str, int, float, bool, type(None)) + if isinstance(value, primitives): + return value + if isinstance(value, dict): + return {k: _sanitize_for_transport(v) for k, v in value.items()} + if isinstance(value, tuple): + return tuple(_sanitize_for_transport(v) for v in value) + if isinstance(value, list): + return [_sanitize_for_transport(v) for v in value] + return str(value) + + +class SealedNodeExtension(ExtensionBase): + """Minimal node wrapper for sealed workers. + + The wrapper supports V1-style ``NODE_CLASS_MAPPINGS`` nodes without importing + ComfyUI runtime modules into the child interpreter. + """ + + def __init__(self) -> None: + super().__init__() + self.node_classes: dict[str, type[Any]] = {} + self.display_names: dict[str, str] = {} + self.node_instances: dict[str, Any] = {} + self.remote_objects: dict[str, Any] = {} + self._module: ModuleType | None = None + + async def on_module_loaded(self, module: ModuleType) -> None: + self._module = module + self.node_classes = getattr(module, "NODE_CLASS_MAPPINGS", {}) or {} + self.display_names = getattr(module, "NODE_DISPLAY_NAME_MAPPINGS", {}) or {} + self.node_instances = {} + + # Web directory handling β€” delegate to adapter + if getattr(module, "WEB_DIRECTORY", None) is not None: + from ._internal.adapter_registry import AdapterRegistry + + adapter = AdapterRegistry.get() + if adapter and hasattr(adapter, "setup_web_directory"): + adapter.setup_web_directory(module) + + async def list_nodes(self) -> dict[str, str]: + return {name: self.display_names.get(name, name) for name in self.node_classes} + + async def get_node_info(self, node_name: str) -> dict[str, Any]: + return await self.get_node_details(node_name) + + async def get_node_details(self, node_name: str) -> dict[str, Any]: + node_cls = self._get_node_class(node_name) + input_types_raw = node_cls.INPUT_TYPES() if hasattr(node_cls, "INPUT_TYPES") else {} + output_is_list = getattr(node_cls, "OUTPUT_IS_LIST", None) + if output_is_list is not None: + output_is_list = tuple(bool(x) for x in output_is_list) + + return { + "input_types": _sanitize_for_transport(input_types_raw), + "return_types": tuple(str(t) for t in getattr(node_cls, "RETURN_TYPES", ())), + "return_names": getattr(node_cls, "RETURN_NAMES", None), + "function": str(getattr(node_cls, "FUNCTION", "execute")), + "category": str(getattr(node_cls, "CATEGORY", "")), + "output_node": bool(getattr(node_cls, "OUTPUT_NODE", False)), + "output_is_list": output_is_list, + "is_v3": False, + } + + async def get_input_types(self, node_name: str) -> dict[str, Any]: + node_cls = self._get_node_class(node_name) + if hasattr(node_cls, "INPUT_TYPES"): + return cast(dict[str, Any], node_cls.INPUT_TYPES()) + return {} + + def _wrap_for_transport(self, data: Any) -> Any: + """Wrap non-primitive objects as RemoteObjectHandle for proxy round-trip. + + Objects registered as ``data_type=True`` in the SerializerRegistry are + passed through for inline RPC serialization (e.g., ndarray β†’ TensorValue). + All other non-primitive, non-container objects are wrapped as handles. + """ + if isinstance(data, (str, int, float, bool, type(None))): + return data + + if isinstance(data, dict): + return {k: self._wrap_for_transport(v) for k, v in data.items()} + if isinstance(data, (list, tuple)): + wrapped = [self._wrap_for_transport(item) for item in data] + return type(data)(wrapped) + + # Let data_type serializers handle inline transport (e.g., ndarray, PLY, TRIMESH) + from ._internal.serialization_registry import SerializerRegistry + + registry = SerializerRegistry.get_instance() + type_name = type(data).__name__ + if registry.has_handler(type_name) and registry.is_data_type(type_name): + return data + + object_id = str(uuid.uuid4()) + self.remote_objects[object_id] = data + logger.info( + "][ PROXY_HANDLE_WRAP type=%s id=%s remote_objects_count=%d", + type_name, + object_id[:8], + len(self.remote_objects), + ) + return RemoteObjectHandle(object_id, type_name=type_name) + + def _resolve_handles(self, data: Any) -> Any: + """Resolve incoming RemoteObjectHandle values from ``remote_objects``.""" + if isinstance(data, RemoteObjectHandle): + if data.object_id not in self.remote_objects: + raise KeyError(f"Remote object {data.object_id} not found") + return self.remote_objects[data.object_id] + + if isinstance(data, dict): + return {k: self._resolve_handles(v) for k, v in data.items()} + if isinstance(data, (list, tuple)): + resolved = [self._resolve_handles(item) for item in data] + return type(data)(resolved) + return data + + async def execute_node(self, node_name: str, **inputs: Any) -> tuple[Any, ...]: + instance = self._get_node_instance(node_name) + node_cls = self._get_node_class(node_name) + function_name = getattr(node_cls, "FUNCTION", "execute") + if not hasattr(instance, function_name): + raise AttributeError(f"Node {node_name} missing callable '{function_name}'") + + # Resolve any proxy handles arriving as inputs from prior nodes. + inputs = {k: self._resolve_handles(v) for k, v in inputs.items()} + + if getattr(node_cls, "INPUT_IS_LIST", False): + inputs = {k: [v] if not isinstance(v, list) else v for k, v in inputs.items()} + + handler = getattr(instance, function_name) + if inspect.iscoroutinefunction(handler): + result = await handler(**inputs) + else: + loop = asyncio.get_running_loop() + result = await loop.run_in_executor(None, lambda: handler(**inputs)) + + if not isinstance(result, tuple): + result = (result,) + + # Wrap unregistered objects as proxy handles for transport. + return tuple(self._wrap_for_transport(item) for item in result) + + async def flush_transport_state(self) -> int: + flushed = 0 + _flush_fn: Any = None + with contextlib.suppress(Exception): + from . import flush_tensor_keeper as _flush_fn + if callable(_flush_fn): + flushed = int(_flush_fn()) + # Clear pack-local proxy handles to prevent memory accumulation + # across workflow runs. + if hasattr(self, "remote_objects"): + self.remote_objects.clear() + return flushed + + async def get_remote_object(self, object_id: str) -> Any: + if object_id not in self.remote_objects: + raise KeyError(f"Remote object {object_id} not found") + return self.remote_objects[object_id] + + def _get_node_class(self, node_name: str) -> type[Any]: + if node_name not in self.node_classes: + raise KeyError(f"Node {node_name} not found") + return self.node_classes[node_name] + + def _get_node_instance(self, node_name: str) -> Any: + if node_name not in self.node_instances: + self.node_instances[node_name] = self._get_node_class(node_name)() + return self.node_instances[node_name] diff --git a/pyisolate/shared.py b/pyisolate/shared.py index b384d99..f1b23fb 100644 --- a/pyisolate/shared.py +++ b/pyisolate/shared.py @@ -1,5 +1,6 @@ """Public host/extension shared interfaces for PyIsolate.""" +import json from types import ModuleType from typing import TypeVar, final @@ -37,6 +38,36 @@ def use_remote(self, proxied_singleton: type[ProxiedSingleton]) -> None: """Configure a ProxiedSingleton class to resolve to remote instances.""" proxied_singleton.use_remote(self._rpc) + @final + def emit_event(self, name: str, payload: dict) -> None: + """Emit a named event to the host process via RPC. + + Payload must be JSON-serializable (dicts, lists, strings, numbers, + booleans, None). Non-serializable payloads raise immediately. + """ + json.dumps(payload) # fail-loud if not JSON-serializable + from ._internal.event_bridge import _EventBridge + + caller = self._rpc.create_caller(_EventBridge, "_event_bridge") + import asyncio + + coro = caller.dispatch(name, payload) + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + if loop is not None and loop.is_running(): + future = asyncio.run_coroutine_threadsafe(coro, loop) + future.result(timeout=10.0) + else: + # No running loop in this thread β€” use the RPC's event loop + rpc_loop = self._rpc.default_loop + if rpc_loop is not None and rpc_loop.is_running(): + future = asyncio.run_coroutine_threadsafe(coro, rpc_loop) + future.result(timeout=10.0) + else: + asyncio.run(coro) + class ExtensionBase(ExtensionLocal): """Base class for all PyIsolate extensions, providing lifecycle hooks and RPC wiring.""" diff --git a/pyproject.toml b/pyproject.toml index b22170d..18e03af 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "pyisolate" -version = "0.9.2" +version = "0.10.0" description = "A Python library for dividing execution across multiple virtual environments" readme = "README.md" requires-python = ">=3.10" @@ -126,12 +126,15 @@ disallow_untyped_defs = true [tool.pytest.ini_options] minversion = "6.0" -addopts = "-ra -q --cov=pyisolate --cov-report=html --cov-report=term-missing" +addopts = "-ra -q --cov=pyisolate --cov-report=html --cov-report=term-missing -m 'not network'" testpaths = ["tests"] asyncio_mode = "auto" filterwarnings = [ "ignore:The pynvml package is deprecated:FutureWarning", ] +markers = [ + "network: tests that require network access to external wheel indices", +] [tool.coverage.run] source = ["pyisolate"] diff --git a/tests/conftest.py b/tests/conftest.py index 9243beb..2c7b8df 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,24 +5,13 @@ """ import logging -import os import sys -from pathlib import Path from types import SimpleNamespace import pytest from pyisolate._internal.singleton_context import singleton_scope -# Add ComfyUI to sys.path BEFORE any tests run -# This is required because pyisolate is now ComfyUI-integrated -COMFYUI_ROOT = os.environ.get("COMFYUI_ROOT") or str(Path.home() / "ComfyUI") -if COMFYUI_ROOT not in sys.path: - sys.path.insert(0, COMFYUI_ROOT) - -# Set environment variable so child processes know ComfyUI location -os.environ.setdefault("COMFYUI_ROOT", COMFYUI_ROOT) - @pytest.fixture(autouse=True) def clean_singletons(): diff --git a/tests/fixtures/conda_sealed_node/__init__.py b/tests/fixtures/conda_sealed_node/__init__.py new file mode 100644 index 0000000..5257ee0 --- /dev/null +++ b/tests/fixtures/conda_sealed_node/__init__.py @@ -0,0 +1,81 @@ +"""Simple V1-style node fixture for sealed conda integration tests.""" + +from __future__ import annotations + +import os +import sys +from pathlib import Path +from typing import Any + + +class InspectRuntimeNode: + RETURN_TYPES = ("STRING", "STRING", "STRING") + RETURN_NAMES = ( + "path_dump", + "host_leak_report", + "python_exe", + ) + FUNCTION = "inspect" + CATEGORY = "tests" + + @classmethod + def INPUT_TYPES(cls) -> dict[str, Any]: # noqa: N802 + return {"required": {}} + + def inspect(self) -> tuple[str, str, str]: + path_dump = "\n".join(sys.path) + host_leak_report = f"sys_path_count={len(sys.path)}" + return (path_dump, host_leak_report, sys.executable) + + +class EchoTensorNode: + RETURN_TYPES = ("TENSOR", "BOOLEAN") + RETURN_NAMES = ("tensor", "saw_json_tensor") + FUNCTION = "echo" + CATEGORY = "tests" + + @classmethod + def INPUT_TYPES(cls) -> dict[str, Any]: # noqa: N802 - Comfy node API requires this name + return {"required": {"tensor": ("TENSOR",)}} + + def echo(self, tensor: Any) -> tuple[Any, bool]: + saw_json_tensor = isinstance(tensor, dict) and tensor.get("__type__") == "TensorValue" + return (tensor, saw_json_tensor) + + +class OpenWeatherDatasetNode: + RETURN_TYPES = ("FLOAT", "STRING") + RETURN_NAMES = ("sum_value", "grib_path") + FUNCTION = "open_dataset" + CATEGORY = "tests" + + @classmethod + def INPUT_TYPES(cls) -> dict[str, Any]: # noqa: N802 - Comfy node API requires this name + return {"required": {}} + + def open_dataset(self) -> tuple[float, str]: + from boltons import strutils + from packaging.version import Version + + artifact_dir = Path(os.environ["PYISOLATE_ARTIFACT_DIR"]) + artifact_dir.mkdir(parents=True, exist_ok=True) + artifact_path = artifact_dir / "weather_fixture.txt" + label = strutils.slugify("Open Weather Dataset", delim="_") + total = 10.0 + artifact_path.write_text( + f"label={label}\npackaging={Version('1.0')}\nsum={total}\n", + encoding="utf-8", + ) + return (total, str(artifact_path)) + + +NODE_CLASS_MAPPINGS = { + "InspectRuntime": InspectRuntimeNode, + "EchoTensor": EchoTensorNode, + "OpenWeatherDataset": OpenWeatherDatasetNode, +} +NODE_DISPLAY_NAME_MAPPINGS = { + "InspectRuntime": "Inspect Runtime", + "EchoTensor": "Echo Tensor", + "OpenWeatherDataset": "Open Weather Dataset", +} diff --git a/tests/fixtures/conda_sealed_node/pyproject.toml b/tests/fixtures/conda_sealed_node/pyproject.toml new file mode 100644 index 0000000..a077841 --- /dev/null +++ b/tests/fixtures/conda_sealed_node/pyproject.toml @@ -0,0 +1,11 @@ +[project] +name = "conda-sealed-node" +version = "0.1.0" +dependencies = ["packaging"] + +[tool.pyisolate] +can_isolate = true +package_manager = "conda" +share_torch = false +conda_channels = ["conda-forge"] +conda_dependencies = ["boltons"] diff --git a/tests/fixtures/uv_sealed_worker/__init__.py b/tests/fixtures/uv_sealed_worker/__init__.py new file mode 100644 index 0000000..e90afd5 --- /dev/null +++ b/tests/fixtures/uv_sealed_worker/__init__.py @@ -0,0 +1,161 @@ +"""Repo-owned uv sealed worker fixture for pyisolate integration tests.""" + +from __future__ import annotations + +import os +import site +import sys +from pathlib import Path +from typing import Any + + +def _artifact_dir() -> Path: + artifact_dir = Path(os.environ["PYISOLATE_ARTIFACT_DIR"]) + artifact_dir.mkdir(parents=True, exist_ok=True) + return artifact_dir + + +class UVSealedRuntimeProbeNode: + RETURN_TYPES = ("STRING", "STRING", "STRING", "BOOLEAN") + RETURN_NAMES = ( + "path_dump", + "boltons_origin", + "report", + "saw_user_site", + ) + FUNCTION = "probe" + CATEGORY = "tests" + + @classmethod + def INPUT_TYPES(cls) -> dict[str, Any]: # noqa: N802 + return {"required": {}} + + def probe(self) -> tuple[str, str, str, bool]: + from boltons import strutils + + artifact_dir = _artifact_dir() + path_dump = "\n".join(sys.path) + user_site = site.getusersitepackages() + saw_user_site = user_site in sys.path + report = f"python={sys.executable}\nuser_site={user_site}\npaths={len(sys.path)}" + + (artifact_dir / "child_bootstrap_paths.txt").write_text(path_dump, encoding="utf-8") + (artifact_dir / "child_dependency_dump.txt").write_text(strutils.__file__, encoding="utf-8") + return ( + path_dump, + strutils.__file__, + report, + saw_user_site, + ) + + +class UVSealedBoltonsSlugifyNode: + RETURN_TYPES = ("STRING", "STRING") + RETURN_NAMES = ("slug", "slug_origin") + FUNCTION = "slug" + CATEGORY = "tests" + + @classmethod + def INPUT_TYPES(cls) -> dict[str, Any]: # noqa: N802 + return {"required": {"text": ("STRING",)}} + + def slug(self, text: str) -> tuple[str, str]: + from boltons import strutils + + return (strutils.slugify(text, delim="_"), strutils.__file__) + + +class UVSealedFilesystemBarrierNode: + RETURN_TYPES = ("STRING", "BOOLEAN", "BOOLEAN", "BOOLEAN") + RETURN_NAMES = ("report", "outside_blocked", "module_mutation_blocked", "artifact_write_ok") + FUNCTION = "probe" + CATEGORY = "tests" + + @classmethod + def INPUT_TYPES(cls) -> dict[str, Any]: # noqa: N802 + return {"required": {}} + + def probe(self) -> tuple[str, bool, bool, bool]: + artifact_dir = _artifact_dir() + fixture_dir = Path(__file__).resolve().parent + outside_probe = ( + Path("/usr/.__outside_probe.txt") + if os.name != "nt" + else fixture_dir.parent / ".__outside_probe.txt" + ) + module_probe = fixture_dir / ".__module_probe.txt" + artifact_probe = artifact_dir / "filesystem_barrier_probe.txt" + + outside_blocked = False + module_mutation_blocked = False + artifact_write_ok = False + + try: + outside_probe.write_text("probe", encoding="utf-8") + except Exception: + outside_blocked = True + else: + outside_probe.unlink(missing_ok=True) + + try: + module_probe.write_text("probe", encoding="utf-8") + except Exception: + module_mutation_blocked = True + else: + module_probe.unlink(missing_ok=True) + + artifact_probe.write_text("ok", encoding="utf-8") + artifact_write_ok = artifact_probe.exists() + + report = ( + f"outside_blocked={outside_blocked}\n" + f"module_mutation_blocked={module_mutation_blocked}\n" + f"artifact_write_ok={artifact_write_ok}" + ) + return (report, outside_blocked, module_mutation_blocked, artifact_write_ok) + + +class UVSealedTensorEchoNode: + RETURN_TYPES = ("TENSOR", "BOOLEAN") + RETURN_NAMES = ("tensor", "saw_json_tensor") + FUNCTION = "echo" + CATEGORY = "tests" + + @classmethod + def INPUT_TYPES(cls) -> dict[str, Any]: # noqa: N802 + return {"required": {"tensor": ("TENSOR",)}} + + def echo(self, tensor: Any) -> tuple[Any, bool]: + saw_json_tensor = isinstance(tensor, dict) and tensor.get("__type__") == "TensorValue" + return (tensor, saw_json_tensor) + + +class UVSealedLatentEchoNode: + RETURN_TYPES = ("LATENT",) + RETURN_NAMES = ("latent",) + FUNCTION = "echo" + CATEGORY = "tests" + + @classmethod + def INPUT_TYPES(cls) -> dict[str, Any]: # noqa: N802 + return {"required": {"latent": ("LATENT",)}} + + def echo(self, latent: Any) -> tuple[Any]: + return (latent,) + + +NODE_CLASS_MAPPINGS = { + "UVSealedRuntimeProbe": UVSealedRuntimeProbeNode, + "UVSealedBoltonsSlugify": UVSealedBoltonsSlugifyNode, + "UVSealedFilesystemBarrier": UVSealedFilesystemBarrierNode, + "UVSealedTensorEcho": UVSealedTensorEchoNode, + "UVSealedLatentEcho": UVSealedLatentEchoNode, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "UVSealedRuntimeProbe": "UV Sealed Runtime Probe", + "UVSealedBoltonsSlugify": "UV Sealed Boltons Slugify", + "UVSealedFilesystemBarrier": "UV Sealed Filesystem Barrier", + "UVSealedTensorEcho": "UV Sealed Tensor Echo", + "UVSealedLatentEcho": "UV Sealed Latent Echo", +} diff --git a/tests/fixtures/uv_sealed_worker/pyproject.toml b/tests/fixtures/uv_sealed_worker/pyproject.toml new file mode 100644 index 0000000..08dca98 --- /dev/null +++ b/tests/fixtures/uv_sealed_worker/pyproject.toml @@ -0,0 +1,10 @@ +[project] +name = "uv-sealed-worker" +version = "0.1.0" +dependencies = ["boltons"] + +[tool.pyisolate] +can_isolate = true +package_manager = "uv" +execution_model = "sealed_worker" +share_torch = false diff --git a/tests/harness/host.py b/tests/harness/host.py index f2fde38..0fbe372 100644 --- a/tests/harness/host.py +++ b/tests/harness/host.py @@ -73,6 +73,8 @@ class ReferenceHost: def __init__(self, use_temp_dir: bool = True): self.temp_dir: tempfile.TemporaryDirectory | None = None self.root_dir: Path = Path(os.getcwd()) + self._had_previous_tmpdir = "TMPDIR" in os.environ + self._previous_tmpdir = os.environ.get("TMPDIR") if use_temp_dir: self.temp_dir = tempfile.TemporaryDirectory(prefix="pyisolate_harness_") self.root_dir = Path(self.temp_dir.name) @@ -163,7 +165,7 @@ def load_test_extension( isolated=isolated, dependencies=deps, apis=[], - env={}, + env={"PYISOLATE_SIGNAL_CLEANUP": "1"}, share_torch=share_torch, share_cuda_ipc=share_cuda, sandbox=sandbox_cfg, @@ -188,19 +190,30 @@ async def cleanup(self): # Stop processes for ext in self.extensions: try: + with contextlib.suppress(Exception): + proxy = ext.get_proxy() + await proxy.stop() ext.stop() except Exception as e: cleanup_errors.append(str(e)) if self._adapter_registered: AdapterRegistry.unregister() + self._adapter_registered = False if self.temp_dir: try: self.temp_dir.cleanup() + self.temp_dir = None except Exception as e: cleanup_errors.append(f"temp_dir: {e}") + if self._had_previous_tmpdir: + assert self._previous_tmpdir is not None + os.environ["TMPDIR"] = self._previous_tmpdir + else: + os.environ.pop("TMPDIR", None) + if cleanup_errors: pass diff --git a/tests/harness/test_package/__init__.py b/tests/harness/test_package/__init__.py index de88bdb..e30bec3 100644 --- a/tests/harness/test_package/__init__.py +++ b/tests/harness/test_package/__init__.py @@ -3,6 +3,7 @@ import sys from typing import Any +from pyisolate import flush_tensor_keeper from pyisolate.shared import ExtensionBase try: @@ -34,6 +35,15 @@ async def initialize(self) -> None: async def prepare_shutdown(self) -> None: logger.info("[TestPkg] Preparing shutdown.") + async def stop(self) -> None: + try: + flush_tensor_keeper() + if HAS_TORCH and torch.cuda.is_available(): + torch.cuda.synchronize() + torch.cuda.ipc_collect() + finally: + await super().stop() + async def ping(self) -> str: """Basic connectivity check.""" return "pong" diff --git a/tests/integration_v2/test_isolation.py b/tests/integration_v2/test_isolation.py index 257f2cb..c938d84 100644 --- a/tests/integration_v2/test_isolation.py +++ b/tests/integration_v2/test_isolation.py @@ -1,6 +1,7 @@ import os import sys import tempfile +from pathlib import Path import pytest @@ -82,3 +83,26 @@ async def test_module_path_ro(reference_host): write_success = False assert not write_success, "Module path should be mounted Read-Only" + + +@pytest.mark.asyncio +async def test_host_tmp_marker_hidden_from_child(reference_host): + """Verify host /tmp is hidden while child /tmp remains writable.""" + host_marker = Path(tempfile.mkstemp(prefix="pyisolate_host_tmp_", dir="/tmp")[1]) + child_scratch = "/tmp/child_scratch.txt" + + try: + host_marker.write_text("host-only", encoding="utf-8") + + ext = reference_host.load_test_extension("tmp_privacy", isolated=True) + proxy = ext.get_proxy() + + with pytest.raises(Exception, match="No such file or directory"): + await proxy.read_file(str(host_marker)) + + assert await proxy.write_file(child_scratch, "child-only") == "ok" + assert await proxy.read_file(child_scratch) == "child-only" + assert not Path(child_scratch).exists(), "Child /tmp scratch leaked into host /tmp" + finally: + if host_marker.exists(): + host_marker.unlink() diff --git a/tests/integration_v2/test_tensors.py b/tests/integration_v2/test_tensors.py index 5f74e3a..13cfb5c 100644 --- a/tests/integration_v2/test_tensors.py +++ b/tests/integration_v2/test_tensors.py @@ -1,5 +1,7 @@ +import gc + import pytest -import torch +import torch # noqa: E402 try: import numpy as np # noqa: F401 @@ -52,6 +54,9 @@ async def test_cuda_allocation(reference_host): assert "device" in info assert "cuda" in info["device"] assert info["allocated_bytes"] >= 10 * 1024 * 1024 + del info + gc.collect() + torch.cuda.synchronize() print("[TEST] CUDA allocation verified.") @@ -76,4 +81,8 @@ async def test_tensor_roundtrip_cuda(reference_host): assert isinstance(result, torch.Tensor) assert result.device.type == "cuda" assert torch.equal(result.cpu(), t.cpu()) + del result + del t + gc.collect() + torch.cuda.synchronize() print("[TEST] CUDA IPC verified.") diff --git a/tests/path_unification/test_path_helpers.py b/tests/path_unification/test_path_helpers.py index de400c1..68549da 100644 --- a/tests/path_unification/test_path_helpers.py +++ b/tests/path_unification/test_path_helpers.py @@ -93,60 +93,66 @@ def test_removes_duplicates(self): assert result.count("/host/lib") == 1 assert result[0] == "/host/lib" - def test_inserts_comfy_root_first_when_missing(self): - """If comfy_root provided and not in host_paths, prepend it.""" + def test_inserts_preferred_root_first_when_missing(self): + """If preferred_root provided and not in host_paths, prepend it.""" host = ["/host/lib1", "/host/lib2"] extras = ["/venv/lib"] - comfy_root = os.environ.get("COMFYUI_ROOT") or str(Path.home() / "ComfyUI") + preferred = "/myapp/root" - result = build_child_sys_path(host, extras, comfy_root) + result = build_child_sys_path(host, extras, preferred) - assert result[0] == comfy_root + assert result[0] == preferred assert result[1:3] == host - def test_does_not_duplicate_comfy_root_if_present(self): - """If comfy_root already in host_paths, don't duplicate it.""" - comfy_root = os.environ.get("COMFYUI_ROOT") or str(Path.home() / "ComfyUI") - host = [comfy_root, "/host/lib1"] + def test_does_not_duplicate_preferred_root_if_present(self): + """If preferred_root already in host_paths, don't duplicate it.""" + preferred = "/myapp/root" + host = [preferred, "/host/lib1"] extras = ["/venv/lib"] - result = build_child_sys_path(host, extras, comfy_root) + result = build_child_sys_path(host, extras, preferred) - # Should only appear once - assert result.count(comfy_root) == 1 - assert result[0] == comfy_root + assert result.count(preferred) == 1 + assert result[0] == preferred - def test_removes_comfy_subdirectories_when_root_specified(self): - """Subdirectories of comfy_root should be filtered to avoid shadowing.""" - comfy_root = os.environ.get("COMFYUI_ROOT") or str(Path.home() / "ComfyUI") - host = [f"{comfy_root}/comfy", f"{comfy_root}/app", "/host/lib"] + def test_filtered_subdirs_removes_named_dirs_when_provided(self): + """Subdirectories in filtered_subdirs list should be excluded from output.""" + root = "/myapp/root" + host = [f"{root}/comfy", f"{root}/app", "/host/lib"] extras = ["/venv/lib"] - result = build_child_sys_path(host, extras, comfy_root) + result = build_child_sys_path(host, extras, root, filtered_subdirs=["comfy", "app"]) - # ComfyUI root should be first - assert result[0] == comfy_root - # Subdirectories should be removed - assert f"{comfy_root}/comfy" not in result - assert f"{comfy_root}/app" not in result - # Other paths should remain + assert result[0] == root + assert f"{root}/comfy" not in result + assert f"{root}/app" not in result assert "/host/lib" in result - def test_preserves_venv_site_packages_under_comfy_root(self): - """ComfyUI .venv site-packages should NOT be filtered out.""" - comfy_root = os.environ.get("COMFYUI_ROOT") or str(Path.home() / "ComfyUI") - venv_site = f"{comfy_root}/.venv/lib/python3.12/site-packages" - host = [f"{comfy_root}/comfy", venv_site, "/host/lib"] + def test_filtered_subdirs_none_preserves_all_subdirectory_paths(self): + """When filtered_subdirs is None, no subdirectory filtering is applied.""" + root = "/myapp/root" + host = [f"{root}/utils", f"{root}/app", "/host/lib"] extras = [] - result = build_child_sys_path(host, extras, comfy_root) + result = build_child_sys_path(host, extras, root, filtered_subdirs=None) - # ComfyUI root should be first - assert result[0] == comfy_root - # .venv site-packages MUST be preserved + assert result[0] == root + assert f"{root}/utils" in result + assert f"{root}/app" in result + assert "/host/lib" in result + + def test_filtered_subdirs_does_not_filter_deep_venv_paths(self): + """Paths deeper than one level under preferred_root are not filtered.""" + root = "/myapp/root" + venv_site = f"{root}/.venv/lib/python3.12/site-packages" + host = [f"{root}/comfy", venv_site, "/host/lib"] + extras = [] + + result = build_child_sys_path(host, extras, root, filtered_subdirs=["comfy"]) + + assert result[0] == root assert venv_site in result - # comfy subdir should be removed - assert f"{comfy_root}/comfy" not in result + assert f"{root}/comfy" not in result def test_appends_extra_paths(self): """Extra paths (isolated venv) should be appended after host paths.""" @@ -227,15 +233,13 @@ def test_round_trip_snapshot_and_rebuild(self): fake_venv = Path(tmpdir) / ".venv" / "lib" / "python3.12" / "site-packages" extras = [str(fake_venv)] - # Build child path + # Build child path using tmpdir as preferred_root (host-agnostic synthetic root) + preferred = tmpdir child_path = build_child_sys_path( snapshot["sys_path"], extras, - preferred_root=os.environ.get("COMFYUI_ROOT") or str(Path.home() / "ComfyUI"), + preferred_root=preferred, ) - # Verify structure - check that preferred_root is present - preferred = os.environ.get("COMFYUI_ROOT") or str(Path.home() / "ComfyUI") assert preferred in child_path assert str(fake_venv) in child_path - # Note: child_path may be shorter than snapshot["sys_path"] due to filtering of code subdirs diff --git a/tests/test_bootstrap.py b/tests/test_bootstrap.py index deb4fd2..4b3fac3 100644 --- a/tests/test_bootstrap.py +++ b/tests/test_bootstrap.py @@ -1,5 +1,6 @@ import json import sys +from importlib import import_module import pytest @@ -83,3 +84,148 @@ def test_bootstrap_missing_adapter(monkeypatch): ) with pytest.raises(ValueError): bootstrap.bootstrap_child() + + +def test_bootstrap_skips_host_sys_path_for_sealed_worker(monkeypatch, tmp_path): + host_only_path = str(tmp_path / "host_only") + snapshot = { + "sys_path": [host_only_path], + "adapter_ref": "fake:FakeAdapter", + "apply_host_sys_path": False, + } + monkeypatch.setenv("PYISOLATE_HOST_SNAPSHOT", json.dumps(snapshot)) + + original_sys_path = list(sys.path) + try: + adapter = bootstrap.bootstrap_child() + updated_sys_path = list(sys.path) + finally: + sys.path[:] = original_sys_path + + assert adapter is None + assert host_only_path not in updated_sys_path + + +def test_bootstrap_sealed_worker_skips_adapter_rehydration(monkeypatch): + snapshot = { + "adapter_ref": "bad.module:BadClass", + "apply_host_sys_path": False, + } + monkeypatch.setenv("PYISOLATE_HOST_SNAPSHOT", json.dumps(snapshot)) + monkeypatch.setattr( + bootstrap, "_rehydrate_adapter", lambda name: (_ for _ in ()).throw(ValueError("should not run")) + ) + + assert bootstrap.bootstrap_child() is None + + +def test_sealed_worker_host_policy_ro_paths_enable_import_without_host_sys_path(monkeypatch, tmp_path): + module_name = "sealed_opt_in_visible_module" + module_root = tmp_path / "opt_in_root" + module_root.mkdir(parents=True, exist_ok=True) + (module_root / f"{module_name}.py").write_text("VALUE = 42\n", encoding="utf-8") + + snapshot = { + "sys_path": [], + "apply_host_sys_path": False, + "sealed_host_ro_paths": [str(module_root)], + } + monkeypatch.setenv("PYISOLATE_HOST_SNAPSHOT", json.dumps(snapshot)) + + original_sys_path = list(sys.path) + try: + bootstrap.bootstrap_child() + imported = import_module(module_name) + finally: + sys.path[:] = original_sys_path + sys.modules.pop(module_name, None) + + assert imported.VALUE == 42 + + +def test_sealed_worker_without_opt_in_still_cannot_import_module(monkeypatch, tmp_path): + module_name = "sealed_no_opt_in_hidden_module" + blocked_root = tmp_path / "blocked_root" + blocked_root.mkdir(parents=True, exist_ok=True) + (blocked_root / f"{module_name}.py").write_text("VALUE = 7\n", encoding="utf-8") + + snapshot = { + "sys_path": [str(blocked_root)], + "apply_host_sys_path": False, + } + monkeypatch.setenv("PYISOLATE_HOST_SNAPSHOT", json.dumps(snapshot)) + + original_sys_path = list(sys.path) + try: + bootstrap.bootstrap_child() + with pytest.raises(ModuleNotFoundError): + import_module(module_name) + finally: + sys.path[:] = original_sys_path + sys.modules.pop(module_name, None) + + +def test_sealed_worker_attempts_adapter_rehydration_non_fatal(monkeypatch, tmp_path): + """Sealed workers attempt adapter rehydration for serializer registration. + + If rehydration fails, it is not fatal β€” the sealed worker continues + without an adapter. This changed from the previous behavior where + sealed workers skipped rehydration entirely. + """ + module_name = "sealed_opt_in_without_adapter" + module_root = tmp_path / "adapter_guard_root" + module_root.mkdir(parents=True, exist_ok=True) + (module_root / f"{module_name}.py").write_text("VALUE = 99\n", encoding="utf-8") + + called = {"rehydrate": False} + + def _fail(_name: str): + called["rehydrate"] = True + raise ImportError("adapter module not available in sealed env") + + monkeypatch.setattr(bootstrap, "_rehydrate_adapter", _fail) + snapshot = { + "sys_path": [], + "apply_host_sys_path": False, + "adapter_ref": "fake.module:Adapter", + "sealed_host_ro_paths": [str(module_root)], + } + monkeypatch.setenv("PYISOLATE_HOST_SNAPSHOT", json.dumps(snapshot)) + + original_sys_path = list(sys.path) + try: + adapter = bootstrap.bootstrap_child() + imported = import_module(module_name) + finally: + sys.path[:] = original_sys_path + sys.modules.pop(module_name, None) + + assert adapter is None + assert called["rehydrate"] is True # rehydration was attempted + assert imported.VALUE == 99 + + +def test_sealed_worker_singleton_bootstrap_attempts_adapter_rehydration(monkeypatch): + """Sealed workers attempt adapter rehydration. Failure is non-fatal.""" + called = {"rehydrate": False} + + def _fail(_name: str): + called["rehydrate"] = True + raise ImportError("sealed singleton cannot import adapter module") + + monkeypatch.setattr(bootstrap, "_rehydrate_adapter", _fail) + monkeypatch.setenv( + "PYISOLATE_HOST_SNAPSHOT", + json.dumps( + { + "apply_host_sys_path": False, + "adapter_ref": "comfy.isolation.adapter:ComfyIsolationAdapter", + "sealed_host_ro_paths": ["/home/johnj/ComfyUI"], + } + ), + ) + + adapter = bootstrap.bootstrap_child() + + assert adapter is None + assert called["rehydrate"] is True diff --git a/tests/test_bwrap_command.py b/tests/test_bwrap_command.py index d5869f7..cb4f7ff 100644 --- a/tests/test_bwrap_command.py +++ b/tests/test_bwrap_command.py @@ -316,6 +316,104 @@ def test_dev_shm_always_bound_for_tensor_sharing(self) -> None: class TestFilesystemIsolation: """Test filesystem isolation properties.""" + def test_base_prefix_ro_bound(self) -> None: + """Verify sys.base_prefix is added to --ro-bind to support non-standard host Pythons.""" + with patch.object(sys, "base_prefix", "/opt/custom_python"): + cmd = _mockbuild_bwrap_command( + python_exe="/venv/bin/python", + module_path="/path/to/module", + venv_path="/venv", + uds_address="/run/user/1000/pyisolate/test.sock", + allow_gpu=False, + restriction_model=RestrictionModel.NONE, + ) + + cmd_str = " ".join(cmd) + assert "--ro-bind /opt/custom_python /opt/custom_python" in cmd_str + + def test_adapter_system_paths_ro_bound(self) -> None: + """Verify adapter-provided system paths are added to --ro-bind.""" + mock_adapter = MagicMock() + mock_adapter.get_sandbox_system_paths.return_value = ["/app/framework"] + + cmd = _mockbuild_bwrap_command( + python_exe="/venv/bin/python", + module_path="/path/to/module", + venv_path="/venv", + uds_address="/run/user/1000/pyisolate/test.sock", + allow_gpu=False, + restriction_model=RestrictionModel.NONE, + adapter=mock_adapter, + ) + + cmd_str = " ".join(cmd) + assert "--ro-bind /app/framework /app/framework" in cmd_str + + def test_adapter_none_no_framework_path(self) -> None: + """Verify adapter=None does NOT produce framework path binding.""" + cmd = _mockbuild_bwrap_command( + python_exe="/venv/bin/python", + module_path="/path/to/module", + venv_path="/venv", + uds_address="/run/user/1000/pyisolate/test.sock", + allow_gpu=False, + restriction_model=RestrictionModel.NONE, + adapter=None, + ) + + cmd_str = " ".join(cmd) + assert "/app/framework" not in cmd_str + + def test_resolved_python_prefix_ro_bound(self) -> None: + """Verify the resolved venv interpreter prefix is also bound read-only.""" + with patch( + "pathlib.Path.resolve", + return_value=Path("/home/linuxbrew/.linuxbrew/Cellar/python@3.13/3.13.12_1/bin/python3.13"), + ): + cmd = _mockbuild_bwrap_command( + python_exe="/venv/bin/python", + module_path="/path/to/module", + venv_path="/venv", + uds_address="/run/user/1000/pyisolate/test.sock", + allow_gpu=False, + restriction_model=RestrictionModel.NONE, + ) + + cmd_str = " ".join(cmd) + assert ( + "--ro-bind /home/linuxbrew/.linuxbrew/Cellar/python@3.13/3.13.12_1 " + "/home/linuxbrew/.linuxbrew/Cellar/python@3.13/3.13.12_1" + ) in cmd_str + + def test_python_symlink_prefix_ro_bound(self) -> None: + """Verify the raw symlink target prefix is also bound read-only.""" + with ( + patch("pathlib.Path.is_symlink", return_value=True), + patch( + "os.readlink", + return_value="/home/linuxbrew/.linuxbrew/opt/python@3.13/bin/python3.13", + ), + patch( + "pathlib.Path.resolve", + return_value=Path("/home/linuxbrew/.linuxbrew/Cellar/python@3.13/3.13.12_1/bin/python3.13"), + ), + ): + cmd = _mockbuild_bwrap_command( + python_exe="/venv/bin/python", + module_path="/path/to/module", + venv_path="/venv", + uds_address="/run/user/1000/pyisolate/test.sock", + allow_gpu=False, + restriction_model=RestrictionModel.NONE, + ) + + cmd_str = " ".join(cmd) + assert ( + "--ro-bind /home/linuxbrew/.linuxbrew/opt/python@3.13 " + "/home/linuxbrew/.linuxbrew/opt/python@3.13" + ) in cmd_str + assert "--ro-bind /home/linuxbrew/.linuxbrew /home/linuxbrew/.linuxbrew" in cmd_str + def test_venv_readonly(self) -> None: """Verify venv is bound read-only.""" cmd = _mockbuild_bwrap_command( @@ -365,6 +463,22 @@ def test_tmpfs_tmp(self) -> None: cmd_str = " ".join(cmd) assert "--tmpfs /tmp" in cmd_str + def test_tmpfs_tmp_and_no_host_tmp_bind(self) -> None: + """Verify host /tmp cannot override the private tmpfs.""" + cmd = _mockbuild_bwrap_command( + python_exe="/venv/bin/python", + module_path="/path/to/module", + venv_path="/venv", + uds_address="/run/user/1000/pyisolate/test.sock", + allow_gpu=False, + restriction_model=RestrictionModel.NONE, + sandbox_config={"writable_paths": ["/dev/shm", "/tmp", "/tmp/"]}, + ) + + cmd_str = " ".join(cmd) + assert "--tmpfs /tmp" in cmd_str + assert "--bind /tmp /tmp" not in cmd_str + def test_proc_dev_mounted(self) -> None: """Verify /proc and /dev are mounted.""" cmd = _mockbuild_bwrap_command( @@ -459,3 +573,121 @@ def test_ends_with_python_uds_client(self) -> None: assert cmd[-3] == "/venv/bin/python" assert cmd[-2] == "-m" assert cmd[-1] == "pyisolate._internal.uds_client" + + +class TestSealedWorkerCommand: + """Test strict sealed-worker sandbox policy.""" + + def test_sealed_worker_uses_clearenv(self) -> None: + cmd = _mockbuild_bwrap_command( + python_exe="/venv/bin/python", + module_path="/path/to/module", + venv_path="/venv", + uds_address="/run/user/1000/pyisolate/test.sock", + allow_gpu=False, + restriction_model=RestrictionModel.NONE, + execution_model="sealed_worker", + ) + assert "--clearenv" in cmd + + def test_sealed_worker_does_not_bind_host_site_packages(self) -> None: + cmd = _mockbuild_bwrap_command( + python_exe="/venv/bin/python", + module_path="/path/to/module", + venv_path="/venv", + uds_address="/run/user/1000/pyisolate/test.sock", + allow_gpu=False, + restriction_model=RestrictionModel.NONE, + execution_model="sealed_worker", + ) + cmd_str = " ".join(cmd) + assert "site-packages" not in cmd_str + + def test_sealed_worker_does_not_bind_host_pyisolate_or_comfy(self) -> None: + cmd = _mockbuild_bwrap_command( + python_exe="/venv/bin/python", + module_path="/path/to/module", + venv_path="/venv", + uds_address="/run/user/1000/pyisolate/test.sock", + allow_gpu=False, + restriction_model=RestrictionModel.NONE, + execution_model="sealed_worker", + ) + cmd_str = " ".join(cmd) + assert "/fake/pyisolate" not in cmd_str + assert "comfy" not in cmd_str + + def test_sealed_worker_does_not_set_pythonpath(self) -> None: + cmd = _mockbuild_bwrap_command( + python_exe="/venv/bin/python", + module_path="/path/to/module", + venv_path="/venv", + uds_address="/run/user/1000/pyisolate/test.sock", + allow_gpu=False, + restriction_model=RestrictionModel.NONE, + execution_model="sealed_worker", + ) + assert "PYTHONPATH" not in cmd + + def test_sealed_worker_does_not_bind_dev_shm(self) -> None: + cmd = _mockbuild_bwrap_command( + python_exe="/venv/bin/python", + module_path="/path/to/module", + venv_path="/venv", + uds_address="/run/user/1000/pyisolate/test.sock", + allow_gpu=False, + restriction_model=RestrictionModel.NONE, + execution_model="sealed_worker", + ) + for i, arg in enumerate(cmd): + if arg in ("--bind", "--dev-bind") and i + 1 < len(cmd): + assert "/dev/shm" not in cmd[i + 1] + + def test_sealed_worker_host_policy_ro_paths_add_ro_bind_and_keep_clearenv(self) -> None: + cmd = _mockbuild_bwrap_command( + python_exe="/venv/bin/python", + module_path="/path/to/module", + venv_path="/venv", + uds_address="/run/user/1000/pyisolate/test.sock", + allow_gpu=False, + restriction_model=RestrictionModel.NONE, + execution_model="sealed_worker", + sealed_host_ro_paths=["/home/johnj/ComfyUI"], + ) + cmd_str = " ".join(cmd) + assert "--ro-bind /home/johnj/ComfyUI /home/johnj/ComfyUI" in cmd_str + assert "--clearenv" in cmd + assert "PYTHONPATH" not in cmd + assert "--setenv PYTHONPATH " not in cmd_str + + def test_sealed_worker_sets_explicit_env_allowlist(self) -> None: + with patch.dict( + "os.environ", + { + "PATH": "/usr/bin:/bin", + "LANG": "C.UTF-8", + "LC_ALL": "C.UTF-8", + "HOME": "/home/johnj", + "PYTHONPATH": "/host/leak", + "SECRET_TOKEN": "leak", + }, + clear=True, + ): + cmd = _mockbuild_bwrap_command( + python_exe="/venv/bin/python", + module_path="/path/to/module", + venv_path="/venv", + uds_address="/run/user/1000/pyisolate/test.sock", + allow_gpu=False, + restriction_model=RestrictionModel.NONE, + execution_model="sealed_worker", + ) + cmd_str = " ".join(cmd) + assert "--setenv PATH /usr/bin:/bin" in cmd_str + assert "--setenv LANG C.UTF-8" in cmd_str + assert "--setenv LC_ALL C.UTF-8" in cmd_str + assert "--setenv HOME /tmp" in cmd_str + assert "--setenv TMPDIR /tmp" in cmd_str + assert "--setenv PYTHONNOUSERSITE 1" in cmd_str + assert "SECRET_TOKEN" not in cmd_str + assert "/host/leak" not in cmd_str diff --git a/tests/test_client_entrypoint_extra.py b/tests/test_client_entrypoint_extra.py index e0963f6..b5554ee 100644 --- a/tests/test_client_entrypoint_extra.py +++ b/tests/test_client_entrypoint_extra.py @@ -1,9 +1,10 @@ +import importlib import sys from types import ModuleType import pytest -from pyisolate._internal import client +from pyisolate._internal import client, uds_client from pyisolate._internal.rpc_protocol import ProxiedSingleton from pyisolate.config import ExtensionConfig from pyisolate.shared import ExtensionBase @@ -204,3 +205,25 @@ async def run_until_stopped(self): assert DummyAPI.last_rpc is not None assert dummy_adapter.calls + + +def test_sealed_worker_skips_api_class_import(monkeypatch): + config = { + "name": "demo-sealed", + "dependencies": [], + "share_torch": False, + "share_cuda_ipc": False, + "execution_model": "sealed_worker", + "apis": ["forbidden.module.ForbiddenAPI"], + } + + def _forbidden_import(name: str, package: str | None = None): # noqa: ARG001 + if name == "forbidden.module": + raise AssertionError("sealed worker must not import API classes from config") + return importlib.import_module(name) + + monkeypatch.setattr(importlib, "import_module", _forbidden_import) + + resolved = uds_client._resolve_api_classes_from_config(config) + + assert resolved == [] diff --git a/tests/test_conda_integration.py b/tests/test_conda_integration.py new file mode 100644 index 0000000..ca84517 --- /dev/null +++ b/tests/test_conda_integration.py @@ -0,0 +1,131 @@ +"""Process-level integration tests for the sealed conda runtime.""" + +from __future__ import annotations + +import contextlib +import gc +import os +import shutil +import site +import uuid +from pathlib import Path + +import pytest +import torch # noqa: E402 + +from pyisolate._internal.host import Extension # noqa: E402 +from pyisolate.sealed import SealedNodeExtension # noqa: E402 + +PIXI_AVAILABLE = shutil.which("pixi") is not None + +pytestmark = pytest.mark.skipif(not PIXI_AVAILABLE, reason="pixi not on PATH") + + +def _expected_pixi_python(env_path: Path) -> Path: + if os.name == "nt": + return env_path / ".pixi" / "envs" / "default" / "python.exe" + return env_path / ".pixi" / "envs" / "default" / "bin" / "python" + + +def _shm_snapshot() -> set[str]: + shm_root = Path("/dev/shm") + if os.name == "nt" or not shm_root.exists(): + return set() + return {path.name for path in shm_root.glob("torch_*")} + + +def _build_conda_config(fixture_path: Path, run_dir: Path) -> dict: + # Inlined from fixtures/conda_sealed_node/pyproject.toml β€” no TOML parser needed. + return { + "name": "conda-sealed-node", + "module_path": str(fixture_path), + "isolated": True, + "dependencies": ["packaging"], + "apis": [], + "env": { + "PYISOLATE_ARTIFACT_DIR": str(run_dir / "artifacts"), + "PYISOLATE_SIGNAL_CLEANUP": "1", + }, + "share_torch": False, + "share_cuda_ipc": False, + "sandbox": {"writable_paths": [str(run_dir / "artifacts")]}, + "package_manager": "conda", + "execution_model": "sealed_worker", + "conda_channels": ["conda-forge"], + "conda_dependencies": ["boltons"], + } + + +@pytest.mark.asyncio +async def test_conda_sealed_runtime_avoids_host_path_leakage() -> None: + fixture_path = Path(__file__).resolve().parent / "fixtures" / "conda_sealed_node" + run_root = Path(__file__).resolve().parent.parent / ".pytest_artifacts" / "conda_integration" + run_dir = run_root / uuid.uuid4().hex + (run_dir / "artifacts").mkdir(parents=True, exist_ok=True) + venv_root = run_dir / "venvs" + venv_root.mkdir(parents=True, exist_ok=True) + config = _build_conda_config(fixture_path, run_dir) + + ext = Extension( + module_path=str(fixture_path), + extension_type=SealedNodeExtension, + config=config, + venv_root_path=str(venv_root), + ) + + try: + ext.ensure_process_started() + proxy = ext.get_proxy() + + nodes = await proxy.list_nodes() + assert nodes == { + "InspectRuntime": "Inspect Runtime", + "EchoTensor": "Echo Tensor", + "OpenWeatherDataset": "Open Weather Dataset", + } + + pixi_manifest = (ext.venv_path / "pixi.toml").read_text(encoding="utf-8") + assert 'boltons = "*"' in pixi_manifest + assert 'packaging = "*"' in pixi_manifest + + ( + path_dump, + host_leak_report, + python_exe, + ) = await proxy.execute_node("InspectRuntime") + path_entries = path_dump.splitlines() + assert str(fixture_path) in path_entries + assert site.getusersitepackages() not in path_dump + assert python_exe == str(_expected_pixi_python(ext.venv_path)) + + weather_sum, grib_path = await proxy.execute_node("OpenWeatherDataset") + assert weather_sum == pytest.approx(10.0) + assert Path(grib_path).exists() + + shm_before = _shm_snapshot() + input_tensor = torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float32) + echoed_tensor, saw_json_tensor = await proxy.execute_node("EchoTensor", tensor=input_tensor) + shm_after = _shm_snapshot() + + assert torch.equal(echoed_tensor, input_tensor) + assert saw_json_tensor is True + launch_args = " ".join(str(part) for part in ext.proc.args) + if os.name != "nt": + assert shm_after == shm_before + assert launch_args.startswith("bwrap ") + assert str(_expected_pixi_python(ext.venv_path)) in launch_args + assert "PYTHONPATH" not in launch_args + finally: + with contextlib.suppress(Exception): + if "proxy" in locals(): + await proxy.flush_transport_state() + with contextlib.suppress(UnboundLocalError): + del echoed_tensor + with contextlib.suppress(UnboundLocalError): + del input_tensor + gc.collect() + if torch.cuda.is_available(): + with contextlib.suppress(Exception): + torch.cuda.synchronize() + ext.stop() + shutil.rmtree(run_dir, ignore_errors=True) diff --git a/tests/test_conda_sealed_worker_contract.py b/tests/test_conda_sealed_worker_contract.py new file mode 100644 index 0000000..29fcb30 --- /dev/null +++ b/tests/test_conda_sealed_worker_contract.py @@ -0,0 +1,226 @@ +"""Generic conda/uv sealed-worker contract tests.""" + +from __future__ import annotations + +import importlib +import json +import os +import sys +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +from pyisolate._internal import bootstrap +from pyisolate._internal.environment_conda import _generate_pixi_toml, _resolve_pixi_python +from pyisolate._internal.host import Extension +from pyisolate._internal.sandbox_detect import RestrictionModel +from pyisolate.shared import ExtensionBase + + +def _make_conda_config(**overrides): + config = { + "name": "contract_ext", + "module": "contract_module", + "dependencies": ["requests>=2.0"], + "share_torch": False, + "share_cuda_ipc": False, + "package_manager": "conda", + "conda_channels": ["conda-forge"], + "conda_dependencies": ["numpy>=1.26"], + "conda_platforms": ["linux-64"], + } + config.update(overrides) + return config + + +def _make_extension(config: dict) -> Extension: + ext = Extension.__new__(Extension) + ext.name = config["name"] + ext.config = config + ext.venv_path = Path("/fake/venv") + ext.module_path = "/fake/module" + ext.extension_type = ExtensionBase + ext._cuda_ipc_enabled = False + ext._uds_path = None + ext._uds_listener = None + ext._client_sock = None + return ext + + +def _capture_bootstrap_payload(config: dict) -> dict: + ext = _make_extension(config) + + listener = MagicMock() + listener.accept.return_value = (MagicMock(), None) + transport = MagicMock() + proc = MagicMock() + proc.pid = 1234 + + with ( + patch( + "pyisolate._internal.host._resolve_pixi_python", + return_value=Path("/fake/venv/.pixi/envs/default/bin/python"), + ), + patch("pyisolate._internal.host.socket") as mock_socket, + patch( + "pyisolate._internal.host.tempfile.mktemp", return_value="/run/user/1000/pyisolate/ext_test.sock" + ), + patch("pyisolate._internal.host.subprocess.Popen", return_value=proc), + patch( + "pyisolate._internal.host.detect_sandbox_capability", + return_value=MagicMock(available=True, restriction_model=RestrictionModel.NONE), + ), + patch("pyisolate._internal.host.build_bwrap_command", return_value=["bwrap", "--clearenv", "python"]), + patch("pyisolate._internal.host.JSONSocketTransport", return_value=transport), + patch("pyisolate._internal.host.AsyncRPC"), + patch("pyisolate._internal.socket_utils.has_af_unix", return_value=True), + patch( + "pyisolate._internal.socket_utils.ensure_ipc_socket_dir", + return_value=Path("/run/user/1000/pyisolate"), + ), + patch( + "pyisolate._internal.host.build_extension_snapshot", + return_value={"sys_path": ["/host/path"], "apply_host_sys_path": True}, + ), + patch("os.chmod"), + patch("sys.platform", "linux"), + ): + mock_socket.socket.return_value = listener + mock_socket.AF_UNIX = 1 + mock_socket.SOCK_STREAM = 1 + ext._launch_with_uds() + + return transport.send.call_args[0][0] + + +def test_conda_dependency_split(): + config = _make_conda_config( + conda_dependencies=["numpy>=1.26", "scipy"], + dependencies=["requests>=2.0", "pandas"], + ) + + toml_text = _generate_pixi_toml(config) + + assert "[dependencies]" in toml_text + assert 'numpy = ">=1.26"' in toml_text + assert 'scipy = "*"' in toml_text + assert "[pypi-dependencies]" in toml_text + assert 'requests = ">=2.0"' in toml_text + assert 'pandas = "*"' in toml_text + + +def test_conda_channels_platforms_pass_through(): + config = _make_conda_config( + conda_channels=["conda-forge", "nvidia"], + conda_platforms=["linux-64", "win-64"], + ) + + toml_text = _generate_pixi_toml(config) + + assert 'channels = ["conda-forge", "nvidia"]' in toml_text + assert 'platforms = ["linux-64", "win-64"]' in toml_text + + +def test_uv_defaults_unchanged(): + config = _make_conda_config(package_manager="uv") + ext = _make_extension(config) + + assert ext._execution_model() == "host-coupled" + assert ext._tensor_transport_mode() == "shared_memory" + + +def test_no_host_fallback(tmp_path: Path): + env_path = tmp_path / "conda_env" + if os.name == "nt": + python_path = env_path / ".pixi" / "envs" / "default" / "python.exe" + else: + python_path = env_path / ".pixi" / "envs" / "default" / "bin" / "python" + python_path.parent.mkdir(parents=True, exist_ok=True) + python_path.touch() + + resolved = _resolve_pixi_python(env_path) + + assert str(resolved) != sys.executable + assert ".pixi" in str(resolved) + + +def test_no_host_sys_path(): + payload = _capture_bootstrap_payload( + _make_conda_config(package_manager="conda", execution_model="sealed_worker") + ) + + snapshot = payload["snapshot"] + assert snapshot["apply_host_sys_path"] is False + assert snapshot["additional_paths"] == [] + assert snapshot["preferred_root"] is None + + +def test_no_extension_wrapper_import(): + payload = _capture_bootstrap_payload( + _make_conda_config(package_manager="conda", execution_model="sealed_worker") + ) + + snapshot = payload["snapshot"] + assert snapshot["adapter_ref"] is None + assert snapshot["adapter_name"] is None + assert "extension_wrapper" not in str(payload) + + +def test_sealed_worker_host_policy_ro_paths_default_block_and_opt_in_allow( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +): + payload_default = _capture_bootstrap_payload( + _make_conda_config(package_manager="conda", execution_model="sealed_worker") + ) + payload_opt_in = _capture_bootstrap_payload( + _make_conda_config( + package_manager="conda", + execution_model="sealed_worker", + sealed_host_ro_paths=["/opt/example/app"], + ) + ) + + assert payload_default["snapshot"].get("sealed_host_ro_paths", []) == [] + assert payload_opt_in["snapshot"]["sealed_host_ro_paths"] == ["/opt/example/app"] + + app_framework_root = tmp_path / "app_framework_root" + app_api_dir = app_framework_root / "app_framework_api" + app_api_dir.mkdir(parents=True, exist_ok=True) + (app_api_dir / "__init__.py").write_text("", encoding="utf-8") + (app_api_dir / "latest.py").write_text("MARKER = 'ok'\n", encoding="utf-8") + + module_name = "app_framework_api.latest" + original_sys_path = list(sys.path) + try: + monkeypatch.setenv( + "PYISOLATE_HOST_SNAPSHOT", + json.dumps( + { + "sys_path": [str(app_framework_root)], + "apply_host_sys_path": False, + } + ), + ) + bootstrap.bootstrap_child() + with pytest.raises(ModuleNotFoundError): + importlib.import_module(module_name) + + monkeypatch.setenv( + "PYISOLATE_HOST_SNAPSHOT", + json.dumps( + { + "sys_path": [str(app_framework_root)], + "apply_host_sys_path": False, + "sealed_host_ro_paths": [str(app_framework_root)], + } + ), + ) + bootstrap.bootstrap_child() + imported = importlib.import_module(module_name) + finally: + sys.path[:] = original_sys_path + sys.modules.pop("app_framework_api.latest", None) + sys.modules.pop("app_framework_api", None) + + assert imported.MARKER == "ok" diff --git a/tests/test_config_conda.py b/tests/test_config_conda.py new file mode 100644 index 0000000..9979314 --- /dev/null +++ b/tests/test_config_conda.py @@ -0,0 +1,117 @@ +"""Tests for conda backend configuration and validation.""" + +from __future__ import annotations + +from unittest.mock import patch + +import pytest + +from pyisolate._internal.environment import validate_backend_config + + +def _make_config(**overrides): + """Build a minimal ExtensionConfig dict with conda defaults.""" + base = { + "name": "test_ext", + "module_path": "/fake/path", + "isolated": True, + "dependencies": [], + "apis": [], + "share_torch": False, + "share_cuda_ipc": False, + "sandbox": {}, + "sandbox_mode": "disabled", + "env": {}, + "package_manager": "uv", + } + base.update(overrides) + return base + + +class TestDefaultPackageManager: + def test_default_package_manager_is_uv(self): + """Config without package_manager should default to 'uv' and pass validation.""" + config = _make_config() + del config["package_manager"] + # Should not raise β€” uv is the default + validate_backend_config(config) + + +class TestCondaShareTorchRaises: + def test_conda_share_torch_raises(self): + """conda + share_torch=True must raise ValueError.""" + config = _make_config( + package_manager="conda", + share_torch=True, + conda_channels=["conda-forge"], + ) + with pytest.raises(ValueError, match="share_torch=False"): + validate_backend_config(config) + + +class TestCondaCudaWheelsAllowed: + @patch("shutil.which", return_value="/usr/bin/pixi") + def test_conda_cuda_wheels_allowed(self, _mock_which): + """conda + cuda_wheels is valid β€” pixi resolves via [pypi-options].""" + config = _make_config( + package_manager="conda", + conda_channels=["conda-forge"], + cuda_wheels=["cu121"], + ) + # Should not raise β€” conda supports cuda_wheels via extra-index-urls + validate_backend_config(config) + + +class TestCondaMissingChannelsRaises: + def test_conda_missing_channels_raises(self): + """conda + empty/missing conda_channels must raise ValueError.""" + config = _make_config( + package_manager="conda", + ) + with pytest.raises(ValueError, match="conda_channels"): + validate_backend_config(config) + + def test_conda_empty_channels_raises(self): + """conda + empty conda_channels list must raise ValueError.""" + config = _make_config( + package_manager="conda", + conda_channels=[], + ) + with pytest.raises(ValueError, match="conda_channels"): + validate_backend_config(config) + + +class TestCondaMissingPixiRaises: + @patch("shutil.which", return_value=None) + def test_conda_missing_pixi_raises(self, mock_which): + """conda + pixi not on PATH must raise ValueError.""" + config = _make_config( + package_manager="conda", + conda_channels=["conda-forge"], + ) + with pytest.raises(ValueError, match="pixi is required"): + validate_backend_config(config) + + +class TestCondaValidConfigPasses: + @patch("shutil.which", return_value="/usr/bin/pixi") + def test_conda_valid_config_passes(self, mock_which): + """Valid conda config must pass validation without error.""" + config = _make_config( + package_manager="conda", + conda_channels=["conda-forge"], + conda_dependencies=["numpy"], + ) + # Should not raise + validate_backend_config(config) + + @patch("shutil.which", return_value="/usr/bin/pixi") + def test_conda_with_platforms_passes(self, mock_which): + """Valid conda config with platforms must pass.""" + config = _make_config( + package_manager="conda", + conda_channels=["conda-forge"], + conda_dependencies=["numpy"], + conda_platforms=["linux-64"], + ) + validate_backend_config(config) diff --git a/tests/test_config_sealed_worker.py b/tests/test_config_sealed_worker.py new file mode 100644 index 0000000..287c632 --- /dev/null +++ b/tests/test_config_sealed_worker.py @@ -0,0 +1,134 @@ +"""Tests for execution_model validation and backward-compatible defaults.""" + +from __future__ import annotations + +from unittest.mock import patch + +import pytest + +from pyisolate._internal.environment import validate_backend_config + + +def _make_config(**overrides): + base = { + "name": "test_ext", + "module_path": "/fake/path", + "isolated": True, + "dependencies": [], + "apis": [], + "share_torch": False, + "share_cuda_ipc": False, + "sandbox": {}, + "sandbox_mode": "disabled", + "env": {}, + "package_manager": "uv", + } + base.update(overrides) + return base + + +def test_uv_defaults_to_host_coupled() -> None: + config = _make_config() + validate_backend_config(config) + + +def test_uv_explicit_sealed_worker_passes() -> None: + config = _make_config(execution_model="sealed_worker") + validate_backend_config(config) + + +def test_sealed_worker_rejects_share_torch_true() -> None: + config = _make_config(execution_model="sealed_worker", share_torch=True) + with pytest.raises(ValueError, match="sealed_worker execution_model requires share_torch=False"): + validate_backend_config(config) + + +@pytest.mark.parametrize( + ("package_manager", "execution_model"), + [ + ("uv", "host-coupled"), + ("uv", "sealed_worker"), + ("conda", "sealed_worker"), + ], +) +@patch("shutil.which", return_value="/usr/bin/pixi") +def test_rejects_cuda_ipc_without_share_torch(mock_which, package_manager: str, execution_model: str) -> None: + config = _make_config( + package_manager=package_manager, + execution_model=execution_model, + share_torch=False, + share_cuda_ipc=True, + conda_channels=["conda-forge"] if package_manager == "conda" else None, + conda_dependencies=["numpy"] if package_manager == "conda" else None, + ) + with pytest.raises(ValueError, match="share_cuda_ipc=True requires share_torch=True"): + validate_backend_config(config) + + +@patch("shutil.which", return_value="/usr/bin/pixi") +def test_accepts_valid_mode_matrix(mock_which) -> None: + valid_configs = [ + _make_config(execution_model="host-coupled", share_torch=True, share_cuda_ipc=True), + _make_config(execution_model="host-coupled", share_torch=True, share_cuda_ipc=False), + _make_config(execution_model="host-coupled", share_torch=False, share_cuda_ipc=False), + _make_config( + package_manager="conda", + execution_model="sealed_worker", + share_torch=False, + share_cuda_ipc=False, + conda_channels=["conda-forge"], + conda_dependencies=["numpy"], + ), + ] + + for config in valid_configs: + validate_backend_config(config) + + +def test_sealed_host_ro_paths_defaults_off_and_validation() -> None: + config = _make_config(execution_model="sealed_worker") + validate_backend_config(config) + assert config.get("sealed_host_ro_paths") is None + + valid = _make_config( + execution_model="sealed_worker", + sealed_host_ro_paths=["/home/johnj/ComfyUI"], + ) + validate_backend_config(valid) + + wrong_mode = _make_config( + execution_model="host-coupled", + sealed_host_ro_paths=["/home/johnj/ComfyUI"], + ) + with pytest.raises(ValueError, match="sealed_host_ro_paths requires execution_model='sealed_worker'"): + validate_backend_config(wrong_mode) + + non_list = _make_config(execution_model="sealed_worker", sealed_host_ro_paths="/home/johnj/ComfyUI") + with pytest.raises(ValueError, match="sealed_host_ro_paths must be a list of absolute paths"): + validate_backend_config(non_list) + + relative = _make_config(execution_model="sealed_worker", sealed_host_ro_paths=["relative/path"]) + with pytest.raises(ValueError, match="sealed_host_ro_paths entries must be absolute paths"): + validate_backend_config(relative) + + +@patch("shutil.which", return_value="/usr/bin/pixi") +def test_conda_defaults_to_sealed_worker(mock_which) -> None: + config = _make_config( + package_manager="conda", + conda_channels=["conda-forge"], + conda_dependencies=["numpy"], + ) + validate_backend_config(config) + + +@patch("shutil.which", return_value="/usr/bin/pixi") +def test_conda_rejects_host_coupled_execution_model(mock_which) -> None: + config = _make_config( + package_manager="conda", + execution_model="host-coupled", + conda_channels=["conda-forge"], + conda_dependencies=["numpy"], + ) + with pytest.raises(ValueError, match="conda backend requires execution_model='sealed_worker'"): + validate_backend_config(config) diff --git a/tests/test_cuda_wheels.py b/tests/test_cuda_wheels.py index 40d78dc..10ad7db 100644 --- a/tests/test_cuda_wheels.py +++ b/tests/test_cuda_wheels.py @@ -15,8 +15,10 @@ from pyisolate._internal import environment from pyisolate._internal.cuda_wheels import ( CUDAWheelResolutionError, + _normalize_cuda_wheel_config, get_cuda_wheel_runtime, resolve_cuda_wheel_requirements, + resolve_cuda_wheel_url, ) @@ -45,7 +47,7 @@ def test_resolve_cuda_wheel_requirement_to_direct_url(monkeypatch): wheel = _wheel_filename("flash_attn", "1.1.0+cu128torch28") page_url = "https://example.invalid/cuda-wheels/flash-attn/" - monkeypatch.setattr("pyisolate._internal.cuda_wheels.get_cuda_wheel_runtime", lambda: runtime) + monkeypatch.setattr("pyisolate._internal.cuda_wheels.get_cuda_wheel_runtime", lambda **kw: runtime) monkeypatch.setattr( "pyisolate._internal.cuda_wheels._fetch_index_html", lambda url: _simple_index_html(wheel) if url == page_url else None, @@ -68,7 +70,7 @@ def test_resolve_cuda_wheel_requirement_supports_underscore_index(monkeypatch): wheel = _wheel_filename("torch_generic_nms", "0.2.0+cu128torch28") page_url = "https://example.invalid/cuda-wheels/torch_generic_nms/" - monkeypatch.setattr("pyisolate._internal.cuda_wheels.get_cuda_wheel_runtime", lambda: runtime) + monkeypatch.setattr("pyisolate._internal.cuda_wheels.get_cuda_wheel_runtime", lambda **kw: runtime) monkeypatch.setattr( "pyisolate._internal.cuda_wheels._fetch_index_html", lambda url: _simple_index_html(wheel) if url == page_url else None, @@ -92,7 +94,7 @@ def test_resolve_cuda_wheel_requirement_supports_percent_encoded_links(monkeypat encoded_wheel = wheel.replace("+", "%2B") page_url = "https://example.invalid/cuda-wheels/torch-generic-nms/" - monkeypatch.setattr("pyisolate._internal.cuda_wheels.get_cuda_wheel_runtime", lambda: runtime) + monkeypatch.setattr("pyisolate._internal.cuda_wheels.get_cuda_wheel_runtime", lambda **kw: runtime) monkeypatch.setattr( "pyisolate._internal.cuda_wheels._fetch_index_html", lambda url: _simple_index_html(encoded_wheel) if url == page_url else None, @@ -115,7 +117,7 @@ def test_resolve_cuda_wheel_requirement_honors_package_map(monkeypatch): wheel = _wheel_filename("flash_attn", "1.2.0+cu128torch28") page_url = "https://example.invalid/cuda-wheels/flash_attn_special/" - monkeypatch.setattr("pyisolate._internal.cuda_wheels.get_cuda_wheel_runtime", lambda: runtime) + monkeypatch.setattr("pyisolate._internal.cuda_wheels.get_cuda_wheel_runtime", lambda **kw: runtime) monkeypatch.setattr( "pyisolate._internal.cuda_wheels._fetch_index_html", lambda url: _simple_index_html(wheel) if url == page_url else None, @@ -141,7 +143,7 @@ def test_resolve_cuda_wheel_requirement_picks_highest_matching_version(monkeypat out_of_range = _wheel_filename("flash_attn", "2.0.0+cu128torch28") page_url = "https://example.invalid/cuda-wheels/flash-attn/" - monkeypatch.setattr("pyisolate._internal.cuda_wheels.get_cuda_wheel_runtime", lambda: runtime) + monkeypatch.setattr("pyisolate._internal.cuda_wheels.get_cuda_wheel_runtime", lambda **kw: runtime) monkeypatch.setattr( "pyisolate._internal.cuda_wheels._fetch_index_html", lambda url: ( @@ -182,7 +184,7 @@ def test_resolve_cuda_wheel_requirement_prefers_better_supported_tag(monkeypatch preferred = f"torch_generic_nms-0.1+cu128torch28-{ml_tag.interpreter}-{ml_tag.abi}-{ml_tag.platform}.whl" fallback = f"torch_generic_nms-0.1+cu128torch28-{lx_tag.interpreter}-{lx_tag.abi}-{lx_tag.platform}.whl" - monkeypatch.setattr("pyisolate._internal.cuda_wheels.get_cuda_wheel_runtime", lambda: runtime) + monkeypatch.setattr("pyisolate._internal.cuda_wheels.get_cuda_wheel_runtime", lambda **kw: runtime) monkeypatch.setattr( "pyisolate._internal.cuda_wheels._fetch_index_html", lambda url: ( @@ -211,7 +213,7 @@ def test_resolve_cuda_wheel_requirement_raises_when_no_match(monkeypatch): wheel = _wheel_filename("flash_attn", "1.1.0+cu127torch28") page_url = "https://example.invalid/cuda-wheels/flash-attn/" - monkeypatch.setattr("pyisolate._internal.cuda_wheels.get_cuda_wheel_runtime", lambda: runtime) + monkeypatch.setattr("pyisolate._internal.cuda_wheels.get_cuda_wheel_runtime", lambda **kw: runtime) monkeypatch.setattr( "pyisolate._internal.cuda_wheels._fetch_index_html", lambda url: _simple_index_html(wheel) if url == page_url else None, @@ -315,3 +317,251 @@ def __exit__(self, exc_type, exc, tb): environment.install_dependencies(venv_path, config, "demo") assert len(popen_calls) == 2 + + +# ── target_python parameter tests ───────────────────────────────────── + + +def _wheel_filename_for_cpython(distribution: str, version: str, cpython: str) -> str: + """Build a wheel filename with a specific cpython tag (e.g. 'cp312').""" + tag = next(iter(sys_tags())) + return f"{distribution}-{version}-{cpython}-{cpython}-{tag.platform}.whl" + + +def test_resolve_cuda_wheel_url_accepts_target_python_parameter(monkeypatch): + """AC-1: resolve_cuda_wheel_url accepts target_python=(3, 12) without TypeError.""" + runtime = _runtime() + wheel_312 = _wheel_filename_for_cpython("flash_attn", "1.0.0+cu128torch28", "cp312") + page_url = "https://example.invalid/cuda-wheels/flash-attn/" + + monkeypatch.setattr("pyisolate._internal.cuda_wheels.get_cuda_wheel_runtime", lambda **kw: runtime) + monkeypatch.setattr( + "pyisolate._internal.cuda_wheels._fetch_index_html", + lambda url: _simple_index_html(wheel_312) if url == page_url else None, + ) + + from packaging.requirements import Requirement + + url = resolve_cuda_wheel_url( + Requirement("flash-attn"), + { + "index_url": "https://example.invalid/cuda-wheels/", + "packages": ["flash-attn"], + "package_map": {}, + }, + runtime, + target_python=(3, 12), + ) + assert url.endswith(".whl") + assert "cp312" in url + + +def test_resolve_cuda_wheel_uses_target_python_tags(monkeypatch): + """AC-2: target_python=(3, 12) selects cp312 wheel, not cp313.""" + runtime = _runtime() + wheel_312 = _wheel_filename_for_cpython("flash_attn", "1.0.0+cu128torch28", "cp312") + wheel_313 = _wheel_filename_for_cpython("flash_attn", "1.0.0+cu128torch28", "cp313") + page_url = "https://example.invalid/cuda-wheels/flash-attn/" + + monkeypatch.setattr("pyisolate._internal.cuda_wheels.get_cuda_wheel_runtime", lambda **kw: runtime) + monkeypatch.setattr( + "pyisolate._internal.cuda_wheels._fetch_index_html", + lambda url: _simple_index_html(wheel_312, wheel_313) if url == page_url else None, + ) + + from packaging.requirements import Requirement + + url = resolve_cuda_wheel_url( + Requirement("flash-attn"), + { + "index_url": "https://example.invalid/cuda-wheels/", + "packages": ["flash-attn"], + "package_map": {}, + }, + runtime, + target_python=(3, 12), + ) + assert "cp312" in url + assert "cp313" not in url + + +def test_resolve_cuda_wheel_requirements_threads_target_python(monkeypatch): + """AC-3: resolve_cuda_wheel_requirements threads target_python to resolve_cuda_wheel_url.""" + runtime = _runtime() + wheel_312 = _wheel_filename_for_cpython("flash_attn", "1.0.0+cu128torch28", "cp312") + wheel_313 = _wheel_filename_for_cpython("flash_attn", "1.0.0+cu128torch28", "cp313") + page_url = "https://example.invalid/cuda-wheels/flash-attn/" + + monkeypatch.setattr("pyisolate._internal.cuda_wheels.get_cuda_wheel_runtime", lambda **kw: runtime) + monkeypatch.setattr( + "pyisolate._internal.cuda_wheels._fetch_index_html", + lambda url: _simple_index_html(wheel_312, wheel_313) if url == page_url else None, + ) + + resolved = resolve_cuda_wheel_requirements( + ["flash-attn>=1.0"], + { + "index_url": "https://example.invalid/cuda-wheels/", + "packages": ["flash-attn"], + "package_map": {}, + }, + target_python=(3, 12), + ) + assert len(resolved) == 1 + assert "cp312" in resolved[0] + assert "cp313" not in resolved[0] + + +def test_resolve_cuda_wheel_target_python_rejects_host_only_wheel(monkeypatch): + """AC-4: target_python=(3, 12) raises when only cp313 wheel available.""" + runtime = _runtime() + wheel_313_only = _wheel_filename_for_cpython("flash_attn", "1.0.0+cu128torch28", "cp313") + page_url = "https://example.invalid/cuda-wheels/flash-attn/" + + monkeypatch.setattr("pyisolate._internal.cuda_wheels.get_cuda_wheel_runtime", lambda **kw: runtime) + monkeypatch.setattr( + "pyisolate._internal.cuda_wheels._fetch_index_html", + lambda url: _simple_index_html(wheel_313_only) if url == page_url else None, + ) + + from packaging.requirements import Requirement + + with pytest.raises(CUDAWheelResolutionError, match="No compatible CUDA wheel found"): + resolve_cuda_wheel_url( + Requirement("flash-attn"), + { + "index_url": "https://example.invalid/cuda-wheels/", + "packages": ["flash-attn"], + "package_map": {}, + }, + runtime, + target_python=(3, 12), + ) + + +# ── index_urls (plural) support tests ───────────────────────────────── + + +def test_normalize_config_index_urls(): + """AC-1: _normalize_cuda_wheel_config accepts index_urls (plural list).""" + config = { + "index_urls": [ + "https://download.pytorch.org/whl/cu128", + "https://pozzettiandrea.github.io/cuda-wheels/", + ], + "packages": ["cumesh", "torch"], + "package_map": {}, + } + result = _normalize_cuda_wheel_config(config) + assert "index_urls" in result + assert isinstance(result["index_urls"], list) + assert len(result["index_urls"]) == 2 + assert result["index_urls"][0] == "https://download.pytorch.org/whl/cu128/" + assert result["index_urls"][1] == "https://pozzettiandrea.github.io/cuda-wheels/" + + +def test_normalize_config_singular_returns_index_urls_key(): + """AC-2: _normalize_cuda_wheel_config with index_url (singular) returns index_urls key (plural).""" + config = { + "index_url": "https://example.invalid/cuda-wheels", + "packages": ["flash-attn"], + "package_map": {}, + } + result = _normalize_cuda_wheel_config(config) + assert "index_urls" in result, "Expected 'index_urls' key in normalized config" + assert "index_url" not in result, "Expected NO 'index_url' key in normalized config" + assert result["index_urls"] == ["https://example.invalid/cuda-wheels/"] + + +def test_resolve_iterates_multiple_indexes(monkeypatch): + """AC-3: resolve_cuda_wheel_url iterates multiple index URLs to find a package.""" + runtime = _runtime() + wheel = _wheel_filename("cumesh", "0.0.1+cu128torch28") + # cumesh only exists on the second index URL + second_index_page = "https://pozzettiandrea.github.io/cuda-wheels/cumesh/" + + monkeypatch.setattr("pyisolate._internal.cuda_wheels.get_cuda_wheel_runtime", lambda **kw: runtime) + monkeypatch.setattr( + "pyisolate._internal.cuda_wheels._fetch_index_html", + lambda url: _simple_index_html(wheel) if url == second_index_page else None, + ) + + from packaging.requirements import Requirement + + url = resolve_cuda_wheel_url( + Requirement("cumesh"), + { + "index_urls": [ + "https://download.pytorch.org/whl/cu128/", + "https://pozzettiandrea.github.io/cuda-wheels/", + ], + "packages": ["cumesh"], + "package_map": {}, + }, + runtime, + ) + assert url.endswith(".whl") + assert "cumesh" in url + assert "pozzettiandrea" in url + + +# ── Live network tests ──────────────────────────────────────────────── + + +@pytest.mark.network +def test_resolve_sageattention_for_target_python_312(): + """AC-1/AC-2: Live index resolves cp312 wheel with correct torch+CUDA pattern.""" + from packaging.requirements import Requirement + + from pyisolate._internal.cuda_wheels import get_cuda_wheel_runtime + + runtime = get_cuda_wheel_runtime(target_python=(3, 12)) + url = resolve_cuda_wheel_url( + Requirement("sageattention"), + { + "index_url": "https://pozzettiandrea.github.io/cuda-wheels/", + "packages": ["sageattention"], + "package_map": {}, + }, + runtime, + target_python=(3, 12), + ) + assert "cp312" in url, f"Expected cp312 in URL: {url}" + assert "cp313" not in url, f"Unexpected cp313 in URL: {url}" + # AC-2: torch+CUDA pattern preserved (URL may use dotted "torch2.9" or nodot "torch29") + assert runtime["cuda_nodot"] in url, f"Expected cuda {runtime['cuda_nodot']} in URL: {url}" + assert f"torch{runtime['torch']}" in url or f"torch{runtime['torch_nodot']}" in url, ( + f"Expected torch {runtime['torch']} or {runtime['torch_nodot']} in URL: {url}" + ) + + +@pytest.mark.network +def test_resolve_sageattention_host_tags_selects_cp313(): + """AC-3: Without target_python selects host cpXXX; with target_python=(3, 11) selects cp311.""" + from packaging.requirements import Requirement + + from pyisolate._internal.cuda_wheels import get_cuda_wheel_runtime + + # Host tags (should select host interpreter's cp tag β€” cp312 on this venv) + runtime_host = get_cuda_wheel_runtime() + config = { + "index_url": "https://pozzettiandrea.github.io/cuda-wheels/", + "packages": ["sageattention"], + "package_map": {}, + } + url_host = resolve_cuda_wheel_url(Requirement("sageattention"), config, runtime_host) + # The host cpXXX tag should be present + import sys + + host_cp = f"cp{sys.version_info.major}{sys.version_info.minor}" + assert host_cp in url_host, f"Expected {host_cp} in host URL: {url_host}" + + # Target tags (3, 11) β€” should select cp311, different from host + runtime_target = get_cuda_wheel_runtime(target_python=(3, 11)) + url_target = resolve_cuda_wheel_url( + Requirement("sageattention"), config, runtime_target, target_python=(3, 11) + ) + assert "cp311" in url_target, f"Expected cp311 in target URL: {url_target}" + assert host_cp not in url_target or host_cp == "cp311", ( + f"Unexpected {host_cp} in target URL: {url_target}" + ) diff --git a/tests/test_environment_conda.py b/tests/test_environment_conda.py new file mode 100644 index 0000000..eb63bdc --- /dev/null +++ b/tests/test_environment_conda.py @@ -0,0 +1,582 @@ +"""Tests for conda/pixi environment creation (environment_conda.py).""" + +from __future__ import annotations + +import json +import os +import sys +from pathlib import Path +from unittest.mock import patch + +import pytest + +from pyisolate._internal.environment_conda import ( + _generate_pixi_toml, + _install_cuda_wheels_into_pixi, + _install_local_wheels, + _parse_dep, + _pyisolate_source_path, + _resolve_pixi_python, + _resolve_uv_exe, + _toml_path_string, + create_conda_env, +) + + +def _make_conda_config(**overrides: object) -> dict: + """Minimal valid conda config for tests.""" + base: dict = { + "package_manager": "conda", + "conda_channels": ["conda-forge"], + "conda_dependencies": ["numpy"], + "dependencies": ["requests"], + "share_torch": False, + "module": "test_ext", + } + base.update(overrides) + return base + + +def _pixi_python_path(env_path: Path) -> Path: + if os.name == "nt": + return env_path / ".pixi" / "envs" / "default" / "python.exe" + return env_path / ".pixi" / "envs" / "default" / "bin" / "python" + + +# ── _generate_pixi_toml ────────────────────────────────────────────── + + +class TestGeneratePixiToml: + def test_basic_toml_structure(self) -> None: + config = _make_conda_config() + toml_str = _generate_pixi_toml(config) + assert "[workspace]" in toml_str + assert "[project]" not in toml_str + assert "[dependencies]" in toml_str + assert 'python = "*"' in toml_str + assert "numpy" in toml_str + + def test_generate_pixi_toml_uses_workspace_header(self) -> None: + config = _make_conda_config() + toml_str = _generate_pixi_toml(config) + assert "[workspace]" in toml_str + assert "[project]" not in toml_str + + def test_conda_deps_in_dependencies_section(self) -> None: + config = _make_conda_config(conda_dependencies=["numpy", "scipy>=1.10"]) + toml_str = _generate_pixi_toml(config) + assert "numpy" in toml_str + assert "scipy" in toml_str + + def test_pip_deps_in_pypi_dependencies(self) -> None: + config = _make_conda_config(dependencies=["requests>=2.0", "flask"]) + toml_str = _generate_pixi_toml(config) + assert "[pypi-dependencies]" in toml_str + assert "requests" in toml_str + assert "flask" in toml_str + + def test_generate_pixi_toml_excludes_cuda_wheel_packages_from_pypi_dependencies( + self, + ) -> None: + config = _make_conda_config( + dependencies=["requests>=2.0", "spconv", "cumm", "flash-attn"], + cuda_wheels={ + "index_url": "https://example.invalid/cuda-wheels/", + "packages": ["spconv", "cumm", "flash-attn"], + }, + ) + toml_str = _generate_pixi_toml(config) + assert "[pypi-dependencies]" in toml_str + assert 'requests = ">=2.0"' in toml_str + assert "spconv =" not in toml_str + assert "cumm =" not in toml_str + assert "flash-attn =" not in toml_str + + def test_conda_manifest_installs_local_pyisolate(self) -> None: + config = _make_conda_config() + toml_str = _generate_pixi_toml(config) + assert f'pyisolate = {{ path = "{_toml_path_string(_pyisolate_source_path())}" }}' in toml_str + + def test_windows_pixi_manifest_path_is_toml_safe(self) -> None: + config = _make_conda_config() + toml_str = _generate_pixi_toml(config) + assert f'pyisolate = {{ path = "{_toml_path_string(_pyisolate_source_path())}" }}' in toml_str + if os.name == "nt": + assert f'pyisolate = {{ path = "{_pyisolate_source_path()}" }}' not in toml_str + + def test_channels_included(self) -> None: + config = _make_conda_config(conda_channels=["conda-forge", "nvidia"]) + toml_str = _generate_pixi_toml(config) + assert "conda-forge" in toml_str + assert "nvidia" in toml_str + + def test_platforms_included(self) -> None: + config = _make_conda_config(conda_platforms=["linux-64", "win-64"]) + toml_str = _generate_pixi_toml(config) + assert "linux-64" in toml_str + assert "win-64" in toml_str + + def test_no_pip_deps_omits_pypi_section(self) -> None: + config = _make_conda_config(dependencies=[]) + toml_str = _generate_pixi_toml(config) + assert "[pypi-dependencies]" in toml_str + assert f'pyisolate = {{ path = "{_toml_path_string(_pyisolate_source_path())}" }}' in toml_str + + def test_no_conda_deps_omits_dependencies_section(self) -> None: + config = _make_conda_config(conda_dependencies=[]) + toml_str = _generate_pixi_toml(config) + assert "[workspace]" in toml_str + assert "[dependencies]" in toml_str + assert 'python = "*"' in toml_str + + def test_generate_pixi_toml_preserves_extras(self) -> None: + config = _make_conda_config(dependencies=["jax[cuda12]>=0.4.30", "numpy>=2.2"]) + toml_str = _generate_pixi_toml(config) + assert 'jax = { version = ">=0.4.30", extras = ["cuda12"] }' in toml_str + assert 'numpy = ">=2.2"' in toml_str + + def test_generate_pixi_toml_marker_passthrough(self) -> None: + config = _make_conda_config( + dependencies=[ + "jax[cuda12]>=0.4.30; sys_platform == 'linux'", + "jax>=0.4.30; sys_platform == 'win32'", + ] + ) + toml_str = _generate_pixi_toml(config) + assert "sys_platform == 'linux'" in toml_str + assert "sys_platform == 'win32'" in toml_str + + def test_generate_pixi_toml_marker_not_in_version(self) -> None: + config = _make_conda_config(dependencies=["jax[cuda12]>=0.4.30; sys_platform == 'linux'"]) + toml_str = _generate_pixi_toml(config) + # The marker must NOT appear inside the version field + assert 'version = ">=0.4.30; sys_platform' not in toml_str + # It must appear in a separate markers field + assert 'markers = "sys_platform ==' in toml_str + + def test_generate_pixi_toml_marker_version_clean(self) -> None: + config = _make_conda_config(dependencies=["jax[cuda12]>=0.4.30; sys_platform == 'linux'"]) + toml_str = _generate_pixi_toml(config) + assert 'version = ">=0.4.30"' in toml_str + + +# ── _parse_dep ────────────────────────────────────────────────────── + + +class TestParseDep: + def test_parse_dep_preserves_extras(self) -> None: + name, sep, ver, extras, marker = _parse_dep("jax[cuda12]>=0.4.30") + assert name == "jax" + assert sep == ">=" + assert ver == ">=0.4.30" + assert extras == ["cuda12"] + assert marker == "" + + def test_parse_dep_preserves_easy_extra(self) -> None: + name, sep, ver, extras, marker = _parse_dep("trimesh[easy]>=4.0.0") + assert name == "trimesh" + assert sep == ">=" + assert ver == ">=4.0.0" + assert extras == ["easy"] + assert marker == "" + + def test_parse_dep_no_extras(self) -> None: + name, sep, ver, extras, marker = _parse_dep("numpy>=2.0") + assert name == "numpy" + assert sep == ">=" + assert ver == ">=2.0" + assert extras == [] + assert marker == "" + + def test_parse_dep_bare_name(self) -> None: + name, sep, ver, extras, marker = _parse_dep("requests") + assert name == "requests" + assert sep == "" + assert ver == "" + assert extras == [] + assert marker == "" + + def test_parse_dep_url(self) -> None: + name, sep, ver, extras, marker = _parse_dep("pkg @ https://example.com/pkg.whl") + assert name == "pkg" + assert sep == "@" + assert ver == "https://example.com/pkg.whl" + assert extras == [] + assert marker == "" + + def test_parse_dep_marker_extras(self) -> None: + name, sep, ver, extras, marker = _parse_dep("jax[cuda12]>=0.4.30; sys_platform == 'linux'") + assert name == "jax" + assert sep == ">=" + assert ver == ">=0.4.30" + assert extras == ["cuda12"] + assert marker == "sys_platform == 'linux'" + + def test_parse_dep_marker_version_only(self) -> None: + name, sep, ver, extras, marker = _parse_dep("numpy>=2.0; platform_system != 'Windows'") + assert name == "numpy" + assert sep == ">=" + assert ver == ">=2.0" + assert extras == [] + assert marker == "platform_system != 'Windows'" + + def test_parse_dep_marker_url(self) -> None: + name, sep, ver, extras, marker = _parse_dep( + "pkg @ https://example.com/pkg.whl ; python_version >= '3.12'" + ) + assert name == "pkg" + assert sep == "@" + assert ver == "https://example.com/pkg.whl" + assert extras == [] + assert marker == "python_version >= '3.12'" + + +# ── create_conda_env ───────────────────────────────────────────────── + + +class TestCreateCondaEnv: + def test_pixi_not_found_raises(self, tmp_path: Path) -> None: + config = _make_conda_config() + with patch("shutil.which", return_value=None), pytest.raises(RuntimeError, match="pixi.*not found"): + create_conda_env(tmp_path / "env", config, "test_ext") + + def test_pixi_install_called(self, tmp_path: Path) -> None: + env_path = tmp_path / "env" + config = _make_conda_config() + pixi_python = _pixi_python_path(env_path) + + with ( + patch("shutil.which", return_value="/usr/bin/pixi"), + patch("subprocess.check_call") as mock_call, + patch.object(Path, "exists", return_value=True), + ): + # Make the pixi python appear to exist + pixi_python.parent.mkdir(parents=True, exist_ok=True) + pixi_python.touch() + create_conda_env(env_path, config, "test_ext") + + # pixi install should have been called + assert mock_call.called + call_args = mock_call.call_args[0][0] + assert "pixi" in call_args[0] + assert "install" in call_args + + def test_pixi_install_failure_raises(self, tmp_path: Path) -> None: + import subprocess + + config = _make_conda_config() + with ( + patch("shutil.which", return_value="/usr/bin/pixi"), + patch( + "subprocess.check_call", + side_effect=subprocess.CalledProcessError(1, "pixi"), + ), + pytest.raises(subprocess.CalledProcessError), + ): + create_conda_env(tmp_path / "env", config, "test_ext") + + def test_create_conda_env_installs_cuda_wheels_post_pixi(self, tmp_path: Path) -> None: + env_path = tmp_path / "env" + config = _make_conda_config( + dependencies=["requests>=2.0", "spconv", "cumm"], + cuda_wheels={ + "index_url": "https://example.invalid/cuda-wheels/", + "packages": ["spconv", "cumm"], + }, + ) + pixi_python = _pixi_python_path(env_path) + + with ( + patch("shutil.which", return_value="/usr/bin/pixi"), + patch("subprocess.check_call") as mock_check_call, + patch( + "pyisolate._internal.environment_conda._install_cuda_wheels_into_pixi" + ) as mock_install_cuda_wheels, + ): + pixi_python.parent.mkdir(parents=True, exist_ok=True) + pixi_python.touch() + create_conda_env(env_path, config, "test_ext") + + assert mock_check_call.called + mock_install_cuda_wheels.assert_called_once_with( + pixi_python, + config, + config["cuda_wheels"], + "test_ext", + ) + + def test_writes_pixi_toml(self, tmp_path: Path) -> None: + env_path = tmp_path / "env" + config = _make_conda_config() + pixi_python = _pixi_python_path(env_path) + + with ( + patch("shutil.which", return_value="/usr/bin/pixi"), + patch("subprocess.check_call"), + ): + pixi_python.parent.mkdir(parents=True, exist_ok=True) + pixi_python.touch() + create_conda_env(env_path, config, "test_ext") + + toml_path = env_path / "pixi.toml" + assert toml_path.exists() + content = toml_path.read_text() + assert "[workspace]" in content + + def test_sanitizes_invalid_tmpdir_for_pixi_install( + self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch + ) -> None: + env_path = tmp_path / "env" + config = _make_conda_config() + stale_tmpdir = tmp_path / "deleted" / "ipc_shared" + monkeypatch.setenv("TMPDIR", str(stale_tmpdir)) + + with ( + patch("shutil.which", return_value="/usr/bin/pixi"), + patch("subprocess.check_call") as mock_call, + patch( + "pyisolate._internal.environment_conda._resolve_pixi_python", + return_value=_pixi_python_path(env_path), + ), + ): + create_conda_env(env_path, config, "test_ext") + + call_kwargs = mock_call.call_args.kwargs + passed_env = call_kwargs["env"] + assert passed_env["TMPDIR"] != str(stale_tmpdir) + assert Path(passed_env["TMPDIR"]).exists() + + def test_fingerprint_skip(self, tmp_path: Path) -> None: + """If fingerprint matches, pixi install should be skipped.""" + env_path = tmp_path / "env" + env_path.mkdir(parents=True) + config = _make_conda_config() + + # Pre-create a matching fingerprint + import hashlib + + toml_content = _generate_pixi_toml(config) + descriptor = { + "conda_dependencies": config.get("conda_dependencies", []), + "pip_dependencies": config.get("dependencies", []), + "channels": config.get("conda_channels", []), + "platforms": config.get("conda_platforms", []), + "cuda_wheels": config.get("cuda_wheels"), + "find_links": config.get("find_links", []), + "pixi_toml": toml_content, + } + fingerprint = hashlib.sha256(json.dumps(descriptor, sort_keys=True).encode()).hexdigest() + lock_path = env_path / ".pyisolate_deps.json" + lock_path.write_text(json.dumps({"fingerprint": fingerprint, "descriptor": descriptor})) + + pixi_python = _pixi_python_path(env_path) + pixi_python.parent.mkdir(parents=True, exist_ok=True) + pixi_python.touch() + + with ( + patch("shutil.which", return_value="/usr/bin/pixi"), + patch("subprocess.check_call") as mock_call, + ): + create_conda_env(env_path, config, "test_ext") + + # pixi install should NOT have been called + assert not mock_call.called + + +# ── _resolve_pixi_python ───────────────────────────────────────────── + + +class TestResolvePixiPython: + def test_returns_pixi_env_python(self, tmp_path: Path) -> None: + env_path = tmp_path / "env" + expected = _pixi_python_path(env_path) + expected.parent.mkdir(parents=True, exist_ok=True) + expected.touch() + result = _resolve_pixi_python(env_path) + assert result == expected + + def test_missing_python_raises(self, tmp_path: Path) -> None: + env_path = tmp_path / "env" + env_path.mkdir(parents=True) + with pytest.raises(RuntimeError, match="Python.*not found"): + _resolve_pixi_python(env_path) + + def test_never_returns_host_python(self, tmp_path: Path) -> None: + env_path = tmp_path / "env" + expected = _pixi_python_path(env_path) + expected.parent.mkdir(parents=True, exist_ok=True) + expected.touch() + result = _resolve_pixi_python(env_path) + assert str(result) != sys.executable + assert ".pixi" in str(result) + + +# ── _install_cuda_wheels_into_pixi target_python threading ───────────── + + +def test_install_cuda_wheels_passes_target_python(monkeypatch, tmp_path): + """AC-1: conda_python='3.12.*' is parsed and passed as target_python=(3, 12).""" + captured_kwargs: list[dict] = [] + + def mock_resolve(deps, config, **kwargs): + captured_kwargs.append(kwargs) + return deps # return unchanged (no wheel resolution) + + monkeypatch.setattr( + "pyisolate._internal.cuda_wheels.resolve_cuda_wheel_requirements", + mock_resolve, + ) + + python_exe = tmp_path / "bin" / "python" + python_exe.parent.mkdir(parents=True) + python_exe.touch() + + config = _make_conda_config( + conda_python="3.12.*", + dependencies=["flash-attn"], + cuda_wheels={ + "index_url": "https://example.invalid/", + "packages": ["flash-attn"], + }, + ) + + _install_cuda_wheels_into_pixi(python_exe, config, config["cuda_wheels"], "test") + + assert len(captured_kwargs) == 1 + assert captured_kwargs[0]["target_python"] == (3, 12) + + +def test_install_cuda_wheels_wildcard_python_uses_host_tags(monkeypatch, tmp_path): + """AC-2: conda_python='*' passes target_python=None (host tags fallback).""" + captured_kwargs: list[dict] = [] + + def mock_resolve(deps, config, **kwargs): + captured_kwargs.append(kwargs) + return deps + + monkeypatch.setattr( + "pyisolate._internal.cuda_wheels.resolve_cuda_wheel_requirements", + mock_resolve, + ) + + python_exe = tmp_path / "bin" / "python" + python_exe.parent.mkdir(parents=True) + python_exe.touch() + + config = _make_conda_config( + conda_python="*", + dependencies=["flash-attn"], + cuda_wheels={ + "index_url": "https://example.invalid/", + "packages": ["flash-attn"], + }, + ) + + _install_cuda_wheels_into_pixi(python_exe, config, config["cuda_wheels"], "test") + + assert len(captured_kwargs) == 1 + assert captured_kwargs[0]["target_python"] is None + + +def test_install_cuda_wheels_parses_311(monkeypatch, tmp_path): + """AC-3: conda_python='3.11.*' parses to target_python=(3, 11).""" + captured_kwargs: list[dict] = [] + + def mock_resolve(deps, config, **kwargs): + captured_kwargs.append(kwargs) + return deps + + monkeypatch.setattr( + "pyisolate._internal.cuda_wheels.resolve_cuda_wheel_requirements", + mock_resolve, + ) + + python_exe = tmp_path / "bin" / "python" + python_exe.parent.mkdir(parents=True) + python_exe.touch() + + config = _make_conda_config( + conda_python="3.11.*", + dependencies=["flash-attn"], + cuda_wheels={ + "index_url": "https://example.invalid/", + "packages": ["flash-attn"], + }, + ) + + _install_cuda_wheels_into_pixi(python_exe, config, config["cuda_wheels"], "test") + + assert len(captured_kwargs) == 1 + assert captured_kwargs[0]["target_python"] == (3, 11) + + +# ── _resolve_uv_exe / uv path fallback ───────────────────────────── + + +class TestResolveUvExe: + def test_install_cuda_wheels_uv_exe_fallback(self, monkeypatch, tmp_path): + """When python_exe.parent/uv does not exist, falls back to shutil.which.""" + # Create python_exe in a dir without uv + python_exe = tmp_path / "no_uv_here" / "python" + python_exe.parent.mkdir(parents=True) + python_exe.touch() + + resolved = _resolve_uv_exe(python_exe) + # Should have fallen back to shutil.which since no local uv exists + import shutil + + system_uv = shutil.which("uv") + assert resolved == system_uv + + def test_install_local_wheels_uv_exe_fallback(self, monkeypatch, tmp_path): + """_install_local_wheels uses _resolve_uv_exe and falls back correctly.""" + python_exe = tmp_path / "no_uv_here" / "python" + python_exe.parent.mkdir(parents=True) + python_exe.touch() + + # Create a fake wheel file + wheel_dir = tmp_path / "wheels" + wheel_dir.mkdir() + (wheel_dir / "fake-1.0-py3-none-any.whl").touch() + + captured_cmds: list[list[str]] = [] + + def mock_check_call(cmd, **kwargs): + captured_cmds.append(cmd) + + monkeypatch.setattr("subprocess.check_call", mock_check_call) + + config = _make_conda_config(module_path=str(tmp_path)) + _install_local_wheels(python_exe, config, [str(wheel_dir)], "test") + + assert len(captured_cmds) == 1 + import shutil + + system_uv = shutil.which("uv") + assert captured_cmds[0][0] == system_uv + + def test_install_cuda_wheels_uv_exe_prefers_local(self, tmp_path): + """When python_exe.parent/uv exists, it is preferred over shutil.which.""" + python_exe = tmp_path / "bin" / "python" + python_exe.parent.mkdir(parents=True) + python_exe.touch() + local_uv = tmp_path / "bin" / "uv" + local_uv.touch() + + resolved = _resolve_uv_exe(python_exe) + assert resolved == str(local_uv) + + def test_install_cuda_wheels_uv_exe_windows_layout(self, tmp_path): + """Windows pixi layout: python at envs/default/python.exe, no bin/ dir.""" + # Simulate Windows pixi path structure + pixi_default = tmp_path / ".pixi" / "envs" / "default" + pixi_default.mkdir(parents=True) + python_exe = pixi_default / "python.exe" + python_exe.touch() + # No uv in python_exe.parent (Windows layout) + + resolved = _resolve_uv_exe(python_exe) + import shutil + + system_uv = shutil.which("uv") + assert resolved == system_uv diff --git a/tests/test_environment_sealed_worker.py b/tests/test_environment_sealed_worker.py new file mode 100644 index 0000000..ee5bcdf --- /dev/null +++ b/tests/test_environment_sealed_worker.py @@ -0,0 +1,72 @@ +from __future__ import annotations + +import io +import os +from pathlib import Path + +from pyisolate._internal import environment + + +def _mock_venv_python(venv_path: Path) -> None: + python_exe = venv_path / "Scripts" / "python.exe" if os.name == "nt" else venv_path / "bin" / "python" + python_exe.parent.mkdir(parents=True, exist_ok=True) + python_exe.write_text("#!/usr/bin/env python\n", encoding="utf-8") + + +def _capture_install_commands(monkeypatch) -> list[list[str]]: + popen_calls: list[list[str]] = [] + + class MockPopen: + def __init__(self, cmd, **kwargs): + popen_calls.append(cmd) + self.stdout = io.StringIO("installed\n") + + def wait(self): + return 0 + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr(environment.shutil, "which", lambda binary: "/usr/bin/uv") + monkeypatch.setattr(environment.subprocess, "Popen", MockPopen) + return popen_calls + + +def test_sealed_worker_uv_does_not_auto_inject_torch(monkeypatch, tmp_path: Path) -> None: + venv_path = tmp_path / "venv" + _mock_venv_python(venv_path) + popen_calls = _capture_install_commands(monkeypatch) + + config = { + "dependencies": ["boltons"], + "share_torch": False, + "execution_model": "sealed_worker", + } + + environment.install_dependencies(venv_path, config, "demo") + + assert len(popen_calls) == 1 + cmd = popen_calls[0] + assert "boltons" in cmd + assert not any(str(part).startswith("torch==") for part in cmd) + + +def test_host_coupled_uv_still_auto_injects_torch(monkeypatch, tmp_path: Path) -> None: + venv_path = tmp_path / "venv" + _mock_venv_python(venv_path) + popen_calls = _capture_install_commands(monkeypatch) + + config = { + "dependencies": ["boltons"], + "share_torch": False, + } + + environment.install_dependencies(venv_path, config, "demo") + + assert len(popen_calls) == 1 + cmd = popen_calls[0] + assert "boltons" in cmd + assert any(str(part).startswith("torch==") for part in cmd) diff --git a/tests/test_event_channel.py b/tests/test_event_channel.py new file mode 100644 index 0000000..0a1de87 --- /dev/null +++ b/tests/test_event_channel.py @@ -0,0 +1,99 @@ +"""Tests for the pyisolate event channel (emit_event / register_event_handler). + +Tests verify: +1. Events dispatch from child to host handler +2. Unregistered events raise +3. Non-JSON payloads are rejected +4. API surface exists on ExtensionBase and SealedNodeExtension +""" + +import asyncio + +import pytest + +from pyisolate._internal.event_bridge import _EventBridge + + +class TestEventBridgeDispatch: + """Tests for _EventBridge RPC callee behavior.""" + + def test_emit_event_dispatches_to_handler(self): + """emit_event("progress", payload) calls the registered handler with exact payload.""" + bridge = _EventBridge() + received = [] + + def handler(payload): + received.append(payload) + + bridge.register_handler("progress", handler) + asyncio.get_event_loop().run_until_complete(bridge.dispatch("progress", {"value": 5, "total": 10})) + + assert len(received) == 1 + assert received[0] == {"value": 5, "total": 10} + + def test_emit_unregistered_event_raises(self): + """emit_event("unknown_event", {}) raises ValueError, not silently dropped.""" + bridge = _EventBridge() + + with pytest.raises(ValueError, match="No handler registered for event 'unknown_event'"): + asyncio.get_event_loop().run_until_complete(bridge.dispatch("unknown_event", {})) + + def test_emit_event_rejects_non_json_payload(self): + """emit_event with non-JSON-serializable payload raises immediately.""" + from pyisolate.shared import ExtensionLocal + + ext = ExtensionLocal() + + # ExtensionLocal.emit_event does json.dumps(payload) before RPC call + # Create a non-serializable object + class NotSerializable: + pass + + with pytest.raises(TypeError): + ext.emit_event("progress", NotSerializable()) + + def test_dispatch_with_async_handler(self): + """Async handlers are awaited correctly.""" + bridge = _EventBridge() + received = [] + + async def async_handler(payload): + received.append(payload) + + bridge.register_handler("test", async_handler) + asyncio.get_event_loop().run_until_complete(bridge.dispatch("test", {"key": "value"})) + + assert received == [{"key": "value"}] + + def test_multiple_events_independent(self): + """Different event names dispatch to different handlers.""" + bridge = _EventBridge() + progress_calls = [] + preview_calls = [] + + bridge.register_handler("progress", lambda p: progress_calls.append(p)) + bridge.register_handler("preview", lambda p: preview_calls.append(p)) + + asyncio.get_event_loop().run_until_complete(bridge.dispatch("progress", {"value": 1})) + asyncio.get_event_loop().run_until_complete(bridge.dispatch("preview", {"image": "data"})) + + assert progress_calls == [{"value": 1}] + assert preview_calls == [{"image": "data"}] + + +class TestApiSurface: + """Tests that the event channel API exists on the right classes.""" + + def test_extension_base_has_emit_event(self): + """ExtensionBase has emit_event method.""" + from pyisolate.shared import ExtensionBase + + assert hasattr(ExtensionBase, "emit_event") + assert callable(ExtensionBase.emit_event) + + def test_sealed_node_extension_has_emit_event(self): + """SealedNodeExtension inherits emit_event from ExtensionBase.""" + from pyisolate.sealed import SealedNodeExtension + + assert hasattr(SealedNodeExtension, "emit_event") + assert callable(SealedNodeExtension.emit_event) diff --git a/tests/test_exact_proxy_bootstrap.py b/tests/test_exact_proxy_bootstrap.py new file mode 100644 index 0000000..7490a49 --- /dev/null +++ b/tests/test_exact_proxy_bootstrap.py @@ -0,0 +1,91 @@ +from __future__ import annotations + +from pathlib import Path +from unittest.mock import MagicMock, patch + +from pyisolate._internal.host import Extension +from pyisolate.shared import ExtensionBase + + +class DummyProxy: + pass + + +DummyProxy.__module__ = "tests.test_exact_proxy_bootstrap" + + +def _make_extension(config: dict) -> Extension: + ext = Extension.__new__(Extension) + ext.name = config["name"] + ext.normalized_name = config["name"] + ext.config = config + ext.venv_path = Path("/fake/venv") + ext.module_path = "/fake/module" + ext.extension_type = ExtensionBase + ext._cuda_ipc_enabled = False + ext._uds_path = None + ext._uds_listener = None + ext._client_sock = None + ext._host_rpc_services = [] + return ext + + +def _capture_bootstrap_payload(config: dict) -> dict: + ext = _make_extension(config) + + listener = MagicMock() + listener.accept.return_value = (MagicMock(), None) + transport = MagicMock() + proc = MagicMock() + proc.pid = 1234 + + with ( + patch("pyisolate._internal.host.socket") as mock_socket, + patch( + "pyisolate._internal.host.tempfile.mktemp", return_value="/run/user/1000/pyisolate/ext_test.sock" + ), + patch("pyisolate._internal.host.subprocess.Popen", return_value=proc), + patch( + "pyisolate._internal.host.detect_sandbox_capability", + return_value=MagicMock(available=True, restriction_model="none"), + ), + patch("pyisolate._internal.host.build_bwrap_command", return_value=["bwrap", "--clearenv", "python"]), + patch("pyisolate._internal.host.JSONSocketTransport", return_value=transport), + patch("pyisolate._internal.host.AsyncRPC"), + patch("pyisolate._internal.socket_utils.has_af_unix", return_value=True), + patch( + "pyisolate._internal.socket_utils.ensure_ipc_socket_dir", + return_value=Path("/run/user/1000/pyisolate"), + ), + patch( + "pyisolate._internal.host.build_extension_snapshot", + return_value={"sys_path": ["/host/path"], "apply_host_sys_path": True}, + ), + patch("os.chmod"), + patch("sys.platform", "linux"), + ): + mock_socket.socket.return_value = listener + mock_socket.AF_UNIX = 1 + mock_socket.SOCK_STREAM = 1 + ext._launch_with_uds() + + return transport.send.call_args[0][0] + + +def test_sealed_worker_exact_proxy_binding() -> None: + payload = _capture_bootstrap_payload( + { + "name": "test_ext", + "module": "test_module", + "module_path": "/fake/module", + "dependencies": [], + "share_torch": False, + "share_cuda_ipc": False, + "package_manager": "uv", + "execution_model": "sealed_worker", + "apis": [DummyProxy], + } + ) + + assert payload["snapshot"]["apply_host_sys_path"] is False + assert payload["config"]["apis"] == ["tests.test_exact_proxy_bootstrap.DummyProxy"] diff --git a/tests/test_harness_host_env.py b/tests/test_harness_host_env.py new file mode 100644 index 0000000..2bffece --- /dev/null +++ b/tests/test_harness_host_env.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +import os +from pathlib import Path + +import pytest + +from tests.harness.host import ReferenceHost + + +@pytest.mark.asyncio +async def test_reference_host_cleanup_restores_tmpdir( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + original_tmpdir = tmp_path / "original-tmpdir" + original_tmpdir.mkdir() + monkeypatch.setenv("TMPDIR", str(original_tmpdir)) + + host = ReferenceHost() + + assert os.environ["TMPDIR"] != str(original_tmpdir) + + await host.cleanup() + + assert os.environ["TMPDIR"] == str(original_tmpdir) diff --git a/tests/test_host_conda_dispatch.py b/tests/test_host_conda_dispatch.py new file mode 100644 index 0000000..21b6d8d --- /dev/null +++ b/tests/test_host_conda_dispatch.py @@ -0,0 +1,423 @@ +"""Tests for host.py dispatch to conda/uv backend (Slice 3).""" + +from __future__ import annotations + +import contextlib +from pathlib import Path +from unittest.mock import MagicMock, patch + +from pyisolate._internal.sandbox_detect import RestrictionModel + +# ── __launch dispatch ──────────────────────────────────────────────── + + +class TestLaunchDispatchConda: + """Verify __launch() dispatches to conda backend when package_manager='conda'.""" + + @patch("pyisolate._internal.host.validate_backend_config") + @patch("pyisolate._internal.host.create_conda_env") + @patch("pyisolate._internal.host.create_venv") + @patch("pyisolate._internal.host.install_dependencies") + def test_conda_calls_create_conda_env( + self, + mock_install_deps: MagicMock, + mock_create_venv: MagicMock, + mock_create_conda: MagicMock, + mock_validate: MagicMock, + ) -> None: + """When package_manager='conda', __launch should call create_conda_env, NOT create_venv.""" + from pyisolate._internal.host import Extension + from pyisolate.shared import ExtensionBase + + config = { + "name": "test_ext", + "module": "test_module", + "dependencies": [], + "share_torch": False, + "share_cuda_ipc": False, + "package_manager": "conda", + "conda_channels": ["conda-forge"], + "conda_dependencies": ["numpy"], + } + + ext = Extension.__new__(Extension) + ext.name = "test_ext" + ext.config = config + ext.venv_path = Path("/fake/venv") + ext.module_path = "/fake/module" + ext.extension_type = ExtensionBase + + # Call the private __launch via name mangling + with patch.object(ext, "_launch_with_uds", return_value=MagicMock()): + ext._Extension__launch() + + mock_create_conda.assert_called_once() + mock_create_venv.assert_not_called() + mock_install_deps.assert_not_called() + + @patch("pyisolate._internal.host.validate_backend_config") + @patch("pyisolate._internal.host.create_conda_env") + @patch("pyisolate._internal.host.create_venv") + @patch("pyisolate._internal.host.install_dependencies") + def test_uv_calls_create_venv( + self, + mock_install_deps: MagicMock, + mock_create_venv: MagicMock, + mock_create_conda: MagicMock, + mock_validate: MagicMock, + ) -> None: + """When package_manager='uv' (default), __launch uses original uv path.""" + from pyisolate._internal.host import Extension + from pyisolate.shared import ExtensionBase + + config = { + "name": "test_ext", + "module": "test_module", + "dependencies": [], + "share_torch": False, + "share_cuda_ipc": False, + "package_manager": "uv", + } + + ext = Extension.__new__(Extension) + ext.name = "test_ext" + ext.config = config + ext.venv_path = Path("/fake/venv") + ext.module_path = "/fake/module" + ext.extension_type = ExtensionBase + + with patch.object(ext, "_launch_with_uds", return_value=MagicMock()): + ext._Extension__launch() + + mock_create_venv.assert_called_once() + mock_install_deps.assert_called_once() + mock_create_conda.assert_not_called() + + @patch("pyisolate._internal.host.validate_backend_config") + @patch("pyisolate._internal.host.create_conda_env") + def test_validate_called_before_conda_launch( + self, + mock_create_conda: MagicMock, + mock_validate: MagicMock, + ) -> None: + """validate_backend_config must be called before env creation.""" + from pyisolate._internal.host import Extension + from pyisolate.shared import ExtensionBase + + config = { + "name": "test_ext", + "module": "test_module", + "dependencies": [], + "share_torch": False, + "share_cuda_ipc": False, + "package_manager": "conda", + "conda_channels": ["conda-forge"], + "conda_dependencies": ["numpy"], + } + + ext = Extension.__new__(Extension) + ext.name = "test_ext" + ext.config = config + ext.venv_path = Path("/fake/venv") + ext.module_path = "/fake/module" + ext.extension_type = ExtensionBase + + with patch.object(ext, "_launch_with_uds", return_value=MagicMock()): + ext._Extension__launch() + + mock_validate.assert_called_once_with(config) + + @patch("pyisolate._internal.host.validate_backend_config") + @patch("pyisolate._internal.host.create_venv") + @patch("pyisolate._internal.host.install_dependencies") + def test_validate_called_before_uv_launch( + self, + mock_install: MagicMock, + mock_create: MagicMock, + mock_validate: MagicMock, + ) -> None: + """validate_backend_config must also be called for uv backend.""" + from pyisolate._internal.host import Extension + from pyisolate.shared import ExtensionBase + + config = { + "name": "test_ext", + "module": "test_module", + "dependencies": [], + "share_torch": False, + "share_cuda_ipc": False, + "package_manager": "uv", + } + + ext = Extension.__new__(Extension) + ext.name = "test_ext" + ext.config = config + ext.venv_path = Path("/fake/venv") + ext.module_path = "/fake/module" + ext.extension_type = ExtensionBase + + with patch.object(ext, "_launch_with_uds", return_value=MagicMock()): + ext._Extension__launch() + + mock_validate.assert_called_once_with(config) + + +# ── Python exe resolution ──────────────────────────────────────────── + + +class TestPythonExeResolution: + """Verify _launch_with_uds resolves correct python for each backend.""" + + def test_conda_resolves_pixi_python(self) -> None: + """Conda backend must use .pixi/envs/default/bin/python, not venv/bin/python.""" + from pyisolate._internal.host import Extension + from pyisolate.shared import ExtensionBase + + config = { + "name": "test_ext", + "module": "test_module", + "dependencies": [], + "share_torch": False, + "share_cuda_ipc": False, + "package_manager": "conda", + "conda_channels": ["conda-forge"], + "conda_dependencies": ["numpy"], + } + + ext = Extension.__new__(Extension) + ext.name = "test_ext" + ext.config = config + ext.venv_path = Path("/fake/venv") + ext.module_path = "/fake/module" + ext.extension_type = ExtensionBase + ext._cuda_ipc_enabled = False + + # We need to test that the python exe resolved inside _launch_with_uds + # uses _resolve_pixi_python for conda, not the standard venv path. + # We'll mock _resolve_pixi_python and verify it's called. + with ( + patch( + "pyisolate._internal.host._resolve_pixi_python", + return_value=Path("/fake/venv/.pixi/envs/default/bin/python"), + ) as mock_resolve, + patch("pyisolate._internal.host.socket"), + patch("pyisolate._internal.host.tempfile"), + patch("pyisolate._internal.host.subprocess"), + patch("pyisolate._internal.host.threading"), + patch("pyisolate._internal.host.build_extension_snapshot"), + patch("pyisolate._internal.host.JSONSocketTransport"), + patch("pyisolate._internal.host.AsyncRPC"), + patch( + "pyisolate._internal.host.detect_sandbox_capability", + return_value=MagicMock(available=True), + ), + ): + # This will fail because we need more mocking, but the key assertion + # is that _resolve_pixi_python is called for conda backend + with contextlib.suppress(Exception): + ext._launch_with_uds() + + mock_resolve.assert_called_once() + + +class TestCondaSealedWorkerBwrapDispatch: + """Verify conda sealed_worker launches through bubblewrap with pixi python.""" + + @patch("pyisolate._internal.host.build_bwrap_command") + @patch("pyisolate._internal.host.subprocess.Popen") + def test_conda_sealed_worker_launches_through_bwrap( + self, + mock_popen: MagicMock, + mock_build_bwrap: MagicMock, + ) -> None: + from pyisolate._internal.host import Extension + from pyisolate.shared import ExtensionBase + + config = { + "name": "test_ext", + "module": "test_module", + "dependencies": [], + "share_torch": False, + "share_cuda_ipc": False, + "package_manager": "conda", + "execution_model": "sealed_worker", + "conda_channels": ["conda-forge"], + "conda_dependencies": ["numpy"], + "sandbox": {"writable_paths": ["/fake/artifacts"]}, + } + + ext = Extension.__new__(Extension) + ext.name = "test_ext" + ext.config = config + ext.venv_path = Path("/fake/venv") + ext.module_path = "/fake/module" + ext.extension_type = ExtensionBase + ext._cuda_ipc_enabled = False + ext._uds_path = None + ext._uds_listener = None + ext._client_sock = None + + pixi_python = Path("/fake/venv/.pixi/envs/default/bin/python") + mock_proc = MagicMock() + mock_proc.pid = 12345 + mock_proc.args = ["bwrap", "--clearenv", str(pixi_python)] + mock_popen.return_value = mock_proc + mock_build_bwrap.return_value = [ + "bwrap", + "--clearenv", + str(pixi_python), + "-m", + "pyisolate._internal.uds_client", + ] + + transport = MagicMock() + transport.send = MagicMock() + + with ( + patch("pyisolate._internal.host._resolve_pixi_python", return_value=pixi_python), + patch("pyisolate._internal.host.socket") as mock_socket, + patch("pyisolate._internal.host.tempfile"), + patch("pyisolate._internal.host.detect_sandbox_capability") as mock_detect, + patch("sys.platform", "linux"), + patch("pyisolate._internal.host.JSONSocketTransport", return_value=transport), + patch("pyisolate._internal.host.AsyncRPC"), + ): + mock_detect.return_value = MagicMock( + available=True, + restriction_model=RestrictionModel.NONE, + ) + mock_listener = MagicMock() + mock_listener.accept.return_value = (MagicMock(), None) + mock_socket.socket.return_value = mock_listener + mock_socket.AF_UNIX = 1 + mock_socket.SOCK_STREAM = 1 + + with ( + patch("pyisolate._internal.socket_utils.has_af_unix", return_value=True), + patch("pyisolate._internal.socket_utils.ensure_ipc_socket_dir", return_value=Path("/run")), + patch("pyisolate._internal.host.build_extension_snapshot", return_value={}), + patch("os.chmod"), + ): + ext._launch_with_uds() + + mock_build_bwrap.assert_called_once() + kwargs = mock_build_bwrap.call_args.kwargs + assert kwargs["execution_model"] == "sealed_worker" + assert kwargs["sandbox_config"] == {"writable_paths": ["/fake/artifacts"]} + assert kwargs["python_exe"] == str(pixi_python) + transport.send.assert_called_once() + bootstrap_data = transport.send.call_args[0][0] + assert bootstrap_data["snapshot"]["apply_host_sys_path"] is False + + +class TestEnvPropagation: + """Verify child env overrides are applied on non-Linux launches too.""" + + @patch("pyisolate._internal.host.subprocess.Popen") + def test_windows_launch_propagates_config_env( + self, + mock_popen: MagicMock, + ) -> None: + from pyisolate._internal.host import Extension + from pyisolate.shared import ExtensionBase + + config = { + "name": "test_ext", + "module": "test_module", + "dependencies": [], + "share_torch": False, + "share_cuda_ipc": False, + "package_manager": "conda", + "execution_model": "sealed_worker", + "conda_channels": ["conda-forge"], + "conda_dependencies": ["boltons"], + "env": {"PYISOLATE_ARTIFACT_DIR": r"C:\artifacts"}, + } + + ext = Extension.__new__(Extension) + ext.name = "test_ext" + ext.config = config + ext.venv_path = Path(r"C:\fake\venv") + ext.module_path = r"C:\fake\module" + ext.extension_type = ExtensionBase + ext._cuda_ipc_enabled = False + ext._uds_path = None + ext._uds_listener = None + ext._client_sock = None + + mock_proc = MagicMock() + mock_proc.pid = 12345 + mock_popen.return_value = mock_proc + transport = MagicMock() + transport.send = MagicMock() + + with ( + patch( + "pyisolate._internal.host._resolve_pixi_python", + return_value=Path(r"C:\fake\venv\.pixi\envs\default\python.exe"), + ), + patch("pyisolate._internal.host.socket") as mock_socket, + patch("pyisolate._internal.host.JSONSocketTransport", return_value=transport), + patch("pyisolate._internal.host.AsyncRPC"), + patch("pyisolate._internal.host.build_extension_snapshot", return_value={}), + patch("pyisolate._internal.socket_utils.has_af_unix", return_value=False), + patch("os.name", "nt"), + patch("sys.platform", "win32"), + ): + mock_listener = MagicMock() + mock_listener.accept.return_value = (MagicMock(), None) + mock_listener.getsockname.return_value = ("127.0.0.1", 43210) + mock_socket.socket.return_value = mock_listener + mock_socket.AF_INET = 2 + mock_socket.SOCK_STREAM = 1 + mock_socket.SOL_SOCKET = 1 + mock_socket.SO_REUSEADDR = 2 + + ext._launch_with_uds() + + child_env = mock_popen.call_args.kwargs["env"] + assert child_env["PYISOLATE_ARTIFACT_DIR"] == r"C:\artifacts" + + +# ── share_cuda_ipc forced False ────────────────────────────────────── + + +class TestCondaCudaIpcForced: + """Conda backend must force share_cuda_ipc=False.""" + + @patch("pyisolate._internal.host.create_conda_env") + @patch("pyisolate._internal.host.validate_backend_config") + def test_conda_forces_cuda_ipc_false( + self, + mock_validate: MagicMock, + mock_conda: MagicMock, + ) -> None: + """Even if config says share_cuda_ipc=True, conda must override to False.""" + from pyisolate._internal.host import Extension + from pyisolate.shared import ExtensionBase + + config = { + "name": "test_ext", + "module": "test_module", + "dependencies": [], + "share_torch": False, + "share_cuda_ipc": True, # Explicitly set, should be overridden + "package_manager": "conda", + "conda_channels": ["conda-forge"], + "conda_dependencies": ["numpy"], + } + + ext = Extension.__new__(Extension) + ext.name = "test_ext" + ext.config = config + ext.venv_path = Path("/fake/venv") + ext.module_path = "/fake/module" + ext.extension_type = ExtensionBase + ext._cuda_ipc_enabled = True + + with patch.object(ext, "_launch_with_uds", return_value=MagicMock()): + ext._Extension__launch() + + # After __launch, cuda_ipc should be forced False + assert ext._cuda_ipc_enabled is False + assert config["share_cuda_ipc"] is False diff --git a/tests/test_host_internal_ext.py b/tests/test_host_internal_ext.py index cedd172..113ff48 100644 --- a/tests/test_host_internal_ext.py +++ b/tests/test_host_internal_ext.py @@ -14,6 +14,9 @@ class DummyRPC: def __init__(self, *args, **kwargs): self.run_called = False + def register_callee(self, obj, object_id): + pass + def run(self): self.run_called = True diff --git a/tests/test_host_sealed_worker_dispatch.py b/tests/test_host_sealed_worker_dispatch.py new file mode 100644 index 0000000..e09780c --- /dev/null +++ b/tests/test_host_sealed_worker_dispatch.py @@ -0,0 +1,163 @@ +"""Tests for uv + sealed_worker host dispatch under bwrap (Issue 8 Slice 2).""" + +from __future__ import annotations + +import os +from pathlib import Path +from unittest.mock import MagicMock, patch + +from pyisolate._internal.sandbox_detect import RestrictionModel + + +def _make_extension(config: dict[str, object]): + from pyisolate._internal.host import Extension + from pyisolate.shared import ExtensionBase + + ext = Extension.__new__(Extension) + ext.name = "test_ext" + ext.config = config + ext.venv_path = Path("/fake/venv") + ext.module_path = "/fake/module" + ext.extension_type = ExtensionBase + ext._cuda_ipc_enabled = False + return ext + + +def _uv_python_path(venv_path: Path) -> Path: + if os.name == "nt": + return venv_path / "Scripts" / "python.exe" + return venv_path / "bin" / "python" + + +class TestLaunchDispatchSealedWorker: + @patch("pyisolate._internal.host.validate_backend_config") + @patch("pyisolate._internal.host.create_conda_env") + @patch("pyisolate._internal.host.create_venv") + @patch("pyisolate._internal.host.install_dependencies") + def test_uv_sealed_worker_uses_uv_env_path( + self, + mock_install_deps: MagicMock, + mock_create_venv: MagicMock, + mock_create_conda: MagicMock, + mock_validate: MagicMock, + ) -> None: + config = { + "name": "test_ext", + "module": "test_module", + "dependencies": [], + "share_torch": False, + "share_cuda_ipc": False, + "package_manager": "uv", + "execution_model": "sealed_worker", + } + + ext = _make_extension(config) + + with patch.object(ext, "_launch_with_uds", return_value=MagicMock()): + ext._Extension__launch() + + mock_validate.assert_called_once_with(config) + mock_create_venv.assert_called_once() + mock_install_deps.assert_called_once() + mock_create_conda.assert_not_called() + + def test_uv_sealed_worker_uses_json_tensor_transport(self) -> None: + config = { + "name": "test_ext", + "module": "test_module", + "dependencies": [], + "share_torch": False, + "share_cuda_ipc": False, + "package_manager": "uv", + "execution_model": "sealed_worker", + } + + ext = _make_extension(config) + + assert ext._tensor_transport_mode() == "json" + + @patch("pyisolate._internal.host.build_bwrap_command") + @patch("pyisolate._internal.host.subprocess.Popen") + def test_uv_sealed_worker_launches_through_bwrap_with_strict_policy( + self, + mock_popen: MagicMock, + mock_build_bwrap: MagicMock, + ) -> None: + config = { + "name": "test_ext", + "module": "test_module", + "dependencies": [], + "share_torch": False, + "share_cuda_ipc": False, + "package_manager": "uv", + "execution_model": "sealed_worker", + } + + ext = _make_extension(config) + ext._uds_path = None + ext._uds_listener = None + ext._client_sock = None + + mock_proc = MagicMock() + mock_proc.pid = 12345 + mock_popen.return_value = mock_proc + mock_build_bwrap.return_value = [ + "bwrap", + "--clearenv", + "--setenv", + "PYISOLATE_UDS_ADDRESS", + "/run/ext.sock", + ] + + transport = MagicMock() + transport.send = MagicMock() + + with ( + patch("pyisolate._internal.host.socket") as mock_socket, + patch("pyisolate._internal.host.tempfile"), + patch("pyisolate._internal.host.detect_sandbox_capability") as mock_detect, + patch("sys.platform", "linux"), + patch("pyisolate._internal.host.JSONSocketTransport", return_value=transport), + patch("pyisolate._internal.host.AsyncRPC"), + ): + mock_detect.return_value = MagicMock( + available=True, + restriction_model=RestrictionModel.NONE, + ) + mock_listener = MagicMock() + mock_listener.accept.return_value = (MagicMock(), None) + mock_socket.socket.return_value = mock_listener + mock_socket.AF_UNIX = 1 + mock_socket.SOCK_STREAM = 1 + + with ( + patch("pyisolate._internal.socket_utils.has_af_unix", return_value=True), + patch("pyisolate._internal.socket_utils.ensure_ipc_socket_dir", return_value=Path("/run")), + patch("pyisolate._internal.host.build_extension_snapshot", return_value={}), + patch("os.chmod"), + ): + ext._launch_with_uds() + + mock_build_bwrap.assert_called_once() + kwargs = mock_build_bwrap.call_args.kwargs + assert kwargs["execution_model"] == "sealed_worker" + assert kwargs["sandbox_config"] == {} + assert kwargs["python_exe"] == str(_uv_python_path(ext.venv_path)) + assert kwargs["module_path"] == ext.module_path + transport.send.assert_called_once() + bootstrap_data = transport.send.call_args[0][0] + assert bootstrap_data["snapshot"]["apply_host_sys_path"] is False + + def test_uv_host_coupled_keeps_shared_memory_tensor_transport(self) -> None: + config = { + "name": "test_ext", + "module": "test_module", + "dependencies": [], + "share_torch": False, + "share_cuda_ipc": False, + "package_manager": "uv", + } + + ext = _make_extension(config) + + assert ext._tensor_transport_mode() == "shared_memory" diff --git a/tests/test_model_serialization.py b/tests/test_model_serialization.py index ba9fc6f..45688a5 100644 --- a/tests/test_model_serialization.py +++ b/tests/test_model_serialization.py @@ -119,3 +119,34 @@ async def test_int_passthrough(self) -> None: async def test_none_passthrough(self) -> None: result = await deserialize_from_isolation(None) assert result is None + + +class TestOpaqueHandlePreservation: + """Pack-local proxy handle tests (issue #58).""" + + async def test_deserialize_preserves_opaque_handle_no_rpc(self) -> None: + """RemoteObjectHandle with no handler stays opaque without RPC call.""" + from unittest.mock import AsyncMock + + from pyisolate._internal.remote_handle import RemoteObjectHandle + + handle = RemoteObjectHandle("test-id", "UnregisteredType") + mock_extension = AsyncMock() + + result = await deserialize_from_isolation(handle, extension=mock_extension) + + assert result is handle + mock_extension.get_remote_object.assert_not_called() + + async def test_flush_clears_remote_objects(self) -> None: + """flush_transport_state() empties remote_objects dict.""" + from pyisolate.sealed import SealedNodeExtension + + ext = SealedNodeExtension() + for i in range(5): + ext.remote_objects[f"obj-{i}"] = object() + assert len(ext.remote_objects) == 5 + + await ext.flush_transport_state() + + assert len(ext.remote_objects) == 0 diff --git a/tests/test_rpc_contract.py b/tests/test_rpc_contract.py index f193742..844011d 100644 --- a/tests/test_rpc_contract.py +++ b/tests/test_rpc_contract.py @@ -109,49 +109,61 @@ class TestEventLoopResilience: def test_singleton_survives_loop_recreation(self): """Singleton instance survives event loop recreation.""" - # Create initial loop - loop1 = asyncio.new_event_loop() - asyncio.set_event_loop(loop1) - - # Create singleton and store data - registry = MockRegistry() - obj_id = registry.register("loop1_object") - - # Close loop1 - loop1.close() + try: + previous_loop = asyncio.get_event_loop_policy().get_event_loop() + except RuntimeError: + previous_loop = None - # Create new loop - loop2 = asyncio.new_event_loop() - asyncio.set_event_loop(loop2) - - # Singleton should still work - result = registry.get(obj_id) - assert result == "loop1_object" - - # Cleanup - loop2.close() + loop1 = asyncio.new_event_loop() + loop2: asyncio.AbstractEventLoop | None = None + try: + asyncio.set_event_loop(loop1) + registry = MockRegistry() + obj_id = registry.register("loop1_object") + + loop1.close() + + loop2 = asyncio.new_event_loop() + asyncio.set_event_loop(loop2) + + result = registry.get(obj_id) + assert result == "loop1_object" + finally: + asyncio.set_event_loop(previous_loop) + if loop2 is not None: + loop2.close() + elif not loop1.is_closed(): + loop1.close() def test_singleton_data_persists_across_loops(self): """Data stored in singleton persists across event loops.""" - # First loop - loop1 = asyncio.new_event_loop() - asyncio.set_event_loop(loop1) - - registry = MockRegistry() - id1 = registry.register("first") - id2 = registry.register("second") - - loop1.close() + try: + previous_loop = asyncio.get_event_loop_policy().get_event_loop() + except RuntimeError: + previous_loop = None - # Second loop - loop2 = asyncio.new_event_loop() - asyncio.set_event_loop(loop2) - - # All data should still be accessible - assert registry.get(id1) == "first" - assert registry.get(id2) == "second" - - loop2.close() + loop1 = asyncio.new_event_loop() + loop2: asyncio.AbstractEventLoop | None = None + try: + asyncio.set_event_loop(loop1) + + registry = MockRegistry() + id1 = registry.register("first") + id2 = registry.register("second") + + loop1.close() + + loop2 = asyncio.new_event_loop() + asyncio.set_event_loop(loop2) + + assert registry.get(id1) == "first" + assert registry.get(id2) == "second" + finally: + asyncio.set_event_loop(previous_loop) + if loop2 is not None: + loop2.close() + elif not loop1.is_closed(): + loop1.close() class TestRpcErrorHandling: diff --git a/tests/test_rpc_transports.py b/tests/test_rpc_transports.py index e9d93f9..ce848ae 100644 --- a/tests/test_rpc_transports.py +++ b/tests/test_rpc_transports.py @@ -5,6 +5,7 @@ roundtrip and connection-error tests. """ +import contextlib import logging import socket import struct @@ -76,69 +77,90 @@ def test_send_does_not_enforce_2gb_limit( class TestRecvHardLimit: def test_2gb_minus_1_not_rejected(self) -> None: transport = _make_transport() - with ( - patch.object(transport, "_recvall", side_effect=_header_then_empty(2 * GB - 1)), - pytest.raises(ConnectionError), - ): - transport.recv() + try: + with ( + patch.object(transport, "_recvall", side_effect=_header_then_empty(2 * GB - 1)), + pytest.raises(ConnectionError), + ): + transport.recv() + finally: + transport.close() def test_2gb_exact_not_rejected(self) -> None: transport = _make_transport() - with ( - patch.object(transport, "_recvall", side_effect=_header_then_empty(2 * GB)), - pytest.raises(ConnectionError), - ): - transport.recv() + try: + with ( + patch.object(transport, "_recvall", side_effect=_header_then_empty(2 * GB)), + pytest.raises(ConnectionError), + ): + transport.recv() + finally: + transport.close() def test_2gb_plus_1_raises_value_error(self) -> None: # 2GB+1 = 2147483649 fits in uint32 and is the exact guard boundary transport = _make_transport() - with ( - patch.object(transport, "_recvall", side_effect=_header_then_empty(2 * GB + 1)), - pytest.raises(ValueError, match="Message too large"), - ): - transport.recv() + try: + with ( + patch.object(transport, "_recvall", side_effect=_header_then_empty(2 * GB + 1)), + pytest.raises(ValueError, match="Message too large"), + ): + transport.recv() + finally: + transport.close() def test_3gb_raises_value_error(self) -> None: # 3GB fits in an unsigned 32-bit header and exceeds the 2GB hard limit transport = _make_transport() - with ( - patch.object(transport, "_recvall", side_effect=_header_then_empty(3 * GB)), - pytest.raises(ValueError, match="Message too large"), - ): - transport.recv() + try: + with ( + patch.object(transport, "_recvall", side_effect=_header_then_empty(3 * GB)), + pytest.raises(ValueError, match="Message too large"), + ): + transport.recv() + finally: + transport.close() class TestRecvWarningThreshold: def test_100mb_minus_1_no_warning(self, caplog: pytest.LogCaptureFixture) -> None: transport = _make_transport() - with ( - caplog.at_level(logging.WARNING, logger=TRANSPORT_LOGGER), - patch.object(transport, "_recvall", side_effect=_header_then_empty(100 * MB - 1)), - pytest.raises(ConnectionError), - ): - transport.recv() + try: + with ( + caplog.at_level(logging.WARNING, logger=TRANSPORT_LOGGER), + patch.object(transport, "_recvall", side_effect=_header_then_empty(100 * MB - 1)), + pytest.raises(ConnectionError), + ): + transport.recv() + finally: + transport.close() assert not any("Large RPC message" in r.message for r in caplog.records) def test_100mb_exact_no_warning(self, caplog: pytest.LogCaptureFixture) -> None: # Threshold is strictly >100MB; exactly 100MB must NOT trigger the warning transport = _make_transport() - with ( - caplog.at_level(logging.WARNING, logger=TRANSPORT_LOGGER), - patch.object(transport, "_recvall", side_effect=_header_then_empty(100 * MB)), - pytest.raises(ConnectionError), - ): - transport.recv() + try: + with ( + caplog.at_level(logging.WARNING, logger=TRANSPORT_LOGGER), + patch.object(transport, "_recvall", side_effect=_header_then_empty(100 * MB)), + pytest.raises(ConnectionError), + ): + transport.recv() + finally: + transport.close() assert not any("Large RPC message" in r.message for r in caplog.records) def test_100mb_plus_1_triggers_warning(self, caplog: pytest.LogCaptureFixture) -> None: transport = _make_transport() - with ( - caplog.at_level(logging.WARNING, logger=TRANSPORT_LOGGER), - patch.object(transport, "_recvall", side_effect=_header_then_empty(100 * MB + 1)), - pytest.raises(ConnectionError), - ): - transport.recv() + try: + with ( + caplog.at_level(logging.WARNING, logger=TRANSPORT_LOGGER), + patch.object(transport, "_recvall", side_effect=_header_then_empty(100 * MB + 1)), + pytest.raises(ConnectionError), + ): + transport.recv() + finally: + transport.close() assert any("Large RPC message" in r.message for r in caplog.records) @@ -149,11 +171,14 @@ def test_incomplete_length_header_raises(self) -> None: def fake_recvall(n: int) -> bytes: return b"\x00\x00" # only 2 bytes instead of 4 - with ( - patch.object(transport, "_recvall", side_effect=fake_recvall), - pytest.raises(ConnectionError, match="incomplete length header"), - ): - transport.recv() + try: + with ( + patch.object(transport, "_recvall", side_effect=fake_recvall), + pytest.raises(ConnectionError, match="incomplete length header"), + ): + transport.recv() + finally: + transport.close() def test_incomplete_message_body_raises(self) -> None: transport = _make_transport() @@ -164,15 +189,23 @@ def fake_recvall(n: int) -> bytes: call_count += 1 return struct.pack(">I", 100) if call_count == 1 else b"short" - with ( - patch.object(transport, "_recvall", side_effect=fake_recvall), - pytest.raises(ConnectionError, match="Incomplete message"), - ): - transport.recv() + try: + with ( + patch.object(transport, "_recvall", side_effect=fake_recvall), + pytest.raises(ConnectionError, match="Incomplete message"), + ): + transport.recv() + finally: + transport.close() def test_socket_closed_mid_header_raises(self) -> None: a, b = socket.socketpair() transport = JSONSocketTransport(a) b.close() - with pytest.raises((ConnectionError, OSError)): - transport.recv() + try: + with pytest.raises((ConnectionError, OSError)): + transport.recv() + finally: + with contextlib.suppress(Exception): + b.close() + transport.close() diff --git a/tests/test_sandbox_detect.py b/tests/test_sandbox_detect.py index 6d74b9f..82a65fa 100644 --- a/tests/test_sandbox_detect.py +++ b/tests/test_sandbox_detect.py @@ -163,6 +163,18 @@ def test_bwrap_test_uses_unshare_user_try(self) -> None: args = mock_run.call_args[0][0] assert "--unshare-user-try" in args + def test_bwrap_test_skips_lib64_when_absent(self) -> None: + """Test that /lib64 bind is omitted when the path does not exist.""" + mock_result = MagicMock() + mock_result.returncode = 0 + with ( + patch("subprocess.run", return_value=mock_result) as mock_run, + patch("os.path.exists", side_effect=lambda path: path != "/lib64"), + ): + _test_bwrap("/usr/bin/bwrap") + args = mock_run.call_args[0][0] + assert "/lib64" not in args + def test_bwrap_test_failure_permission(self) -> None: """Test bwrap failure with permission denied.""" mock_result = MagicMock() @@ -196,6 +208,18 @@ def test_bwrap_degraded_test_success(self) -> None: assert success is True assert error == "" + def test_bwrap_degraded_test_skips_lib64_when_absent(self) -> None: + """Test degraded bwrap probe omits /lib64 when the path does not exist.""" + mock_result = MagicMock() + mock_result.returncode = 0 + with ( + patch("subprocess.run", return_value=mock_result) as mock_run, + patch("os.path.exists", side_effect=lambda path: path != "/lib64"), + ): + _test_bwrap_degraded("/usr/bin/bwrap") + args = mock_run.call_args[0][0] + assert "/lib64" not in args + class TestErrorClassification: """Test error message classification.""" diff --git a/tests/test_sealed_proxy_handle.py b/tests/test_sealed_proxy_handle.py new file mode 100644 index 0000000..cce4bcc --- /dev/null +++ b/tests/test_sealed_proxy_handle.py @@ -0,0 +1,71 @@ +"""Tests for SealedNodeExtension proxy handle mechanism (issue #58 Slice 2). + +Proves that sealed workers can wrap unregistered objects as +RemoteObjectHandle, store them in a child-local registry, and resolve +incoming handles back to the original objects by identity. +""" + +import numpy as np +import pytest + +from pyisolate._internal.remote_handle import RemoteObjectHandle +from pyisolate._internal.serialization_registry import SerializerRegistry +from pyisolate.sealed import SealedNodeExtension + + +@pytest.fixture(autouse=True) +def clean_registry(): + SerializerRegistry.get_instance().clear() + + +class _FakeWidget: + """Unregistered type used for proxy handle tests.""" + + def __init__(self, value: int) -> None: + self.value = value + + +def test_sealed_wraps_unregistered_object_as_handle(): + ext = SealedNodeExtension() + widget = _FakeWidget(42) + + result = ext._wrap_for_transport(widget) + + assert isinstance(result, RemoteObjectHandle) + assert result.type_name == "_FakeWidget" + assert len(ext.remote_objects) == 1 + assert ext.remote_objects[result.object_id] is widget + + +def test_sealed_resolves_handle_to_original_object(): + ext = SealedNodeExtension() + original = _FakeWidget(99) + ext.remote_objects["id-1"] = original + + handle = RemoteObjectHandle("id-1", "Foo") + result = ext._resolve_handles(handle) + + assert result is original + + +def test_sealed_stale_handle_raises_keyerror(): + ext = SealedNodeExtension() + + handle = RemoteObjectHandle("nonexistent", "Foo") + with pytest.raises(KeyError, match="nonexistent"): + ext._resolve_handles(handle) + + +def test_sealed_ndarray_roundtrip_via_handle(): + ext = SealedNodeExtension() + original = np.ones((100, 100), dtype=np.float32) + + # Wrap β€” ndarray has no registered serializer so it becomes a handle + wrapped = ext._wrap_for_transport(original) + assert isinstance(wrapped, RemoteObjectHandle) + assert wrapped.type_name == "ndarray" + + # Resolve β€” should return the exact same object + resolved = ext._resolve_handles(wrapped) + assert resolved is original + assert np.array_equal(original, resolved) diff --git a/tests/test_security_conda.py b/tests/test_security_conda.py new file mode 100644 index 0000000..3b19383 --- /dev/null +++ b/tests/test_security_conda.py @@ -0,0 +1,198 @@ +"""Tests for conda sealed_worker sandbox launch under bwrap (Issue 8 Slice 4).""" + +from __future__ import annotations + +import contextlib +import os +import re +from pathlib import Path +from unittest.mock import MagicMock, patch + +from pyisolate._internal.sandbox_detect import RestrictionModel + + +def _make_extension(): + from pyisolate._internal.host import Extension + from pyisolate.shared import ExtensionBase + + config = { + "name": "test_conda", + "module": "test_module", + "dependencies": [], + "share_torch": False, + "share_cuda_ipc": False, + "package_manager": "conda", + "execution_model": "sealed_worker", + "conda_channels": ["conda-forge"], + "conda_dependencies": ["numpy"], + } + + ext = Extension.__new__(Extension) + ext.name = "test_conda" + ext.config = config + ext.venv_path = Path("/fake/venv") + ext.module_path = "/fake/module" + ext.extension_type = ExtensionBase + ext._cuda_ipc_enabled = False + ext._uds_path = None + ext._uds_listener = None + ext._client_sock = None + return ext + + +def _pixi_python_path() -> Path: + if os.name == "nt": + return Path("/fake/venv/.pixi/envs/default/python.exe") + return Path("/fake/venv/.pixi/envs/default/bin/python") + + +def _launch_extension(ext, mock_popen: MagicMock) -> MagicMock: + mock_proc = MagicMock() + mock_proc.pid = 12345 + mock_proc.args = [ + "bwrap", + "--clearenv", + str(_pixi_python_path()), + "-m", + "pyisolate._internal.uds_client", + ] + mock_popen.return_value = mock_proc + + transport = MagicMock() + transport.send = MagicMock() + + with ( + patch( + "pyisolate._internal.host._resolve_pixi_python", + return_value=_pixi_python_path(), + ), + patch("pyisolate._internal.host.socket") as mock_socket, + patch("pyisolate._internal.host.tempfile"), + patch("pyisolate._internal.host.detect_sandbox_capability") as mock_detect, + patch("sys.platform", "linux"), + patch("pyisolate._internal.host.JSONSocketTransport", return_value=transport), + patch("pyisolate._internal.host.AsyncRPC"), + ): + mock_detect.return_value = MagicMock( + available=True, + restriction_model=RestrictionModel.NONE, + ) + mock_listener = MagicMock() + mock_listener.accept.return_value = (MagicMock(), None) + mock_socket.socket.return_value = mock_listener + mock_socket.AF_UNIX = 1 + mock_socket.SOCK_STREAM = 1 + + with ( + patch( + "pyisolate._internal.socket_utils.has_af_unix", + return_value=True, + ), + patch( + "pyisolate._internal.socket_utils.ensure_ipc_socket_dir", + return_value=Path("/run"), + ), + patch("pyisolate._internal.host.build_extension_snapshot", return_value={}), + patch("os.chmod"), + contextlib.suppress(Exception), + ): + ext._launch_with_uds() + + return transport + + +def _setenv_map(cmd: list[str]) -> dict[str, str]: + env_map: dict[str, str] = {} + index = 0 + while index < len(cmd): + if cmd[index] == "--setenv": + env_map[cmd[index + 1]] = cmd[index + 2] + index += 3 + continue + index += 1 + return env_map + + +class TestCondaSealedWorkerSandboxLaunch: + @patch("pyisolate._internal.host.subprocess.Popen") + def test_conda_sealed_worker_launches_via_bwrap(self, mock_popen: MagicMock) -> None: + ext = _make_extension() + + _launch_extension(ext, mock_popen) + + cmd = mock_popen.call_args[0][0] + assert cmd[0] == "bwrap" + + @patch("pyisolate._internal.host.subprocess.Popen") + def test_conda_sealed_worker_uses_explicit_env_allowlist(self, mock_popen: MagicMock) -> None: + ext = _make_extension() + + with patch.dict( + "os.environ", + { + "PATH": "/usr/bin", + "LANG": "C.UTF-8", + "PYTHONPATH": "/host/leak", + "SECRET_TOKEN": "should_not_leak", + }, + clear=True, + ): + _launch_extension(ext, mock_popen) + + cmd = mock_popen.call_args[0][0] + env_map = _setenv_map(cmd) + + assert "--clearenv" in cmd + assert env_map["PATH"] == "/usr/bin" + assert env_map["LANG"] == "C.UTF-8" + assert env_map["HOME"] == "/tmp" + assert env_map["TMPDIR"] == "/tmp" + assert env_map["PYTHONNOUSERSITE"] == "1" + assert "PYTHONPATH" not in env_map + assert "SECRET_TOKEN" not in env_map + + @patch("pyisolate._internal.host.subprocess.Popen") + def test_conda_sealed_worker_uses_pixi_python_inside_bwrap(self, mock_popen: MagicMock) -> None: + ext = _make_extension() + + _launch_extension(ext, mock_popen) + + cmd = mock_popen.call_args[0][0] + assert str(_pixi_python_path()) in cmd + + @patch("pyisolate._internal.host.subprocess.Popen") + def test_conda_sealed_worker_does_not_inject_credential_like_vars(self, mock_popen: MagicMock) -> None: + credential_pattern = re.compile( + r".*(_TOKEN|_SECRET|_KEY|_PASSWORD|_CREDENTIAL)$", + re.IGNORECASE, + ) + ext = _make_extension() + + with patch.dict( + "os.environ", + { + "PATH": "/usr/bin", + "HOME": "/home/test", + "API_TOKEN": "topsecret", + "GITHUB_SECRET": "still_secret", + }, + clear=True, + ): + _launch_extension(ext, mock_popen) + + cmd = mock_popen.call_args[0][0] + env_map = _setenv_map(cmd) + injected_creds = [key for key in env_map if credential_pattern.match(key)] + assert injected_creds == [] + + +class TestCondaSealedWorkerBootstrapGuards: + @patch("pyisolate._internal.host.subprocess.Popen") + def test_conda_sealed_worker_snapshot_disables_host_sys_path(self, mock_popen: MagicMock) -> None: + ext = _make_extension() + + transport = _launch_extension(ext, mock_popen) + + transport.send.assert_called_once() + bootstrap_data = transport.send.call_args[0][0] + assert bootstrap_data["snapshot"]["apply_host_sys_path"] is False diff --git a/tests/test_security_sealed_worker.py b/tests/test_security_sealed_worker.py new file mode 100644 index 0000000..5ba169e --- /dev/null +++ b/tests/test_security_sealed_worker.py @@ -0,0 +1,157 @@ +"""Tests for sandboxed sealed_worker runtime security behavior (Issue 8 Slice 2).""" + +from __future__ import annotations + +import contextlib +import re +from pathlib import Path +from unittest.mock import MagicMock, patch + +from pyisolate._internal.sandbox_detect import RestrictionModel + + +def _make_extension(): + from pyisolate._internal.host import Extension + from pyisolate.shared import ExtensionBase + + config = { + "name": "test_sealed", + "module": "test_module", + "dependencies": [], + "share_torch": False, + "share_cuda_ipc": False, + "package_manager": "uv", + "execution_model": "sealed_worker", + } + + ext = Extension.__new__(Extension) + ext.name = "test_sealed" + ext.config = config + ext.venv_path = Path("/fake/venv") + ext.module_path = "/fake/module" + ext.extension_type = ExtensionBase + ext._cuda_ipc_enabled = False + ext._uds_path = None + ext._uds_listener = None + ext._client_sock = None + return ext + + +def _launch_extension(ext, mock_popen: MagicMock) -> MagicMock: + mock_proc = MagicMock() + mock_proc.pid = 12345 + mock_popen.return_value = mock_proc + + transport = MagicMock() + transport.send = MagicMock() + + with ( + patch("pyisolate._internal.host.socket") as mock_socket, + patch("pyisolate._internal.host.tempfile"), + patch("pyisolate._internal.host.detect_sandbox_capability") as mock_detect, + patch("sys.platform", "linux"), + patch("pyisolate._internal.host.JSONSocketTransport", return_value=transport), + patch("pyisolate._internal.host.AsyncRPC"), + ): + mock_detect.return_value = MagicMock(available=True, restriction_model=RestrictionModel.NONE) + mock_listener = MagicMock() + mock_listener.accept.return_value = (MagicMock(), None) + mock_socket.socket.return_value = mock_listener + mock_socket.AF_UNIX = 1 + mock_socket.SOCK_STREAM = 1 + + with ( + patch("pyisolate._internal.socket_utils.has_af_unix", return_value=True), + patch("pyisolate._internal.socket_utils.ensure_ipc_socket_dir", return_value=Path("/run")), + patch("pyisolate._internal.host.build_extension_snapshot", return_value={}), + patch("os.chmod"), + contextlib.suppress(Exception), + ): + ext._launch_with_uds() + + return transport + + +def _setenv_map(cmd: list[str]) -> dict[str, str]: + env_map: dict[str, str] = {} + index = 0 + while index < len(cmd): + if cmd[index] == "--setenv": + env_map[cmd[index + 1]] = cmd[index + 2] + index += 3 + continue + index += 1 + return env_map + + +class TestSealedWorkerSandboxLaunch: + @patch("pyisolate._internal.host.subprocess.Popen") + def test_sealed_worker_launches_via_bwrap(self, mock_popen: MagicMock) -> None: + ext = _make_extension() + + _launch_extension(ext, mock_popen) + + cmd = mock_popen.call_args[0][0] + assert cmd[0] == "bwrap" + + @patch("pyisolate._internal.host.subprocess.Popen") + def test_sealed_worker_uses_explicit_env_allowlist(self, mock_popen: MagicMock) -> None: + ext = _make_extension() + + with patch.dict( + "os.environ", + { + "PATH": "/usr/bin", + "LANG": "C.UTF-8", + "PYTHONPATH": "/host/leak", + "SECRET_TOKEN": "should_not_leak", + }, + clear=True, + ): + _launch_extension(ext, mock_popen) + + cmd = mock_popen.call_args[0][0] + env_map = _setenv_map(cmd) + + assert "--clearenv" in cmd + assert env_map["PATH"] == "/usr/bin" + assert env_map["LANG"] == "C.UTF-8" + assert env_map["HOME"] == "/tmp" + assert env_map["TMPDIR"] == "/tmp" + assert env_map["PYTHONNOUSERSITE"] == "1" + assert "PYTHONPATH" not in env_map + assert "SECRET_TOKEN" not in env_map + + @patch("pyisolate._internal.host.subprocess.Popen") + def test_sealed_worker_does_not_inject_credential_like_vars(self, mock_popen: MagicMock) -> None: + credential_pattern = re.compile(r".*(_TOKEN|_SECRET|_KEY|_PASSWORD|_CREDENTIAL)$", re.IGNORECASE) + ext = _make_extension() + + with patch.dict( + "os.environ", + { + "PATH": "/usr/bin", + "HOME": "/home/test", + "API_TOKEN": "topsecret", + "GITHUB_SECRET": "still_secret", + }, + clear=True, + ): + _launch_extension(ext, mock_popen) + + cmd = mock_popen.call_args[0][0] + env_map = _setenv_map(cmd) + injected_creds = [key for key in env_map if credential_pattern.match(key)] + assert injected_creds == [] + + +class TestSealedWorkerBootstrapGuards: + @patch("pyisolate._internal.host.subprocess.Popen") + def test_sealed_worker_snapshot_disables_host_sys_path(self, mock_popen: MagicMock) -> None: + ext = _make_extension() + + transport = _launch_extension(ext, mock_popen) + + transport.send.assert_called_once() + bootstrap_data = transport.send.call_args[0][0] + assert bootstrap_data["snapshot"]["apply_host_sys_path"] is False diff --git a/tests/test_shared_additional.py b/tests/test_shared_additional.py index be5eef2..3c96b94 100644 --- a/tests/test_shared_additional.py +++ b/tests/test_shared_additional.py @@ -53,7 +53,12 @@ async def test_async_rpc_stop_requires_run(): def test_async_rpc_send_thread_sets_exception_on_send_failure(): + previous_loop = None loop = asyncio.new_event_loop() + try: + previous_loop = asyncio.get_event_loop_policy().get_event_loop() + except RuntimeError: + previous_loop = None asyncio.set_event_loop(loop) class FailingQueue: @@ -76,16 +81,24 @@ def put(self, _): rpc.outbox.put(pending) rpc.outbox.put(None) - rpc._send_thread() - loop.run_until_complete(asyncio.sleep(0)) - assert pending["future"].done() is True - with pytest.raises(RuntimeError): - pending["future"].result() - loop.close() + try: + rpc._send_thread() + loop.run_until_complete(asyncio.sleep(0)) + assert pending["future"].done() is True + with pytest.raises(RuntimeError): + pending["future"].result() + finally: + asyncio.set_event_loop(previous_loop) + loop.close() def test_async_rpc_send_thread_callback_failure_sets_exception(): + previous_loop = None loop = asyncio.new_event_loop() + try: + previous_loop = asyncio.get_event_loop_policy().get_event_loop() + except RuntimeError: + previous_loop = None asyncio.set_event_loop(loop) class FailingQueue: @@ -108,12 +121,15 @@ def put(self, _): rpc.outbox.put(pending) rpc.outbox.put(None) - rpc._send_thread() - loop.run_until_complete(asyncio.sleep(0)) - assert pending["future"].done() is True - with pytest.raises(RuntimeError): - pending["future"].result() - loop.close() + try: + rpc._send_thread() + loop.run_until_complete(asyncio.sleep(0)) + assert pending["future"].done() is True + with pytest.raises(RuntimeError): + pending["future"].result() + finally: + asyncio.set_event_loop(previous_loop) + loop.close() def test_singleton_metaclass_inject_guard(): diff --git a/tests/test_tensor_serializer_signal_cleanup.py b/tests/test_tensor_serializer_signal_cleanup.py new file mode 100644 index 0000000..5e99537 --- /dev/null +++ b/tests/test_tensor_serializer_signal_cleanup.py @@ -0,0 +1,22 @@ +from __future__ import annotations + +import importlib +import signal + + +def test_signal_cleanup_handler_tolerates_missing_sighup(monkeypatch) -> None: + import pyisolate._internal.tensor_serializer as tensor_serializer + + monkeypatch.setenv("PYISOLATE_SIGNAL_CLEANUP", "1") + monkeypatch.delattr(signal, "SIGHUP", raising=False) + + installed: list[object] = [] + + def fake_signal(sig, _handler): + installed.append(sig) + + monkeypatch.setattr(signal, "signal", fake_signal) + + importlib.reload(tensor_serializer) + + assert installed == [signal.SIGTERM] diff --git a/tests/test_uv_sealed_integration.py b/tests/test_uv_sealed_integration.py new file mode 100644 index 0000000..6a7436f --- /dev/null +++ b/tests/test_uv_sealed_integration.py @@ -0,0 +1,175 @@ +"""Process-level integration tests for the toolkit-owned uv sealed worker.""" + +from __future__ import annotations + +import contextlib +import gc +import os +import shutil +import site +import sys +import uuid +from pathlib import Path + +import pytest +import torch # noqa: E402 + +from pyisolate._internal.host import Extension # noqa: E402 +from pyisolate.sealed import SealedNodeExtension # noqa: E402 + +UV_BIN = Path(sys.executable).with_name("uv.exe" if os.name == "nt" else "uv") +UV_AVAILABLE = shutil.which("uv") is not None or UV_BIN.exists() +BWRAP_AVAILABLE = os.name == "nt" or shutil.which("bwrap") is not None + +pytestmark = [ + pytest.mark.skipif(not UV_AVAILABLE, reason="uv not on PATH"), + pytest.mark.skipif(not BWRAP_AVAILABLE, reason="bwrap not on PATH"), +] + + +def _fixture_path() -> Path: + return Path(__file__).resolve().parent / "fixtures" / "uv_sealed_worker" + + +def _shm_snapshot() -> set[str]: + shm_root = Path("/dev/shm") + if os.name == "nt" or not shm_root.exists(): + return set() + return {path.name for path in shm_root.glob("torch_*")} + + +def _host_site_root() -> str: + return str(Path(sys.executable).resolve().parents[1]) + + +def _build_uv_config(fixture_path: Path, run_dir: Path) -> dict: + # Inlined from fixtures/uv_sealed_worker/pyproject.toml β€” no TOML parser needed. + return { + "name": "uv-sealed-worker", + "module_path": str(fixture_path), + "isolated": True, + "dependencies": ["boltons"], + "apis": [], + "env": { + "PYISOLATE_ARTIFACT_DIR": str(run_dir / "artifacts"), + "PYISOLATE_SIGNAL_CLEANUP": "1", + }, + "share_torch": False, + "share_cuda_ipc": False, + "sandbox": {"writable_paths": [str(run_dir / "artifacts")]}, + "package_manager": "uv", + "execution_model": "sealed_worker", + } + + +@pytest.mark.asyncio +async def test_uv_sealed_runtime_uses_toolkit_fixture_without_host_leakage() -> None: + fixture_path = _fixture_path() + run_root = Path(__file__).resolve().parent.parent / ".pytest_artifacts" / "uv_sealed_integration" + run_dir = run_root / uuid.uuid4().hex + (run_dir / "artifacts").mkdir(parents=True, exist_ok=True) + venv_root = run_dir / "venvs" + venv_root.mkdir(parents=True, exist_ok=True) + config = _build_uv_config(fixture_path, run_dir) + + ext = Extension( + module_path=str(fixture_path), + extension_type=SealedNodeExtension, + config=config, + venv_root_path=str(venv_root), + ) + + try: + path_env = os.environ.get("PATH", "") + uv_path = f"{UV_BIN.parent}{os.pathsep}{path_env}" if UV_BIN.exists() else path_env + with pytest.MonkeyPatch.context() as monkeypatch: + monkeypatch.setenv("PATH", uv_path) + monkeypatch.setenv("PYISOLATE_ARTIFACT_DIR", str(run_dir / "artifacts")) + try: + ext.ensure_process_started() + except RuntimeError as exc: + if "bubblewrap" in str(exc).lower(): + pytest.skip(f"bwrap unavailable on this platform: {exc}") + raise + proxy = ext.get_proxy() + + nodes = await proxy.list_nodes() + assert nodes == { + "UVSealedRuntimeProbe": "UV Sealed Runtime Probe", + "UVSealedBoltonsSlugify": "UV Sealed Boltons Slugify", + "UVSealedFilesystemBarrier": "UV Sealed Filesystem Barrier", + "UVSealedTensorEcho": "UV Sealed Tensor Echo", + "UVSealedLatentEcho": "UV Sealed Latent Echo", + } + + ( + path_dump, + boltons_origin, + report, + saw_user_site, + ) = await proxy.execute_node("UVSealedRuntimeProbe") + print(report) + print(f"child boltons origin: {boltons_origin}") + + assert str(ext.venv_path) in boltons_origin + assert _host_site_root() not in boltons_origin + assert site.getusersitepackages() not in path_dump + assert saw_user_site is False + + slug, slug_origin = await proxy.execute_node( + "UVSealedBoltonsSlugify", text="Sealed Worker Still Works" + ) + assert slug == "sealed_worker_still_works" + assert slug_origin == boltons_origin + + ( + barrier_report, + outside_blocked, + module_mutation_blocked, + artifact_write_ok, + ) = await proxy.execute_node("UVSealedFilesystemBarrier") + print(barrier_report) + assert artifact_write_ok is True + if os.name != "nt": + assert outside_blocked is True + assert module_mutation_blocked is True + + shm_before = _shm_snapshot() + input_tensor = torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float32) + echoed_tensor, saw_json_tensor = await proxy.execute_node("UVSealedTensorEcho", tensor=input_tensor) + shm_after = _shm_snapshot() + + max_abs = float((echoed_tensor - input_tensor).abs().max().item()) + print(f"tensor roundtrip max_abs={max_abs:.8f}") + print(f"launch args={ext.proc.args}") + + assert torch.equal(echoed_tensor, input_tensor) + assert max_abs <= 1e-5 + assert saw_json_tensor is True + launch_args = " ".join(str(part) for part in ext.proc.args) + if os.name != "nt": + assert shm_after == shm_before + assert launch_args.startswith("bwrap ") + assert "pyisolate._internal.uds_client" in launch_args + + artifact_dir = run_dir / "artifacts" + assert (artifact_dir / "child_bootstrap_paths.txt").exists() + # child_import_trace.txt is only written by setup_child_environment, + # which sealed workers skip (no host sys.path application). + assert (artifact_dir / "filesystem_barrier_probe.txt").exists() + finally: + with contextlib.suppress(Exception): + if "proxy" in locals(): + await proxy.flush_transport_state() + with contextlib.suppress(UnboundLocalError): + del echoed_tensor + with contextlib.suppress(UnboundLocalError): + del input_tensor + gc.collect() + if torch.cuda.is_available(): + with contextlib.suppress(Exception): + torch.cuda.synchronize() + ext.stop() + if getattr(ext, "proc", None) is not None: + assert ext.proc.poll() is not None + shutil.rmtree(run_dir, ignore_errors=True)