Skip to content

Commit 10fcae6

Browse files
rparolinclaudecpcloudleofang
authored
Add TMA TensorMapDescriptor support (#1687)
* initial commit * tma wide * clean up * Add comments to prepare_tensor_map_arg explaining allocation and lifetime Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * Address Copilot review feedback - Remove unused _alloc_device_tensor helper from tests - Add test for rank > 5 (6D tensor) to verify upper bound validation - Add NULL check for PyMem_Malloc in prepare_tensor_map_arg Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * Split TMA example into two focused files Move the replace_address() demonstration into its own self-contained example (tma_replace_address.py) so each file covers a single concept. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * pre-commit * adding stride meta data to gpu allocated memory * im2col fixes * Reuse CCCL TMA descriptor construction for tiled TensorMap and keep validated views alive to avoid DLPack-backed pointer lifetime hazards. Add explicit tiled element-stride coverage and acknowledge the DLPack include-layout compatibility follow-up in NVIDIA/cccl#7871. Made-with: Cursor * Skip im2col-wide TensorMap tests when runtime support is unavailable. Probe support in the fixture and skip when cuda.core is built without CUDA 13 im2col-wide support or when the driver/GPU reports CUDA_ERROR_INVALID_VALUE, so unsupported RTXPRO6000 lanes don't block unrelated changes. Made-with: Cursor * Align TensorMap API surface with review feedback and enforce context safety. Expose only TensorMapDescriptor in cuda.core, add StridedMemoryView.as_tensor_map(), remove redundant tensor-map fallback packing, and track/check descriptor context/device compatibility before replacement and kernel launch argument packing. Made-with: Cursor * Restore cu12 feature definitions in cuda_core pixi manifest. Bring back the cu12 feature blocks so pixi can parse the manifest and local test commands no longer fail early with a missing feature error. Made-with: Cursor * Handle TensorMap device validation by DLPack type Reject CUDA device-local tensors from a different GPU while still allowing CUDA host and managed memory. Add regression tests for descriptor creation, replace_address, and the shared validation helper. * formatting change * Update cuda_core/cuda/core/_cpp/tensor_map_cccl.h Co-authored-by: Leo Fang <leo80042@gmail.com> * Update cuda_core/examples/tma_replace_address.py Co-authored-by: Leo Fang <leo80042@gmail.com> * Update cuda_core/cuda/core/__init__.py Co-authored-by: Leo Fang <leo80042@gmail.com> * Align TensorMap creation and launch behavior with the latest review guidance. Keep the public TMA entry point on StridedMemoryView and remove avoidable launch/build overhead so the reviewed API stays smaller without regressing local CUDA builds. Made-with: Cursor * Consolidate the TMA examples around the libcudacxx wrappers. Keep the example surface smaller and closer to CUDA C++ by showing barrier/TMA helpers and replace_address() in one place instead of duplicating raw PTX snippets. Made-with: Cursor * Teach the TMA example where to find libcudacxx headers. Use the toolkit include and optional cccl include roots when compiling the wrapper-based example so NVRTC can resolve cuda/barrier outside the test harness. Made-with: Cursor * Bundle tiled TensorMap options and type retained views. Centralize the tiled descriptor arguments in an options object, keep dtype-like inputs on the public path while using raw driver values internally, and declare StridedMemoryView in a pxd so retained views stay typed without extra helper indirection. Made-with: Cursor * Keep the rebased TensorMap validation helper consistent. Remove the stale Cython-only `_require_view_device` definition left behind while porting the TensorMap fixes onto the PR head branch so the extension builds against the newer managed-memory-aware helper. Made-with: Cursor * Apply the pre-commit fixes for the rebased TensorMap branch. Add the missing SPDX header on the new `_memoryview.pxd` file and keep the test module formatted the way `ruff format` expects so pre-commit.ci can clear on the live PR branch. Made-with: Cursor * Keep the TensorMap multi-GPU tests on the view-based API. Replace the last stale `TensorMapDescriptor.from_tiled()` call sites with `StridedMemoryView.as_tensor_map()` so the multi-device CI coverage exercises the constructor path that actually exists on this branch. Made-with: Cursor --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Co-authored-by: Leo Fang <leo80042@gmail.com>
1 parent 43decaa commit 10fcae6

File tree

10 files changed

+2222
-29
lines changed

10 files changed

+2222
-29
lines changed

cuda_core/cuda/core/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,3 +68,4 @@
6868
Stream,
6969
StreamOptions,
7070
)
71+
from cuda.core._tensor_map import TensorMapDescriptor, TensorMapDescriptorOptions
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
//
3+
// SPDX-License-Identifier: Apache-2.0
4+
5+
#include "tensor_map_cccl.h"
6+
7+
#include <string.h>
8+
9+
#include <algorithm>
10+
#include <exception>
11+
12+
#if defined(__has_include)
13+
// Older CTK releases do not ship <cuda/tma>. When it is unavailable we keep
14+
// the CCCL helper compiled out and fall back to the direct driver path.
15+
# if __has_include(<cuda/tma>)
16+
# include <cuda/tma>
17+
# define CUDA_CORE_HAS_CUDA_TMA 1
18+
# else
19+
# define CUDA_CORE_HAS_CUDA_TMA 0
20+
# endif
21+
# if __has_include("dlpack.h")
22+
# include "dlpack.h"
23+
# define CUDA_CORE_HAS_DLPACK_H 1
24+
# elif __has_include(<dlpack/dlpack.h>)
25+
# include <dlpack/dlpack.h>
26+
# define CUDA_CORE_HAS_DLPACK_H 1
27+
# else
28+
# define CUDA_CORE_HAS_DLPACK_H 0
29+
# endif
30+
#else
31+
# define CUDA_CORE_HAS_CUDA_TMA 0
32+
# define CUDA_CORE_HAS_DLPACK_H 0
33+
#endif
34+
35+
static inline void cuda_core_write_err(char* err, size_t cap, const char* msg) noexcept
36+
{
37+
if (!err || cap == 0)
38+
return;
39+
if (!msg)
40+
{
41+
err[0] = '\0';
42+
return;
43+
}
44+
size_t n = ::strlen(msg);
45+
if (n >= cap)
46+
n = cap - 1;
47+
::memcpy(err, msg, n);
48+
err[n] = '\0';
49+
}
50+
51+
int cuda_core_cccl_make_tma_descriptor_tiled(
52+
void* out_tensor_map,
53+
void* data,
54+
int device_type,
55+
int device_id,
56+
int ndim,
57+
const int64_t* shape,
58+
const int64_t* strides,
59+
uint8_t dtype_code,
60+
uint8_t dtype_bits,
61+
uint16_t dtype_lanes,
62+
const int* box_sizes,
63+
const int* elem_strides,
64+
int interleave_layout,
65+
int swizzle,
66+
int l2_fetch_size,
67+
int oob_fill,
68+
char* err,
69+
size_t err_cap) noexcept
70+
{
71+
#if !(CUDA_CORE_HAS_CUDA_TMA && CUDA_CORE_HAS_DLPACK_H)
72+
(void)out_tensor_map;
73+
(void)data;
74+
(void)device_type;
75+
(void)device_id;
76+
(void)ndim;
77+
(void)shape;
78+
(void)strides;
79+
(void)dtype_code;
80+
(void)dtype_bits;
81+
(void)dtype_lanes;
82+
(void)box_sizes;
83+
(void)elem_strides;
84+
(void)interleave_layout;
85+
(void)swizzle;
86+
(void)l2_fetch_size;
87+
(void)oob_fill;
88+
cuda_core_write_err(err, err_cap, "CCCL <cuda/tma> and/or <dlpack/dlpack.h> not available at build time");
89+
return 1;
90+
#else
91+
try
92+
{
93+
if (!out_tensor_map)
94+
{
95+
cuda_core_write_err(err, err_cap, "out_tensor_map is NULL");
96+
return 1;
97+
}
98+
if (!data)
99+
{
100+
cuda_core_write_err(err, err_cap, "tensor data pointer is NULL");
101+
return 1;
102+
}
103+
if (!shape || !box_sizes || ndim <= 0)
104+
{
105+
cuda_core_write_err(err, err_cap, "invalid rank/shape/box_sizes");
106+
return 1;
107+
}
108+
109+
DLTensor t{};
110+
t.data = data;
111+
t.device = {static_cast<DLDeviceType>(device_type), device_id};
112+
t.ndim = ndim;
113+
t.dtype.code = dtype_code;
114+
t.dtype.bits = dtype_bits;
115+
t.dtype.lanes = dtype_lanes;
116+
// CCCL promises not to mutate the arrays, but DLPack uses non-const pointers.
117+
t.shape = const_cast<int64_t*>(shape);
118+
t.strides = const_cast<int64_t*>(strides);
119+
t.byte_offset = 0;
120+
121+
const auto layout = static_cast<cuda::tma_interleave_layout>(interleave_layout);
122+
const auto swz = static_cast<cuda::tma_swizzle>(swizzle);
123+
const auto l2 = static_cast<cuda::tma_l2_fetch_size>(l2_fetch_size);
124+
const auto oob = static_cast<cuda::tma_oob_fill>(oob_fill);
125+
126+
auto box = cuda::std::span<const int>(box_sizes, static_cast<size_t>(ndim));
127+
128+
CUtensorMap desc{};
129+
if (elem_strides)
130+
{
131+
auto es = cuda::std::span<const int>(elem_strides, static_cast<size_t>(ndim));
132+
desc = cuda::make_tma_descriptor(t, box, es, layout, swz, l2, oob);
133+
}
134+
else
135+
{
136+
desc = cuda::make_tma_descriptor(t, box, layout, swz, l2, oob);
137+
}
138+
139+
::memcpy(out_tensor_map, &desc, sizeof(CUtensorMap));
140+
cuda_core_write_err(err, err_cap, nullptr);
141+
return 0;
142+
}
143+
catch (const std::exception& e)
144+
{
145+
cuda_core_write_err(err, err_cap, e.what());
146+
return 1;
147+
}
148+
catch (...)
149+
{
150+
cuda_core_write_err(err, err_cap, "unknown error while building TMA descriptor");
151+
return 1;
152+
}
153+
#endif
154+
}
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
//
3+
// SPDX-License-Identifier: Apache-2.0
4+
5+
#ifndef CUDA_CORE_TENSOR_MAP_CCCL_H_
6+
#define CUDA_CORE_TENSOR_MAP_CCCL_H_
7+
8+
#ifdef __cplusplus
9+
#include <cstddef>
10+
#include <cstdint>
11+
extern "C" {
12+
#else
13+
#include <stddef.h>
14+
#include <stdint.h>
15+
#endif
16+
17+
// Build a tiled CUtensorMap using CCCL's cuda::make_tma_descriptor (from <cuda/tma>).
18+
//
19+
// Returns 0 on success; on failure returns non-zero and writes a best-effort
20+
// human-readable message into (err, err_cap) if provided.
21+
int cuda_core_cccl_make_tma_descriptor_tiled(
22+
void* out_tensor_map,
23+
void* data,
24+
int device_type,
25+
int device_id,
26+
int ndim,
27+
const int64_t* shape, // length ndim
28+
const int64_t* strides, // length ndim, or NULL for contiguous
29+
uint8_t dtype_code,
30+
uint8_t dtype_bits,
31+
uint16_t dtype_lanes,
32+
const int* box_sizes, // length ndim
33+
const int* elem_strides, // length ndim, or NULL for all-ones overload
34+
int interleave_layout,
35+
int swizzle,
36+
int l2_fetch_size,
37+
int oob_fill,
38+
char* err,
39+
size_t err_cap) noexcept;
40+
41+
#ifdef __cplusplus
42+
} // extern "C"
43+
#endif
44+
45+
#endif // CUDA_CORE_TENSOR_MAP_CCCL_H_

