Skip to content

Commit e702b5e

Browse files
committed
fix error handling
1 parent 96ce480 commit e702b5e

File tree

4 files changed

+39
-34
lines changed

4 files changed

+39
-34
lines changed

cuda_core/cuda/core/experimental/_event.pyx

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ from libc.stdint cimport uintptr_t
1010
from cuda.bindings cimport cydriver
1111

1212
from cuda.core.experimental._utils.cuda_utils cimport (
13-
_check_driver_error as raise_if_driver_error,
1413
check_or_create_options,
14+
HANDLE_RETURN
1515
)
1616

1717
from dataclasses import dataclass
@@ -110,8 +110,7 @@ cdef class Event:
110110
self._busy_waited = True
111111
if opts.support_ipc:
112112
raise NotImplementedError("WIP: https://github.com/NVIDIA/cuda-python/issues/103")
113-
# TODO: use HANDLE_RETURN
114-
err = cydriver.cuEventCreate(&self._handle, flags)
113+
HANDLE_RETURN(cydriver.cuEventCreate(&self._handle, flags))
115114
self._device_id = device_id
116115
self._ctx_handle = ctx_handle
117116
return self
@@ -120,8 +119,7 @@ cdef class Event:
120119
if is_shutting_down and is_shutting_down():
121120
return
122121
if self._handle != NULL:
123-
# TODO: use HANDLE_RETURN
124-
err = cydriver.cuEventDestroy(self._handle)
122+
HANDLE_RETURN(cydriver.cuEventDestroy(self._handle))
125123
self._handle = <cydriver.CUevent>(NULL)
126124

127125
cpdef close(self):
@@ -190,8 +188,7 @@ cdef class Event:
190188
has been completed.
191189
192190
"""
193-
# TODO: use HANDLE_RETURN
194-
err = cydriver.cuEventSynchronize(self._handle)
191+
HANDLE_RETURN(cydriver.cuEventSynchronize(self._handle))
195192

196193
@property
197194
def is_done(self) -> bool:
@@ -201,7 +198,7 @@ cdef class Event:
201198
return True
202199
if result == cydriver.CUresult.CUDA_ERROR_NOT_READY:
203200
return False
204-
# TODO: use HANDLE_RETURN
201+
HANDLE_RETURN(result)
205202

206203
@property
207204
def handle(self) -> cuda.bindings.driver.CUevent:

cuda_core/cuda/core/experimental/_stream.pyx

Lines changed: 13 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ from libc.stdint cimport uintptr_t
1010
from cuda.bindings cimport cydriver
1111

1212
from cuda.core.experimental._utils.cuda_utils cimport (
13-
_check_driver_error as raise_if_driver_error,
1413
check_or_create_options,
14+
HANDLE_RETURN,
1515
)
1616

1717
import sys
@@ -178,18 +178,16 @@ cdef class Stream:
178178
priority = opts.priority
179179

180180
flags = cydriver.CUstream_flags.CU_STREAM_NON_BLOCKING if nonblocking else cydriver.CUstream_flags.CU_STREAM_DEFAULT
181-
# TODO: use HANDLE_RETURN
182181
cdef int high, low
183-
err = cydriver.cuCtxGetStreamPriorityRange(&high, &low)
182+
HANDLE_RETURN(cydriver.cuCtxGetStreamPriorityRange(&high, &low))
184183
if priority is not None:
185184
if not (low <= priority <= high):
186185
raise ValueError(f"{priority=} is out of range {[low, high]}")
187186
else:
188187
priority = high
189188

190189
cdef cydriver.CUstream s
191-
# TODO: add HANDLE_RETURN macro to check driver error code?
192-
err = cydriver.cuStreamCreateWithPriority(&s, flags, priority)
190+
HANDLE_RETURN(cydriver.cuStreamCreateWithPriority(&s, flags, priority))
193191
self._handle = s
194192
self._owner = None
195193
self._nonblocking = nonblocking
@@ -207,8 +205,7 @@ cdef class Stream:
207205

208206
if self._owner is None:
209207
if self._handle and not self._builtin:
210-
# TODO: use HANDLE_RETURN
211-
err = cydriver.cuStreamDestroy(self._handle)
208+
HANDLE_RETURN(cydriver.cuStreamDestroy(self._handle))
212209
else:
213210
self._owner = None
214211
self._handle = <cydriver.CUstream>(NULL)
@@ -242,8 +239,7 @@ cdef class Stream:
242239
"""Return True if this is a nonblocking stream, otherwise False."""
243240
cdef unsigned int flags
244241
if self._nonblocking is None:
245-
# TODO: switch to HANDLE_RETURN
246-
err = cydriver.cuStreamGetFlags(self._handle, &flags)
242+
HANDLE_RETURN(cydriver.cuStreamGetFlags(self._handle, &flags))
247243
if flags & cydriver.CUstream_flags.CU_STREAM_NON_BLOCKING:
248244
self._nonblocking = True
249245
else:
@@ -255,15 +251,13 @@ cdef class Stream:
255251
"""Return the stream priority."""
256252
cdef int prio
257253
if self._priority is None:
258-
# TODO: switch to HANDLE_RETURN
259-
err = cydriver.cuStreamGetPriority(self._handle, &prio)
254+
HANDLE_RETURN(cydriver.cuStreamGetPriority(self._handle, &prio))
260255
self._priority = prio
261256
return self._priority
262257

263258
def sync(self):
264259
"""Synchronize the stream."""
265-
# TODO: switch to HANDLE_RETURN
266-
err = cydriver.cuStreamSynchronize(self._handle)
260+
HANDLE_RETURN(cydriver.cuStreamSynchronize(self._handle))
267261

268262
def record(self, event: Event = None, options: EventOptions = None) -> Event:
269263
"""Record an event onto the stream.
@@ -290,9 +284,8 @@ cdef class Stream:
290284
if event is None:
291285
self._get_device_and_context()
292286
event = Event._init(self._device_id, self._ctx_handle, options)
293-
# TODO: switch to HANDLE_RETURN
294287
# TODO: revisit after Event is cythonized
295-
err = cydriver.cuEventRecord(<cydriver.CUevent><uintptr_t>(event.handle), self._handle)
288+
HANDLE_RETURN(cydriver.cuEventRecord(<cydriver.CUevent><uintptr_t>(event.handle), self._handle))
296289
return event
297290

