From 5d0a185001b8ff48c2afe6b3153a0b508601b782 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sat, 9 May 2026 16:19:37 +0200 Subject: [PATCH] ENH: make __setitem__ raise on a device mismatch --- array_api_strict/_array_object.py | 4 ++++ array_api_strict/tests/test_array_object.py | 8 ++++++++ 2 files changed, 12 insertions(+) diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index 629af98..fab8811 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -960,6 +960,10 @@ def __setitem__( other = value if isinstance(value, (bool, int, float, complex)): other = self._promote_scalar(value) + else: + if value.device != self.device: + raise ValueError(f"mismatched devices: {self.device = } != {value.device =}.") + dt = _result_type(self.dtype, other.dtype) if dt != self.dtype: raise TypeError(f"mismatched dtypes: {self.dtype = } and {other.dtype = }") diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index f580585..46d8029 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -220,6 +220,14 @@ def test_setitem_invalid_promotions(): a[0] = asarray(3.5j, dtype=complex128) +def test_setitem_device_transfer(): + a = arange(3) + b = arange(4, 1, -1, device=Device('device1')) + + with pytest.raises(ValueError): + a[:] = b[:] + + def test_promoted_scalar_inherits_device(): device1 = Device("device1") x = asarray([1., 2, 3], device=device1)