From 7695e9180994dbc155f5fb02654d756d6bb71d3d Mon Sep 17 00:00:00 2001 From: "Ralf W. Grosse-Kunstleve" Date: Mon, 29 Jun 2026 17:46:49 -0700 Subject: [PATCH] test: patch pytest-run-parallel worker context setup --- .../tests/test_pytest_run_parallel_patch.py | 121 ++++++++++ cuda_bindings/tests/conftest.py | 54 +++++ cuda_core/tests/conftest.py | 84 ++++++- .../pytest_run_parallel.py | 209 ++++++++++++++++++ 4 files changed, 456 insertions(+), 12 deletions(-) create mode 100644 ci/tools/tests/test_pytest_run_parallel_patch.py create mode 100644 cuda_python_test_helpers/cuda_python_test_helpers/pytest_run_parallel.py diff --git a/ci/tools/tests/test_pytest_run_parallel_patch.py b/ci/tools/tests/test_pytest_run_parallel_patch.py new file mode 100644 index 00000000000..0eff629c210 --- /dev/null +++ b/ci/tools/tests/test_pytest_run_parallel_patch.py @@ -0,0 +1,121 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import sys +import threading +import types +from contextlib import contextmanager +from pathlib import Path + +import pytest + +_TEST_HELPERS_ROOT = Path(__file__).resolve().parents[3] / "cuda_python_test_helpers" +sys.path.insert(0, str(_TEST_HELPERS_ROOT)) + +from cuda_python_test_helpers.pytest_run_parallel import ( + install_run_parallel_worker_context_patch, + mark_item_for_worker_context, +) + + +def _install_fake_pytest_run_parallel(monkeypatch): + package = types.ModuleType("pytest_run_parallel") + package.__path__ = [] + plugin = types.ModuleType("pytest_run_parallel.plugin") + + def wrap_function_parallel(fn, n_workers, n_iterations): + raise AssertionError("unpatched fake wrapper should not be called") + + plugin.wrap_function_parallel = wrap_function_parallel + monkeypatch.setitem(sys.modules, "pytest_run_parallel", package) + monkeypatch.setitem(sys.modules, "pytest_run_parallel.plugin", plugin) + return plugin + + +@pytest.mark.agent_authored(model="gpt-5") +def test_install_run_parallel_worker_context_patch_is_idempotent(monkeypatch): + plugin = _install_fake_pytest_run_parallel(monkeypatch) + + assert install_run_parallel_worker_context_patch() is True + patched = plugin.wrap_function_parallel + assert patched._cuda_python_patched_run_parallel_worker_context + + assert install_run_parallel_worker_context_patch() is True + assert plugin.wrap_function_parallel is patched + + +@pytest.mark.agent_authored(model="gpt-5") +def test_patched_wrapper_runs_context_with_isolated_kwargs(monkeypatch): + plugin = _install_fake_pytest_run_parallel(monkeypatch) + install_run_parallel_worker_context_patch() + + lock = threading.Lock() + context_events = [] + calls = [] + + @contextmanager + def worker_context(*, thread_index, iteration_index, kwargs): + token = object() + kwargs["token"] = (thread_index, iteration_index, id(token)) + kwargs["context_kwargs_id"] = id(kwargs) + with lock: + context_events.append(("enter", thread_index, iteration_index, id(kwargs))) + try: + yield + finally: + with lock: + context_events.append(("exit", thread_index, iteration_index, id(kwargs))) + + def test_body(*, thread_index, iteration_index, token, context_kwargs_id, static_value): + with lock: + calls.append( + { + "thread_index": thread_index, + "iteration_index": iteration_index, + "token": token, + "context_kwargs_id": context_kwargs_id, + "static_value": static_value, + } + ) + + item = types.SimpleNamespace(obj=test_body) + assert mark_item_for_worker_context(item, worker_context) is True + + wrapped = plugin.wrap_function_parallel(item.obj, n_workers=3, n_iterations=2) + wrapped(thread_index=-1, iteration_index=-1, static_value="fixture-value") + + expected_pairs = {(thread_index, iteration_index) for thread_index in range(3) for iteration_index in range(2)} + actual_pairs = {(call["thread_index"], call["iteration_index"]) for call in calls} + assert actual_pairs == expected_pairs + assert {call["token"][:2] for call in calls} == expected_pairs + assert {call["static_value"] for call in calls} == {"fixture-value"} + + kwargs_ids = {call["context_kwargs_id"] for call in calls} + assert len(kwargs_ids) == 6 + assert len(context_events) == 12 + assert {event[3] for event in context_events} == kwargs_ids + + +@pytest.mark.agent_authored(model="gpt-5") +def test_mark_item_for_worker_context_wraps_callables_without_attrs(): + class CallableWithoutAttrs: + __slots__ = ("calls",) + + def __init__(self): + self.calls = [] + + def __call__(self, **kwargs): + self.calls.append(kwargs) + + @contextmanager + def worker_context(*, thread_index, iteration_index, kwargs): + kwargs["patched"] = True + yield + + original = CallableWithoutAttrs() + item = types.SimpleNamespace(obj=original) + + assert mark_item_for_worker_context(item, worker_context) is True + assert item.obj is not original + item.obj() + assert original.calls == [{}] diff --git a/cuda_bindings/tests/conftest.py b/cuda_bindings/tests/conftest.py index f30500c1342..8d7f6593d52 100644 --- a/cuda_bindings/tests/conftest.py +++ b/cuda_bindings/tests/conftest.py @@ -3,6 +3,7 @@ import pathlib import sys +from contextlib import contextmanager from importlib.metadata import PackageNotFoundError, distribution import pytest @@ -25,6 +26,59 @@ sys.path.insert(0, test_helpers_root) +from cuda_python_test_helpers.pytest_run_parallel import ( + install_run_parallel_worker_context_patch, + mark_item_for_worker_context, +) + + +def pytest_configure(config): + install_run_parallel_worker_context_patch() + + +@contextmanager +def _thread_context(): + (err,) = cuda.cuInit(0) + assert err == cuda.CUresult.CUDA_SUCCESS + err, device = cuda.cuDeviceGet(0) + assert err == cuda.CUresult.CUDA_SUCCESS + err, ctx = cuda.cuCtxCreate(None, 0, device) + assert err == cuda.CUresult.CUDA_SUCCESS + try: + yield device, ctx + finally: + (err,) = cuda.cuCtxDestroy(ctx) + assert err == cuda.CUresult.CUDA_SUCCESS + + +@contextmanager +def _cuda_bindings_worker_context(*, thread_index, iteration_index, kwargs): + with _thread_context() as (device, ctx): + if "device" in kwargs: + kwargs["device"] = device + if "ctx" in kwargs: + kwargs["ctx"] = ctx + yield + + +def _is_cudla_item(item): + nodeid = item.nodeid.replace("\\", "/") + return nodeid.startswith("tests/cudla/") or "cuda_bindings/tests/cudla/" in nodeid + + +def _item_needs_thread_ctx(item): + if _is_cudla_item(item): + return False + fixturenames = set(getattr(item, "fixturenames", ())) + return bool(fixturenames & {"device", "ctx", "driver", "cufile_env_json"}) + + +def pytest_collection_modifyitems(config, items): + for item in items: + if _item_needs_thread_ctx(item): + mark_item_for_worker_context(item, _cuda_bindings_worker_context) + + @pytest.fixture(scope="module") def cuda_driver(): (err,) = cuda.cuInit(0) diff --git a/cuda_core/tests/conftest.py b/cuda_core/tests/conftest.py index 611f83c3a0e..b38930c09b7 100644 --- a/cuda_core/tests/conftest.py +++ b/cuda_core/tests/conftest.py @@ -91,6 +91,76 @@ def xfail_if_mempool_oom(err_or_exc, api_name=None, device=0): sys.path.insert(0, test_helpers_root) +from cuda_python_test_helpers.pytest_run_parallel import ( + install_run_parallel_worker_context_patch, + mark_item_for_worker_context, +) + + +def pytest_configure(config): + install_run_parallel_worker_context_patch() + + +@contextmanager +def _init_cuda_context(): + # TODO: rename this to e.g. init_context + device = Device(0) + device.set_current() + + # Set option to avoid spin-waiting on synchronization. + if int(os.environ.get("CUDA_CORE_TEST_BLOCKING_SYNC", 0)) != 0: + handle_return( + driver.cuDevicePrimaryCtxSetFlags(device.device_id, driver.CUctx_flags.CU_CTX_SCHED_BLOCKING_SYNC) + ) + + try: + yield device + finally: + _ = _device_unset_current() + + +@contextmanager +def _cuda_core_worker_context(*, thread_index, iteration_index, kwargs): + with _init_cuda_context() as device: + if "init_cuda" in kwargs: + kwargs["init_cuda"] = device + if "mempool_device" in kwargs: + kwargs["mempool_device"] = device + if "ipc_device" in kwargs: + kwargs["ipc_device"] = device + if "mempool_device_x2" in kwargs: + kwargs["mempool_device_x2"] = _mempool_device_impl(2) + if "mempool_device_x3" in kwargs: + kwargs["mempool_device_x3"] = _mempool_device_impl(3) + if "ipc_mempool_device_x2" in kwargs: + kwargs["ipc_mempool_device_x2"] = _require_ipc_mempool_devices(_mempool_device_impl(2)) + yield + + +_CUDA_CONTEXT_FIXTURES = frozenset( + { + "init_cuda", + "ipc_device", + "ipc_memory_resource", + "mempool_device", + "mempool_device_x2", + "mempool_device_x3", + "ipc_mempool_device_x2", + "memory_resource_factory", + } +) + + +def _item_needs_thread_ctx(item): + return bool(_CUDA_CONTEXT_FIXTURES & set(getattr(item, "fixturenames", ()))) + + +def pytest_collection_modifyitems(config, items): + for item in items: + if _item_needs_thread_ctx(item): + mark_item_for_worker_context(item, _cuda_core_worker_context) + + def skip_if_pinned_memory_unsupported(device): try: if not device.properties.host_memory_pools_supported: @@ -194,18 +264,8 @@ def session_setup(): @pytest.fixture def init_cuda(): - # TODO: rename this to e.g. init_context - device = Device(0) - device.set_current() - - # Set option to avoid spin-waiting on synchronization. - if int(os.environ.get("CUDA_CORE_TEST_BLOCKING_SYNC", 0)) != 0: - handle_return( - driver.cuDevicePrimaryCtxSetFlags(device.device_id, driver.CUctx_flags.CU_CTX_SCHED_BLOCKING_SYNC) - ) - - yield device - _ = _device_unset_current() + with _init_cuda_context() as device: + yield device def _device_unset_current() -> bool: diff --git a/cuda_python_test_helpers/cuda_python_test_helpers/pytest_run_parallel.py b/cuda_python_test_helpers/cuda_python_test_helpers/pytest_run_parallel.py new file mode 100644 index 00000000000..9a1fe32a84f --- /dev/null +++ b/cuda_python_test_helpers/cuda_python_test_helpers/pytest_run_parallel.py @@ -0,0 +1,209 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import functools +import importlib +import inspect +import sys +import threading +from contextlib import ExitStack, contextmanager, nullcontext +from typing import Any, Callable + +import _pytest.outcomes +import pytest + +_WORKER_CONTEXT_ATTR = "_cuda_python_run_parallel_worker_context" +_PATCHED_ATTR = "_cuda_python_patched_run_parallel_worker_context" +_ORIGINAL_ATTR = "_cuda_python_original_wrap_function_parallel" + + +def install_run_parallel_worker_context_patch() -> bool: + """Patch pytest-run-parallel to run optional per-worker context managers. + + Returns True when pytest-run-parallel is importable and patched, and False + when the plugin is not installed in the active environment. + """ + try: + plugin = importlib.import_module("pytest_run_parallel.plugin") + except ModuleNotFoundError as exc: + if exc.name == "pytest_run_parallel" or exc.name.startswith("pytest_run_parallel."): + return False + raise + + wrap_function_parallel = getattr(plugin, "wrap_function_parallel", None) + if wrap_function_parallel is None: + raise RuntimeError("pytest-run-parallel does not expose wrap_function_parallel") + + if getattr(wrap_function_parallel, _PATCHED_ATTR, False): + return True + + _validate_wrap_function_parallel(wrap_function_parallel) + patched = _make_patched_wrap_function_parallel() + setattr(patched, _PATCHED_ATTR, True) + setattr(patched, _ORIGINAL_ATTR, wrap_function_parallel) + plugin.wrap_function_parallel = patched + return True + + +def mark_item_for_worker_context(item: Any, context_factory: Callable[..., Any]) -> bool: + """Attach a worker context factory to a pytest item function. + + The factory is called in each pytest-run-parallel worker thread as: + ``factory(thread_index=..., iteration_index=..., kwargs=...)``. + It may mutate ``kwargs`` before yielding. + """ + obj = getattr(item, "obj", None) + if obj is None: + return False + + try: + set_worker_context_factory(obj, context_factory) + except (AttributeError, TypeError): + original = obj + + @functools.wraps(original) + def wrapper(*args: Any, **kwargs: Any) -> Any: + return original(*args, **kwargs) + + set_worker_context_factory(wrapper, context_factory) + item.obj = wrapper + + return True + + +def set_worker_context_factory(func: Callable[..., Any], context_factory: Callable[..., Any]) -> Callable[..., Any]: + """Attach or compose a pytest-run-parallel worker context factory.""" + existing = getattr(func, _WORKER_CONTEXT_ATTR, None) + if existing is None: + setattr(func, _WORKER_CONTEXT_ATTR, context_factory) + elif existing is not context_factory: + setattr(func, _WORKER_CONTEXT_ATTR, _compose_context_factories(existing, context_factory)) + return func + + +def _validate_wrap_function_parallel(wrap_function_parallel: Callable[..., Any]) -> None: + parameters = tuple(inspect.signature(wrap_function_parallel).parameters) + expected = ("fn", "n_workers", "n_iterations") + if parameters != expected: + raise RuntimeError( + f"Unsupported pytest-run-parallel wrap_function_parallel signature: expected {expected}, got {parameters}" + ) + + +def _make_patched_wrap_function_parallel() -> Callable[..., Any]: + def wrap_function_parallel(fn: Callable[..., Any], n_workers: int, n_iterations: int) -> Callable[..., Any]: + context_factory = getattr(fn, _WORKER_CONTEXT_ATTR, None) + + @functools.wraps(fn) + def inner(*args: Any, **kwargs: Any) -> None: + errors = [] + skip = None + failed = None + barrier = threading.Barrier(n_workers) + original_switch = sys.getswitchinterval() + new_switch = 1e-6 + for _ in range(3): + try: + sys.setswitchinterval(new_switch) + break + except ValueError: + new_switch *= 10 + else: + sys.setswitchinterval(original_switch) + + try: + + def closure(*args: Any, **kwargs: Any) -> None: + nonlocal skip, failed + + thread_index, args = args[0], args[1:] + worker_kwargs = dict(kwargs) + if n_workers > 1: + if "thread_index" in worker_kwargs: + worker_kwargs["thread_index"] = thread_index + if "tmp_path" in worker_kwargs: + worker_kwargs["tmp_path"] = worker_kwargs["tmp_path"] / f"thread_{thread_index!s}" + worker_kwargs["tmp_path"].mkdir(exist_ok=True) + if "tmpdir" in worker_kwargs: + worker_kwargs["tmpdir"] = worker_kwargs["tmpdir"].ensure( + f"thread_{thread_index!s}", dir=True + ) + + for i in range(n_iterations): + call_kwargs = dict(worker_kwargs) + if "iteration_index" in call_kwargs: + call_kwargs["iteration_index"] = i + + barrier.wait() + try: + with _worker_context(context_factory, thread_index, i, call_kwargs): + fn(*args, **call_kwargs) + except Warning: + pass + except Exception as e: + errors.append(e) + except _pytest.outcomes.Skipped as s: + skip = s.msg + except _pytest.outcomes.Failed as f: + failed = f + + workers = [] + for i in range(n_workers): + workers.append(threading.Thread(target=closure, args=(i, *args), kwargs=kwargs)) + + num_completed = 0 + try: + for worker in workers: + worker.start() + num_completed += 1 + finally: + if num_completed < len(workers): + barrier.abort() + + for worker in workers: + worker.join() + + finally: + sys.setswitchinterval(original_switch) + + if skip is not None: + pytest.skip(skip) + elif failed is not None: + raise failed + elif errors: + raise errors[0] + + return inner + + return wrap_function_parallel + + +@contextmanager +def _worker_context(context_factory, thread_index: int, iteration_index: int, kwargs: dict): + if context_factory is None: + with nullcontext(): + yield + return + + context = context_factory(thread_index=thread_index, iteration_index=iteration_index, kwargs=kwargs) + if context is None: + context = nullcontext() + with context: + yield + + +def _compose_context_factories(first: Callable[..., Any], second: Callable[..., Any]) -> Callable[..., Any]: + @contextmanager + def combined(**kwargs: Any): + with ExitStack() as stack: + first_context = first(**kwargs) + if first_context is not None: + stack.enter_context(first_context) + second_context = second(**kwargs) + if second_context is not None: + stack.enter_context(second_context) + yield + + return combined