From 3cbca39934c4982596eb047c1a6baa1b15404b1e Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Tue, 5 May 2026 01:08:59 -0700 Subject: [PATCH] Fix indexing bug in slice update with op --- mlx/backend/metal/kernels/indexing/scatter.h | 4 +++- python/tests/test_array.py | 24 ++++++++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/mlx/backend/metal/kernels/indexing/scatter.h b/mlx/backend/metal/kernels/indexing/scatter.h index b7f304c9ad..3604f20f61 100644 --- a/mlx/backend/metal/kernels/indexing/scatter.h +++ b/mlx/backend/metal/kernels/indexing/scatter.h @@ -80,7 +80,9 @@ template < uint3 gsize [[threads_per_grid]]) { Op op; - IdxT idx = IdxT(gid.z) * gsize.y + gid.y * gsize.x + gid.x * NWORK; + IdxT idx = + (IdxT(gid.z) * IdxT(gsize.y) + IdxT(gid.y)) * IdxT(gsize.x) * NWORK + + IdxT(gid.x) * NWORK; IdxT out_idx; IdxT update_idx; diff --git a/python/tests/test_array.py b/python/tests/test_array.py index a35460c922..c133659c83 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -1547,6 +1547,30 @@ def test_array_at_slice_update_extensive(self): a = a.at[:, :, :].add(update) self.assertEqualArray(a, update) + def test_slice_update_contiguous_2d(self): + for shape in [(32, 32), (64, 64), (17, 33), (128, 128)]: + for upd_shape in [(16, 16), (8, 8), (4, 4)]: + if upd_shape[0] > shape[0] or upd_shape[1] > shape[1]: + continue + y = mx.zeros(shape) + x = mx.random.normal(upd_shape) + z = y.at[: upd_shape[0], : upd_shape[1]].add(x) + diff = z[: upd_shape[0], : upd_shape[1]] - x + self.assertTrue(mx.allclose(diff, mx.zeros_like(diff))) + + # Test non-zero offset slice update + y = mx.zeros((32, 32)) + x = mx.random.normal((16, 16)) + z = y.at[16:, 16:].add(x) + self.assertTrue(mx.allclose(z[16:, 16:] - x, mx.zeros_like(x))) + + # Test with size divisible by 4, 2, and odd + for cols in [32, 18, 15]: + y = mx.zeros((32, cols)) + x = mx.random.normal((16, cols)) + z = y.at[:16, :].add(x) + self.assertTrue(mx.allclose(z[:16, :] - x, mx.zeros_like(x))) + def test_slice_negative_step(self): a_np = np.arange(20) a_mx = mx.array(a_np)