Skip to content

Commit 8df301e

Browse files
committed
Fail fast on unsupported StridedMemoryView strides.
Validate CAI and array-interface layout during StridedMemoryView construction so non-itemsize-divisible strides fail immediately. Add regression coverage to assert constructor-time failures for both interfaces. Made-with: Cursor
1 parent e46fcac commit 8df301e

File tree

2 files changed

+23
-3
lines changed

2 files changed

+23
-3
lines changed

cuda_core/cuda/core/_memoryview.pyx

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1096,6 +1096,8 @@ cpdef StridedMemoryView view_as_cai(obj, stream_ptr, view=None):
10961096
buf.exporting_obj = obj
10971097
buf.metadata = cai_data
10981098
buf.dl_tensor = NULL
1099+
# Validate shape/strides/typestr eagerly so constructor paths fail fast.
1100+
buf.get_layout()
10991101
buf.ptr, buf.readonly = cai_data["data"]
11001102
buf.is_device_accessible = True
11011103
if buf.ptr != 0:
@@ -1138,6 +1140,8 @@ cpdef StridedMemoryView view_as_array_interface(obj, view=None):
11381140
buf.exporting_obj = obj
11391141
buf.metadata = data
11401142
buf.dl_tensor = NULL
1143+
# Validate shape/strides/typestr eagerly so constructor paths fail fast.
1144+
buf.get_layout()
11411145
buf.ptr, buf.readonly = data["data"]
11421146
buf.is_device_accessible = False
11431147
buf.device_id = handle_return(driver.cuCtxGetDevice())

cuda_core/tests/test_utils.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -582,10 +582,26 @@ def test_from_array_interface_unsupported_strides(init_cuda):
582582
# Create an array with strides that aren't a multiple of itemsize
583583
x = np.array([(1, 2.0), (3, 4.0)], dtype=[("a", "i4"), ("b", "f8")])
584584
b = x["b"]
585-
smv = StridedMemoryView.from_array_interface(b)
586585
with pytest.raises(ValueError, match="strides must be divisible by itemsize"):
587-
# TODO: ideally this would raise on construction
588-
smv.strides # noqa: B018
586+
StridedMemoryView.from_array_interface(b)
587+
588+
589+
def test_from_cuda_array_interface_unsupported_strides(init_cuda):
590+
cai_obj = type(
591+
"UnsupportedStridesCAI",
592+
(),
593+
{
594+
"__cuda_array_interface__": {
595+
"shape": (2,),
596+
"strides": (10,),
597+
"typestr": "<f8",
598+
"data": (0, False),
599+
"version": 3,
600+
}
601+
},
602+
)()
603+
with pytest.raises(ValueError, match="strides must be divisible by itemsize"):
604+
StridedMemoryView.from_cuda_array_interface(cai_obj, stream_ptr=-1)
589605

590606

591607
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)