Skip to content

Commit 3695783

Browse files
authored
Fail fast on unsupported StridedMemoryView strides (#1730)
* Fail fast on unsupported StridedMemoryView strides. Validate CAI and array-interface layout during StridedMemoryView construction so non-itemsize-divisible strides fail immediately. Add regression coverage to assert constructor-time failures for both interfaces. Made-with: Cursor * Expand CAI stride edge-case coverage for StridedMemoryView. Add explicit CAI tests for zero, empty-array, and negative-stride layouts so stride divisibility validation remains covered for uncommon but valid cases. Require `shape` and `strides` as keyword-only args in the synthetic CAI helper to keep test call sites unambiguous. Made-with: Cursor
1 parent 09069a3 commit 3695783

File tree

2 files changed

+50
-3
lines changed

2 files changed

+50
-3
lines changed

cuda_core/cuda/core/_memoryview.pyx

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1096,6 +1096,8 @@ cpdef StridedMemoryView view_as_cai(obj, stream_ptr, view=None):
10961096
buf.exporting_obj = obj
10971097
buf.metadata = cai_data
10981098
buf.dl_tensor = NULL
1099+
# Validate shape/strides/typestr eagerly so constructor paths fail fast.
1100+
buf.get_layout()
10991101
buf.ptr, buf.readonly = cai_data["data"]
11001102
buf.is_device_accessible = True
11011103
if buf.ptr != 0:
@@ -1138,6 +1140,8 @@ cpdef StridedMemoryView view_as_array_interface(obj, view=None):
11381140
buf.exporting_obj = obj
11391141
buf.metadata = data
11401142
buf.dl_tensor = NULL
1143+
# Validate shape/strides/typestr eagerly so constructor paths fail fast.
1144+
buf.get_layout()
11411145
buf.ptr, buf.readonly = data["data"]
11421146
buf.is_device_accessible = False
11431147
buf.device_id = handle_return(driver.cuCtxGetDevice())

cuda_core/tests/test_utils.py

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -582,10 +582,53 @@ def test_from_array_interface_unsupported_strides(init_cuda):
582582
# Create an array with strides that aren't a multiple of itemsize
583583
x = np.array([(1, 2.0), (3, 4.0)], dtype=[("a", "i4"), ("b", "f8")])
584584
b = x["b"]
585-
smv = StridedMemoryView.from_array_interface(b)
586585
with pytest.raises(ValueError, match="strides must be divisible by itemsize"):
587-
# TODO: ideally this would raise on construction
588-
smv.strides # noqa: B018
586+
StridedMemoryView.from_array_interface(b)
587+
588+
589+
def _make_cuda_array_interface_obj(*, shape, strides, typestr="<f8", data=(0, False), version=3):
590+
return type(
591+
"SyntheticCAI",
592+
(),
593+
{
594+
"__cuda_array_interface__": {
595+
"shape": shape,
596+
"strides": strides,
597+
"typestr": typestr,
598+
"data": data,
599+
"version": version,
600+
}
601+
},
602+
)()
603+
604+
605+
def test_from_cuda_array_interface_unsupported_strides(init_cuda):
606+
cai_obj = _make_cuda_array_interface_obj(shape=(2,), strides=(10,))
607+
with pytest.raises(ValueError, match="strides must be divisible by itemsize"):
608+
StridedMemoryView.from_cuda_array_interface(cai_obj, stream_ptr=-1)
609+
610+
611+
def test_from_cuda_array_interface_zero_strides(init_cuda):
612+
cai_obj = _make_cuda_array_interface_obj(shape=(1, 1), strides=(0, 0))
613+
smv = StridedMemoryView.from_cuda_array_interface(cai_obj, stream_ptr=-1)
614+
assert smv.shape == (1, 1)
615+
assert smv.strides == (0, 0)
616+
617+
618+
@pytest.mark.skipif(cp is None, reason="CuPy is not installed")
619+
def test_from_cuda_array_interface_negative_strides(init_cuda):
620+
x = cp.arange(4, dtype=cp.float64)[::-1]
621+
smv = StridedMemoryView.from_cuda_array_interface(_EnforceCAIView(x), stream_ptr=-1)
622+
assert smv.shape == x.shape
623+
assert smv.strides == (-1,)
624+
625+
626+
def test_from_cuda_array_interface_empty_array(init_cuda):
627+
cai_obj = _make_cuda_array_interface_obj(shape=(0, 3), strides=(24, 8))
628+
smv = StridedMemoryView.from_cuda_array_interface(cai_obj, stream_ptr=-1)
629+
assert smv.size == 0
630+
assert smv.shape == (0, 3)
631+
assert smv.strides == (3, 1)
589632

590633

591634
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)