From 7bc51a1c8f4e49fbaa7ac093bea2e93323f1010d Mon Sep 17 00:00:00 2001 From: Pradyot Ranjan <99216956+pradyotRanjan@users.noreply.github.com> Date: Wed, 13 May 2026 15:00:11 +0530 Subject: [PATCH 1/2] Fixing check_shape's else condition Signed-off-by: Pradyot Ranjan <99216956+pradyotRanjan@users.noreply.github.com> --- src/array_api_extra/_lib/_testing.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/array_api_extra/_lib/_testing.py b/src/array_api_extra/_lib/_testing.py index 30e2f1ef..32557a52 100644 --- a/src/array_api_extra/_lib/_testing.py +++ b/src/array_api_extra/_lib/_testing.py @@ -77,10 +77,11 @@ def _check_ns_shape_dtype( if check_shape: msg = f"shapes do not match: {actual_shape} != f{desired_shape}" assert actual_shape == desired_shape, msg - else: + elif desired.ndim > 0: # Ignore shape, but check flattened size. This is normally done by # np.testing.assert_array_equal etc even when strict=False, but not for # non-materializable arrays. + # This check excludes 0d arrays as they are special case in numpy. actual_size = math.prod(actual_shape) # pyright: ignore[reportUnknownArgumentType] desired_size = math.prod(desired_shape) # pyright: ignore[reportUnknownArgumentType] msg = f"sizes do not match: {actual_size} != f{desired_size}" From 28ffd470d9c100504101e602b95fb09c7ce2de60 Mon Sep 17 00:00:00 2001 From: Lucas Colley Date: Wed, 13 May 2026 10:55:20 +0100 Subject: [PATCH 2/2] nit --- src/array_api_extra/_lib/_testing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/array_api_extra/_lib/_testing.py b/src/array_api_extra/_lib/_testing.py index 32557a52..a5cd808f 100644 --- a/src/array_api_extra/_lib/_testing.py +++ b/src/array_api_extra/_lib/_testing.py @@ -81,7 +81,7 @@ def _check_ns_shape_dtype( # Ignore shape, but check flattened size. This is normally done by # np.testing.assert_array_equal etc even when strict=False, but not for # non-materializable arrays. - # This check excludes 0d arrays as they are special case in numpy. + # This check excludes 0d arrays as they are special-cased in NumPy. actual_size = math.prod(actual_shape) # pyright: ignore[reportUnknownArgumentType] desired_size = math.prod(desired_shape) # pyright: ignore[reportUnknownArgumentType] msg = f"sizes do not match: {actual_size} != f{desired_size}"