Skip to content

Commit f02b730

Browse files
Add Device context manager for temporary device switching
Closes #1586. Adds __enter__/__exit__ to Device so it can be used as a context manager that saves the current CUDA context on entry and restores it on exit. Uses cuCtxGetCurrent/cuCtxSetCurrent (not push/pop) for interoperability with the runtime API. Saved contexts are stored on a per-thread stack (_tls._ctx_stack) so nested and reentrant usage works correctly. Also adds teardown to mempool_device_x2/x3 fixtures to clean up residual contexts between tests. Co-authored-by: Cursor <cursoragent@cursor.com>
1 parent 80463f6 commit f02b730

File tree

3 files changed

+224
-3
lines changed

3 files changed

+224
-3
lines changed

cuda_core/cuda/core/_device.pyx

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

@@ -1188,6 +1188,49 @@ class Device:
11881188
def __reduce__(self):
11891189
return Device, (self.device_id,)
11901190

1191+
def __enter__(self):
1192+
"""Set this device as current for the duration of the ``with`` block.
1193+
1194+
On exit, the previously current device is restored automatically.
1195+
Nested ``with`` blocks are supported and restore correctly at each
1196+
level.
1197+
1198+
Returns
1199+
-------
1200+
Device
1201+
This device instance.
1202+
1203+
Examples
1204+
--------
1205+
>>> from cuda.core import Device
1206+
>>> with Device(0) as dev0:
1207+
... buf = dev0.allocate(1024)
1208+
1209+
See Also
1210+
--------
1211+
set_current : Non-context-manager entry point.
1212+
"""
1213+
cdef cydriver.CUcontext prev_ctx
1214+
with nogil:
1215+
HANDLE_RETURN(cydriver.cuCtxGetCurrent(&prev_ctx))
1216+
if not hasattr(_tls, '_ctx_stack'):
1217+
_tls._ctx_stack = []
1218+
_tls._ctx_stack.append(<uintptr_t><void*>prev_ctx)
1219+
self.set_current()
1220+
return self
1221+
1222+
def __exit__(self, exc_type, exc_val, exc_tb):
1223+
"""Restore the previously current device upon exiting the ``with`` block.
1224+
1225+
Exceptions are not suppressed.
1226+
"""
1227+
cdef uintptr_t prev_ctx_ptr = _tls._ctx_stack[-1]
1228+
cdef cydriver.CUcontext prev_ctx = <cydriver.CUcontext><void*>prev_ctx_ptr
1229+
with nogil:
1230+
HANDLE_RETURN(cydriver.cuCtxSetCurrent(prev_ctx))
1231+
_tls._ctx_stack.pop()
1232+
return False
1233+
11911234
def set_current(self, ctx: Context = None) -> Context | None:
11921235
"""Set device to be used for GPU executions.
11931236

cuda_core/tests/conftest.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -192,13 +192,15 @@ def _mempool_device_impl(num):
192192
@pytest.fixture
193193
def mempool_device_x2():
194194
"""Fixture that provides two devices if available, otherwise skips test."""
195-
return _mempool_device_impl(2)
195+
yield _mempool_device_impl(2)
196+
_device_unset_current()
196197

197198

198199
@pytest.fixture
199200
def mempool_device_x3():
200201
"""Fixture that provides three devices if available, otherwise skips test."""
201-
return _mempool_device_impl(3)
202+
yield _mempool_device_impl(3)
203+
_device_unset_current()
202204

203205

204206
@pytest.fixture(

cuda_core/tests/test_device.py

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,3 +436,179 @@ def test_device_set_membership(init_cuda):
436436
# Same device_id should not add duplicate
437437
device_set.add(dev0_b)
438438
assert len(device_set) == 1, "Should not add duplicate device"
439+
440+
441+
# ============================================================================
442+
# Device Context Manager Tests
443+
# ============================================================================
444+
445+
446+
def _get_current_context():
447+
"""Return the current CUcontext as an int (0 means NULL / no context)."""
448+
return int(handle_return(driver.cuCtxGetCurrent()))
449+
450+
451+
def test_context_manager_basic(deinit_cuda):
452+
"""with Device(0) sets the device as current and restores on exit."""
453+
assert _get_current_context() == 0, "Should start with no active context"
454+
455+
with Device(0):
456+
assert _get_current_context() != 0, "Device should be current inside the with block"
457+
458+
assert _get_current_context() == 0, "No context should be current after exiting"
459+
460+
461+
def test_context_manager_restores_previous(deinit_cuda):
462+
"""Context manager restores the previously active context, not NULL."""
463+
dev0 = Device(0)
464+
dev0.set_current()
465+
ctx_before = _get_current_context()
466+
assert ctx_before != 0
467+
468+
with Device(0):
469+
pass
470+
471+
assert _get_current_context() == ctx_before, "Should restore the previous context"
472+
473+
474+
def test_context_manager_exception_safety(deinit_cuda):
475+
"""Device context is restored even when an exception is raised."""
476+
# Start with no active context so restoration is distinguishable
477+
assert _get_current_context() == 0
478+
479+
with pytest.raises(RuntimeError, match="test error"), Device(0):
480+
assert _get_current_context() != 0, "Device should be active inside the block"
481+
raise RuntimeError("test error")
482+
483+
assert _get_current_context() == 0, "Must restore NULL context after exception"
484+
485+
486+
def test_context_manager_returns_device(deinit_cuda):
487+
"""__enter__ returns the Device instance for use in 'as' clause."""
488+
device = Device(0)
489+
with device as dev:
490+
assert dev is device
491+
492+
assert _get_current_context() == 0
493+
494+
495+
def test_context_manager_nesting_same_device(deinit_cuda):
496+
"""Nested with-blocks on the same device work correctly."""
497+
dev0 = Device(0)
498+
499+
with dev0:
500+
ctx_outer = _get_current_context()
501+
with dev0:
502+
ctx_inner = _get_current_context()
503+
assert ctx_inner == ctx_outer, "Same device should yield same context"
504+
assert _get_current_context() == ctx_outer, "Outer context restored after inner exit"
505+
506+
assert _get_current_context() == 0
507+
508+
509+
def test_context_manager_deep_nesting(deinit_cuda):
510+
"""Deep nesting and reentrancy restore correctly at each level."""
511+
dev0 = Device(0)
512+
513+
with dev0:
514+
ctx_level1 = _get_current_context()
515+
with dev0:
516+
ctx_level2 = _get_current_context()
517+
with dev0:
518+
assert _get_current_context() != 0
519+
assert _get_current_context() == ctx_level2
520+
assert _get_current_context() == ctx_level1
521+
522+
assert _get_current_context() == 0
523+
524+
525+
def test_context_manager_nesting_different_devices(mempool_device_x2):
526+
"""Nested with-blocks on different devices restore correctly."""
527+
dev0, dev1 = mempool_device_x2
528+
ctx_dev0 = _get_current_context()
529+
530+
with dev1:
531+
ctx_inside = _get_current_context()
532+
assert ctx_inside != ctx_dev0, "Different device should have different context"
533+
534+
assert _get_current_context() == ctx_dev0, "Original device context should be restored"
535+
536+
537+
def test_context_manager_deep_nesting_multi_gpu(mempool_device_x2):
538+
"""Deep nesting across multiple devices restores correctly at each level."""
539+
dev0, dev1 = mempool_device_x2
540+
541+
with dev0:
542+
ctx_level0 = _get_current_context()
543+
with dev1:
544+
ctx_level1 = _get_current_context()
545+
assert ctx_level1 != ctx_level0
546+
with dev0:
547+
assert _get_current_context() == ctx_level0, "Same device should yield same primary context"
548+
with dev1:
549+
assert _get_current_context() == ctx_level1
550+
assert _get_current_context() == ctx_level0
551+
assert _get_current_context() == ctx_level1
552+
assert _get_current_context() == ctx_level0
553+
554+
555+
def test_context_manager_set_current_inside(mempool_device_x2):
556+
"""set_current() inside a with block does not affect restoration on exit."""
557+
dev0, dev1 = mempool_device_x2
558+
ctx_dev0 = _get_current_context() # dev0 is current from fixture
559+
560+
with dev0:
561+
dev1.set_current() # change the active device inside the block
562+
assert _get_current_context() != ctx_dev0
563+
564+
assert _get_current_context() == ctx_dev0, "Must restore the context saved at __enter__"
565+
566+
567+
def test_context_manager_device_usable_after_exit(deinit_cuda):
568+
"""Device singleton is not corrupted after context manager exit."""
569+
device = Device(0)
570+
with device:
571+
pass
572+
573+
assert _get_current_context() == 0
574+
575+
# Device should still be usable via set_current
576+
device.set_current()
577+
assert _get_current_context() != 0
578+
stream = device.create_stream()
579+
assert stream is not None
580+
581+
582+
def test_context_manager_initializes_device(deinit_cuda):
583+
"""with Device(N) should initialize the device, making it ready for use."""
584+
device = Device(0)
585+
with device:
586+
# allocate requires an active context; should not raise
587+
buf = device.allocate(1024)
588+
assert buf.handle != 0
589+
590+
591+
def test_context_manager_thread_safety(mempool_device_x3):
592+
"""Concurrent threads using context managers on different devices don't interfere."""
593+
import concurrent.futures
594+
import threading
595+
596+
devices = mempool_device_x3
597+
barrier = threading.Barrier(len(devices))
598+
errors = []
599+
600+
def worker(dev):
601+
try:
602+
ctx_before = _get_current_context()
603+
with dev:
604+
barrier.wait(timeout=5)
605+
buf = dev.allocate(1024)
606+
assert buf.handle != 0
607+
assert _get_current_context() == ctx_before
608+
except Exception as e:
609+
errors.append(e)
610+
611+
with concurrent.futures.ThreadPoolExecutor(max_workers=len(devices)) as pool:
612+
pool.map(worker, devices)
613+
614+
assert not errors, f"Thread errors: {errors}"

0 commit comments

Comments
 (0)