Problem
There is no context manager for temporarily switching the active CUDA device in cuda.core. Users who need to perform work on a specific device and then restore the previous device must manually manage this state:
from cuda.core import Device
# Save current state, switch, do work, restore
dev0 = Device(0)
dev0.set_current()
# ... do work on device 0 ...
dev1 = Device(1)
dev1.set_current()
# ... do work on device 1 ...
dev0.set_current() # manually restore -- easy to forget, not exception-safe
This pattern is error-prone (the restore call is not exception-safe) and verbose compared to the idiomatic Python with statement. The pattern appears in real-world code such as dask-cuda, which implemented a workaround for this missing feature.
Proposed Design
Add __enter__ and __exit__ methods to Device, making it usable as a context manager that temporarily activates the device's primary context and restores the previous state on exit.
API
from cuda.core import Device
dev0 = Device(0)
dev0.set_current()
# ... do work on device 0 ...
with Device(1) as device:
# device 1 is now current
stream = device.create_stream()
# ...
# device 0 is automatically restored here
Semantics
On __enter__:
- Query the current CUDA context via
cuCtxGetCurrent and save it on the context manager instance.
- Call
self.set_current() (which uses the primary context for this device via cuCtxSetCurrent).
- Return
self.
On __exit__:
- Restore the saved context via
cuCtxSetCurrent. If the saved context was NULL, set NULL (no active context).
- Do NOT suppress exceptions (return
False).
Key design properties
Stateless restoration (no Python-side stack). Each __enter__ call queries the actual CUDA driver state for the current context rather than maintaining a Python-side stack. On __exit__, it restores exactly what was saved. This is the critical lesson from CuPy's experience (cupy/cupy#6965, cupy/cupy#7427): libraries that maintain their own stack of previous devices break interoperability with libraries that use the CUDA API directly to check device state. By always querying and restoring the driver-level state, we interoperate correctly with PyTorch, CuPy, and any other library that uses cudaGetDevice/cudaSetDevice or cuCtxGetCurrent/cuCtxSetCurrent.
Reentrant and reusable. Because Device is a thread-local singleton and the saved-context state is stored per-__enter__ invocation (not on the Device object itself), the context manager is both reusable and reentrant:
dev0 = Device(0)
dev1 = Device(1)
with dev0:
with dev1:
with dev0: # reentrant -- works correctly
...
# dev1 restored
# dev0 restored
# original state restored
To achieve reentrancy, the saved context must NOT be stored on self (the Device singleton). Instead, use a thread-local stack or return a helper object from __enter__ that holds the saved state. The simplest correct approach: store a per-thread stack of saved contexts on the Device class (or module-level), pushing on __enter__ and popping on __exit__.
Implementation sketch (in _device.pyx):
# Module-level: add a per-thread stack for saved contexts
# (reuse existing _tls threading.local())
def __enter__(self):
# Query actual CUDA state -- do NOT use a Python-side device cache
prev_ctx = handle_return(cuCtxGetCurrent())
# Store on a per-thread stack so nested `with` works
if not hasattr(_tls, '_ctx_stack'):
_tls._ctx_stack = []
_tls._ctx_stack.append(prev_ctx)
self.set_current()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
prev_ctx = _tls._ctx_stack.pop()
handle_return(cuCtxSetCurrent(prev_ctx))
return False
Note: The stack here is NOT a device stack -- it is a stack of saved CUcontext values that the __exit__ restores. Each entry corresponds to exactly one __enter__ call. This is fundamentally different from CuPy's old broken approach which tracked a stack of device IDs and queried that stack instead of the CUDA API.
Interoperability with other libraries. Because we use cuCtxSetCurrent (driver API), and both PyTorch and CuPy use cudaGetDevice/cudaSetDevice (runtime API) which queries the same underlying driver state, cross-library nesting works:
with torch.cuda.device(1):
with Device(2):
# Both torch and cuda.core see device 2
...
# torch sees device 1 again (cuda.core restored the context)
Note that correct cross-library nesting depends on each library querying the CUDA API for the current device on context exit rather than relying on a cached value. Libraries that follow this pattern (including CuPy v12+ and the CUDA runtime API) will interoperate correctly.
Alternatives Considered
1. Separate Device.activate() method returning a context manager
with Device(1).activate():
...
This avoids adding __enter__/__exit__ to the singleton Device object. However, it adds API surface for no practical benefit -- the saved context state can be stored on a thread-local stack rather than on the Device instance, making Device itself safe to use directly as a reentrant context manager. The with Device(1): syntax is also more natural and matches PyTorch's with torch.cuda.device(1): pattern.
Rejected because it adds unnecessary indirection.
2. Do nothing -- recommend set_current() only
Per CuPy's internal policy, context managers for device switching are banned in CuPy's own codebase because they are footguns for library developers. The argument is that set_current() is explicit and unambiguous.
However, cuda.core targets end users (not just library internals), and the context manager pattern is:
- Exception-safe by default
- Idiomatic Python
- Already provided by PyTorch and CuPy (for end users)
- Requested by downstream users (dask-cuda)
Rejected as the sole approach, but set_current() remains the recommended approach for library code that needs precise control.
3. Use cuCtxPushCurrent/cuCtxPopCurrent instead of cuCtxSetCurrent
The CUDA driver provides an explicit context stack via push/pop. Using this would make nesting trivially correct. However, Device.set_current() currently uses cuCtxSetCurrent for primary contexts (not push/pop), and mixing the two models is fragile. The push/pop model also does not interoperate with libraries using cudaSetDevice (runtime API). The current approach of save-via-query/restore-via-set is correct and interoperable.
Rejected because it would diverge from the runtime API model that other libraries use.
Open Questions
-
Should __enter__ call set_current() even if this device is already current? Calling cuCtxSetCurrent with the already-current context is cheap (no-op at the driver level) and keeps the implementation simple. The alternative (check-and-skip) adds complexity for negligible performance gain. Recommendation: always call set_current().
-
What should __enter__ do if set_current() has never been called on this device? Currently, many Device properties require set_current() to have been called first (_check_context_initialized). The context manager should unconditionally call set_current(), initializing the device if needed. This is the natural expectation: with Device(1): should make device 1 ready for use.
-
Should we document cross-library interop expectations? We should document that with Device(N): works correctly for cuda.core code, and that cross-library nesting works as long as the other library's context manager correctly queries CUDA state on exit rather than relying on a cached value.
Test Plan
- Basic usage:
with Device(0): sets device 0 as current, restores on exit.
- Exception safety: device is restored even when an exception is raised inside the
with block.
- Nesting (same device):
with dev0: with dev0: works without error.
- Nesting (different devices):
with dev0: with dev1: correctly restores dev0 on exit of inner block.
- Deep nesting / reentrancy:
with dev0: with dev1: with dev0: with dev1: restores correctly at each level.
Device remains usable after context manager exit (singleton not corrupted).
- Multi-GPU: requires 2+ GPUs. Verify
cudaGetDevice() (runtime API) reflects the device set by the context manager.
- Thread safety: context manager state is per-thread (uses thread-local storage), so concurrent threads using different devices should not interfere.
Problem
There is no context manager for temporarily switching the active CUDA device in
cuda.core. Users who need to perform work on a specific device and then restore the previous device must manually manage this state:This pattern is error-prone (the restore call is not exception-safe) and verbose compared to the idiomatic Python
withstatement. The pattern appears in real-world code such as dask-cuda, which implemented a workaround for this missing feature.Proposed Design
Add
__enter__and__exit__methods toDevice, making it usable as a context manager that temporarily activates the device's primary context and restores the previous state on exit.API
Semantics
On
__enter__:cuCtxGetCurrentand save it on the context manager instance.self.set_current()(which uses the primary context for this device viacuCtxSetCurrent).self.On
__exit__:cuCtxSetCurrent. If the saved context wasNULL, setNULL(no active context).False).Key design properties
Stateless restoration (no Python-side stack). Each
__enter__call queries the actual CUDA driver state for the current context rather than maintaining a Python-side stack. On__exit__, it restores exactly what was saved. This is the critical lesson from CuPy's experience (cupy/cupy#6965, cupy/cupy#7427): libraries that maintain their own stack of previous devices break interoperability with libraries that use the CUDA API directly to check device state. By always querying and restoring the driver-level state, we interoperate correctly with PyTorch, CuPy, and any other library that usescudaGetDevice/cudaSetDeviceorcuCtxGetCurrent/cuCtxSetCurrent.Reentrant and reusable. Because
Deviceis a thread-local singleton and the saved-context state is stored per-__enter__invocation (not on theDeviceobject itself), the context manager is both reusable and reentrant:To achieve reentrancy, the saved context must NOT be stored on
self(theDevicesingleton). Instead, use a thread-local stack or return a helper object from__enter__that holds the saved state. The simplest correct approach: store a per-thread stack of saved contexts on theDeviceclass (or module-level), pushing on__enter__and popping on__exit__.Implementation sketch (in
_device.pyx):Note: The stack here is NOT a device stack -- it is a stack of saved
CUcontextvalues that the__exit__restores. Each entry corresponds to exactly one__enter__call. This is fundamentally different from CuPy's old broken approach which tracked a stack of device IDs and queried that stack instead of the CUDA API.Interoperability with other libraries. Because we use
cuCtxSetCurrent(driver API), and both PyTorch and CuPy usecudaGetDevice/cudaSetDevice(runtime API) which queries the same underlying driver state, cross-library nesting works:Note that correct cross-library nesting depends on each library querying the CUDA API for the current device on context exit rather than relying on a cached value. Libraries that follow this pattern (including CuPy v12+ and the CUDA runtime API) will interoperate correctly.
Alternatives Considered
1. Separate
Device.activate()method returning a context managerThis avoids adding
__enter__/__exit__to the singletonDeviceobject. However, it adds API surface for no practical benefit -- the saved context state can be stored on a thread-local stack rather than on theDeviceinstance, makingDeviceitself safe to use directly as a reentrant context manager. Thewith Device(1):syntax is also more natural and matches PyTorch'swith torch.cuda.device(1):pattern.Rejected because it adds unnecessary indirection.
2. Do nothing -- recommend
set_current()onlyPer CuPy's internal policy, context managers for device switching are banned in CuPy's own codebase because they are footguns for library developers. The argument is that
set_current()is explicit and unambiguous.However,
cuda.coretargets end users (not just library internals), and the context manager pattern is:Rejected as the sole approach, but
set_current()remains the recommended approach for library code that needs precise control.3. Use
cuCtxPushCurrent/cuCtxPopCurrentinstead ofcuCtxSetCurrentThe CUDA driver provides an explicit context stack via push/pop. Using this would make nesting trivially correct. However,
Device.set_current()currently usescuCtxSetCurrentfor primary contexts (not push/pop), and mixing the two models is fragile. The push/pop model also does not interoperate with libraries usingcudaSetDevice(runtime API). The current approach of save-via-query/restore-via-set is correct and interoperable.Rejected because it would diverge from the runtime API model that other libraries use.
Open Questions
Should
__enter__callset_current()even if this device is already current? CallingcuCtxSetCurrentwith the already-current context is cheap (no-op at the driver level) and keeps the implementation simple. The alternative (check-and-skip) adds complexity for negligible performance gain. Recommendation: always callset_current().What should
__enter__do ifset_current()has never been called on this device? Currently, manyDeviceproperties requireset_current()to have been called first (_check_context_initialized). The context manager should unconditionally callset_current(), initializing the device if needed. This is the natural expectation:with Device(1):should make device 1 ready for use.Should we document cross-library interop expectations? We should document that
with Device(N):works correctly forcuda.corecode, and that cross-library nesting works as long as the other library's context manager correctly queries CUDA state on exit rather than relying on a cached value.Test Plan
with Device(0):sets device 0 as current, restores on exit.withblock.with dev0: with dev0:works without error.with dev0: with dev1:correctly restores dev0 on exit of inner block.with dev0: with dev1: with dev0: with dev1:restores correctly at each level.Deviceremains usable after context manager exit (singleton not corrupted).cudaGetDevice()(runtime API) reflects the device set by the context manager.