Skip to content

Commit 998de2f

Browse files
committed
Checking for context initalization
1 parent e1265b4 commit 998de2f

File tree

1 file changed

+14
-4
lines changed

1 file changed

+14
-4
lines changed

cuda_core/cuda/core/experimental/_stream.pyx

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -127,22 +127,32 @@ cdef class Stream:
127127
def _legacy_default(cls):
128128
cdef Stream self = Stream.__new__(cls)
129129
cdef cydriver.CUcontext ctx
130+
cdef cydriver.CUresult err
130131
self._handle = <cydriver.CUstream>(cydriver.CU_STREAM_LEGACY)
131132
self._builtin = True
132133
with nogil:
133-
HANDLE_RETURN(cydriver.cuCtxGetCurrent(&ctx))
134-
self._ctx_handle = ctx
134+
err = cydriver.cuCtxGetCurrent(&ctx)
135+
if err == cydriver.CUresult.CUDA_SUCCESS:
136+
self._ctx_handle = ctx
137+
else:
138+
# CUDA not initialized yet, will be lazily initialized later
139+
self._ctx_handle = CU_CONTEXT_INVALID
135140
return self
136141

137142
@classmethod
138143
def _per_thread_default(cls):
139144
cdef Stream self = Stream.__new__(cls)
140145
cdef cydriver.CUcontext ctx
146+
cdef cydriver.CUresult err
141147
self._handle = <cydriver.CUstream>(cydriver.CU_STREAM_PER_THREAD)
142148
self._builtin = True
143149
with nogil:
144-
HANDLE_RETURN(cydriver.cuCtxGetCurrent(&ctx))
145-
self._ctx_handle = ctx
150+
err = cydriver.cuCtxGetCurrent(&ctx)
151+
if err == cydriver.CUresult.CUDA_SUCCESS:
152+
self._ctx_handle = ctx
153+
else:
154+
# CUDA not initialized yet, will be lazily initialized later
155+
self._ctx_handle = CU_CONTEXT_INVALID
146156
return self
147157

148158
@classmethod

0 commit comments

Comments
 (0)