Skip to content

Commit e923636

Browse files
cpcloudcursoragentleofang
authored
Add StridedMemoryView DLPack export and C exchange API support (#1630)
* Add StridedMemoryView DLPack export tests Add coverage for StridedMemoryView DLPack producer behavior and the C exchange API capsule so the implementation can be landed in follow-up commits. Co-authored-by: Cursor <cursoragent@cursor.com> * Sync vendored dlpack.h to v1.3. Align cuda-core's bundled DLPack header with upstream v1.3 so protocol constants and C API declarations match current spec. Co-authored-by: Cursor <cursoragent@cursor.com> * Implement StridedMemoryView Python DLPack producer methods Add __dlpack__ and __dlpack_device__ support to StridedMemoryView, including owned DLManagedTensor capsule creation and lifecycle handling for shape, strides, dtype, and device metadata. Co-authored-by: Cursor <cursoragent@cursor.com> * Add StridedMemoryView DLPack C exchange API exposure. Expose the C exchange API capsules for StridedMemoryView and align exported tensor metadata with current DLPack semantics and shared type declarations. Co-authored-by: Cursor <cursoragent@cursor.com> * Add cuda-nvrtc-dev to cuda-core host dependencies. Ensure nvrtc headers are present when building cuda-core in the test environment so pixi test runs do not fail on missing nvrtc.h. Co-authored-by: Cursor <cursoragent@cursor.com> * Apply pre-commit import ordering in dlpack tests. Normalize test_utils import order after pre-commit autofixes so local and CI hook runs stay clean. Co-authored-by: Cursor <cursoragent@cursor.com> * Document explicit stride requirement in DLPack export helper. Clarify why setup_dl_tensor_layout always sets a non-NULL stride pointer for non-scalar exports under DLPack v1.2+ semantics. Co-authored-by: Cursor <cursoragent@cursor.com> * Update cuda_core/cuda/core/_dlpack.pyx Co-authored-by: Leo Fang <leof@nvidia.com> --------- Co-authored-by: Cursor <cursoragent@cursor.com> Co-authored-by: Leo Fang <leof@nvidia.com>
1 parent a2cd9e8 commit e923636

File tree

6 files changed

+768
-9
lines changed

6 files changed

+768
-9
lines changed

cuda_core/cuda/core/_dlpack.pxd

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ cdef extern from "_include/dlpack.h" nogil:
2626
_kDLCUDA "kDLCUDA"
2727
_kDLCUDAHost "kDLCUDAHost"
2828
_kDLCUDAManaged "kDLCUDAManaged"
29+
_kDLTrn "kDLTrn"
2930

3031
ctypedef struct DLDevice:
3132
_DLDeviceType device_type
@@ -72,8 +73,52 @@ cdef extern from "_include/dlpack.h" nogil:
7273
int DLPACK_MAJOR_VERSION
7374
int DLPACK_MINOR_VERSION
7475
int DLPACK_FLAG_BITMASK_READ_ONLY
76+
int DLPACK_FLAG_BITMASK_IS_COPIED
77+
int DLPACK_FLAG_BITMASK_IS_SUBBYTE_TYPE_PADDED
7578

7679
const char* DLPACK_TENSOR_UNUSED_NAME
7780
const char* DLPACK_VERSIONED_TENSOR_UNUSED_NAME
7881
const char* DLPACK_TENSOR_USED_NAME
7982
const char* DLPACK_VERSIONED_TENSOR_USED_NAME
83+
84+
85+
cdef extern from "_include/dlpack.h":
86+
ctypedef int (*DLPackManagedTensorAllocator)(
87+
DLTensor* prototype,
88+
DLManagedTensorVersioned** out,
89+
void* error_ctx,
90+
void (*SetError)(void* error_ctx, const char* kind, const char* message) noexcept
91+
)
92+
93+
ctypedef int (*DLPackManagedTensorFromPyObjectNoSync)(
94+
void* py_object,
95+
DLManagedTensorVersioned** out
96+
)
97+
98+
ctypedef int (*DLPackManagedTensorToPyObjectNoSync)(
99+
DLManagedTensorVersioned* tensor,
100+
void** out_py_object
101+
)
102+
103+
ctypedef int (*DLPackDLTensorFromPyObjectNoSync)(
104+
void* py_object,
105+
DLTensor* out
106+
)
107+
108+
ctypedef int (*DLPackCurrentWorkStream)(
109+
_DLDeviceType device_type,
110+
int32_t device_id,
111+
void** out_current_stream
112+
)
113+
114+
ctypedef struct DLPackExchangeAPIHeader:
115+
DLPackVersion version
116+
DLPackExchangeAPIHeader* prev_api
117+
118+
ctypedef struct DLPackExchangeAPI:
119+
DLPackExchangeAPIHeader header
120+
DLPackManagedTensorAllocator managed_tensor_allocator
121+
DLPackManagedTensorFromPyObjectNoSync managed_tensor_from_py_object_no_sync
122+
DLPackManagedTensorToPyObjectNoSync managed_tensor_to_py_object_no_sync
123+
DLPackDLTensorFromPyObjectNoSync dltensor_from_py_object_no_sync
124+
DLPackCurrentWorkStream current_work_stream

cuda_core/cuda/core/_dlpack.pyx

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,13 @@ cdef inline int setup_dl_tensor_layout(DLTensor* dl_tensor, object buf) except -
7777
dl_tensor.ndim = 1
7878
cdef int64_t* shape_strides = \
7979
<int64_t*>stdlib.malloc(sizeof(int64_t) * 2)
80+
if shape_strides == NULL:
81+
raise MemoryError()
82+
# DLPack v1.2+ requires non-NULL strides for ndim != 0.
8083
shape_strides[0] = <int64_t>buf.size
81-
shape_strides[1] = 1 # redundant
84+
shape_strides[1] = 1
8285
dl_tensor.shape = shape_strides
83-
dl_tensor.strides = NULL
86+
dl_tensor.strides = shape_strides + 1
8487
dl_tensor.byte_offset = 0
8588
return 0
8689

cuda_core/cuda/core/_include/dlpack.h

Lines changed: 208 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
#define DLPACK_MAJOR_VERSION 1
2020

2121
/*! \brief The current minor version of dlpack */
22-
#define DLPACK_MINOR_VERSION 1
22+
#define DLPACK_MINOR_VERSION 3
2323

2424
/*! \brief DLPACK_DLL prefix for windows */
2525
#ifdef _WIN32
@@ -118,6 +118,8 @@ typedef enum {
118118
kDLHexagon = 16,
119119
/*! \brief Microsoft MAIA devices */
120120
kDLMAIA = 17,
121+
/*! \brief AWS Trainium */
122+
kDLTrn = 18,
121123
} DLDeviceType;
122124

123125
/*!
@@ -252,11 +254,23 @@ typedef struct {
252254
int32_t ndim;
253255
/*! \brief The data type of the pointer*/
254256
DLDataType dtype;
255-
/*! \brief The shape of the tensor */
257+
/*!
258+
* \brief The shape of the tensor
259+
*
260+
* When ndim == 0, shape can be set to NULL.
261+
*/
256262
int64_t* shape;
257263
/*!
258-
* \brief strides of the tensor (in number of elements, not bytes)
259-
* can be NULL, indicating tensor is compact and row-majored.
264+
* \brief strides of the tensor (in number of elements, not bytes),
265+
* can not be NULL if ndim != 0, must points to
266+
* an array of ndim elements that specifies the strides,
267+
* so consumer can always rely on strides[dim] being valid for 0 <= dim < ndim.
268+
*
269+
* When ndim == 0, strides can be set to NULL.
270+
*
271+
* \note Before DLPack v1.2, strides can be NULL to indicate contiguous data.
272+
* This is not allowed in DLPack v1.2 and later. The rationale
273+
* is to simplify the consumer handling.
260274
*/
261275
int64_t* strides;
262276
/*! \brief The offset in bytes to the beginning pointer to data */
@@ -324,7 +338,7 @@ typedef struct DLManagedTensor {
324338
*
325339
* \note This is the current standard DLPack exchange data structure.
326340
*/
327-
struct DLManagedTensorVersioned {
341+
typedef struct DLManagedTensorVersioned {
328342
/*!
329343
* \brief The API and ABI version of the current managed Tensor
330344
*/
@@ -358,7 +372,195 @@ struct DLManagedTensorVersioned {
358372
uint64_t flags;
359373
/*! \brief DLTensor which is being memory managed */
360374
DLTensor dl_tensor;
361-
};
375+
} DLManagedTensorVersioned;
376+
377+
//----------------------------------------------------------------------
378+
// DLPack `__dlpack_c_exchange_api__` fast exchange protocol definitions
379+
//----------------------------------------------------------------------
380+
/*!
381+
* \brief Request a producer library to create a new tensor.
382+
*
383+
* Create a new `DLManagedTensorVersioned` within the context of the producer
384+
* library. The allocation is defined via the prototype DLTensor.
385+
*
386+
* This function is exposed by the framework through the DLPackExchangeAPI.
387+
*
388+
* \param prototype The prototype DLTensor. Only the dtype, ndim, shape,
389+
* and device fields are used.
390+
* \param out The output DLManagedTensorVersioned.
391+
* \param error_ctx Context for `SetError`.
392+
* \param SetError The function to set the error.
393+
* \return 0 on success, -1 on failure. SetError is called exactly when
394+
* -1 is returned (the implementer must ensure this).
395+
* \note - As a C function, must not thrown C++ exceptions.
396+
* - Error propagation via SetError to avoid any direct need
397+
* of Python API. Due to this `SetError` may have to ensure the GIL is
398+
* held since it will presumably set a Python error.
399+
*
400+
* \sa DLPackExchangeAPI
401+
*/
402+
typedef int (*DLPackManagedTensorAllocator)(
403+
DLTensor* prototype, DLManagedTensorVersioned** out, void* error_ctx,
404+
void (*SetError)(void* error_ctx, const char* kind, const char* message));
405+
406+
/*!
407+
* \brief Exports a PyObject* Tensor/NDArray to a DLManagedTensorVersioned.
408+
*
409+
* This function does not perform any stream synchronization. The consumer should query
410+
* DLPackCurrentWorkStream to get the current work stream and launch kernels on it.
411+
*
412+
* This function is exposed by the framework through the DLPackExchangeAPI.
413+
*
414+
* \param py_object The Python object to convert. Must have the same type
415+
* as the one the `DLPackExchangeAPI` was discovered from.
416+
* \param out The output DLManagedTensorVersioned.
417+
* \return 0 on success, -1 on failure with a Python exception set.
418+
* If the data cannot be described using DLPack this should be a BufferError if possible.
419+
* \note - As a C function, must not thrown C++ exceptions.
420+
*
421+
* \sa DLPackExchangeAPI, DLPackCurrentWorkStream
422+
*/
423+
typedef int (*DLPackManagedTensorFromPyObjectNoSync)(
424+
void* py_object, DLManagedTensorVersioned** out);
425+
426+
/*!
427+
* \brief Exports a PyObject* Tensor/NDArray to a provided DLTensor.
428+
*
429+
* This function provides a faster interface for temporary, non-owning, exchange.
430+
* The producer (implementer) still owns the memory of data, strides, shape.
431+
* The liveness of the DLTensor and the data it views is only guaranteed until
432+
* control is returned.
433+
*
434+
* This function currently assumes that the producer (implementer) can fill
435+
* in the DLTensor shape and strides without the need for temporary allocations.
436+
*
437+
* This function does not perform any stream synchronization. The consumer should query
438+
* DLPackCurrentWorkStream to get the current work stream and launch kernels on it.
439+
*
440+
* This function is exposed by the framework through the DLPackExchangeAPI.
441+
*
442+
* \param py_object The Python object to convert. Must have the same type
443+
* as the one the `DLPackExchangeAPI` was discovered from.
444+
* \param out The output DLTensor, whose space is pre-allocated on stack.
445+
* \return 0 on success, -1 on failure with a Python exception set.
446+
* \note - As a C function, must not thrown C++ exceptions.
447+
*
448+
* \sa DLPackExchangeAPI, DLPackCurrentWorkStream
449+
*/
450+
typedef int (*DLPackDLTensorFromPyObjectNoSync)(void* py_object, DLTensor* out);
451+
452+
/*!
453+
* \brief Obtain the current work stream of a device.
454+
*
455+
* Obtain the current work stream of a device from the producer framework.
456+
* For example, it should map to torch.cuda.current_stream in PyTorch.
457+
*
458+
* When device_type is kDLCPU, the consumer do not have to query the stream
459+
* and the producer can simply return NULL when queried.
460+
* The consumer do not have to do anything on stream sync or setting.
461+
* So CPU only framework can just provide a dummy implementation that
462+
* always set out_current_stream[0] to NULL.
463+
*
464+
* \param device_type The device type.
465+
* \param device_id The device id.
466+
* \param out_current_stream The output current work stream.
467+
*
468+
* \return 0 on success, -1 on failure with a Python exception set.
469+
* \note - As a C function, must not thrown C++ exceptions.
470+
*
471+
* \sa DLPackExchangeAPI
472+
*/
473+
typedef int (*DLPackCurrentWorkStream)(
474+
DLDeviceType device_type, int32_t device_id, void** out_current_stream);
475+
476+
/*!
477+
* \brief Imports a DLManagedTensorVersioned to a PyObject* Tensor/NDArray.
478+
*
479+
* Convert an owning DLManagedTensorVersioned* to the Python tensor of the
480+
* producer (implementer) library with the correct type.
481+
*
482+
* This function does not perform any stream synchronization.
483+
*
484+
* This function is exposed by the framework through the DLPackExchangeAPI.
485+
*
486+
* \param tensor The DLManagedTensorVersioned to convert the ownership of the
487+
* tensor is stolen.
488+
* \param out_py_object The output Python object.
489+
* \return 0 on success, -1 on failure with a Python exception set.
490+
*
491+
* \sa DLPackExchangeAPI
492+
*/
493+
typedef int (*DLPackManagedTensorToPyObjectNoSync)(
494+
DLManagedTensorVersioned* tensor, void** out_py_object);
495+
496+
/*!
497+
* \brief DLPackExchangeAPI stable header.
498+
* \sa DLPackExchangeAPI
499+
*/
500+
typedef struct DLPackExchangeAPIHeader {
501+
/*!
502+
* \brief The provided DLPack version the consumer must check major version
503+
* compatibility before using this struct.
504+
*/
505+
DLPackVersion version;
506+
/*!
507+
* \brief Optional pointer to an older DLPackExchangeAPI in the chain.
508+
*
509+
* It must be NULL if the framework does not support older versions.
510+
* If the current major version is larger than the one supported by the
511+
* consumer, the consumer may walk this to find an earlier supported version.
512+
*
513+
* \sa DLPackExchangeAPI
514+
*/
515+
struct DLPackExchangeAPIHeader* prev_api;
516+
} DLPackExchangeAPIHeader;
517+
518+
/*!
519+
* \brief Framework-specific function pointers table for DLPack exchange.
520+
*
521+
* Additionally to `__dlpack__()` we define a C function table sharable by
522+
* Python implementations via `__dlpack_c_exchange_api__`.
523+
* This attribute must be set on the type as a Python PyCapsule
524+
* with name "dlpack_exchange_api".
525+
*
526+
* Note that this must be defined on the type. The consumer should look up the
527+
* attribute on the type and may cache the result for each unique type.
528+
*
529+
* Array/Tensor libraries should statically create and initialize this structure
530+
* then return a pointer to DLPackExchangeAPI as an int value in Tensor/Array.
531+
* The DLPackExchangeAPI* must stay alive throughout the lifetime of the process.
532+
*/
533+
typedef struct DLPackExchangeAPI {
534+
/*!
535+
* \brief The header that remains stable across versions.
536+
*/
537+
DLPackExchangeAPIHeader header;
538+
/*!
539+
* \brief Producer function pointer for DLPackManagedTensorAllocator.
540+
* This function must not be NULL.
541+
*/
542+
DLPackManagedTensorAllocator managed_tensor_allocator;
543+
/*!
544+
* \brief Producer function pointer for DLPackManagedTensorFromPyObjectNoSync.
545+
* This function must not be NULL.
546+
*/
547+
DLPackManagedTensorFromPyObjectNoSync managed_tensor_from_py_object_no_sync;
548+
/*!
549+
* \brief Producer function pointer for DLPackManagedTensorToPyObjectNoSync.
550+
* This function must not be NULL.
551+
*/
552+
DLPackManagedTensorToPyObjectNoSync managed_tensor_to_py_object_no_sync;
553+
/*!
554+
* \brief Producer function pointer for DLPackDLTensorFromPyObjectNoSync.
555+
* This function can be NULL when the producer does not support this function.
556+
*/
557+
DLPackDLTensorFromPyObjectNoSync dltensor_from_py_object_no_sync;
558+
/*!
559+
* \brief Producer function pointer for DLPackCurrentWorkStream.
560+
* This function must not be NULL.
561+
*/
562+
DLPackCurrentWorkStream current_work_stream;
563+
} DLPackExchangeAPI;
362564

363565
#ifdef __cplusplus
364566
} // DLPACK_EXTERN_C

0 commit comments

Comments
 (0)