298291
def wait(self, event_or_stream: Union[Event, Stream]):
@@ -324,16 +317,14 @@ cdef class Stream:
324317
f" got {type(event_or_stream)}"
325318
) from e
326319
stream = <cydriver.CUstream><uintptr_t>(s.handle)
327-
# TODO: switch to HANDLE_RETURN
328-
err = cydriver.cuEventCreate(&event, cydriver.CUevent_flags.CU_EVENT_DISABLE_TIMING)
329-
err = cydriver.cuEventRecord(event, stream)
320+
HANDLE_RETURN(cydriver.cuEventCreate(&event, cydriver.CUevent_flags.CU_EVENT_DISABLE_TIMING))
321+
HANDLE_RETURN(cydriver.cuEventRecord(event, stream))
330322
discard_event = True
331323

332324
# TODO: support flags other than 0?
333-
# TODO: switch to HANDLE_RETURN
334-
err = cydriver.cuStreamWaitEvent(self._handle, event, 0)
325+
HANDLE_RETURN(cydriver.cuStreamWaitEvent(self._handle, event, 0))
335326
if discard_event:
336-
err = cydriver.cuEventDestroy(event)
327+
HANDLE_RETURN(cydriver.cuEventDestroy(event))
337328

338329
@property
339330
def device(self) -> Device:
@@ -354,8 +345,7 @@ cdef class Stream:
354345
# TODO: consider making self._ctx_handle typed?
355346
cdef cydriver.CUcontext ctx
356347
if self._ctx_handle is None:
357-
# TODO: switch to HANDLE_RETURN
358-
err = cydriver.cuStreamGetCtx(self._handle, &ctx)
348+
HANDLE_RETURN(cydriver.cuStreamGetCtx(self._handle, &ctx))
359349
self._ctx_handle = driver.CUcontext(<uintptr_t>ctx)
360350
return 0
361351

cuda_core/cuda/core/experimental/_utils/cuda_utils.pxd

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,30 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5-
65
cimport cpython
7-
cimport libc.stdint
6+
from libc.stdint cimport int64_t
7+
8+
# TODO: how about cuda.bindings < 12.6.2?
9+
from cuda.bindings cimport cydriver
10+
11+
12+
ctypedef fused supported_error_type:
13+
cydriver.CUresult
814

915

16+
cdef int HANDLE_RETURN(supported_error_type err) except?-1
17+
18+
19+
# TODO: stop exposing these within the codebase?
1020
cpdef int _check_driver_error(error) except?-1
1121
cpdef int _check_runtime_error(error) except?-1
1222
cpdef int _check_nvrtc_error(error) except?-1
23+
24+
1325
cpdef check_or_create_options(type cls, options, str options_description=*, bint keep_none=*)
1426

1527

16-
cdef inline tuple carray_int64_t_to_tuple(libc.stdint.int64_t *ptr, int length):
28+
cdef inline tuple carray_int64_t_to_tuple(int64_t *ptr, int length):
1729
# Construct shape and strides tuples using the Python/C API for speed
1830
result = cpython.PyTuple_New(length)
1931
for i in range(length):

cuda_core/cuda/core/experimental/_utils/cuda_utils.pyx

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,12 @@ def _reduce_3_tuple(t: tuple):
5252
return t[0] * t[1] * t[2]
5353

5454

55+
cdef int HANDLE_RETURN(supported_error_type err) except?-1:
56+
if supported_error_type is cydriver.CUresult:
57+
if err != cydriver.CUresult.CUDA_SUCCESS:
58+
return _check_driver_error(err)
59+
60+
5561
cdef object _DRIVER_SUCCESS = driver.CUresult.CUDA_SUCCESS
5662
cdef object _RUNTIME_SUCCESS = runtime.cudaError_t.cudaSuccess
5763
cdef object _NVRTC_SUCCESS = nvrtc.nvrtcResult.NVRTC_SUCCESS

0 commit comments

Comments
 (0)