@@ -18,7 +18,6 @@ from cuda.core.experimental._utils cimport cuda_utils
1818# TODO(leofang): support NumPy structured dtypes
1919
2020
21- @cython.dataclasses.dataclass
2221cdef class StridedMemoryView:
2322 """ A dataclass holding metadata of a strided dense array/tensor.
2423
@@ -51,7 +50,7 @@ cdef class StridedMemoryView:
5150 Pointer to the tensor buffer (as a Python `int`).
5251 shape : tuple
5352 Shape of the tensor.
54- strides : tuple
53+ strides : Optional[ tuple]
5554 Strides of the tensor (in **counts**, not bytes).
5655 dtype: numpy.dtype
5756 Data type of the tensor.
@@ -74,15 +73,27 @@ cdef class StridedMemoryView:
7473 The pointer address (as Python `int`) to the **consumer** stream.
7574 Stream ordering will be properly established unless ``-1`` is passed.
7675 """
77- # TODO: switch to use Cython's cdef typing?
78- ptr: int = None
79- shape: tuple = None
80- strides: tuple = None # in counts, not bytes
81- dtype: numpy.dtype = None
82- device_id: int = None # -1 for CPU
83- is_device_accessible: bool = None
84- readonly: bool = None
85- exporting_obj: Any = None
76+ cdef readonly:
77+ intptr_t ptr
78+ int device_id
79+ bint is_device_accessible
80+ bint readonly
81+ object exporting_obj
82+
83+ # If using dlpack, this is a strong reference to the result of
84+ # obj.__dlpack__() so we can lazily create shape and strides from
85+ # it later. If using CAI, this is a reference to the source
86+ # `__cuda_array_interface__` object.
87+ cdef object metadata
88+
89+ # The tensor object if has obj has __dlpack__, otherwise must be NULL
90+ cdef DLTensor * dl_tensor
91+
92+ # Memoized properties
93+ cdef tuple _shape
94+ cdef tuple _strides
95+ cdef bint _strides_init # Has the strides tuple been init'ed?
96+ cdef object _dtype
8697
8798 def __init__ (self , obj = None , stream_ptr = None ):
8899 if obj is not None :
@@ -92,9 +103,56 @@ cdef class StridedMemoryView:
92103 else :
93104 view_as_cai(obj, stream_ptr, self )
94105 else :
95- # default construct
96106 pass
97107
108+ @property
109+ def shape (self ) -> tuple[int]:
110+ if self._shape is None and self.exporting_obj is not None:
111+ if self.dl_tensor != NULL:
112+ self._shape = cuda_utils.carray_int64_t_to_tuple(
113+ self .dl_tensor.shape,
114+ self .dl_tensor.ndim
115+ )
116+ else:
117+ self._shape = self .metadata[" shape" ]
118+ else:
119+ self._shape = ()
120+ return self._shape
121+
122+ @property
123+ def strides(self ) -> Optional[tuple[int]]:
124+ cdef int itemsize
125+ if self._strides_init is False:
126+ if self.exporting_obj is not None:
127+ if self.dl_tensor != NULL:
128+ if self.dl_tensor.strides:
129+ self._strides = cuda_utils.carray_int64_t_to_tuple(
130+ self .dl_tensor.strides,
131+ self .dl_tensor.ndim
132+ )
133+ else:
134+ strides = self .metadata.get(" strides" )
135+ if strides is not None:
136+ itemsize = self .dtype.itemsize
137+ self._strides = cpython.PyTuple_New(len (strides))
138+ for i in range(len(strides )):
139+ cpython.PyTuple_SET_ITEM(
140+ self ._strides, i, strides[i] // itemsize
141+ )
142+ self ._strides_init = True
143+ return self ._strides
144+
145+ @property
146+ def dtype (self ) -> Optional[numpy.dtype]:
147+ if self._dtype is None:
148+ if self.exporting_obj is not None:
149+ if self.dl_tensor != NULL:
150+ self._dtype = dtype_dlpack_to_numpy(& self .dl_tensor.dtype)
151+ else:
152+ # TODO: this only works for built-in numeric types
153+ self._dtype = numpy.dtype(self .metadata[" typestr" ])
154+ return self._dtype
155+
98156 def __repr__(self ):
99157 return (f" StridedMemoryView(ptr={self.ptr},\n "
100158 + f" shape={self.shape},\n "
@@ -152,7 +210,7 @@ cdef class _StridedMemoryViewProxy:
152210
153211cdef StridedMemoryView view_as_dlpack(obj, stream_ptr, view = None ):
154212 cdef int dldevice, device_id, i
155- cdef bint is_device_accessible, versioned, is_readonly
213+ cdef bint is_device_accessible, is_readonly
156214 is_device_accessible = False
157215 dldevice, device_id = obj.__dlpack_device__()
158216 if dldevice == _kDLCPU:
@@ -193,7 +251,6 @@ cdef StridedMemoryView view_as_dlpack(obj, stream_ptr, view=None):
193251 capsule, DLPACK_VERSIONED_TENSOR_UNUSED_NAME):
194252 data = cpython.PyCapsule_GetPointer(
195253 capsule, DLPACK_VERSIONED_TENSOR_UNUSED_NAME)
196- versioned = True
197254 dlm_tensor_ver = < DLManagedTensorVersioned* > data
198255 dl_tensor = & dlm_tensor_ver.dl_tensor
199256 is_readonly = bool ((dlm_tensor_ver.flags & DLPACK_FLAG_BITMASK_READ_ONLY) != 0 )
@@ -202,32 +259,24 @@ cdef StridedMemoryView view_as_dlpack(obj, stream_ptr, view=None):
202259 capsule, DLPACK_TENSOR_UNUSED_NAME):
203260 data = cpython.PyCapsule_GetPointer(
204261 capsule, DLPACK_TENSOR_UNUSED_NAME)
205- versioned = False
206262 dlm_tensor = < DLManagedTensor* > data
207263 dl_tensor = & dlm_tensor.dl_tensor
208264 is_readonly = False
209265 used_name = DLPACK_TENSOR_USED_NAME
210266 else :
211267 assert False
212268
269+ cpython.PyCapsule_SetName(capsule, used_name)
270+
213271 cdef StridedMemoryView buf = StridedMemoryView() if view is None else view
272+ buf.dl_tensor = dl_tensor
273+ buf.metadata = capsule
214274 buf.ptr = < intptr_t> (dl_tensor.data)
215-
216- buf.shape = cuda_utils.carray_int64_t_to_tuple(dl_tensor.shape, dl_tensor.ndim)
217- if dl_tensor.strides:
218- buf.strides = cuda_utils.carray_int64_t_to_tuple(dl_tensor.strides, dl_tensor.ndim)
219- else :
220- # C-order
221- buf.strides = None
222-
223- buf.dtype = dtype_dlpack_to_numpy(& dl_tensor.dtype)
224275 buf.device_id = device_id
225276 buf.is_device_accessible = is_device_accessible
226277 buf.readonly = is_readonly
227278 buf.exporting_obj = obj
228279
229- cpython.PyCapsule_SetName(capsule, used_name)
230-
231280 return buf
232281
233282
@@ -291,7 +340,8 @@ cdef object dtype_dlpack_to_numpy(DLDataType* dtype):
291340 return numpy.dtype(np_dtype)
292341
293342
294- cdef StridedMemoryView view_as_cai(obj, stream_ptr, view = None ):
343+ # Also generate for Python so we can test this code path
344+ cpdef StridedMemoryView view_as_cai(obj, stream_ptr, view = None ):
295345 cdef dict cai_data = obj.__cuda_array_interface__
296346 if cai_data[" version" ] < 3 :
297347 raise BufferError(" only CUDA Array Interface v3 or above is supported" )
@@ -302,33 +352,30 @@ cdef StridedMemoryView view_as_cai(obj, stream_ptr, view=None):
302352
303353 cdef StridedMemoryView buf = StridedMemoryView() if view is None else view
304354 buf.exporting_obj = obj
355+ buf.metadata = cai_data
356+ buf.dl_tensor = NULL
305357 buf.ptr, buf.readonly = cai_data[" data" ]
306- buf.shape = cai_data[" shape" ]
307- # TODO: this only works for built-in numeric types
308- buf.dtype = numpy.dtype(cai_data[" typestr" ])
309- buf.strides = cai_data.get(" strides" )
310- if buf.strides is not None :
311- # convert to counts
312- buf.strides = tuple (s // buf.dtype.itemsize for s in buf.strides)
313358 buf.is_device_accessible = True
314359 buf.device_id = handle_return(
315360 driver.cuPointerGetAttribute(
316361 driver.CUpointer_attribute.CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL,
317362 buf.ptr))
318363
319364 cdef intptr_t producer_s, consumer_s
320- stream = cai_data.get(" stream" )
321- if stream is not None :
322- producer_s = < intptr_t> (stream)
323- consumer_s = < intptr_t> (stream_ptr)
324- assert producer_s > 0
325- # establish stream order
326- if producer_s != consumer_s:
327- e = handle_return(driver.cuEventCreate(
328- driver.CUevent_flags.CU_EVENT_DISABLE_TIMING))
329- handle_return(driver.cuEventRecord(e, producer_s))
330- handle_return(driver.cuStreamWaitEvent(consumer_s, e, 0 ))
331- handle_return(driver.cuEventDestroy(e))
365+ stream_ptr = int (stream_ptr)
366+ if stream_ptr != - 1 :
367+ stream = cai_data.get(" stream" )
368+ if stream is not None :
369+ producer_s = < intptr_t> (stream)
370+ consumer_s = < intptr_t> (stream_ptr)
371+ assert producer_s > 0
372+ # establish stream order
373+ if producer_s != consumer_s:
374+ e = handle_return(driver.cuEventCreate(
375+ driver.CUevent_flags.CU_EVENT_DISABLE_TIMING))
376+ handle_return(driver.cuEventRecord(e, producer_s))
377+ handle_return(driver.cuStreamWaitEvent(consumer_s, e, 0 ))
378+ handle_return(driver.cuEventDestroy(e))
332379
333380 return buf
334381
0 commit comments