From 3c6e454a1a03624bd87690e7f05e5fd3cc4a4c5c Mon Sep 17 00:00:00 2001 From: David Gornshtein Date: Tue, 5 May 2026 23:45:16 +0200 Subject: [PATCH] Fix mx.array DLPack dispatch --- python/src/CMakeLists.txt | 2 + python/src/convert.cpp | 15 +- python/src/dlpack_consumer.cpp | 296 ++++++++++++++++++++++++ python/src/dlpack_consumer.h | 69 ++++++ python/src/dlpack_consumer_metal.cpp | 65 ++++++ python/src/dlpack_consumer_no_metal.cpp | 13 ++ python/src/dlpack_format.h | 39 ++++ python/tests/test_dlpack_consumer.py | 164 +++++++++++++ 8 files changed, 662 insertions(+), 1 deletion(-) create mode 100644 python/src/dlpack_consumer.cpp create mode 100644 python/src/dlpack_consumer.h create mode 100644 python/src/dlpack_consumer_metal.cpp create mode 100644 python/src/dlpack_consumer_no_metal.cpp create mode 100644 python/src/dlpack_format.h create mode 100644 python/tests/test_dlpack_consumer.py diff --git a/python/src/CMakeLists.txt b/python/src/CMakeLists.txt index 447271500b..58cd31c0c7 100644 --- a/python/src/CMakeLists.txt +++ b/python/src/CMakeLists.txt @@ -10,6 +10,8 @@ nanobind_add_module( ${CMAKE_CURRENT_SOURCE_DIR}/array.cpp ${CMAKE_CURRENT_SOURCE_DIR}/convert.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/dlpack_consumer.cpp + $,${CMAKE_CURRENT_SOURCE_DIR}/dlpack_consumer_metal.cpp,${CMAKE_CURRENT_SOURCE_DIR}/dlpack_consumer_no_metal.cpp> ${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp ${CMAKE_CURRENT_SOURCE_DIR}/export.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp diff --git a/python/src/convert.cpp b/python/src/convert.cpp index 471ff24a99..bef773db0b 100644 --- a/python/src/convert.cpp +++ b/python/src/convert.cpp @@ -6,6 +6,7 @@ #include #include "python/src/convert.h" +#include "python/src/dlpack_consumer.h" #include "python/src/utils.h" #include "mlx/utils.h" @@ -494,7 +495,13 @@ mx::array create_array(nb::object v, std::optional t) { } else if (nb::isinstance(v)) { auto arr = nb::cast(v); return mx::astype(arr, t.value_or(arr.dtype())); - } else if (nb::ndarray_check(v)) { + } + + const bool has_mlx_array = nb::hasattr(v, "__mlx_array__"); + const bool is_dlpack = + PyCapsule_CheckExact(v.ptr()) || nb::hasattr(v, "__dlpack__"); + + if (!has_mlx_array && nb::ndarray_check(v)) { using ContigArray = nb::ndarray; ContigArray nd; std::optional nb_dtype; @@ -507,6 +514,12 @@ mx::array create_array(nb::object v, std::optional t) { nd = nb::cast(v); } return nd_array_to_mlx(nd, t, nb_dtype); + } else if (has_mlx_array) { + auto arr = nb::cast(v.attr("__mlx_array__")()); + return mx::astype(arr, t.value_or(arr.dtype())); + } else if (is_dlpack) { + auto arr = dlpack_to_mlx(v); + return mx::astype(arr, t.value_or(arr.dtype())); } else { auto arr = to_array_with_accessor(v); return mx::astype(arr, t.value_or(arr.dtype())); diff --git a/python/src/dlpack_consumer.cpp b/python/src/dlpack_consumer.cpp new file mode 100644 index 0000000000..63aa5f2222 --- /dev/null +++ b/python/src/dlpack_consumer.cpp @@ -0,0 +1,296 @@ +// Copyright © 2026 Apple Inc. + +#include "python/src/dlpack_consumer.h" + +#include +#include +#include +#include +#include + +#include "mlx/allocator.h" +#include "mlx/dtype.h" +#include "python/src/convert.h" +#include "python/src/dlpack_format.h" + +mx::Dtype dlpack_to_mlx_dtype(const nb::dlpack::dtype& dt) { + if (dt.lanes != 1) { + std::ostringstream msg; + msg << "[array] DLPack tensors with lanes != 1 are not supported " + << "(got lanes=" << dt.lanes << ")."; + throw std::invalid_argument(msg.str()); + } + using Code = nb::dlpack::dtype_code; + switch (static_cast(dt.code)) { + case Code::Bool: + if (dt.bits == 8) + return mx::bool_; + break; + case Code::Int: + switch (dt.bits) { + case 8: + return mx::int8; + case 16: + return mx::int16; + case 32: + return mx::int32; + case 64: + return mx::int64; + } + break; + case Code::UInt: + switch (dt.bits) { + case 8: + return mx::uint8; + case 16: + return mx::uint16; + case 32: + return mx::uint32; + case 64: + return mx::uint64; + } + break; + case Code::Float: + switch (dt.bits) { + case 16: + return mx::float16; + case 32: + return mx::float32; + case 64: + return mx::float64; + } + break; + case Code::Bfloat: + if (dt.bits == 16) + return mx::bfloat16; + break; + case Code::Complex: + if (dt.bits == 64) + return mx::complex64; + break; + default: + break; + } + std::ostringstream msg; + msg << "[array] Unsupported DLPack dtype: code=" << int(dt.code) + << ", bits=" << int(dt.bits) << "."; + throw std::invalid_argument(msg.str()); +} + +mx::Shape validate_and_extract_shape(const nb::dlpack::dltensor& t) { + if (t.ndim < 0) { + throw std::invalid_argument("[array] ndim must be non-negative."); + } + if (t.ndim > 0 && t.shape == nullptr) { + throw std::invalid_argument( + "[array] shape must not be null when ndim > 0."); + } + mx::Shape shape; + shape.reserve(t.ndim); + for (int i = 0; i < t.ndim; ++i) { + if (t.shape[i] < 0) { + throw std::invalid_argument("[array] shape dims must be non-negative."); + } + if (t.shape[i] > std::numeric_limits::max()) { + throw std::invalid_argument("[array] shape dim exceeds int32 range."); + } + shape.push_back(static_cast(t.shape[i])); + } + return shape; +} + +bool is_row_contiguous(const mx::Shape& shape, const int64_t* strides) { + if (strides == nullptr) { + return true; + } + int64_t expected = 1; + for (int i = static_cast(shape.size()) - 1; i >= 0; --i) { + if (strides[i] != expected) { + return false; + } + if (shape[i] != 0 && + expected > std::numeric_limits::max() / shape[i]) { + return false; + } + expected *= shape[i]; + } + return true; +} + +size_t checked_num_bytes(const mx::Shape& shape, mx::Dtype dtype) { + size_t nelems = 1; + for (auto dim : shape) { + if (dim != 0 && + nelems > + std::numeric_limits::max() / static_cast(dim)) { + throw std::invalid_argument( + "[array] shape element count overflows size_t."); + } + nelems *= static_cast(dim); + } + if (dtype.size() != 0 && + nelems > std::numeric_limits::max() / dtype.size()) { + throw std::invalid_argument("[array] tensor byte size overflows."); + } + return nelems * dtype.size(); +} + +namespace { + +struct ParsedCapsule { + PyObject* capsule = nullptr; + bool versioned = false; + nb::dlpack::dltensor* tensor = nullptr; + void* managed = nullptr; // typed by `versioned` +}; + +ParsedCapsule parse_capsule(PyObject* obj) { + ParsedCapsule out; + if (PyCapsule_IsValid(obj, "dltensor_versioned")) { + out.versioned = true; + auto* m = static_cast( + PyCapsule_GetPointer(obj, "dltensor_versioned")); + if (m == nullptr) { + throw std::invalid_argument( + "[array] dltensor_versioned capsule is null."); + } + out.managed = static_cast(m); + out.tensor = &m->dl_tensor; + } else if (PyCapsule_IsValid(obj, "dltensor")) { + out.versioned = false; + auto* m = static_cast( + PyCapsule_GetPointer(obj, "dltensor")); + if (m == nullptr) { + throw std::invalid_argument("[array] dltensor capsule is null."); + } + out.managed = static_cast(m); + out.tensor = &m->dl_tensor; + } else { + throw std::invalid_argument( + "[array] expected a PyCapsule named 'dltensor' or " + "'dltensor_versioned'."); + } + out.capsule = obj; + return out; +} + +void mark_capsule_consumed(PyObject* capsule, bool versioned) { + const char* used = versioned ? "used_dltensor_versioned" : "used_dltensor"; + if (PyCapsule_SetName(capsule, used) != 0 || + PyCapsule_SetDestructor(capsule, nullptr) != 0) { + PyErr_Clear(); + throw std::runtime_error( + "[array] failed to mark DLPack capsule as consumed."); + } +} + +mx::array build_cpu_array(nb::dlpack::dltensor& t, const mx::Shape& shape) { + if (!is_row_contiguous(shape, t.strides)) { + throw std::invalid_argument( + "[array] non-row-contiguous DLPack strides are not supported " + "for kDLCPU tensors yet."); + } + if (t.byte_offset != 0) { + throw std::invalid_argument( + "[array] kDLCPU capsule with non-zero byte_offset is not " + "supported yet."); + } + auto dtype = dlpack_to_mlx_dtype(t.dtype); + size_t nbytes = checked_num_bytes(shape, dtype); + if (nbytes > 0 && t.data == nullptr) { + throw std::invalid_argument( + "[array] kDLCPU capsule has null data pointer."); + } + + // Allocate a fresh mlx buffer and copy the producer's bytes in. This + // mirrors the semantics of nd_array_to_mlx_contiguous for the kDLCPU + // path. We use the (allocator::Buffer, Shape, Dtype, Deleter) overload to + // get an array whose status() == Status::available immediately. + auto buffer = mx::allocator::malloc(nbytes); + if (nbytes > 0) { + std::memcpy(static_cast(buffer.raw_ptr()), t.data, nbytes); + } + mx::array out(buffer, shape, dtype, mx::allocator::free); + + return out; +} + +} // namespace + +void DLPackOwner::invoke() { + if (!active_ || mt_ == nullptr) + return; + if (versioned_) { + auto* m = static_cast(mt_); + if (m->deleter) + m->deleter(m); + } else { + auto* m = static_cast(mt_); + if (m->deleter) + m->deleter(m); + } + mt_ = nullptr; + active_ = false; +} + +mx::array dlpack_to_mlx(nb::object obj) { + // Accept either: + // * a PyCapsule (raw DLPack output), + // * an object that returns a PyCapsule from __dlpack__(), + // * an object whose __dlpack__() returns *another* object that is itself + // PEP-3118 / DLPack-compliant (e.g. nanobind's nb_ndarray wrapper that + // mlx returns from mx.array.__dlpack__). We unwrap up to N times. + constexpr int kMaxUnwrap = 4; + PyObject* raw = obj.ptr(); + nb::object current = obj; // own a reference for the chain + + for (int i = 0; i < kMaxUnwrap; ++i) { + if (PyCapsule_CheckExact(raw)) { + break; + } + if (!nb::hasattr(current, "__dlpack__")) { + throw std::invalid_argument( + "[array] expected a PyCapsule or an object exposing " + "__dlpack__()."); + } + current = current.attr("__dlpack__")(); + raw = current.ptr(); + } + if (!PyCapsule_CheckExact(raw)) { + throw std::invalid_argument( + "[array] could not resolve input to a DLPack PyCapsule " + "after repeated __dlpack__() calls."); + } + + ParsedCapsule p = parse_capsule(raw); + auto& t = *p.tensor; + auto shape = validate_and_extract_shape(t); + + switch (t.device.device_type) { + case dlpack_format::kDLCPU: { + auto owner = std::make_shared(p.versioned, p.managed); + auto out = build_cpu_array(t, shape); + mark_capsule_consumed(p.capsule, p.versioned); + owner->activate(); + owner->invoke(); + return out; + } + case dlpack_format::kDLMetal: { + auto owner = std::make_shared(p.versioned, p.managed); + auto out = build_dlpack_metal_array(t, owner); + mark_capsule_consumed(p.capsule, p.versioned); + owner->activate(); + return out; + } + case dlpack_format::kDLCUDA: + throw std::invalid_argument( + "[array] kDLCUDA tensors are not supported by MLX. Move the " + "tensor to host memory or to a Metal-backed framework first."); + default: { + std::ostringstream msg; + msg << "[array] unsupported DLPack device_type " << t.device.device_type + << "."; + throw std::invalid_argument(msg.str()); + } + } +} diff --git a/python/src/dlpack_consumer.h b/python/src/dlpack_consumer.h new file mode 100644 index 0000000000..74f3b064dc --- /dev/null +++ b/python/src/dlpack_consumer.h @@ -0,0 +1,69 @@ +// Copyright © 2026 Apple Inc. +#pragma once + +#include +#include + +#include +#include + +#include "mlx/array.h" +#include "mlx/dtype.h" + +namespace mx = mlx::core; +namespace nb = nanobind; + +// Convert a DLPack capsule (or a Python object exposing __dlpack__) into an +// mx::array. +// +// Supported device types: +// * kDLCPU (1) : copies host bytes into a fresh mlx allocation. +// * kDLMetal (8) : zero-copy; wraps a foreign MTL::Buffer in shared +// storage mode. Non-shared buffers are rejected. +// +// All other device types raise std::invalid_argument. +// +// kDLCPU input is copied into a fresh MLX allocation and the capsule deleter is +// invoked before return. kDLMetal input is wrapped zero-copy, so the returned +// mx::array keeps the capsule deleter alive until the array and any aliases are +// destroyed. Rejected capsules are left unconsumed. +mx::array dlpack_to_mlx(nb::object obj); + +mx::Dtype dlpack_to_mlx_dtype(const nb::dlpack::dtype& dt); +mx::Shape validate_and_extract_shape(const nb::dlpack::dltensor& t); +bool is_row_contiguous(const mx::Shape& shape, const int64_t* strides); +size_t checked_num_bytes(const mx::Shape& shape, mx::Dtype dtype); + +// A small reference-counted holder that drives the DLPack capsule's deleter +// exactly once after ownership has been committed. Kept here so the metal-glue +// translation unit can reach it. +class DLPackOwner { + public: + DLPackOwner(bool versioned, void* mt) : versioned_(versioned), mt_(mt) {} + + ~DLPackOwner() { + invoke(); + } + + void activate() { + active_ = true; + } + + void invoke(); + + // Disable copies: only one DLPackOwner may exist per managed tensor. + DLPackOwner(const DLPackOwner&) = delete; + DLPackOwner& operator=(const DLPackOwner&) = delete; + + private: + bool versioned_; + void* mt_; + bool active_ = false; +}; + +// Build an mx::array from a DLPack tensor whose data is a foreign MTL::Buffer. +// Defined in dlpack_consumer_metal.cpp when MLX_BUILD_METAL is on, and in +// dlpack_consumer_no_metal.cpp otherwise. +mx::array build_dlpack_metal_array( + nb::dlpack::dltensor& t, + std::shared_ptr owner); diff --git a/python/src/dlpack_consumer_metal.cpp b/python/src/dlpack_consumer_metal.cpp new file mode 100644 index 0000000000..c7e66115cb --- /dev/null +++ b/python/src/dlpack_consumer_metal.cpp @@ -0,0 +1,65 @@ +// Copyright © 2026 Apple Inc. + +#include + +#include + +#include "mlx/allocator.h" +#include "mlx/array.h" +#include "mlx/backend/metal/metal.h" +#include "python/src/dlpack_consumer.h" +#include "python/src/dlpack_format.h" + +mx::array build_dlpack_metal_array( + nb::dlpack::dltensor& t, + std::shared_ptr owner) { + if (!mx::metal::is_available()) { + throw std::invalid_argument( + "[array] Metal device tensors require an MLX build with Metal " + "support enabled and a Metal-capable host."); + } + if (t.data == nullptr) { + throw std::invalid_argument( + "[array] kDLMetal capsule has null MTLBuffer pointer."); + } + // For kDLMetal, DLPack stipulates `data` is an MTL::Buffer*. + auto mtl_buffer = static_cast(t.data); + if (mtl_buffer->storageMode() != MTL::StorageModeShared) { + throw std::invalid_argument( + "[array] foreign MTLBuffer must use MTLStorageModeShared. MLX " + "currently relies on shared-mode buffers for read/write access. " + "Allocate the producer-side buffer with MTLResourceStorageModeShared " + "before exporting via DLPack."); + } + if (t.byte_offset != 0) { + throw std::invalid_argument( + "[array] kDLMetal capsule with non-zero byte_offset is not " + "supported yet."); + } + auto shape = validate_and_extract_shape(t); + if (!is_row_contiguous(shape, t.strides)) { + throw std::invalid_argument( + "[array] non-row-contiguous DLPack strides are not supported. " + "Reshape on the producer side before exporting."); + } + + auto dtype = dlpack_to_mlx_dtype(t.dtype); + size_t nbytes = checked_num_bytes(shape, dtype); + if (nbytes > mtl_buffer->length()) { + throw std::invalid_argument( + "[array] kDLMetal capsule shape/dtype requires more bytes than " + "the exported MTLBuffer contains."); + } + + // Wrap the foreign MTL::Buffer* directly. The producer retains the + // underlying allocation; we drive the capsule's deleter when the wrapping + // mx::array (and any aliases) are destroyed. + mx::allocator::Buffer wrapped(static_cast(mtl_buffer)); + mx::Deleter deleter = [owner](mx::allocator::Buffer) mutable { + // Drop our shared_ptr; if this was the last reference, the owner's + // destructor invokes the DLPack deleter. + owner.reset(); + }; + + return mx::array(wrapped, std::move(shape), dtype, std::move(deleter)); +} diff --git a/python/src/dlpack_consumer_no_metal.cpp b/python/src/dlpack_consumer_no_metal.cpp new file mode 100644 index 0000000000..5e49e6be8e --- /dev/null +++ b/python/src/dlpack_consumer_no_metal.cpp @@ -0,0 +1,13 @@ +// Copyright © 2026 Apple Inc. + +#include + +#include "python/src/dlpack_consumer.h" + +mx::array build_dlpack_metal_array( + nb::dlpack::dltensor& /*t*/, + std::shared_ptr /*owner*/) { + throw std::invalid_argument( + "[array] MLX was built without Metal support; cannot consume " + "kDLMetal capsules."); +} diff --git a/python/src/dlpack_format.h b/python/src/dlpack_format.h new file mode 100644 index 0000000000..b95f747175 --- /dev/null +++ b/python/src/dlpack_format.h @@ -0,0 +1,39 @@ +// Copyright © 2026 Apple Inc. +#pragma once + +#include + +#include + +namespace nb = nanobind; + +// DLPack ABI structs. We define them locally because nanobind does not expose +// the DLManagedTensor wrapper, and the patch should not introduce a new +// third-party dependency. These match the upstream DLPack header at +// https://github.com/dmlc/dlpack/blob/main/include/dlpack/dlpack.h +namespace dlpack_format { + +struct DLManagedTensor { + nb::dlpack::dltensor dl_tensor; + void* manager_ctx; + void (*deleter)(struct DLManagedTensor* self); +}; + +struct DLPackVersion { + uint32_t major; + uint32_t minor; +}; + +struct DLManagedTensorVersioned { + DLPackVersion version; + void* manager_ctx; + void (*deleter)(struct DLManagedTensorVersioned* self); + uint64_t flags; + nb::dlpack::dltensor dl_tensor; +}; + +constexpr int32_t kDLCPU = 1; +constexpr int32_t kDLCUDA = 2; +constexpr int32_t kDLMetal = 8; + +} // namespace dlpack_format diff --git a/python/tests/test_dlpack_consumer.py b/python/tests/test_dlpack_consumer.py new file mode 100644 index 0000000000..51ffc8578d --- /dev/null +++ b/python/tests/test_dlpack_consumer.py @@ -0,0 +1,164 @@ +# Copyright © 2026 Apple Inc. +"""Tests for the DLPack consumer path in ``mx.array``. + +These tests cover scenarios that don't depend on a third-party Metal producer: + +1. Round trip through NumPy's DLPack exporter (``kDLCPU`` path). +2. Self round trip via ``mx.array.__dlpack__`` (``kDLMetal`` on Metal hosts, + ``kDLCPU`` otherwise). +3. Negative paths: used capsule, unsupported strided view, etc. +""" + +from __future__ import annotations + +import unittest +import ctypes + +import numpy as np + +try: + import mlx.core as mx +except ImportError as exc: # pragma: no cover - import error is environment specific + raise unittest.SkipTest(f"mlx.core unavailable: {exc}") + + +class TestArrayDLPackBasic(unittest.TestCase): + def test_mx_array_accepts_dlpack_capsule(self): + # Pass a raw PyCapsule rather than the producer object. + arr_np = np.arange(8, dtype=np.int32).reshape(2, 4) + capsule = arr_np.__dlpack__() + arr_mx = mx.array(capsule) + self.assertEqual(tuple(arr_mx.shape), (2, 4)) + self.assertEqual(arr_mx.dtype, mx.int32) + self.assertTrue(np.array_equal(np.asarray(arr_mx), arr_np)) + + def test_mx_array_accepts_dlpack_producer(self): + class DLPackProducer: + def __init__(self, array): + self.array = array + + def __dlpack__(self): + return self.array.__dlpack__() + + def __dlpack_device__(self): + return self.array.__dlpack_device__() + + arr_np = np.arange(12, dtype=np.float32).reshape(3, 4) + arr_mx = mx.array(DLPackProducer(arr_np)) + self.assertEqual(tuple(arr_mx.shape), (3, 4)) + self.assertEqual(arr_mx.dtype, mx.float32) + self.assertTrue(np.allclose(np.asarray(arr_mx), arr_np)) + + def test_mx_array_accepts_mlx_dlpack_producer(self): + class DLPackProducer: + def __init__(self, array): + self.array = array + + def __dlpack__(self): + return self.array.__dlpack__() + + def __dlpack_device__(self): + return self.array.__dlpack_device__() + + x = mx.arange(20, dtype=mx.float32).reshape(4, 5) + y = mx.array(DLPackProducer(x)) + self.assertTrue(mx.array_equal(x, y).item()) + + def test_mx_array_dlpack_dtype_override(self): + arr_np = np.arange(6, dtype=np.int32).reshape(2, 3) + arr_mx = mx.array(arr_np.__dlpack__(), dtype=mx.float32) + self.assertEqual(arr_mx.dtype, mx.float32) + self.assertTrue(np.array_equal(np.asarray(arr_mx), arr_np.astype(np.float32))) + + def test_mx_array_prefers_mlx_array_protocol_over_dlpack(self): + class BothProtocols: + def __mlx_array__(self): + return mx.array([1, 2, 3], dtype=mx.int32) + + def __dlpack__(self): + raise AssertionError("__dlpack__ should not be called") + + arr_mx = mx.array(BothProtocols()) + self.assertEqual(arr_mx.dtype, mx.int32) + self.assertTrue(np.array_equal(np.asarray(arr_mx), np.array([1, 2, 3]))) + + def test_dtypes(self): + cases = [ + (np.bool_, mx.bool_), + (np.int8, mx.int8), + (np.int16, mx.int16), + (np.int32, mx.int32), + (np.int64, mx.int64), + (np.uint8, mx.uint8), + (np.uint16, mx.uint16), + (np.uint32, mx.uint32), + (np.uint64, mx.uint64), + (np.float16, mx.float16), + (np.float32, mx.float32), + (np.float64, mx.float64), + (np.complex64, mx.complex64), + ] + for np_dtype, mx_dtype in cases: + with self.subTest(np_dtype=np_dtype): + arr = np.zeros((2, 3), dtype=np_dtype) + if np_dtype is np.bool_: + arr[0, 0] = True + else: + arr[0, 0] = 1 + converted = mx.array(arr.__dlpack__()) + # `mx.array` applies the same dtype defaults to DLPack inputs + # as it does to NumPy inputs, e.g. float64 defaults to float32. + self.assertEqual(converted.dtype, mx.array(arr).dtype) + self.assertEqual(tuple(converted.shape), (2, 3)) + + +class TestArrayDLPackErrors(unittest.TestCase): + def test_rejects_used_capsule(self): + arr_np = np.arange(4, dtype=np.float32) + capsule = arr_np.__dlpack__() + # First call consumes; second must fail because the capsule was + # renamed to "used_dltensor". + _ = mx.array(capsule) + with self.assertRaises(Exception): + mx.array(capsule) + + +class TestArrayDLPackNonContiguous(unittest.TestCase): + def test_strided_view_rejected(self): + # MLX's first-cut consumer does not support arbitrary DLPack strides. + # NumPy emits __dlpack__ with explicit strides for slices; producers + # may or may not encode strides depending on contiguity. We assert + # that a non-row-contiguous slice is rejected with a clear error + # rather than silently misinterpreting the layout. + big = np.arange(16, dtype=np.float32).reshape(4, 4) + view = big[::2, :] + try: + capsule = view.__dlpack__() + except (TypeError, BufferError): + self.skipTest( + "NumPy refused to export a non-contiguous DLPack capsule" + ) + with self.assertRaises(Exception): + mx.array(capsule) + + def test_rejected_capsule_is_not_marked_used(self): + big = np.arange(16, dtype=np.float32).reshape(4, 4) + view = big[::2, :] + try: + capsule = view.__dlpack__() + except (TypeError, BufferError): + self.skipTest( + "NumPy refused to export a non-contiguous DLPack capsule" + ) + + with self.assertRaises(Exception): + mx.array(capsule) + + get_name = ctypes.pythonapi.PyCapsule_GetName + get_name.argtypes = [ctypes.py_object] + get_name.restype = ctypes.c_char_p + self.assertEqual(get_name(capsule), b"dltensor") + + +if __name__ == "__main__": + unittest.main()