@@ -10,8 +10,8 @@ from libc.stdint cimport uintptr_t
1010from cuda.bindings cimport cydriver
1111
1212from cuda.core.experimental._utils.cuda_utils cimport (
13- _check_driver_error as raise_if_driver_error,
1413 check_or_create_options,
14+ HANDLE_RETURN,
1515)
1616
1717import sys
@@ -178,18 +178,16 @@ cdef class Stream:
178178 priority = opts.priority
179179
180180 flags = cydriver.CUstream_flags.CU_STREAM_NON_BLOCKING if nonblocking else cydriver.CUstream_flags.CU_STREAM_DEFAULT
181- # TODO: use HANDLE_RETURN
182181 cdef int high, low
183- err = cydriver.cuCtxGetStreamPriorityRange(& high, & low)
182+ HANDLE_RETURN( cydriver.cuCtxGetStreamPriorityRange(& high, & low) )
184183 if priority is not None :
185184 if not (low <= priority <= high):
186185 raise ValueError (f" {priority=} is out of range {[low, high]}" )
187186 else :
188187 priority = high
189188
190189 cdef cydriver.CUstream s
191- # TODO: add HANDLE_RETURN macro to check driver error code?
192- err = cydriver.cuStreamCreateWithPriority(& s, flags, priority)
190+ HANDLE_RETURN(cydriver.cuStreamCreateWithPriority(& s, flags, priority))
193191 self ._handle = s
194192 self ._owner = None
195193 self ._nonblocking = nonblocking
@@ -207,8 +205,7 @@ cdef class Stream:
207205
208206 if self ._owner is None :
209207 if self ._handle and not self ._builtin:
210- # TODO: use HANDLE_RETURN
211- err = cydriver.cuStreamDestroy(self ._handle)
208+ HANDLE_RETURN(cydriver.cuStreamDestroy(self ._handle))
212209 else :
213210 self ._owner = None
214211 self ._handle = < cydriver.CUstream> (NULL )
@@ -242,8 +239,7 @@ cdef class Stream:
242239 """Return True if this is a nonblocking stream , otherwise False."""
243240 cdef unsigned int flags
244241 if self._nonblocking is None:
245- # TODO: switch to HANDLE_RETURN
246- err = cydriver.cuStreamGetFlags(self ._handle, & flags)
242+ HANDLE_RETURN(cydriver.cuStreamGetFlags(self._handle , &flags ))
247243 if flags & cydriver.CUstream_flags.CU_STREAM_NON_BLOCKING:
248244 self._nonblocking = True
249245 else:
@@ -255,15 +251,13 @@ cdef class Stream:
255251 """Return the stream priority."""
256252 cdef int prio
257253 if self._priority is None:
258- # TODO: switch to HANDLE_RETURN
259- err = cydriver.cuStreamGetPriority(self ._handle, & prio)
254+ HANDLE_RETURN(cydriver.cuStreamGetPriority(self._handle , &prio ))
260255 self._priority = prio
261256 return self._priority
262257
263258 def sync(self ):
264259 """ Synchronize the stream."""
265- # TODO: switch to HANDLE_RETURN
266- err = cydriver.cuStreamSynchronize(self ._handle)
260+ HANDLE_RETURN(cydriver.cuStreamSynchronize(self ._handle))
267261
268262 def record (self , event: Event = None , options: EventOptions = None ) -> Event:
269263 """Record an event onto the stream.
@@ -290,9 +284,8 @@ cdef class Stream:
290284 if event is None:
291285 self._get_device_and_context()
292286 event = Event._init(self ._device_id, self ._ctx_handle, options)
293- # TODO: switch to HANDLE_RETURN
294287 # TODO: revisit after Event is cythonized
295- err = cydriver.cuEventRecord(< cydriver.CUevent>< uintptr_t> (event.handle), self ._handle)
288+ HANDLE_RETURN( cydriver.cuEventRecord(<cydriver.CUevent><uintptr_t>(event.handle ), self._handle ) )
296289 return event
297290
298291 def wait(self , event_or_stream: Union[Event , Stream]):
@@ -324,16 +317,14 @@ cdef class Stream:
324317 f" got {type(event_or_stream)}"
325318 ) from e
326319 stream = < cydriver.CUstream>< uintptr_t> (s.handle)
327- # TODO: switch to HANDLE_RETURN
328- err = cydriver.cuEventCreate(& event, cydriver.CUevent_flags.CU_EVENT_DISABLE_TIMING)
329- err = cydriver.cuEventRecord(event, stream)
320+ HANDLE_RETURN(cydriver.cuEventCreate(& event, cydriver.CUevent_flags.CU_EVENT_DISABLE_TIMING))
321+ HANDLE_RETURN(cydriver.cuEventRecord(event, stream))
330322 discard_event = True
331323
332324 # TODO: support flags other than 0?
333- # TODO: switch to HANDLE_RETURN
334- err = cydriver.cuStreamWaitEvent(self ._handle, event, 0 )
325+ HANDLE_RETURN(cydriver.cuStreamWaitEvent(self ._handle, event, 0 ))
335326 if discard_event:
336- err = cydriver.cuEventDestroy(event)
327+ HANDLE_RETURN( cydriver.cuEventDestroy(event) )
337328
338329 @property
339330 def device (self ) -> Device:
@@ -354,8 +345,7 @@ cdef class Stream:
354345 # TODO: consider making self._ctx_handle typed?
355346 cdef cydriver.CUcontext ctx
356347 if self._ctx_handle is None:
357- # TODO: switch to HANDLE_RETURN
358- err = cydriver.cuStreamGetCtx(self ._handle, & ctx)
348+ HANDLE_RETURN(cydriver.cuStreamGetCtx(self._handle , &ctx ))
359349 self._ctx_handle = driver.CUcontext(< uintptr_t> ctx)
360350 return 0
361351
0 commit comments