cuda_core/cuda/core/_kernel_arg_handler.pyx

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ import ctypes
1616
import numpy
1717

1818
from cuda.core._memory import Buffer
19+
from cuda.core._tensor_map import TensorMapDescriptor as _TensorMapDescriptor_py
20+
from cuda.core._tensor_map cimport TensorMapDescriptor
1921
from cuda.core._utils.cuda_utils import driver
2022
from cuda.bindings cimport cydriver
2123

@@ -97,6 +99,9 @@ cdef object numpy_complex64 = numpy.complex64
9799
cdef object numpy_complex128 = numpy.complex128
98100

99101

102+
cdef object tensor_map_descriptor_type = _TensorMapDescriptor_py
103+
104+
100105
# limitation due to cython/cython#534
101106
ctypedef void* voidptr
102107

@@ -124,6 +129,17 @@ cdef inline int prepare_arg(
124129
return 0
125130

126131

132+
cdef inline int prepare_tensor_map_arg(
133+
vector.vector[void*]& data,
134+
vector.vector[void*]& data_addresses,
135+
TensorMapDescriptor arg,
136+
const size_t idx) except -1:
137+
# cuLaunchKernel copies argument bytes during launch, so a TensorMap
138+
# descriptor can point directly at its internal CUtensorMap storage.
139+
data_addresses[idx] = arg._get_data_ptr()
140+
return 0
141+
142+
127143
cdef inline int prepare_ctypes_arg(
128144
vector.vector[void*]& data,
129145
vector.vector[void*]& data_addresses,
@@ -290,6 +306,9 @@ cdef class ParamHolder:
290306
elif arg_type is complex:
291307
prepare_arg[cpp_double_complex](self.data, self.data_addresses, arg, i)
292308
continue
309+
elif arg_type is tensor_map_descriptor_type:
310+
prepare_tensor_map_arg(self.data, self.data_addresses, <TensorMapDescriptor>arg, i)
311+
continue
293312

294313
not_prepared = prepare_numpy_arg(self.data, self.data_addresses, arg, i)
295314
if not_prepared:
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
from libc.stdint cimport intptr_t
6+
7+
from cuda.core._dlpack cimport DLTensor
8+
from cuda.core._layout cimport _StridedLayout
9+
10+
11+
cdef class StridedMemoryView:
12+
cdef readonly:
13+
intptr_t ptr
14+
int device_id
15+
bint is_device_accessible
16+
bint readonly
17+
object exporting_obj
18+
19+
cdef:
20+
object metadata
21+
DLTensor* dl_tensor
22+
_StridedLayout _layout
23+
object _buffer
24+
object _dtype
25+
26+
cdef inline _StridedLayout get_layout(self)
27+
cdef inline object get_buffer(self)
28+
cdef inline object get_dtype(self)

cuda_core/cuda/core/_memoryview.pyx

Lines changed: 38 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -107,35 +107,6 @@ cdef class StridedMemoryView:
107107
it will be the Buffer instance passed to the method.
108108
109109
"""
110-
cdef readonly:
111-
intptr_t ptr
112-
int device_id
113-
bint is_device_accessible
114-
bint readonly
115-
object exporting_obj
116-
117-
cdef:
118-
# If using dlpack, this is a strong reference to the result of
119-
# obj.__dlpack__() so we can lazily create shape and strides from
120-
# it later. If using CAI, this is a reference to the source
121-
# `__cuda_array_interface__` object.
122-
object metadata
123-
124-
# The tensor object if has obj has __dlpack__, otherwise must be NULL
125-
DLTensor *dl_tensor
126-
127-
# Memoized properties
128-
# Either lazily inferred from dl_tensor/metadata,
129-
# or explicitly provided if created with from_buffer().
130-
_StridedLayout _layout
131-
# Either exporting_obj if it is a Buffer, otherwise a Buffer instance
132-
# with owner set to the exporting object.
133-
object _buffer
134-
# Either lazily inferred from dl_tensor/metadata,
135-
# or explicitly provided if created with from_buffer().
136-
# In the latter case, it can be None.
137-
object _dtype
138-
139110
def __init__(self, obj: object = None, stream_ptr: int | None = None) -> None:
140111
cdef str clsname = self.__class__.__name__
141112
if obj is not None:
@@ -316,6 +287,44 @@ cdef class StridedMemoryView:
316287
view_buffer_strided(view, self.get_buffer(), layout, dtype, self.readonly)
317288
return view
318289

290+
def as_tensor_map(
291+
self,
292+
box_dim=None,
293+
*,
294+
options=None,
295+
element_strides=None,
296+
data_type=None,
297+
interleave=None,
298+
swizzle=None,
299+
l2_promotion=None,
300+
oob_fill=None,
301+
):
302+
"""Create a tiled :obj:`TensorMapDescriptor` from this view.
303+
304+
This is the public entry point for creating tiled tensor map
305+
descriptors in ``cuda.core``. Pass either ``box_dim`` and the
306+
individual keyword arguments directly, or provide bundled tiled
307+
options via ``options=``.
308+
"""
309+
from cuda.core._tensor_map import TensorMapDescriptor
310+
311+
kwargs = {}
312+
if options is not None:
313+
kwargs["options"] = options
314+
if element_strides is not None:
315+
kwargs["element_strides"] = element_strides
316+
if data_type is not None:
317+
kwargs["data_type"] = data_type
318+
if interleave is not None:
319+
kwargs["interleave"] = interleave
320+
if swizzle is not None:
321+
kwargs["swizzle"] = swizzle
322+
if l2_promotion is not None:
323+
kwargs["l2_promotion"] = l2_promotion
324+
if oob_fill is not None:
325+
kwargs["oob_fill"] = oob_fill
326+
return TensorMapDescriptor._from_tiled(self, box_dim, **kwargs)
327+
319328
def copy_from(
320329
self, other : StridedMemoryView, stream : Stream,
321330
allocator = None,
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
from cuda.bindings cimport cydriver
6+
from libc.stdint cimport intptr_t
7+
from cuda.core._memoryview cimport StridedMemoryView
8+
9+
10+
cdef class TensorMapDescriptor:
11+
cdef cydriver.CUtensorMap _tensor_map
12+
cdef int _device_id
13+
cdef intptr_t _context
14+
cdef object _source_ref
15+
cdef StridedMemoryView _view_ref
16+
cdef object _repr_info
17+
18+
cdef int _check_context_compat(self) except -1
19+
cdef void* _get_data_ptr(self)

0 commit comments

Comments
 (0)