55from __future__ import annotations
66
77from libc.stdint cimport uintptr_t, INT32_MIN
8+ from libc.stdlib cimport strtol, getenv
89
910from cuda.bindings cimport cydriver
1011
@@ -388,11 +389,16 @@ cdef class Stream:
388389 return GraphBuilder._init(stream = self , is_stream_owner = False )
389390
390391
391- LEGACY_DEFAULT_STREAM = Stream._legacy_default()
392- PER_THREAD_DEFAULT_STREAM = Stream._per_thread_default()
392+ # c-only python objects , not public
393+ cdef Stream C_LEGACY_DEFAULT_STREAM = Stream._legacy_default()
394+ cdef Stream C_PER_THREAD_DEFAULT_STREAM = Stream._per_thread_default()
393395
396+ # standard python objects , public
397+ LEGACY_DEFAULT_STREAM = C_LEGACY_DEFAULT_STREAM
398+ PER_THREAD_DEFAULT_STREAM = C_PER_THREAD_DEFAULT_STREAM
394399
395- def default_stream():
400+
401+ cdef Stream default_stream():
396402 """ Return the default CUDA :obj:`~_stream.Stream`.
397403
398404 The type of default stream returned depends on if the environment
@@ -403,8 +409,14 @@ def default_stream():
403409
404410 """
405411 # TODO: flip the default
406- use_ptds = int (os.environ.get(" CUDA_PYTHON_CUDA_PER_THREAD_DEFAULT_STREAM" , 0 ))
412+ cdef const char * use_ptds_raw = getenv(" CUDA_PYTHON_CUDA_PER_THREAD_DEFAULT_STREAM" )
413+
414+ cdef int use_ptds = 0
415+ if use_ptds_raw != NULL :
416+ use_ptds = strtol(use_ptds_raw, NULL , 10 )
417+
418+ # value is non-zero, including for weird stuff like 123foo
407419 if use_ptds:
408- return PER_THREAD_DEFAULT_STREAM
420+ return C_PER_THREAD_DEFAULT_STREAM
409421 else :
410- return LEGACY_DEFAULT_STREAM
422+ return C_LEGACY_DEFAULT_STREAM
0 commit comments