Skip to content

Commit fbd3fb9

Browse files
authored
fix: ensure that there are no module global lookups when calling default_stream (#1107)
1 parent 4754292 commit fbd3fb9

File tree

5 files changed

+26
-14
lines changed

5 files changed

+26
-14
lines changed

cuda_core/cuda/core/experimental/_device.pyx

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ from cuda.core.experimental._context import Context, ContextOptions
1616
from cuda.core.experimental._event import Event, EventOptions
1717
from cuda.core.experimental._graph import GraphBuilder
1818
from cuda.core.experimental._memory import Buffer, DeviceMemoryResource, MemoryResource, _SynchronousMemoryResource
19-
from cuda.core.experimental._stream import IsStreamT, Stream, StreamOptions, default_stream
19+
from cuda.core.experimental._stream import IsStreamT, Stream, StreamOptions
2020
from cuda.core.experimental._utils.clear_error_support import assert_type
2121
from cuda.core.experimental._utils.cuda_utils import (
2222
ComputeCapability,
@@ -25,6 +25,7 @@ from cuda.core.experimental._utils.cuda_utils import (
2525
handle_return,
2626
runtime,
2727
)
28+
from cuda.core.experimental._stream cimport default_stream
2829

2930

3031
# TODO: I prefer to type these as "cdef object" and avoid accessing them from within Python,

cuda_core/cuda/core/experimental/_memory.pyx

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ from libc.string cimport memset, memcpy
1212
from cuda.bindings cimport cydriver
1313

1414
from cuda.core.experimental._stream cimport Stream as cyStream
15+
from cuda.core.experimental._stream cimport default_stream
1516
from cuda.core.experimental._utils.cuda_utils cimport (
1617
_check_driver_error as raise_if_driver_error,
1718
check_or_create_options,
@@ -30,7 +31,7 @@ import platform
3031
import weakref
3132

3233
from cuda.core.experimental._dlpack import DLDeviceType, make_py_capsule
33-
from cuda.core.experimental._stream import Stream, default_stream
34+
from cuda.core.experimental._stream import Stream
3435
from cuda.core.experimental._utils.cuda_utils import ( driver, Transaction, get_binding_version )
3536

3637
if platform.system() == "Linux":

cuda_core/cuda/core/experimental/_stream.pxd

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,6 @@ cdef class Stream:
2222
cpdef close(self)
2323
cdef int _get_context(self) except?-1 nogil
2424
cdef int _get_device_and_context(self) except?-1
25+
26+
27+
cdef Stream default_stream()

cuda_core/cuda/core/experimental/_stream.pyx

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from __future__ import annotations
66

77
from libc.stdint cimport uintptr_t, INT32_MIN
8+
from libc.stdlib cimport strtol, getenv
89

910
from cuda.bindings cimport cydriver
1011

@@ -388,11 +389,16 @@ cdef class Stream:
388389
return GraphBuilder._init(stream=self, is_stream_owner=False)
389390

390391

391-
LEGACY_DEFAULT_STREAM = Stream._legacy_default()
392-
PER_THREAD_DEFAULT_STREAM = Stream._per_thread_default()
392+
# c-only python objects, not public
393+
cdef Stream C_LEGACY_DEFAULT_STREAM = Stream._legacy_default()
394+
cdef Stream C_PER_THREAD_DEFAULT_STREAM = Stream._per_thread_default()
393395

396+
# standard python objects, public
397+
LEGACY_DEFAULT_STREAM = C_LEGACY_DEFAULT_STREAM
398+
PER_THREAD_DEFAULT_STREAM = C_PER_THREAD_DEFAULT_STREAM
394399

395-
def default_stream():
400+
401+
cdef Stream default_stream():
396402
"""Return the default CUDA :obj:`~_stream.Stream`.
397403
398404
The type of default stream returned depends on if the environment
@@ -403,8 +409,14 @@ def default_stream():
403409
404410
"""
405411
# TODO: flip the default
406-
use_ptds = int(os.environ.get("CUDA_PYTHON_CUDA_PER_THREAD_DEFAULT_STREAM", 0))
412+
cdef const char* use_ptds_raw = getenv("CUDA_PYTHON_CUDA_PER_THREAD_DEFAULT_STREAM")
413+
414+
cdef int use_ptds = 0
415+
if use_ptds_raw != NULL:
416+
use_ptds = strtol(use_ptds_raw, NULL, 10)
417+
418+
# value is non-zero, including for weird stuff like 123foo
407419
if use_ptds:
408-
return PER_THREAD_DEFAULT_STREAM
420+
return C_PER_THREAD_DEFAULT_STREAM
409421
else:
410-
return LEGACY_DEFAULT_STREAM
422+
return C_LEGACY_DEFAULT_STREAM

cuda_core/tests/test_stream.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import pytest
55
from cuda.core.experimental import Device, Stream, StreamOptions
66
from cuda.core.experimental._event import Event
7-
from cuda.core.experimental._stream import LEGACY_DEFAULT_STREAM, PER_THREAD_DEFAULT_STREAM, default_stream
7+
from cuda.core.experimental._stream import LEGACY_DEFAULT_STREAM, PER_THREAD_DEFAULT_STREAM
88
from cuda.core.experimental._utils.cuda_utils import driver
99

1010

@@ -107,11 +107,6 @@ def test_per_thread_default_stream():
107107
assert isinstance(PER_THREAD_DEFAULT_STREAM, Stream)
108108

109109

110-
def test_default_stream():
111-
stream = default_stream()
112-
assert isinstance(stream, Stream)
113-
114-
115110
def test_stream_subclassing(init_cuda):
116111
class MyStream(Stream):
117112
pass

0 commit comments

Comments
 (0)