@@ -436,3 +436,179 @@ def test_device_set_membership(init_cuda):
436436 # Same device_id should not add duplicate
437437 device_set .add (dev0_b )
438438 assert len (device_set ) == 1 , "Should not add duplicate device"
439+
440+
441+ # ============================================================================
442+ # Device Context Manager Tests
443+ # ============================================================================
444+
445+
446+ def _get_current_context ():
447+ """Return the current CUcontext as an int (0 means NULL / no context)."""
448+ return int (handle_return (driver .cuCtxGetCurrent ()))
449+
450+
451+ def test_context_manager_basic (deinit_cuda ):
452+ """with Device(0) sets the device as current and restores on exit."""
453+ assert _get_current_context () == 0 , "Should start with no active context"
454+
455+ with Device (0 ):
456+ assert _get_current_context () != 0 , "Device should be current inside the with block"
457+
458+ assert _get_current_context () == 0 , "No context should be current after exiting"
459+
460+
461+ def test_context_manager_restores_previous (deinit_cuda ):
462+ """Context manager restores the previously active context, not NULL."""
463+ dev0 = Device (0 )
464+ dev0 .set_current ()
465+ ctx_before = _get_current_context ()
466+ assert ctx_before != 0
467+
468+ with Device (0 ):
469+ pass
470+
471+ assert _get_current_context () == ctx_before , "Should restore the previous context"
472+
473+
474+ def test_context_manager_exception_safety (deinit_cuda ):
475+ """Device context is restored even when an exception is raised."""
476+ # Start with no active context so restoration is distinguishable
477+ assert _get_current_context () == 0
478+
479+ with pytest .raises (RuntimeError , match = "test error" ), Device (0 ):
480+ assert _get_current_context () != 0 , "Device should be active inside the block"
481+ raise RuntimeError ("test error" )
482+
483+ assert _get_current_context () == 0 , "Must restore NULL context after exception"
484+
485+
486+ def test_context_manager_returns_device (deinit_cuda ):
487+ """__enter__ returns the Device instance for use in 'as' clause."""
488+ device = Device (0 )
489+ with device as dev :
490+ assert dev is device
491+
492+ assert _get_current_context () == 0
493+
494+
495+ def test_context_manager_nesting_same_device (deinit_cuda ):
496+ """Nested with-blocks on the same device work correctly."""
497+ dev0 = Device (0 )
498+
499+ with dev0 :
500+ ctx_outer = _get_current_context ()
501+ with dev0 :
502+ ctx_inner = _get_current_context ()
503+ assert ctx_inner == ctx_outer , "Same device should yield same context"
504+ assert _get_current_context () == ctx_outer , "Outer context restored after inner exit"
505+
506+ assert _get_current_context () == 0
507+
508+
509+ def test_context_manager_deep_nesting (deinit_cuda ):
510+ """Deep nesting and reentrancy restore correctly at each level."""
511+ dev0 = Device (0 )
512+
513+ with dev0 :
514+ ctx_level1 = _get_current_context ()
515+ with dev0 :
516+ ctx_level2 = _get_current_context ()
517+ with dev0 :
518+ assert _get_current_context () != 0
519+ assert _get_current_context () == ctx_level2
520+ assert _get_current_context () == ctx_level1
521+
522+ assert _get_current_context () == 0
523+
524+
525+ def test_context_manager_nesting_different_devices (mempool_device_x2 ):
526+ """Nested with-blocks on different devices restore correctly."""
527+ dev0 , dev1 = mempool_device_x2
528+ ctx_dev0 = _get_current_context ()
529+
530+ with dev1 :
531+ ctx_inside = _get_current_context ()
532+ assert ctx_inside != ctx_dev0 , "Different device should have different context"
533+
534+ assert _get_current_context () == ctx_dev0 , "Original device context should be restored"
535+
536+
537+ def test_context_manager_deep_nesting_multi_gpu (mempool_device_x2 ):
538+ """Deep nesting across multiple devices restores correctly at each level."""
539+ dev0 , dev1 = mempool_device_x2
540+
541+ with dev0 :
542+ ctx_level0 = _get_current_context ()
543+ with dev1 :
544+ ctx_level1 = _get_current_context ()
545+ assert ctx_level1 != ctx_level0
546+ with dev0 :
547+ assert _get_current_context () == ctx_level0 , "Same device should yield same primary context"
548+ with dev1 :
549+ assert _get_current_context () == ctx_level1
550+ assert _get_current_context () == ctx_level0
551+ assert _get_current_context () == ctx_level1
552+ assert _get_current_context () == ctx_level0
553+
554+
555+ def test_context_manager_set_current_inside (mempool_device_x2 ):
556+ """set_current() inside a with block does not affect restoration on exit."""
557+ dev0 , dev1 = mempool_device_x2
558+ ctx_dev0 = _get_current_context () # dev0 is current from fixture
559+
560+ with dev0 :
561+ dev1 .set_current () # change the active device inside the block
562+ assert _get_current_context () != ctx_dev0
563+
564+ assert _get_current_context () == ctx_dev0 , "Must restore the context saved at __enter__"
565+
566+
567+ def test_context_manager_device_usable_after_exit (deinit_cuda ):
568+ """Device singleton is not corrupted after context manager exit."""
569+ device = Device (0 )
570+ with device :
571+ pass
572+
573+ assert _get_current_context () == 0
574+
575+ # Device should still be usable via set_current
576+ device .set_current ()
577+ assert _get_current_context () != 0
578+ stream = device .create_stream ()
579+ assert stream is not None
580+
581+
582+ def test_context_manager_initializes_device (deinit_cuda ):
583+ """with Device(N) should initialize the device, making it ready for use."""
584+ device = Device (0 )
585+ with device :
586+ # allocate requires an active context; should not raise
587+ buf = device .allocate (1024 )
588+ assert buf .handle != 0
589+
590+
591+ def test_context_manager_thread_safety (mempool_device_x3 ):
592+ """Concurrent threads using context managers on different devices don't interfere."""
593+ import concurrent .futures
594+ import threading
595+
596+ devices = mempool_device_x3
597+ barrier = threading .Barrier (len (devices ))
598+ errors = []
599+
600+ def worker (dev ):
601+ try :
602+ ctx_before = _get_current_context ()
603+ with dev :
604+ barrier .wait (timeout = 5 )
605+ buf = dev .allocate (1024 )
606+ assert buf .handle != 0
607+ assert _get_current_context () == ctx_before
608+ except Exception as e :
609+ errors .append (e )
610+
611+ with concurrent .futures .ThreadPoolExecutor (max_workers = len (devices )) as pool :
612+ pool .map (worker , devices )
613+
614+ assert not errors , f"Thread errors: { errors } "
0 commit comments