diff --git a/benchmarks/benchmark.py b/benchmarks/benchmark.py index 88768ae..766f860 100644 --- a/benchmarks/benchmark.py +++ b/benchmarks/benchmark.py @@ -8,25 +8,28 @@ import argparse import asyncio -import sys import statistics +import sys from pathlib import Path # Add project root to path for pyisolate imports project_root = Path(__file__).parent.parent sys.path.insert(0, str(project_root)) -from benchmark_harness import BenchmarkHarness -from pyisolate import ProxiedSingleton, ExtensionBase, ExtensionConfig, local_execution +from benchmark_harness import BenchmarkHarness # noqa: E402 + +from pyisolate import ExtensionBase, ExtensionConfig, ProxiedSingleton # noqa: E402 try: - import torch - TORCH_AVAILABLE = True -except ImportError: + from importlib.util import find_spec + + TORCH_AVAILABLE = find_spec("torch") is not None +except ImportError: # pragma: no cover TORCH_AVAILABLE = False - + try: from tabulate import tabulate + TABULATE_AVAILABLE = True except ImportError: TABULATE_AVAILABLE = False @@ -36,8 +39,10 @@ # Host-side Classes # ============================================================================= + class DatabaseSingleton(ProxiedSingleton): """Simple dictionary-based singleton for testing state.""" + def __init__(self): self._db = {} @@ -52,11 +57,12 @@ class BenchmarkExtensionWrapper(ExtensionBase): """ Host-side wrapper that proxies calls to the isolated extension. """ + async def on_module_loaded(self, module): """Called when the isolated module is loaded.""" if not getattr(module, "benchmark_entrypoint", None): raise RuntimeError(f"Module {module.__name__} missing 'benchmark_entrypoint'") - + # Instantiate the child-side extension object self.extension = module.benchmark_entrypoint() await self.extension.initialize() @@ -89,7 +95,7 @@ async def get_value(self, key): pass class BenchmarkExtension: """Child-side extension implementation.""" - + async def initialize(self): pass @@ -179,30 +185,29 @@ def __init__(self, mean, stdev, min_time, max_time): class SimpleRunner: """Minimal runner to replace TestRPCBenchmarks.runner.""" + def __init__(self, warmup_runs=5, benchmark_runs=1000): self.warmup_runs = warmup_runs self.benchmark_runs = benchmark_runs async def run_benchmark(self, name, func): import time + times = [] - + # Warmup for _ in range(self.warmup_runs): await func() - + # Benchmark for _ in range(self.benchmark_runs): start = time.perf_counter() await func() end = time.perf_counter() times.append(end - start) - + return BenchmarkResult( - statistics.mean(times), - statistics.stdev(times) if len(times) > 1 else 0, - min(times), - max(times) + statistics.mean(times), statistics.stdev(times) if len(times) > 1 else 0, min(times), max(times) ) @@ -211,13 +216,13 @@ async def run_benchmarks( ): print("PyIsolate RPC Benchmark Suite (Refactored for 1.0)") print("=" * 60) - + harness = BenchmarkHarness() await harness.setup_test_environment("benchmark") - + runner = SimpleRunner( - warmup_runs=2 if quick else 5, - benchmark_runs=100 if quick else 1000 + warmup_runs=2 if quick else 5, + benchmark_runs=100 if quick else 1000, ) try: @@ -230,7 +235,7 @@ async def run_benchmarks( "benchmark_ext", dependencies=["numpy>=1.26.0", "torch>=2.0.0"] if torch_available else ["numpy>=1.26.0"], share_torch=False, - extension_code=BENCHMARK_EXTENSION_CODE + extension_code=BENCHMARK_EXTENSION_CODE, ) extensions_config.append({"name": "benchmark_ext", "share": False}) @@ -239,34 +244,35 @@ async def run_benchmarks( "benchmark_ext_shared", dependencies=["numpy>=1.26.0", "torch>=2.0.0"], share_torch=True, - extension_code=BENCHMARK_EXTENSION_CODE + extension_code=BENCHMARK_EXTENSION_CODE, ) extensions_config.append({"name": "benchmark_ext_shared", "share": True}) # Load Extensions using Manager manager = harness.get_manager(BenchmarkExtensionWrapper) - + ext_standard = None ext_shared = None - + for cfg in extensions_config: name = cfg["name"] share_torch = cfg["share"] print(f"Loading extension {name} (share_torch={share_torch})...") - + # Reconstruct minimal deps for config (manager uses this for venv check/install) deps = ["numpy>=1.26.0"] - if torch_available: deps.append("torch>=2.0.0") - + if torch_available: + deps.append("torch>=2.0.0") + config = ExtensionConfig( name=name, module_path=str(harness.test_root / "extensions" / name), isolated=True, dependencies=deps, - apis=[DatabaseSingleton], # Host must allow the singleton - share_torch=share_torch + apis=[DatabaseSingleton], # Host must allow the singleton + share_torch=share_torch, ) - + ext = manager.load_extension(config) if name == "benchmark_ext": ext_standard = ext @@ -274,32 +280,34 @@ async def run_benchmarks( ext_shared = ext print("Extensions loaded.\n") - + # Define Test Data test_data = [ ("small_int", 42), ("small_string", "hello world"), ] - + runner_results = {} - + # --- Run Benchmarks --- # Note: In a full implementation, we'd replicate the comprehensive test suite. # Here we verify core functionality by running the 'do_stuff' generic method. # This confirms RPC, Serialization, and Process Isolation are working. - + target_extensions = [] - if ext_standard: target_extensions.append(("Standard", ext_standard)) - if ext_shared: target_extensions.append(("Shared", ext_shared)) - - for name, ext in target_extensions: - print(f"--- Benchmarking {name} Mode ---") + if ext_standard: + target_extensions.append(("Standard", ext_standard)) + if ext_shared: + target_extensions.append(("Shared", ext_shared)) + + for mode_name, ext in target_extensions: + print(f"--- Benchmarking {mode_name} Mode ---") for data_name, data_val in test_data: - bench_name = f"{name}_{data_name}" - - async def func(): - return await ext.do_stuff(data_val) - + bench_name = f"{mode_name}_{data_name}" + + async def func(bound_ext=ext, bound_value=data_val): + return await bound_ext.do_stuff(bound_value) + print(f"Running {bench_name}...") try: res = await runner.run_benchmark(bench_name, func) @@ -311,12 +319,12 @@ async def func(): print("\n" + "=" * 60) print("RESULTS") print("=" * 60) - + headers = ["Test", "Mean (ms)", "Std Dev (ms)"] table_data = [] for name, res in runner_results.items(): - table_data.append([name, f"{res.mean*1000:.3f}", f"{res.stdev*1000:.3f}"]) - + table_data.append([name, f"{res.mean * 1000:.3f}", f"{res.stdev * 1000:.3f}"]) + if TABULATE_AVAILABLE: print(tabulate(table_data, headers=headers)) else: @@ -325,27 +333,26 @@ async def func(): finally: await harness.cleanup() - + return 0 -def main(): +def main() -> int: parser = argparse.ArgumentParser(description="PyIsolate 1.0 Benchmark") parser.add_argument("--quick", action="store_true") parser.add_argument("--no-torch", action="store_true") parser.add_argument("--no-gpu", action="store_true") parser.add_argument("--torch-mode", default="both") - + args = parser.parse_args() - - try: - import numpy - import psutil - except ImportError: + + if find_spec("numpy") is None or find_spec("psutil") is None: print("Please install dependencies: pip install numpy psutil tabulate") return 1 - + asyncio.run(run_benchmarks(args.quick, args.no_torch, args.no_gpu, args.torch_mode)) + return 0 + if __name__ == "__main__": - main() + raise SystemExit(main()) diff --git a/benchmarks/benchmark_harness.py b/benchmarks/benchmark_harness.py index 4ec6e34..35a31c5 100644 --- a/benchmarks/benchmark_harness.py +++ b/benchmarks/benchmark_harness.py @@ -1,18 +1,18 @@ +import contextlib import os import sys -import shutil import tempfile -import asyncio from pathlib import Path -from typing import Optional, Any -from contextlib import contextmanager -from pyisolate import ExtensionManagerConfig, ExtensionManager, ExtensionConfig +from pyisolate import ExtensionConfig, ExtensionManager, ExtensionManagerConfig +from pyisolate.config import SandboxMode try: - import torch + import torch.multiprocessing as torch_mp + TORCH_AVAILABLE = True except ImportError: + torch_mp = None TORCH_AVAILABLE = False @@ -41,25 +41,17 @@ async def setup_test_environment(self, name: str) -> None: shared_tmp.mkdir(parents=True, exist_ok=True) # Force host process (and children via inherit) to use this TMPDIR os.environ["TMPDIR"] = str(shared_tmp) - + print(f"Benchmark Harness initialized at {self.test_root}") print(f"IPC Shared Directory: {shared_tmp}") # Ensure proper torch multiprocessing setup - if TORCH_AVAILABLE: - try: - import torch.multiprocessing - torch.multiprocessing.set_sharing_strategy('file_system') - except ImportError: - pass - + if TORCH_AVAILABLE and torch_mp is not None: + with contextlib.suppress(ImportError): + torch_mp.set_sharing_strategy("file_system") def create_extension( - self, - name: str, - dependencies: list[str], - share_torch: bool, - extension_code: str + self, name: str, dependencies: list[str], share_torch: bool, extension_code: str ) -> None: """Create an extension module on disk.""" ext_dir = self.test_root / "extensions" / name @@ -70,37 +62,30 @@ async def load_extensions(self, extension_configs: list[dict], extension_base_cl """Load extensions defined in configs.""" config = ExtensionManagerConfig(venv_root_path=str(self.test_root / "extension-venvs")) self.manager = ExtensionManager(extension_base_cls, config) - + loaded_extensions = [] for cfg in extension_configs: name = cfg["name"] - # Config might be passed as simple dict - - # Reconstruct dependencies if not passed mostly for existing pattern in benchmark.py - # But create_extension handles writing to disk. loading needs ExtensionConfig object. - - # This is slightly tricky because creation and loading are split in benchmark.py - # I'll rely on the caller to pass correct params or infer them? - # Actually benchmark.py logic: create_extension then load_extensions loop. - - # Since we know the path structure from create_extension: - module_path = str(self.test_root / "extensions" / name) - - # NOTE: benchmark.py passed deps to create_extension but strangely not to load_extensions - # We must pass them here to ExtensionConfig. - # Ideally load_extensions accepts full config objects or we recreate them. - # I will adapt this to match what benchmark.py expects or refactor benchmark.py to iterate. - - # Simpler approach: Allow caller to just use manager directly if they want, - # or provide a helper that does what benchmark.py did (but correctly). - pass - - return loaded_extensions # placeholder, I will implement explicit loading in the script + config = ExtensionConfig( + name=name, + module_path=str(self.test_root / "extensions" / name), + isolated=cfg.get("isolated", True), + dependencies=cfg.get("dependencies", []), + apis=cfg.get("apis", []), + share_torch=cfg.get("share_torch", False), + share_cuda_ipc=cfg.get("share_cuda_ipc", False), + sandbox=cfg.get("sandbox", {}), + sandbox_mode=cfg.get("sandbox_mode", SandboxMode.REQUIRED), + env=cfg.get("env", {}), + ) + loaded_extensions.append(self.manager.load_extension(config)) + + return loaded_extensions def get_manager(self, extension_base_cls): if not self.manager: - config = ExtensionManagerConfig(venv_root_path=str(self.test_root / "extension-venvs")) - self.manager = ExtensionManager(extension_base_cls, config) + config = ExtensionManagerConfig(venv_root_path=str(self.test_root / "extension-venvs")) + self.manager = ExtensionManager(extension_base_cls, config) return self.manager async def cleanup(self): @@ -110,6 +95,6 @@ async def cleanup(self): self.manager.stop_all_extensions() except Exception as e: print(f"Error stopping extensions: {e}") - + if self.temp_dir: self.temp_dir.cleanup() diff --git a/benchmarks/memory_benchmark.py b/benchmarks/memory_benchmark.py index 30a3b78..b131956 100644 --- a/benchmarks/memory_benchmark.py +++ b/benchmarks/memory_benchmark.py @@ -8,13 +8,13 @@ import argparse import asyncio +import contextlib import gc import platform import sys import time -import os from pathlib import Path -from typing import Optional +from shutil import which import psutil @@ -44,15 +44,11 @@ nvml = None NVML_AVAILABLE = False -import contextlib -import tempfile -import shutil - -from memory_extension_base import MemoryBenchmarkExtensionBase -from benchmark_harness import BenchmarkHarness -from tabulate import tabulate +from benchmark_harness import BenchmarkHarness # noqa: E402 +from memory_extension_base import MemoryBenchmarkExtensionBase # noqa: E402 +from tabulate import tabulate # noqa: E402 -from pyisolate import ExtensionConfig, ExtensionManager, ExtensionManagerConfig +from pyisolate import ExtensionConfig, ExtensionManager, ExtensionManagerConfig # noqa: E402 class MemoryTracker: @@ -89,14 +85,13 @@ def __init__(self): self.baseline_gpu_memory_mb = baseline print(f"Using nvidia-smi fallback. Initial GPU memory: {baseline:.1f} MB") - def _get_gpu_memory_nvidia_smi(self) -> Optional[float]: + def _get_gpu_memory_nvidia_smi(self) -> float | None: """Get GPU memory usage using nvidia-smi command (Windows fallback).""" try: - import shutil import subprocess # Find nvidia-smi executable - nvidia_smi = shutil.which("nvidia-smi") + nvidia_smi = which("nvidia-smi") if not nvidia_smi: return None @@ -126,10 +121,9 @@ def _get_gpu_memory_windows_fallback(self, memory_info: dict[str, float]) -> dic # Try to get total GPU memory try: - import shutil import subprocess - nvidia_smi = shutil.which("nvidia-smi") + nvidia_smi = which("nvidia-smi") if nvidia_smi: result = subprocess.run( # noqa: S603 [nvidia_smi, "--query-gpu=memory.total", "--format=csv,nounits,noheader"], @@ -331,6 +325,7 @@ def memory_benchmark_entrypoint(): ''' +class MemoryBenchmarkRunner: """Runs memory usage benchmarks with multiple extensions.""" def __init__(self, test_base: BenchmarkHarness): diff --git a/benchmarks/simple_benchmark.py b/benchmarks/simple_benchmark.py index f768c60..566614c 100644 --- a/benchmarks/simple_benchmark.py +++ b/benchmarks/simple_benchmark.py @@ -37,6 +37,7 @@ async def measure_rpc_overhead(include_large_tensors=False): print() import os + if sys.platform == "linux" and os.environ.get("TMPDIR") != "/dev/shm": print("WARNING: TMPDIR is not set to /dev/shm on Linux.") print("If extensions use share_torch=True, execution WILL fail in strict sandboxes.") diff --git a/pyisolate/_internal/cuda_wheels.py b/pyisolate/_internal/cuda_wheels.py new file mode 100644 index 0000000..668bab8 --- /dev/null +++ b/pyisolate/_internal/cuda_wheels.py @@ -0,0 +1,272 @@ +from __future__ import annotations + +import re +from html.parser import HTMLParser +from typing import TypedDict, cast +from urllib.error import HTTPError, URLError +from urllib.parse import unquote, urljoin, urlparse +from urllib.request import urlopen + +from packaging.markers import default_environment +from packaging.requirements import InvalidRequirement, Requirement +from packaging.tags import sys_tags +from packaging.utils import canonicalize_name, parse_wheel_filename +from packaging.version import Version + +from ..config import CUDAWheelConfig + +_TORCH_VERSION_RE = re.compile(r"^(?P\d+)\.(?P\d+)") +_CUDA_LOCAL_PATTERNS = ( + re.compile(r"(^|[.-])cu(?P\d+)torch(?P\d+)([.-]|$)"), + re.compile(r"(^|[.-])pt(?P\d+)cu(?P\d+)([.-]|$)"), +) + + +class CUDAWheelRuntime(TypedDict): + torch: str + torch_nodot: str + cuda: str + cuda_nodot: str + python_tags: list[str] + + +class _SimpleIndexParser(HTMLParser): + def __init__(self) -> None: + super().__init__() + self.hrefs: list[str] = [] + + def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None: + if tag != "a": + return + attributes = dict(attrs) + href = attributes.get("href") + if href: + self.hrefs.append(href) + + +class CUDAWheelResolutionError(RuntimeError): + pass + + +def _parse_major_minor(version_text: str, label: str) -> str: + match = _TORCH_VERSION_RE.match(version_text) + if not match: + raise CUDAWheelResolutionError(f"Could not parse {label} major.minor from '{version_text}'") + return f"{match.group('major')}.{match.group('minor')}" + + +def get_cuda_wheel_runtime() -> CUDAWheelRuntime: + try: + import torch + except ImportError as exc: + raise CUDAWheelResolutionError( + "Custom CUDA wheel resolution requires host torch to be installed" + ) from exc + + torch_version = _parse_major_minor(str(torch.__version__), "torch version") + cuda_version = torch.version.cuda # type: ignore[attr-defined] + if not cuda_version: + raise CUDAWheelResolutionError( + "Custom CUDA wheel resolution requires a CUDA-enabled host torch build" + ) + cuda_major_minor = _parse_major_minor(str(cuda_version), "CUDA version") + return { + "torch": torch_version, + "torch_nodot": torch_version.replace(".", ""), + "cuda": cuda_major_minor, + "cuda_nodot": cuda_major_minor.replace(".", ""), + "python_tags": [str(tag) for tag in sys_tags()], + } + + +def get_cuda_wheel_runtime_descriptor() -> dict[str, object]: + runtime = get_cuda_wheel_runtime() + return { + "torch": runtime["torch"], + "torch_nodot": runtime["torch_nodot"], + "cuda": runtime["cuda"], + "cuda_nodot": runtime["cuda_nodot"], + "python_tags": runtime["python_tags"], + } + + +def _normalize_cuda_wheel_config(config: CUDAWheelConfig) -> CUDAWheelConfig: + index_url = config.get("index_url") + packages = config.get("packages") + package_map = config.get("package_map", {}) + + if not isinstance(index_url, str) or not index_url.strip(): + raise CUDAWheelResolutionError("cuda_wheels.index_url must be a non-empty string") + if not isinstance(packages, list) or not all( + isinstance(package_name, str) and package_name.strip() for package_name in packages + ): + raise CUDAWheelResolutionError("cuda_wheels.packages must be a list of non-empty strings") + if not isinstance(package_map, dict): + raise CUDAWheelResolutionError("cuda_wheels.package_map must be a mapping") + + normalized_map: dict[str, str] = {} + for dependency_name, index_package_name in package_map.items(): + if not isinstance(dependency_name, str) or not dependency_name.strip(): + raise CUDAWheelResolutionError("cuda_wheels.package_map keys must be non-empty strings") + if not isinstance(index_package_name, str) or not index_package_name.strip(): + raise CUDAWheelResolutionError("cuda_wheels.package_map values must be non-empty strings") + normalized_map[canonicalize_name(dependency_name)] = index_package_name.strip() + + return { + "index_url": index_url.rstrip("/") + "/", + "packages": [canonicalize_name(package_name) for package_name in packages], + "package_map": normalized_map, + } + + +def _candidate_package_names(dependency_name: str, package_map: dict[str, str]) -> list[str]: + candidates: list[str] = [] + mapped_name = package_map.get(dependency_name) + if mapped_name: + candidates.append(mapped_name.strip()) + candidates.append(mapped_name.replace("-", "_")) + candidates.append(mapped_name.replace("_", "-")) + + candidates.append(dependency_name) + candidates.append(dependency_name.replace("-", "_")) + candidates.append(dependency_name.replace("_", "-")) + + deduped: list[str] = [] + seen: set[str] = set() + for candidate in candidates: + if candidate and candidate not in seen: + deduped.append(candidate) + seen.add(candidate) + return deduped + + +def _fetch_index_html(url: str) -> str | None: + try: + with urlopen(url, timeout=30) as response: # noqa: S310 - URL is explicit extension config + content: bytes = response.read() + return content.decode("utf-8") + except (HTTPError, URLError, FileNotFoundError): + return None + + +def _parse_index_links(page_url: str, html: str) -> list[str]: + parser = _SimpleIndexParser() + parser.feed(html) + return [urljoin(page_url, href) for href in parser.hrefs] + + +def _normalize_wheel_url(raw_url: str) -> str: + parsed_url = urlparse(raw_url) + return parsed_url._replace(path=unquote(parsed_url.path)).geturl() + + +def _matches_runtime(local_version: str | None, runtime: CUDAWheelRuntime) -> bool: + if not local_version: + return False + normalized_local = local_version.lower() + for pattern in _CUDA_LOCAL_PATTERNS: + match = pattern.search(normalized_local) + if not match: + continue + if match.group("torch") == runtime["torch_nodot"] and match.group("cuda") == runtime["cuda_nodot"]: + return True + return False + + +def resolve_cuda_wheel_url( + requirement: Requirement, config: CUDAWheelConfig, runtime: CUDAWheelRuntime | None = None +) -> str: + normalized_config = _normalize_cuda_wheel_config(config) + dependency_name = canonicalize_name(requirement.name) + runtime_info = runtime or get_cuda_wheel_runtime() + supported_tag_list = list(sys_tags()) + supported_tags = set(supported_tag_list) + tag_rank = {tag: idx for idx, tag in enumerate(supported_tag_list)} + fetch_attempted = False + candidates: list[tuple[Version, int, str]] = [] + + for package_name in _candidate_package_names(dependency_name, normalized_config.get("package_map", {})): + page_url = urljoin(normalized_config["index_url"], package_name.rstrip("/") + "/") + html = _fetch_index_html(page_url) + if html is None: + continue + fetch_attempted = True + for wheel_url in _parse_index_links(page_url, html): + parsed_url = urlparse(wheel_url) + wheel_filename = unquote(parsed_url.path.rsplit("/", 1)[-1]) + if not wheel_filename.endswith(".whl"): + continue + try: + wheel_name, wheel_version, _, wheel_tags = parse_wheel_filename(wheel_filename) + except ValueError: + continue + if canonicalize_name(wheel_name) != dependency_name: + continue + matching_tags = wheel_tags.intersection(supported_tags) + if not matching_tags: + continue + if not _matches_runtime(getattr(wheel_version, "local", None), runtime_info): + continue + if requirement.specifier and wheel_version not in requirement.specifier: + continue + candidates.append( + ( + wheel_version, + min(tag_rank[tag] for tag in matching_tags), + _normalize_wheel_url(wheel_url), + ) + ) + + if not fetch_attempted: + raise CUDAWheelResolutionError( + f"No CUDA wheel index page found for '{requirement.name}' under {normalized_config['index_url']}" + ) + if not candidates: + raise CUDAWheelResolutionError( + "No compatible CUDA wheel found for " + f"'{requirement}' (torch {runtime_info['torch']}, CUDA {runtime_info['cuda']})" + ) + + candidates.sort(key=lambda item: (item[0], -item[1])) + return candidates[-1][2] + + +def resolve_cuda_wheel_requirements(requirements: list[str], config: CUDAWheelConfig) -> list[str]: + normalized_config = _normalize_cuda_wheel_config(config) + configured_packages = set(normalized_config["packages"]) + environment = cast(dict[str, str], default_environment()) + runtime = get_cuda_wheel_runtime() + resolved_requirements: list[str] = [] + + for dependency in requirements: + stripped = dependency.strip() + if not stripped or stripped == "-e" or stripped.startswith("-e "): + resolved_requirements.append(dependency) + continue + if stripped.startswith(("/", "./", "../", "file://")): + resolved_requirements.append(dependency) + continue + + try: + requirement = Requirement(stripped) + except InvalidRequirement: + resolved_requirements.append(dependency) + continue + + dependency_name = canonicalize_name(requirement.name) + if dependency_name not in configured_packages: + resolved_requirements.append(dependency) + continue + if requirement.url: + raise CUDAWheelResolutionError( + f"cuda_wheels dependency '{requirement.name}' must not already use a direct URL" + ) + if requirement.extras: + raise CUDAWheelResolutionError(f"cuda_wheels dependency '{requirement.name}' must not use extras") + if requirement.marker and not requirement.marker.evaluate(environment): + resolved_requirements.append(dependency) + continue + + resolved_requirements.append(resolve_cuda_wheel_url(requirement, normalized_config, runtime)) + + return resolved_requirements diff --git a/pyisolate/_internal/environment.py b/pyisolate/_internal/environment.py index 5db720d..634723f 100644 --- a/pyisolate/_internal/environment.py +++ b/pyisolate/_internal/environment.py @@ -12,9 +12,14 @@ from importlib import metadata as importlib_metadata from pathlib import Path from typing import Any +from urllib.parse import urlparse from ..config import ExtensionConfig from ..path_helpers import serialize_host_snapshot +from .cuda_wheels import ( + get_cuda_wheel_runtime_descriptor, + resolve_cuda_wheel_requirements, +) from .torch_utils import get_torch_ecosystem_packages logger = logging.getLogger(__name__) @@ -306,6 +311,26 @@ def install_dependencies(venv_path: Path, config: ExtensionConfig, name: str) -> if not safe_deps: return + cuda_wheels_config = config.get("cuda_wheels") + cuda_wheel_runtime: dict[str, object] | None = None + if cuda_wheels_config: + from packaging.requirements import InvalidRequirement, Requirement + from packaging.utils import canonicalize_name + + cuda_pkg_names = {canonicalize_name(p) for p in cuda_wheels_config.get("packages", [])} + needs_cuda_probe = False + for dep in safe_deps: + if dep.startswith("-e"): + continue + try: + if canonicalize_name(Requirement(dep).name) in cuda_pkg_names: + needs_cuda_probe = True + break + except InvalidRequirement: + continue + if needs_cuda_probe: + cuda_wheel_runtime = get_cuda_wheel_runtime_descriptor() + # uv handles hardlink vs copy automatically based on filesystem support cmd_prefix: list[str] = [uv_path, "pip", "install", "--python", str(python_exe)] cache_dir_override = os.environ.get("PYISOLATE_UV_CACHE_DIR") @@ -335,6 +360,8 @@ def install_dependencies(venv_path: Path, config: ExtensionConfig, name: str) -> "dependencies": safe_deps, "share_torch": config["share_torch"], "torch_spec": torch_spec, + "cuda_wheels": cuda_wheels_config, + "cuda_wheel_runtime": cuda_wheel_runtime, "pyisolate": pyisolate_version, "python": sys.version, } @@ -349,18 +376,32 @@ def install_dependencies(venv_path: Path, config: ExtensionConfig, name: str) -> except Exception as exc: logger.debug("Dependency cache read failed: %s", exc) + resolved_deps = safe_deps + if cuda_wheels_config: + resolved_deps = resolve_cuda_wheel_requirements(safe_deps, cuda_wheels_config) + for original_dep, resolved_dep in zip(safe_deps, resolved_deps, strict=True): + if original_dep != resolved_dep: + parsed = urlparse(resolved_dep) + redacted = f"{parsed.netloc}/{Path(parsed.path).name}" if parsed.scheme else resolved_dep + logger.info( + "][ CUDA_WHEEL_RESOLVED ext=%s dep=%s wheel=%s", + name, + original_dep, + redacted, + ) + install_targets: list[str] = [] i = 0 - while i < len(safe_deps): - dep = safe_deps[i] + while i < len(resolved_deps): + dep = resolved_deps[i] dep_stripped = dep.strip() # Support split editable args from existing callers: # ["-e", "/path/to/pkg"]. if dep_stripped == "-e": - if i + 1 >= len(safe_deps): + if i + 1 >= len(resolved_deps): raise ValueError("Editable dependency '-e' must include a path or URL") - editable_target = safe_deps[i + 1].strip() + editable_target = resolved_deps[i + 1].strip() if not editable_target: raise ValueError("Editable dependency '-e' must include a path or URL") install_targets.extend(["-e", editable_target]) @@ -376,6 +417,17 @@ def install_dependencies(venv_path: Path, config: ExtensionConfig, name: str) -> install_targets.append(dep) i += 1 + if cuda_wheels_config: + redacted_targets = [ + f"{urlparse(t).netloc}/{Path(urlparse(t).path).name}" if "://" in t else t + for t in install_targets + ] + logger.info( + "][ CUDA_WHEEL_INSTALL ext=%s targets=%s", + name, + redacted_targets, + ) + cmd = cmd_prefix + install_targets + common_args with subprocess.Popen( # noqa: S603 # Trusted: validated pip/uv install cmd @@ -394,6 +446,8 @@ def install_dependencies(venv_path: Path, config: ExtensionConfig, name: str) -> # for users debugging their own extension dependencies. if "pyisolate==" not in clean and "pyisolate @" not in clean: output_lines.append(clean) + if cuda_wheels_config and clean: + logger.info("][ CUDA_WHEEL_UV ext=%s %s", name, clean) return_code = proc.wait() if return_code != 0: diff --git a/pyisolate/_internal/model_serialization.py b/pyisolate/_internal/model_serialization.py index 09e0846..6f7c2ef 100644 --- a/pyisolate/_internal/model_serialization.py +++ b/pyisolate/_internal/model_serialization.py @@ -87,6 +87,8 @@ async def deserialize_from_isolation(data: Any, extension: Any = None, _nested: if isinstance(data, RemoteObjectHandle): if _nested or extension is None: return data + if registry.has_handler(data.type_name): + return data try: resolved = await extension.get_remote_object(data.object_id) with contextlib.suppress(Exception): @@ -95,11 +97,13 @@ async def deserialize_from_isolation(data: Any, extension: Any = None, _nested: except Exception: return data - # Check for adapter-registered deserializers by type name (e.g., NodeOutput) - if registry.has_handler(type_name): + # Check for adapter-registered deserializers by type name (e.g., NodeOutput). + # Only apply to dicts (serialized form). Objects already deserialized by the + # JSON transport layer (e.g., PLY reconstructed via _json_object_hook) are + # passed through as-is. + if isinstance(data, dict) and registry.has_handler(type_name): deserializer = registry.get_deserializer(type_name) if deserializer: - # For async deserializers, we need special handling result = deserializer(data) if hasattr(result, "__await__"): return await result diff --git a/pyisolate/_internal/rpc_protocol.py b/pyisolate/_internal/rpc_protocol.py index 06670f9..1fe4188 100644 --- a/pyisolate/_internal/rpc_protocol.py +++ b/pyisolate/_internal/rpc_protocol.py @@ -336,7 +336,8 @@ async def dispatch_request(self, request: RPCRequest | RPCCallback) -> None: # Try to send response; if serialization fails, send error response instead try: - self._transport.send(_prepare_for_rpc(response)) + prepared = _prepare_for_rpc(response) + self._transport.send(prepared) except (TypeError, ValueError) as serialize_exc: # FAIL LOUD: Log and propagate serialization failures logger.error( diff --git a/pyisolate/_internal/rpc_transports.py b/pyisolate/_internal/rpc_transports.py index 7855257..31f09fe 100644 --- a/pyisolate/_internal/rpc_transports.py +++ b/pyisolate/_internal/rpc_transports.py @@ -145,8 +145,13 @@ def recv(self) -> Any: if not raw_len or len(raw_len) < 4: raise ConnectionError("Socket closed or incomplete length header") msg_len = struct.unpack(">I", raw_len)[0] - if msg_len > 100 * 1024 * 1024: # 100MB sanity limit + if msg_len > 2 * 1024 * 1024 * 1024: # 2GB hard limit raise ValueError(f"Message too large: {msg_len} bytes") + if msg_len > 100 * 1024 * 1024: # 100MB — flag large payloads + logger.warning( + "Large RPC message: %.1fMB — consider SHM-backed transfer for this type", + msg_len / (1024 * 1024), + ) data = self._recvall(msg_len) if len(data) < msg_len: raise ConnectionError(f"Incomplete message: got {len(data)}/{msg_len} bytes") diff --git a/pyisolate/_internal/serialization_registry.py b/pyisolate/_internal/serialization_registry.py index e8d9b96..d4f067b 100644 --- a/pyisolate/_internal/serialization_registry.py +++ b/pyisolate/_internal/serialization_registry.py @@ -22,6 +22,7 @@ class SerializerRegistry: def __init__(self) -> None: self._serializers: dict[str, Callable[[Any], Any]] = {} self._deserializers: dict[str, Callable[[Any], Any]] = {} + self._data_types: set[str] = set() @classmethod def get_instance(cls) -> SerializerRegistry: @@ -35,15 +36,26 @@ def register( type_name: str, serializer: Callable[[Any], Any], deserializer: Callable[[Any], Any] | None = None, + *, + data_type: bool = False, ) -> None: - """Register serializer (and optional deserializer) for a type.""" + """Register serializer (and optional deserializer) for a type. + + Args: + data_type: When True, marks this type as a pure-data payload that + can be serialized in any process context (child or host). + When False (default), the serializer is only used during RPC + transport preparation, not during child-side output wrapping. + """ if type_name in self._serializers: logger.debug("Overwriting existing serializer for %s", type_name) self._serializers[type_name] = serializer if deserializer: self._deserializers[type_name] = deserializer - logger.debug("Registered serializer for type: %s", type_name) + if data_type: + self._data_types.add(type_name) + logger.debug("Registered serializer for type: %s (data_type=%s)", type_name, data_type) def get_serializer(self, type_name: str) -> Callable[[Any], Any] | None: """Return serializer for *type_name*, or None if not registered.""" @@ -57,7 +69,16 @@ def has_handler(self, type_name: str) -> bool: """Return True if *type_name* has a registered serializer.""" return type_name in self._serializers + def is_data_type(self, type_name: str) -> bool: + """Return True if *type_name* is a data payload type. + + Data types are serialized directly during child-side output wrapping + rather than being wrapped as RemoteObjectHandles. + """ + return type_name in self._data_types + def clear(self) -> None: """Remove all registered handlers (useful for tests).""" self._serializers.clear() self._deserializers.clear() + self._data_types.clear() diff --git a/pyisolate/_internal/tensor_serializer.py b/pyisolate/_internal/tensor_serializer.py index 8247c29..fdf89ee 100644 --- a/pyisolate/_internal/tensor_serializer.py +++ b/pyisolate/_internal/tensor_serializer.py @@ -1,6 +1,7 @@ import atexit import base64 import collections +import contextlib import logging import os import signal @@ -195,12 +196,15 @@ def purge_orphan_sender_shm_files(min_age_seconds: float = 1.0, force: bool = Fa def _flush_tensor_keeper_on_exit() -> None: - try: + # purge_orphan_sender_shm_files is intentionally NOT called here. + # Calling it with force=True races with the C++ RefcountedMapAllocator + # destructor: SIGKILL'd child peers leave the refcount file at count > 1, + # so the host's C++ close() decrements but does NOT unlink (refcount still > 0). + # purge_orphan would then unlink the file while the C++ mmap is still open, + # causing a double-unlink → ENOENT → SIGABRT at process exit. + # Files are cleaned up by the C++ destructor when the last consumer closes. + with contextlib.suppress(Exception): flush_tensor_keeper() - purge_orphan_sender_shm_files(min_age_seconds=0.0, force=True) - except Exception: - # Best-effort shutdown cleanup. - pass atexit.register(_flush_tensor_keeper_on_exit) diff --git a/pyisolate/config.py b/pyisolate/config.py index 1ea2334..0005dd2 100644 --- a/pyisolate/config.py +++ b/pyisolate/config.py @@ -1,8 +1,14 @@ from __future__ import annotations +import sys from enum import Enum from typing import TYPE_CHECKING, Any, TypedDict +if sys.version_info >= (3, 11): + from typing import NotRequired +else: + from typing_extensions import NotRequired + if TYPE_CHECKING: from ._internal.rpc_protocol import ProxiedSingleton @@ -36,6 +42,19 @@ class SandboxConfig(TypedDict, total=False): network: bool +class CUDAWheelConfig(TypedDict): + """Configuration for custom CUDA wheel resolution.""" + + index_url: str + """Base URL containing per-package simple index directories.""" + + packages: list[str] + """Canonicalized dependency names that must resolve via the custom index.""" + + package_map: NotRequired[dict[str, str]] + """Optional canonical dependency-name to index-package-name overrides.""" + + class ExtensionConfig(TypedDict): """Configuration for a single extension managed by PyIsolate.""" @@ -69,3 +88,6 @@ class ExtensionConfig(TypedDict): env: dict[str, str] """Environment variable overrides for the child process.""" + + cuda_wheels: NotRequired[CUDAWheelConfig] + """Optional custom CUDA wheel resolution configuration for selected dependencies.""" diff --git a/pyisolate/interfaces.py b/pyisolate/interfaces.py index 4da9bd8..e256c19 100644 --- a/pyisolate/interfaces.py +++ b/pyisolate/interfaces.py @@ -22,6 +22,8 @@ def register( type_name: str, serializer: Callable[[Any], Any], deserializer: Callable[[Any], Any] | None = None, + *, + data_type: bool = False, ) -> None: """Register serializer/deserializer pair for a type.""" @@ -34,6 +36,9 @@ def get_deserializer(self, type_name: str) -> Callable[[Any], Any] | None: def has_handler(self, type_name: str) -> bool: """Return True if a serializer exists for *type_name*.""" + def is_data_type(self, type_name: str) -> bool: + """Return True if *type_name* is a data payload type.""" + @runtime_checkable class IsolationAdapter(Protocol): diff --git a/pyproject.toml b/pyproject.toml index 0db466e..b22170d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "pyisolate" -version = "0.9.1" +version = "0.9.2" description = "A Python library for dividing execution across multiple virtual environments" readme = "README.md" requires-python = ">=3.10" @@ -14,7 +14,8 @@ authors = [ {name = "Jacob Segal", email = "jacob.e.segal@gmail.com"}, ] maintainers = [ - {name = "Jacob Segal", email = "jacob.e.segal@gmail.com"}, + {name = "Jacob Segal", email = "jacob.e.segal@gmail.com"}, + {name = "John Pollock", email = "pollockjj@gmail.com"}, ] classifiers = [ "Development Status :: 3 - Alpha", @@ -28,6 +29,8 @@ keywords = ["virtual environment", "venv", "development"] dependencies = [ "uv>=0.1.0", "tomli>=2.0.1; python_version < '3.11'", + "typing_extensions>=4.0; python_version < '3.11'", + "packaging>=23.0", ] [project.optional-dependencies] diff --git a/tests/test_cuda_wheels.py b/tests/test_cuda_wheels.py new file mode 100644 index 0000000..40d78dc --- /dev/null +++ b/tests/test_cuda_wheels.py @@ -0,0 +1,317 @@ +"""Synthetic/unit coverage for CUDA wheel resolution. + +These tests intentionally use monkeypatches and fake indexes. They do not +perform a real wheel download or a real install. +""" + +import builtins +import io +import sys +from types import SimpleNamespace + +import pytest +from packaging.tags import sys_tags + +from pyisolate._internal import environment +from pyisolate._internal.cuda_wheels import ( + CUDAWheelResolutionError, + get_cuda_wheel_runtime, + resolve_cuda_wheel_requirements, +) + + +def _runtime() -> dict[str, object]: + return { + "torch": "2.8", + "torch_nodot": "28", + "cuda": "12.8", + "cuda_nodot": "128", + "python_tags": [str(tag) for tag in sys_tags()], + } + + +def _wheel_filename(distribution: str, version: str) -> str: + tag = next(iter(sys_tags())) + return f"{distribution}-{version}-{tag.interpreter}-{tag.abi}-{tag.platform}.whl" + + +def _simple_index_html(*filenames: str) -> str: + links = [f'{filename}' for filename in filenames] + return "" + "".join(links) + "" + + +def test_resolve_cuda_wheel_requirement_to_direct_url(monkeypatch): + runtime = _runtime() + wheel = _wheel_filename("flash_attn", "1.1.0+cu128torch28") + page_url = "https://example.invalid/cuda-wheels/flash-attn/" + + monkeypatch.setattr("pyisolate._internal.cuda_wheels.get_cuda_wheel_runtime", lambda: runtime) + monkeypatch.setattr( + "pyisolate._internal.cuda_wheels._fetch_index_html", + lambda url: _simple_index_html(wheel) if url == page_url else None, + ) + + resolved = resolve_cuda_wheel_requirements( + ["flash-attn>=1.0"], + { + "index_url": "https://example.invalid/cuda-wheels/", + "packages": ["flash-attn"], + "package_map": {}, + }, + ) + + assert resolved == [page_url + wheel] + + +def test_resolve_cuda_wheel_requirement_supports_underscore_index(monkeypatch): + runtime = _runtime() + wheel = _wheel_filename("torch_generic_nms", "0.2.0+cu128torch28") + page_url = "https://example.invalid/cuda-wheels/torch_generic_nms/" + + monkeypatch.setattr("pyisolate._internal.cuda_wheels.get_cuda_wheel_runtime", lambda: runtime) + monkeypatch.setattr( + "pyisolate._internal.cuda_wheels._fetch_index_html", + lambda url: _simple_index_html(wheel) if url == page_url else None, + ) + + resolved = resolve_cuda_wheel_requirements( + ["torch-generic-nms>=0.1"], + { + "index_url": "https://example.invalid/cuda-wheels/", + "packages": ["torch-generic-nms"], + "package_map": {}, + }, + ) + + assert resolved == [page_url + wheel] + + +def test_resolve_cuda_wheel_requirement_supports_percent_encoded_links(monkeypatch): + runtime = _runtime() + wheel = _wheel_filename("torch_generic_nms", "0.1+cu128torch28") + encoded_wheel = wheel.replace("+", "%2B") + page_url = "https://example.invalid/cuda-wheels/torch-generic-nms/" + + monkeypatch.setattr("pyisolate._internal.cuda_wheels.get_cuda_wheel_runtime", lambda: runtime) + monkeypatch.setattr( + "pyisolate._internal.cuda_wheels._fetch_index_html", + lambda url: _simple_index_html(encoded_wheel) if url == page_url else None, + ) + + resolved = resolve_cuda_wheel_requirements( + ["torch-generic-nms==0.1"], + { + "index_url": "https://example.invalid/cuda-wheels/", + "packages": ["torch-generic-nms"], + "package_map": {}, + }, + ) + + assert resolved == [page_url + wheel] + + +def test_resolve_cuda_wheel_requirement_honors_package_map(monkeypatch): + runtime = _runtime() + wheel = _wheel_filename("flash_attn", "1.2.0+cu128torch28") + page_url = "https://example.invalid/cuda-wheels/flash_attn_special/" + + monkeypatch.setattr("pyisolate._internal.cuda_wheels.get_cuda_wheel_runtime", lambda: runtime) + monkeypatch.setattr( + "pyisolate._internal.cuda_wheels._fetch_index_html", + lambda url: _simple_index_html(wheel) if url == page_url else None, + ) + + resolved = resolve_cuda_wheel_requirements( + ["flash-attn>=1.0"], + { + "index_url": "https://example.invalid/cuda-wheels/", + "packages": ["flash-attn"], + "package_map": {"flash-attn": "flash_attn_special"}, + }, + ) + + assert resolved == [page_url + wheel] + + +def test_resolve_cuda_wheel_requirement_picks_highest_matching_version(monkeypatch): + runtime = _runtime() + compatible_old = _wheel_filename("flash_attn", "1.1.0+cu128torch28") + compatible_new = _wheel_filename("flash_attn", "1.3.0+pt28cu128") + incompatible_cuda = _wheel_filename("flash_attn", "1.4.0+cu127torch28") + out_of_range = _wheel_filename("flash_attn", "2.0.0+cu128torch28") + page_url = "https://example.invalid/cuda-wheels/flash-attn/" + + monkeypatch.setattr("pyisolate._internal.cuda_wheels.get_cuda_wheel_runtime", lambda: runtime) + monkeypatch.setattr( + "pyisolate._internal.cuda_wheels._fetch_index_html", + lambda url: ( + _simple_index_html( + compatible_old, + compatible_new, + incompatible_cuda, + out_of_range, + ) + if url == page_url + else None + ), + ) + + resolved = resolve_cuda_wheel_requirements( + ["flash-attn>=1.0,<2.0"], + { + "index_url": "https://example.invalid/cuda-wheels/", + "packages": ["flash-attn"], + "package_map": {}, + }, + ) + + assert resolved == [page_url + compatible_new] + + +def test_resolve_cuda_wheel_requirement_prefers_better_supported_tag(monkeypatch): + all_tags = list(sys_tags()) + manylinux_tags = [t for t in all_tags if "manylinux" in t.platform and "x86_64" in t.platform] + linux_tags = [t for t in all_tags if t.platform == "linux_x86_64"] + if not manylinux_tags or not linux_tags: + pytest.skip("manylinux/linux_x86_64 tags not available on this platform") + ml_tag = manylinux_tags[0] + lx_tag = linux_tags[0] + runtime = _runtime() + hyphen_page = "https://example.invalid/cuda-wheels/torch-generic-nms/" + underscore_page = "https://example.invalid/cuda-wheels/torch_generic_nms/" + preferred = f"torch_generic_nms-0.1+cu128torch28-{ml_tag.interpreter}-{ml_tag.abi}-{ml_tag.platform}.whl" + fallback = f"torch_generic_nms-0.1+cu128torch28-{lx_tag.interpreter}-{lx_tag.abi}-{lx_tag.platform}.whl" + + monkeypatch.setattr("pyisolate._internal.cuda_wheels.get_cuda_wheel_runtime", lambda: runtime) + monkeypatch.setattr( + "pyisolate._internal.cuda_wheels._fetch_index_html", + lambda url: ( + _simple_index_html(preferred) + if url == hyphen_page + else _simple_index_html(fallback) + if url == underscore_page + else None + ), + ) + + resolved = resolve_cuda_wheel_requirements( + ["torch-generic-nms==0.1"], + { + "index_url": "https://example.invalid/cuda-wheels/", + "packages": ["torch-generic-nms"], + "package_map": {}, + }, + ) + + assert resolved == [hyphen_page + preferred] + + +def test_resolve_cuda_wheel_requirement_raises_when_no_match(monkeypatch): + runtime = _runtime() + wheel = _wheel_filename("flash_attn", "1.1.0+cu127torch28") + page_url = "https://example.invalid/cuda-wheels/flash-attn/" + + monkeypatch.setattr("pyisolate._internal.cuda_wheels.get_cuda_wheel_runtime", lambda: runtime) + monkeypatch.setattr( + "pyisolate._internal.cuda_wheels._fetch_index_html", + lambda url: _simple_index_html(wheel) if url == page_url else None, + ) + + with pytest.raises(CUDAWheelResolutionError, match="No compatible CUDA wheel found"): + resolve_cuda_wheel_requirements( + ["flash-attn>=1.0"], + { + "index_url": "https://example.invalid/cuda-wheels/", + "packages": ["flash-attn"], + "package_map": {}, + }, + ) + + +def test_get_cuda_wheel_runtime_raises_without_torch(monkeypatch): + real_import = builtins.__import__ + + def missing_torch(name, *args, **kwargs): + if name == "torch": + raise ImportError("missing torch") + return real_import(name, *args, **kwargs) + + monkeypatch.setattr(builtins, "__import__", missing_torch) + + with pytest.raises(CUDAWheelResolutionError, match="host torch"): + get_cuda_wheel_runtime() + + +def test_get_cuda_wheel_runtime_raises_without_cuda(monkeypatch): + fake_torch = SimpleNamespace( + __version__="2.8.1", + version=SimpleNamespace(cuda=None), + ) + monkeypatch.setitem(sys.modules, "torch", fake_torch) + + with pytest.raises(CUDAWheelResolutionError, match="CUDA-enabled host torch"): + get_cuda_wheel_runtime() + + +def test_install_dependencies_cache_invalidation_tracks_cuda_runtime(monkeypatch, tmp_path): + import os + + venv_path = tmp_path / "venv" + python_exe = venv_path / "Scripts" / "python.exe" if os.name == "nt" else venv_path / "bin" / "python" + python_exe.parent.mkdir(parents=True, exist_ok=True) + python_exe.write_text("#!/usr/bin/env python\n", encoding="utf-8") + + monkeypatch.setattr(environment.shutil, "which", lambda binary: "/usr/bin/uv") + monkeypatch.setattr( + environment, + "exclude_satisfied_requirements", + lambda config, requirements, python_exe: requirements, + ) + monkeypatch.setattr( + environment, + "resolve_cuda_wheel_requirements", + lambda requirements, config: ["https://example.invalid/flash_attn.whl"], + ) + + current_runtime = {"value": {"torch": "2.8", "cuda": "12.8", "python_tags": ["cp312"]}} + monkeypatch.setattr( + environment, + "get_cuda_wheel_runtime_descriptor", + lambda: current_runtime["value"], + ) + + popen_calls: list[list[str]] = [] + + class MockPopen: + def __init__(self, cmd, **kwargs): + popen_calls.append(cmd) + self.stdout = io.StringIO("installed\n") + + def wait(self): + return 0 + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr(environment.subprocess, "Popen", MockPopen) + + config = { + "dependencies": ["flash-attn>=1.0"], + "share_torch": True, + "cuda_wheels": { + "index_url": "https://example.invalid/cuda-wheels/", + "packages": ["flash-attn"], + "package_map": {}, + }, + } + + environment.install_dependencies(venv_path, config, "demo") + environment.install_dependencies(venv_path, config, "demo") + + current_runtime["value"] = {"torch": "2.8", "cuda": "12.9", "python_tags": ["cp312"]} + environment.install_dependencies(venv_path, config, "demo") + + assert len(popen_calls) == 2 diff --git a/tests/test_model_serialization.py b/tests/test_model_serialization.py new file mode 100644 index 0000000..ba9fc6f --- /dev/null +++ b/tests/test_model_serialization.py @@ -0,0 +1,121 @@ +"""Tests for model_serialization.py — dict guard on deserialize_from_isolation. + +The isinstance(data, dict) guard at line 104 prevents adapter-registered +deserializers from being applied to non-dict objects (e.g., already- +materialized PLY/File3D instances reconstructed by _json_object_hook). +Without the guard, passing a materialized object to its own deserializer +(which expects a dict) causes a double-deserialization bug. +""" + +import pytest + +from pyisolate._internal.model_serialization import deserialize_from_isolation +from pyisolate._internal.serialization_registry import SerializerRegistry + + +@pytest.fixture(autouse=True) +def clean_registry() -> None: + SerializerRegistry.get_instance().clear() + + +class TestDictGuard: + async def test_dict_with_registered_handler_calls_deserializer(self) -> None: + # dict + matching handler + deserializer → deserializer IS invoked + registry = SerializerRegistry.get_instance() + sentinel = object() + registry.register("dict", lambda x: x, lambda x: sentinel) + result = await deserialize_from_isolation({}) + assert result is sentinel + + async def test_non_dict_object_skips_deserializer(self) -> None: + # Non-dict + matching handler → guard blocks deserializer, object passes through + class Foo: + pass + + called = False + + def bad_deserializer(x: object) -> object: + nonlocal called + called = True + return x + + registry = SerializerRegistry.get_instance() + registry.register("Foo", lambda x: x, bad_deserializer) + + foo = Foo() + result = await deserialize_from_isolation(foo) + assert result is foo + assert not called + + async def test_already_materialized_passthrough(self) -> None: + # Core bug scenario: a PLY-like object already reconstructed by + # _json_object_hook has a registered handler. The dict guard must + # prevent re-deserialization — the object passes through unchanged. + class PLY: + def __init__(self, raw_data: bytes) -> None: + self.raw_data = raw_data + + def ply_deserializer(d: object) -> PLY: + # If called on a PLY instance (not a dict), this would raise or corrupt + raise AssertionError("deserializer must not be called on already-materialized PLY") + + registry = SerializerRegistry.get_instance() + registry.register("PLY", lambda x: x, ply_deserializer) + + materialized = PLY(raw_data=b"\x70\x6c\x79") # already reconstructed + result = await deserialize_from_isolation(materialized) + assert result is materialized + assert result.raw_data == b"\x70\x6c\x79" + + +class TestRefTypeDeserialization: + async def test_dict_ref_type_uses_registered_deserializer(self) -> None: + # {"__type__": "MyRef"} with registered handler → deserializer called + registry = SerializerRegistry.get_instance() + sentinel = object() + registry.register("MyRef", lambda x: x, lambda x: sentinel) + result = await deserialize_from_isolation({"__type__": "MyRef", "id": "abc"}) + assert result is sentinel + + async def test_dict_ref_type_unknown_returns_dict(self) -> None: + # Unknown __type__ with no handler → dict returned as-is (recursively deserialized) + result = await deserialize_from_isolation({"__type__": "Unknown", "val": 42}) + assert isinstance(result, dict) + assert result["val"] == 42 + + async def test_nested_dict_ref_deserialization(self) -> None: + # {"a": {"__type__": "MyRef"}} — inner ref must be recursively deserialized + registry = SerializerRegistry.get_instance() + sentinel = object() + registry.register("MyRef", lambda x: x, lambda x: sentinel) + result = await deserialize_from_isolation({"a": {"__type__": "MyRef", "id": "xyz"}}) + assert isinstance(result, dict) + assert result["a"] is sentinel + + +class TestContainerPassthrough: + async def test_list_items_deserialized(self) -> None: + registry = SerializerRegistry.get_instance() + sentinel = object() + registry.register("MyRef", lambda x: x, lambda x: sentinel) + result = await deserialize_from_isolation([{"__type__": "MyRef"}, 1, "str"]) + assert result[0] is sentinel + assert result[1] == 1 + assert result[2] == "str" + + async def test_tuple_preserved_as_tuple(self) -> None: + result = await deserialize_from_isolation((1, 2, 3)) + assert isinstance(result, tuple) + assert result == (1, 2, 3) + + async def test_string_passthrough(self) -> None: + result = await deserialize_from_isolation("hello") + assert result == "hello" + + async def test_int_passthrough(self) -> None: + result = await deserialize_from_isolation(42) + assert result == 42 + + async def test_none_passthrough(self) -> None: + result = await deserialize_from_isolation(None) + assert result is None diff --git a/tests/test_rpc_transports.py b/tests/test_rpc_transports.py new file mode 100644 index 0000000..e9d93f9 --- /dev/null +++ b/tests/test_rpc_transports.py @@ -0,0 +1,178 @@ +"""Tests for JSONSocketTransport message size limits and recv behavior. + +Strategy for size-limit tests: mock `_recvall` to inject crafted length +headers without allocating multi-GB buffers. Real socketpair() used for +roundtrip and connection-error tests. +""" + +import logging +import socket +import struct +from collections.abc import Iterator +from unittest.mock import patch + +import pytest + +from pyisolate._internal.rpc_transports import JSONSocketTransport + +MB = 1024 * 1024 +GB = 1024 * MB + +TRANSPORT_LOGGER = "pyisolate._internal.rpc_transports" + + +def _make_transport() -> JSONSocketTransport: + a, b = socket.socketpair() + b.close() + return JSONSocketTransport(a) + + +def _header_then_empty(msg_len: int): # type: ignore[no-untyped-def] + """Return a _recvall side_effect: serve header bytes then empty (incomplete body).""" + header = struct.pack(">I", msg_len & 0xFFFFFFFF) + call_count = 0 + + def fake_recvall(n: int) -> bytes: + nonlocal call_count + call_count += 1 + return header if call_count == 1 else b"" + + return fake_recvall + + +@pytest.fixture() +def socket_pair() -> Iterator[tuple[JSONSocketTransport, JSONSocketTransport]]: + a, b = socket.socketpair() + transport_a = JSONSocketTransport(a) + transport_b = JSONSocketTransport(b) + try: + yield transport_a, transport_b + finally: + transport_a.close() + transport_b.close() + + +class TestSendRecvRoundtrip: + def test_small_message_roundtrip( + self, socket_pair: tuple[JSONSocketTransport, JSONSocketTransport] + ) -> None: + sender, receiver = socket_pair + payload = {"kind": "call", "method": "test", "args": [1, 2, 3]} + sender.send(payload) + result = receiver.recv() + assert result["kind"] == "call" + assert result["method"] == "test" + assert result["args"] == [1, 2, 3] + + def test_send_does_not_enforce_2gb_limit( + self, socket_pair: tuple[JSONSocketTransport, JSONSocketTransport] + ) -> None: + # send() uses struct.pack(">I") — no explicit 2GB check; limit is recv-only + sender, _ = socket_pair + payload = {"data": "x" * 1000} + sender.send(payload) # must not raise + + +class TestRecvHardLimit: + def test_2gb_minus_1_not_rejected(self) -> None: + transport = _make_transport() + with ( + patch.object(transport, "_recvall", side_effect=_header_then_empty(2 * GB - 1)), + pytest.raises(ConnectionError), + ): + transport.recv() + + def test_2gb_exact_not_rejected(self) -> None: + transport = _make_transport() + with ( + patch.object(transport, "_recvall", side_effect=_header_then_empty(2 * GB)), + pytest.raises(ConnectionError), + ): + transport.recv() + + def test_2gb_plus_1_raises_value_error(self) -> None: + # 2GB+1 = 2147483649 fits in uint32 and is the exact guard boundary + transport = _make_transport() + with ( + patch.object(transport, "_recvall", side_effect=_header_then_empty(2 * GB + 1)), + pytest.raises(ValueError, match="Message too large"), + ): + transport.recv() + + def test_3gb_raises_value_error(self) -> None: + # 3GB fits in an unsigned 32-bit header and exceeds the 2GB hard limit + transport = _make_transport() + with ( + patch.object(transport, "_recvall", side_effect=_header_then_empty(3 * GB)), + pytest.raises(ValueError, match="Message too large"), + ): + transport.recv() + + +class TestRecvWarningThreshold: + def test_100mb_minus_1_no_warning(self, caplog: pytest.LogCaptureFixture) -> None: + transport = _make_transport() + with ( + caplog.at_level(logging.WARNING, logger=TRANSPORT_LOGGER), + patch.object(transport, "_recvall", side_effect=_header_then_empty(100 * MB - 1)), + pytest.raises(ConnectionError), + ): + transport.recv() + assert not any("Large RPC message" in r.message for r in caplog.records) + + def test_100mb_exact_no_warning(self, caplog: pytest.LogCaptureFixture) -> None: + # Threshold is strictly >100MB; exactly 100MB must NOT trigger the warning + transport = _make_transport() + with ( + caplog.at_level(logging.WARNING, logger=TRANSPORT_LOGGER), + patch.object(transport, "_recvall", side_effect=_header_then_empty(100 * MB)), + pytest.raises(ConnectionError), + ): + transport.recv() + assert not any("Large RPC message" in r.message for r in caplog.records) + + def test_100mb_plus_1_triggers_warning(self, caplog: pytest.LogCaptureFixture) -> None: + transport = _make_transport() + with ( + caplog.at_level(logging.WARNING, logger=TRANSPORT_LOGGER), + patch.object(transport, "_recvall", side_effect=_header_then_empty(100 * MB + 1)), + pytest.raises(ConnectionError), + ): + transport.recv() + assert any("Large RPC message" in r.message for r in caplog.records) + + +class TestConnectionErrors: + def test_incomplete_length_header_raises(self) -> None: + transport = _make_transport() + + def fake_recvall(n: int) -> bytes: + return b"\x00\x00" # only 2 bytes instead of 4 + + with ( + patch.object(transport, "_recvall", side_effect=fake_recvall), + pytest.raises(ConnectionError, match="incomplete length header"), + ): + transport.recv() + + def test_incomplete_message_body_raises(self) -> None: + transport = _make_transport() + call_count = 0 + + def fake_recvall(n: int) -> bytes: + nonlocal call_count + call_count += 1 + return struct.pack(">I", 100) if call_count == 1 else b"short" + + with ( + patch.object(transport, "_recvall", side_effect=fake_recvall), + pytest.raises(ConnectionError, match="Incomplete message"), + ): + transport.recv() + + def test_socket_closed_mid_header_raises(self) -> None: + a, b = socket.socketpair() + transport = JSONSocketTransport(a) + b.close() + with pytest.raises((ConnectionError, OSError)): + transport.recv() diff --git a/tests/test_serialization_registry.py b/tests/test_serialization_registry.py index 50377f5..76d5d39 100644 --- a/tests/test_serialization_registry.py +++ b/tests/test_serialization_registry.py @@ -1,16 +1,24 @@ +import inspect + +import pytest + from pyisolate._internal.serialization_registry import SerializerRegistry +from pyisolate.interfaces import SerializerRegistryProtocol + +@pytest.fixture(autouse=True) +def clean_registry() -> None: + SerializerRegistry.get_instance().clear() -def test_singleton_identity(): + +def test_singleton_identity() -> None: r1 = SerializerRegistry.get_instance() r2 = SerializerRegistry.get_instance() assert r1 is r2 -def test_register_and_lookup(): +def test_register_and_lookup() -> None: registry = SerializerRegistry.get_instance() - registry.clear() - registry.register("Foo", lambda x: {"v": x}, lambda x: x["v"]) assert registry.has_handler("Foo") @@ -22,10 +30,90 @@ def test_register_and_lookup(): assert deserializer(payload) == 123 if deserializer else False -def test_clear_resets_handlers(): +def test_clear_resets_handlers() -> None: registry = SerializerRegistry.get_instance() registry.register("Bar", lambda x: x) assert registry.has_handler("Bar") registry.clear() assert not registry.has_handler("Bar") + + +class TestDataTypeFlag: + def test_register_with_data_type_flag(self) -> None: + registry = SerializerRegistry.get_instance() + registry.register("MyData", lambda x: x, data_type=True) + assert registry.is_data_type("MyData") + + def test_register_without_data_type_flag(self) -> None: + registry = SerializerRegistry.get_instance() + registry.register("MyData", lambda x: x) + assert not registry.is_data_type("MyData") + + def test_is_data_type_unregistered(self) -> None: + registry = SerializerRegistry.get_instance() + assert not registry.is_data_type("NeverRegistered") + + def test_clear_resets_data_types(self) -> None: + registry = SerializerRegistry.get_instance() + registry.register("MyData", lambda x: x, data_type=True) + assert registry.is_data_type("MyData") + registry.clear() + assert not registry.is_data_type("MyData") + + def test_overwrite_preserves_data_type(self) -> None: + # Register with data_type=True, then re-register without — set is additive + registry = SerializerRegistry.get_instance() + registry.register("MyData", lambda x: x, data_type=True) + registry.register("MyData", lambda x: x) # no data_type kwarg + assert registry.is_data_type("MyData") + + def test_overwrite_adds_data_type(self) -> None: + # Register without, then re-register with data_type=True + registry = SerializerRegistry.get_instance() + registry.register("MyData", lambda x: x) + assert not registry.is_data_type("MyData") + registry.register("MyData", lambda x: x, data_type=True) + assert registry.is_data_type("MyData") + + def test_data_type_idempotent(self) -> None: + # Repeated register with data_type=True has no side effects + registry = SerializerRegistry.get_instance() + for _ in range(3): + registry.register("MyData", lambda x: x, data_type=True) + assert registry.is_data_type("MyData") + registry.clear() + assert not registry.is_data_type("MyData") + + def test_data_type_cross_type_isolation(self) -> None: + # Setting type A as data_type does not affect type B + registry = SerializerRegistry.get_instance() + registry.register("TypeA", lambda x: x, data_type=True) + registry.register("TypeB", lambda x: x) + assert registry.is_data_type("TypeA") + assert not registry.is_data_type("TypeB") + + +class TestProtocolCompliance: + def test_protocol_isinstance_check(self) -> None: + # SerializerRegistryProtocol is @runtime_checkable + registry = SerializerRegistry.get_instance() + assert isinstance(registry, SerializerRegistryProtocol) + + def test_protocol_is_data_type_callable_with_correct_signature(self) -> None: + # Verify is_data_type exists on the protocol and has the expected signature + sig = inspect.signature(SerializerRegistryProtocol.is_data_type) + params = list(sig.parameters) + assert "self" in params + assert "type_name" in params + # Invoke via protocol-typed reference to confirm structural contract + registry: SerializerRegistryProtocol = SerializerRegistry.get_instance() + registry.register("SigTest", lambda x: x, data_type=True) + result = registry.is_data_type("SigTest") + assert result is True + + def test_protocol_register_accepts_data_type_kwarg(self) -> None: + registry: SerializerRegistryProtocol = SerializerRegistry.get_instance() + # Must not raise — data_type is a keyword-only arg on the protocol + registry.register("ProtoTest", lambda x: x, data_type=True) + assert registry.is_data_type("ProtoTest")