Skip to content

Commit 44f812e

Browse files
Add tests for library/kernel lifetime vs context lifetime
Experiment to detect segfaults on CI platforms where cuLibraryUnload after context destruction may crash. LibraryBox does not currently store a ContextHandle, so nothing prevents the context from being destroyed before the library. Co-authored-by: Cursor <cursoragent@cursor.com>
1 parent 4597e03 commit 44f812e

File tree

1 file changed

+149
-0
lines changed

1 file changed

+149
-0
lines changed
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
3+
4+
"""Tests for library/kernel lifetime vs context lifetime.
5+
6+
These tests exercise scenarios where a CUDA library (CUlibrary) or kernel
7+
(CUkernel) outlives the context in which it was created. This can happen
8+
when:
9+
- The primary context is reset (e.g. numba-cuda's test teardown)
10+
- A non-primary context is explicitly destroyed
11+
12+
Currently, LibraryBox in resource_handles.cpp does NOT store a
13+
ContextHandle, so nothing prevents the context from being destroyed
14+
before the library. The library's RAII deleter then calls cuLibraryUnload
15+
on a library whose owning context is gone, which may segfault on some
16+
driver versions.
17+
"""
18+
19+
import gc
20+
21+
import pytest
22+
from cuda.bindings import driver
23+
from cuda.core import Device, Program
24+
from cuda.core._utils.cuda_utils import handle_return
25+
26+
KERNEL_SOURCE = 'extern "C" __global__ void test_kernel() {}'
27+
28+
29+
def _compile_and_get_kernel():
30+
"""Compile a trivial kernel and return (ObjectCode, Kernel)."""
31+
prog = Program(KERNEL_SOURCE, "c++")
32+
obj = prog.compile("cubin")
33+
kernel = obj.get_kernel("test_kernel")
34+
return obj, kernel
35+
36+
37+
def _create_nonprimary_context(dev=0):
38+
"""Create a non-primary context, handling CUDA 12.x vs 13.x API differences."""
39+
try:
40+
return handle_return(driver.cuCtxCreate(None, 0, dev))
41+
except TypeError:
42+
return handle_return(driver.cuCtxCreate(0, dev))
43+
44+
45+
@pytest.fixture(autouse=True)
46+
def _restore_primary_context():
47+
"""Re-establish a valid primary context after each test."""
48+
yield
49+
ctx = handle_return(driver.cuDevicePrimaryCtxRetain(0))
50+
handle_return(driver.cuCtxSetCurrent(ctx))
51+
52+
53+
class TestPrimaryContextReset:
54+
"""Library/kernel destroyed after primary context reset."""
55+
56+
def test_objectcode_outlives_primary_context_reset(self):
57+
dev = Device(0)
58+
dev.set_current()
59+
60+
obj, kernel = _compile_and_get_kernel()
61+
del kernel
62+
63+
handle_return(driver.cuDevicePrimaryCtxReset(0))
64+
65+
del obj
66+
gc.collect()
67+
68+
def test_kernel_outlives_primary_context_reset(self):
69+
dev = Device(0)
70+
dev.set_current()
71+
72+
obj, kernel = _compile_and_get_kernel()
73+
del obj
74+
75+
handle_return(driver.cuDevicePrimaryCtxReset(0))
76+
77+
del kernel
78+
gc.collect()
79+
80+
def test_objectcode_outlives_primary_context_reset_and_reretain(self):
81+
"""The numba-cuda pattern: reset, then re-init for the next test."""
82+
dev = Device(0)
83+
dev.set_current()
84+
85+
obj, kernel = _compile_and_get_kernel()
86+
del kernel
87+
88+
handle_return(driver.cuDevicePrimaryCtxReset(0))
89+
90+
ctx = handle_return(driver.cuDevicePrimaryCtxRetain(0))
91+
handle_return(driver.cuCtxSetCurrent(ctx))
92+
93+
del obj
94+
gc.collect()
95+
96+
97+
class TestNonPrimaryContextDestroy:
98+
"""Library/kernel destroyed after non-primary context is destroyed."""
99+
100+
def test_objectcode_outlives_nonprimary_context(self):
101+
ctx = _create_nonprimary_context()
102+
103+
obj, kernel = _compile_and_get_kernel()
104+
del kernel
105+
106+
handle_return(driver.cuCtxDestroy(ctx))
107+
108+
del obj
109+
gc.collect()
110+
111+
def test_kernel_outlives_nonprimary_context(self):
112+
ctx = _create_nonprimary_context()
113+
114+
obj, kernel = _compile_and_get_kernel()
115+
del obj
116+
117+
handle_return(driver.cuCtxDestroy(ctx))
118+
119+
del kernel
120+
gc.collect()
121+
122+
123+
class TestKernelOutlivesObjectCode:
124+
"""Kernel transitively holds library; both outlive context."""
125+
126+
def test_kernel_outlives_objectcode_and_primary_context(self):
127+
dev = Device(0)
128+
dev.set_current()
129+
130+
obj, kernel = _compile_and_get_kernel()
131+
del obj
132+
gc.collect()
133+
134+
handle_return(driver.cuDevicePrimaryCtxReset(0))
135+
136+
del kernel
137+
gc.collect()
138+
139+
def test_kernel_outlives_objectcode_and_nonprimary_context(self):
140+
ctx = _create_nonprimary_context()
141+
142+
obj, kernel = _compile_and_get_kernel()
143+
del obj
144+
gc.collect()
145+
146+
handle_return(driver.cuCtxDestroy(ctx))
147+
148+
del kernel
149+
gc.collect()

0 commit comments

Comments
 (0)