From 33b0efe13fbfecef055274175338280ad51c364b Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 17 Nov 2025 12:08:33 +0100 Subject: [PATCH 01/30] ENH: test side=left,right in searchsorted --- array_api_tests/test_searching_functions.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/array_api_tests/test_searching_functions.py b/array_api_tests/test_searching_functions.py index 8df475d8..1283d9fa 100644 --- a/array_api_tests/test_searching_functions.py +++ b/array_api_tests/test_searching_functions.py @@ -243,7 +243,6 @@ def test_where(shapes, dtypes, data): @pytest.mark.min_version("2023.12") @given(data=st.data()) def test_searchsorted(data): - # TODO: test side="right" # TODO: Allow different dtypes for x1 and x2 _x1 = data.draw( st.lists(xps.from_dtype(dh.default_float), min_size=1, unique=True), @@ -262,10 +261,13 @@ def test_searchsorted(data): ), label="x2", ) + kw = data.draw(hh.kwargs(side=st.sampled_from(["left", "right"]))) - repro_snippet = ph.format_snippet(f"xp.searchsorted({x1!r}, {x2!r}, sorter={sorter!r})") + repro_snippet = ph.format_snippet( + f"xp.searchsorted({x1!r}, {x2!r}, sorter={sorter!r}, **kw) with {kw=}" + ) try: - out = xp.searchsorted(x1, x2, sorter=sorter) + out = xp.searchsorted(x1, x2, sorter=sorter, **kw) ph.assert_dtype( "searchsorted", @@ -273,7 +275,8 @@ def test_searchsorted(data): out_dtype=out.dtype, expected=xp.__array_namespace_info__().default_dtypes()["indexing"], ) - # TODO: shapes and values testing + # TODO: x2.ndim > 1, values testing + ph.assert_shape("searchsorted", out_shape=out.shape, expected=x2.shape) except Exception as exc: exc.add_note(repro_snippet) raise From 4bf7e34c30b2dd5738ff84338684aa961b49b267 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 17 Nov 2025 13:29:41 +0100 Subject: [PATCH 02/30] ENH: test searchsorted with x2.ndim > 1 --- array_api_tests/test_searching_functions.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/array_api_tests/test_searching_functions.py b/array_api_tests/test_searching_functions.py index 1283d9fa..af079591 100644 --- a/array_api_tests/test_searching_functions.py +++ b/array_api_tests/test_searching_functions.py @@ -255,12 +255,18 @@ def test_searchsorted(data): sorter = None x1 = xp.sort(x1) note(f"{x1=}") + x2 = data.draw( st.lists(st.sampled_from(_x1), unique=True, min_size=1).map( lambda o: xp.asarray(o, dtype=dh.default_float) ), label="x2", ) + # make x2.ndim > 1, if it makes sense + factors = hh._factorize(x2.shape[0]) + if len(factors) > 1: + x2 = xp.reshape(x2, tuple(factors)) + kw = data.draw(hh.kwargs(side=st.sampled_from(["left", "right"]))) repro_snippet = ph.format_snippet( @@ -275,7 +281,7 @@ def test_searchsorted(data): out_dtype=out.dtype, expected=xp.__array_namespace_info__().default_dtypes()["indexing"], ) - # TODO: x2.ndim > 1, values testing + # TODO: values testing ph.assert_shape("searchsorted", out_shape=out.shape, expected=x2.shape) except Exception as exc: exc.add_note(repro_snippet) From 78c969f9fd48cf4e35270c66dabaaa6d8ade6874 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 17 Nov 2025 14:04:04 +0100 Subject: [PATCH 03/30] ENH: searchsorted: draw x1.dtype --- array_api_tests/test_searching_functions.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/array_api_tests/test_searching_functions.py b/array_api_tests/test_searching_functions.py index af079591..61c6f436 100644 --- a/array_api_tests/test_searching_functions.py +++ b/array_api_tests/test_searching_functions.py @@ -244,11 +244,12 @@ def test_where(shapes, dtypes, data): @given(data=st.data()) def test_searchsorted(data): # TODO: Allow different dtypes for x1 and x2 + x1_dtype = data.draw(st.sampled_from(dh.real_dtypes)) _x1 = data.draw( - st.lists(xps.from_dtype(dh.default_float), min_size=1, unique=True), + st.lists(xps.from_dtype(x1_dtype), min_size=1, unique=True), label="_x1", ) - x1 = xp.asarray(_x1, dtype=dh.default_float) + x1 = xp.asarray(_x1, dtype=x1_dtype) if data.draw(st.booleans(), label="use sorter?"): sorter = xp.argsort(x1) else: @@ -258,7 +259,7 @@ def test_searchsorted(data): x2 = data.draw( st.lists(st.sampled_from(_x1), unique=True, min_size=1).map( - lambda o: xp.asarray(o, dtype=dh.default_float) + lambda o: xp.asarray(o, dtype=x1_dtype) ), label="x2", ) From 5bd524be003bdab3224e6827dc8f9716cc3444ac Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 17 Nov 2025 18:42:53 +0100 Subject: [PATCH 04/30] MAINT: searchsorted: restrict inputs to be finite real values --- array_api_tests/test_searching_functions.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/array_api_tests/test_searching_functions.py b/array_api_tests/test_searching_functions.py index 61c6f436..fdebd84b 100644 --- a/array_api_tests/test_searching_functions.py +++ b/array_api_tests/test_searching_functions.py @@ -246,7 +246,11 @@ def test_searchsorted(data): # TODO: Allow different dtypes for x1 and x2 x1_dtype = data.draw(st.sampled_from(dh.real_dtypes)) _x1 = data.draw( - st.lists(xps.from_dtype(x1_dtype), min_size=1, unique=True), + st.lists( + xps.from_dtype(x1_dtype, allow_nan=False, allow_infinity=False), + min_size=1, + unique=True + ), label="_x1", ) x1 = xp.asarray(_x1, dtype=x1_dtype) From 9cd8883d224da2ed24112cd5ad7c00fd2e034155 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sat, 22 Nov 2025 16:29:04 +0100 Subject: [PATCH 05/30] ENH: test take_along_axis with indices < 0 --- array_api_tests/test_indexing_functions.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/array_api_tests/test_indexing_functions.py b/array_api_tests/test_indexing_functions.py index 7b8c8763..6cea0a66 100644 --- a/array_api_tests/test_indexing_functions.py +++ b/array_api_tests/test_indexing_functions.py @@ -77,7 +77,6 @@ def test_take(x, data): ) def test_take_along_axis(x, data): # TODO - # 2. negative indices # 3. different dtypes for indices # 4. "broadcast-compatible" indices axis = data.draw( @@ -97,7 +96,7 @@ def test_take_along_axis(x, data): hh.arrays( shape=idx_shape, dtype=dh.default_int, - elements={"min_value": 0, "max_value": x.shape[n_axis]-1} + elements={"min_value": -x.shape[n_axis], "max_value": x.shape[n_axis]-1} ), label="indices" ) From 27845161d97f3c93e5c3e009370bae2f9f7d19a2 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sat, 22 Nov 2025 17:18:28 +0100 Subject: [PATCH 06/30] ENH: test take with negative indices --- array_api_tests/test_indexing_functions.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/array_api_tests/test_indexing_functions.py b/array_api_tests/test_indexing_functions.py index 6cea0a66..b3510e60 100644 --- a/array_api_tests/test_indexing_functions.py +++ b/array_api_tests/test_indexing_functions.py @@ -17,7 +17,6 @@ ) def test_take(x, data): # TODO: - # * negative indices # * different dtypes for indices # axis is optional but only if x.ndim == 1 @@ -28,7 +27,7 @@ def test_take(x, data): kw = {"axis": data.draw(_axis_st)} axis = kw.get("axis", 0) _indices = data.draw( - st.lists(st.integers(0, x.shape[axis] - 1), min_size=1, unique=True), + st.lists(st.integers(-x.shape[axis], x.shape[axis] - 1), min_size=1, unique=True), label="_indices", ) n_axis = axis if axis>=0 else x.ndim + axis From cf3bf263923e09cdc949cdf0abbba8011c9cc3de Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Fri, 21 Nov 2025 15:21:59 +0100 Subject: [PATCH 07/30] ENH: add "repro_snippets" to test_utility_functions.py --- array_api_tests/test_utility_functions.py | 122 +++++++++++++--------- 1 file changed, 72 insertions(+), 50 deletions(-) diff --git a/array_api_tests/test_utility_functions.py b/array_api_tests/test_utility_functions.py index b6e0a4fe..9d136dcb 100644 --- a/array_api_tests/test_utility_functions.py +++ b/array_api_tests/test_utility_functions.py @@ -18,23 +18,28 @@ def test_all(x, data): kw = data.draw(hh.kwargs(axis=hh.axes(x.ndim), keepdims=st.booleans()), label="kw") keepdims = kw.get("keepdims", False) - out = xp.all(x, **kw) - - ph.assert_dtype("all", in_dtype=x.dtype, out_dtype=out.dtype, expected=xp.bool) - _axes = sh.normalize_axis(kw.get("axis", None), x.ndim) - ph.assert_keepdimable_shape( - "all", in_shape=x.shape, out_shape=out.shape, axes=_axes, keepdims=keepdims, kw=kw - ) - scalar_type = dh.get_scalar_type(x.dtype) - for indices, out_idx in zip(sh.axes_ndindex(x.shape, _axes), sh.ndindex(out.shape)): - result = bool(out[out_idx]) - elements = [] - for idx in indices: - s = scalar_type(x[idx]) - elements.append(s) - expected = all(elements) - ph.assert_scalar_equals("all", type_=scalar_type, idx=out_idx, - out=result, expected=expected, kw=kw) + repro_snippet = ph.format_snippet(f"xp.all({x!r}, **kw) with {kw = }") + try: + out = xp.all(x, **kw) + + ph.assert_dtype("all", in_dtype=x.dtype, out_dtype=out.dtype, expected=xp.bool) + _axes = sh.normalize_axis(kw.get("axis", None), x.ndim) + ph.assert_keepdimable_shape( + "all", in_shape=x.shape, out_shape=out.shape, axes=_axes, keepdims=keepdims, kw=kw + ) + scalar_type = dh.get_scalar_type(x.dtype) + for indices, out_idx in zip(sh.axes_ndindex(x.shape, _axes), sh.ndindex(out.shape)): + result = bool(out[out_idx]) + elements = [] + for idx in indices: + s = scalar_type(x[idx]) + elements.append(s) + expected = all(elements) + ph.assert_scalar_equals("all", type_=scalar_type, idx=out_idx, + out=result, expected=expected, kw=kw) + except Exception as exc: + exc.add_note(repro_snippet) + raise @pytest.mark.unvectorized @@ -46,23 +51,28 @@ def test_any(x, data): kw = data.draw(hh.kwargs(axis=hh.axes(x.ndim), keepdims=st.booleans()), label="kw") keepdims = kw.get("keepdims", False) - out = xp.any(x, **kw) - - ph.assert_dtype("any", in_dtype=x.dtype, out_dtype=out.dtype, expected=xp.bool) - _axes = sh.normalize_axis(kw.get("axis", None), x.ndim) - ph.assert_keepdimable_shape( - "any", in_shape=x.shape, out_shape=out.shape, axes=_axes, keepdims=keepdims, kw=kw, - ) - scalar_type = dh.get_scalar_type(x.dtype) - for indices, out_idx in zip(sh.axes_ndindex(x.shape, _axes), sh.ndindex(out.shape)): - result = bool(out[out_idx]) - elements = [] - for idx in indices: - s = scalar_type(x[idx]) - elements.append(s) - expected = any(elements) - ph.assert_scalar_equals("any", type_=scalar_type, idx=out_idx, - out=result, expected=expected, kw=kw) + repro_snippet = ph.format_snippet(f"xp.any({x!r}, **kw) with {kw = }") + try: + out = xp.any(x, **kw) + + ph.assert_dtype("any", in_dtype=x.dtype, out_dtype=out.dtype, expected=xp.bool) + _axes = sh.normalize_axis(kw.get("axis", None), x.ndim) + ph.assert_keepdimable_shape( + "any", in_shape=x.shape, out_shape=out.shape, axes=_axes, keepdims=keepdims, kw=kw, + ) + scalar_type = dh.get_scalar_type(x.dtype) + for indices, out_idx in zip(sh.axes_ndindex(x.shape, _axes), sh.ndindex(out.shape)): + result = bool(out[out_idx]) + elements = [] + for idx in indices: + s = scalar_type(x[idx]) + elements.append(s) + expected = any(elements) + ph.assert_scalar_equals("any", type_=scalar_type, idx=out_idx, + out=result, expected=expected, kw=kw) + except Exception as exc: + exc.add_note(repro_snippet) + raise @pytest.mark.unvectorized @@ -85,19 +95,24 @@ def test_diff(x, data): n = data.draw(st.integers(1, min(x.shape[n_axis], 3))) - out = xp.diff(x, **axis_kw, n=n) + repro_snippet = ph.format_snippet(f"xp.diff({x!r}, **axis_kw, n={n!r}) with {axis_kw = }") + try: + out = xp.diff(x, **axis_kw, n=n) - expected_shape = list(x.shape) - expected_shape[n_axis] -= n + expected_shape = list(x.shape) + expected_shape[n_axis] -= n - assert out.shape == tuple(expected_shape) + assert out.shape == tuple(expected_shape) - # value test - if n == 1: - for idx in sh.ndindex(out.shape): - l = list(idx) - l[n_axis] += 1 - assert out[idx] == x[tuple(l)] - x[idx], f"diff failed with {idx = }" + # value test + if n == 1: + for idx in sh.ndindex(out.shape): + l = list(idx) + l[n_axis] += 1 + assert out[idx] == x[tuple(l)] - x[idx], f"diff failed with {idx = }" + except Exception as exc: + exc.add_note(repro_snippet) + raise @pytest.mark.min_version("2024.12") @@ -130,12 +145,19 @@ def test_diff_append_prepend(x, data): prepend_shape[n_axis] = prepend_axis_len prepend = data.draw(hh.arrays(dtype=x.dtype, shape=tuple(prepend_shape)), label="prepend") - out = xp.diff(x, **axis_kw, n=n, append=append, prepend=prepend) + repro_snippet = ph.format_snippet( + f"xp.diff({x!r}, **axis_kw, n={n!r}, append={append!r}, prepend={prepend!r}) with {axis_kw = }" + ) + try: + out = xp.diff(x, **axis_kw, n=n, append=append, prepend=prepend) - in_1 = xp.concat((prepend, x, append), **axis_kw) - out_1 = xp.diff(in_1, **axis_kw, n=n) + in_1 = xp.concat((prepend, x, append), **axis_kw) + out_1 = xp.diff(in_1, **axis_kw, n=n) - assert out.shape == out_1.shape - for idx in sh.ndindex(out.shape): - assert out[idx] == out_1[idx], f"{idx = }" + assert out.shape == out_1.shape + for idx in sh.ndindex(out.shape): + assert out[idx] == out_1[idx], f"{idx = }" + except Exception as exc: + exc.add_note(repro_snippet) + raise From 50895883c162cc8678df299441640b3faa6418ba Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Fri, 21 Nov 2025 15:22:19 +0100 Subject: [PATCH 08/30] ENH: add "repro_snippets" to test_indexing_functions.py --- array_api_tests/test_indexing_functions.py | 117 +++++++++++---------- 1 file changed, 64 insertions(+), 53 deletions(-) diff --git a/array_api_tests/test_indexing_functions.py b/array_api_tests/test_indexing_functions.py index 7b8c8763..4996932d 100644 --- a/array_api_tests/test_indexing_functions.py +++ b/array_api_tests/test_indexing_functions.py @@ -35,39 +35,43 @@ def test_take(x, data): indices = xp.asarray(_indices, dtype=dh.default_int) note(f"{indices=}") - out = xp.take(x, indices, **kw) - - ph.assert_dtype("take", in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_shape( - "take", - out_shape=out.shape, - expected=x.shape[:n_axis] + (len(_indices),) + x.shape[n_axis + 1:], - kw=dict( - x=x, - indices=indices, - axis=axis, - ), - ) - out_indices = sh.ndindex(out.shape) - axis_indices = list(sh.axis_ndindex(x.shape, n_axis)) - for axis_idx in axis_indices: - f_axis_idx = sh.fmt_idx("x", axis_idx) - for i in _indices: - f_take_idx = sh.fmt_idx(f_axis_idx, i) - indexed_x = x[axis_idx][i, ...] - for at_idx in sh.ndindex(indexed_x.shape): - out_idx = next(out_indices) - ph.assert_0d_equals( - "take", - x_repr=sh.fmt_idx(f_take_idx, at_idx), - x_val=indexed_x[at_idx], - out_repr=sh.fmt_idx("out", out_idx), - out_val=out[out_idx], - ) - # sanity check - with pytest.raises(StopIteration): - next(out_indices) + repro_snippet = ph.format_snippet(f"xp.take({x!r}, {indices!r}, **kw) with {kw = }") + try: + out = xp.take(x, indices, **kw) + ph.assert_dtype("take", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape( + "take", + out_shape=out.shape, + expected=x.shape[:n_axis] + (len(_indices),) + x.shape[n_axis + 1:], + kw=dict( + x=x, + indices=indices, + axis=axis, + ), + ) + out_indices = sh.ndindex(out.shape) + axis_indices = list(sh.axis_ndindex(x.shape, n_axis)) + for axis_idx in axis_indices: + f_axis_idx = sh.fmt_idx("x", axis_idx) + for i in _indices: + f_take_idx = sh.fmt_idx(f_axis_idx, i) + indexed_x = x[axis_idx][i, ...] + for at_idx in sh.ndindex(indexed_x.shape): + out_idx = next(out_indices) + ph.assert_0d_equals( + "take", + x_repr=sh.fmt_idx(f_take_idx, at_idx), + x_val=indexed_x[at_idx], + out_repr=sh.fmt_idx("out", out_idx), + out_val=out[out_idx], + ) + # sanity check + with pytest.raises(StopIteration): + next(out_indices) + except Exception as exc: + exc.add_note(repro_snippet) + raise @pytest.mark.unvectorized @pytest.mark.min_version("2024.12") @@ -103,26 +107,33 @@ def test_take_along_axis(x, data): ) note(f"{indices=} {idx_shape=}") - out = xp.take_along_axis(x, indices, **axis_kw) - - ph.assert_dtype("take_along_axis", in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_shape( - "take_along_axis", - out_shape=out.shape, - expected=x.shape[:n_axis] + (new_len,) + x.shape[n_axis+1:], - kw=dict( - x=x, - indices=indices, - axis=axis, - ), + repro_snippet = ph.format_snippet( + f"xp.take_along_axis({x!r}, {indices!r}, **axis_kw) with {axis_kw = }" ) + try: + out = xp.take_along_axis(x, indices, **axis_kw) + + ph.assert_dtype("take_along_axis", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape( + "take_along_axis", + out_shape=out.shape, + expected=x.shape[:n_axis] + (new_len,) + x.shape[n_axis+1:], + kw=dict( + x=x, + indices=indices, + axis=axis, + ), + ) - # value test: notation is from `np.take_along_axis` docstring - Ni, Nk = x.shape[:n_axis], x.shape[n_axis+1:] - for ii in sh.ndindex(Ni): - for kk in sh.ndindex(Nk): - a_1d = x[ii + (slice(None),) + kk] - i_1d = indices[ii + (slice(None),) + kk] - o_1d = out[ii + (slice(None),) + kk] - for j in range(new_len): - assert o_1d[j] == a_1d[i_1d[j]], f'{ii=}, {kk=}, {j=}' + # value test: notation is from `np.take_along_axis` docstring + Ni, Nk = x.shape[:n_axis], x.shape[n_axis+1:] + for ii in sh.ndindex(Ni): + for kk in sh.ndindex(Nk): + a_1d = x[ii + (slice(None),) + kk] + i_1d = indices[ii + (slice(None),) + kk] + o_1d = out[ii + (slice(None),) + kk] + for j in range(new_len): + assert o_1d[j] == a_1d[i_1d[j]], f'{ii=}, {kk=}, {j=}' + except Exception as exc: + exc.add_note(repro_snippet) + raise From c9f67727518417f039f879660dd0f5a3866994b9 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Fri, 21 Nov 2025 15:22:34 +0100 Subject: [PATCH 09/30] ENH: add "repro_snippets" to test_array_object.py --- array_api_tests/test_array_object.py | 176 +++++++++++++++------------ 1 file changed, 99 insertions(+), 77 deletions(-) diff --git a/array_api_tests/test_array_object.py b/array_api_tests/test_array_object.py index af575182..ba34716b 100644 --- a/array_api_tests/test_array_object.py +++ b/array_api_tests/test_array_object.py @@ -86,7 +86,6 @@ def test_getitem(shape, dtype, data): key = data.draw(xps.indices(shape=shape, allow_newaxis=True), label="key") repro_snippet = ph.format_snippet(f"{x!r}[{key!r}]") - try: out = x[key] @@ -109,6 +108,7 @@ def test_getitem(shape, dtype, data): ph.add_note(exc, repro_snippet) raise + @pytest.mark.unvectorized @given( shape=hh.shapes(), @@ -133,28 +133,34 @@ def test_setitem(shape, dtypes, data): value = data.draw(value_strat, label="value") res = xp.asarray(x, copy=True) - res[key] = value - - ph.assert_dtype("__setitem__", in_dtype=x.dtype, out_dtype=res.dtype, repr_name="x.dtype") - ph.assert_shape("__setitem__", out_shape=res.shape, expected=x.shape, repr_name="x.shape") - f_res = sh.fmt_idx("x", key) - if isinstance(value, get_args(Scalar)): - msg = f"{f_res}={res[key]!r}, but should be {value=} [__setitem__()]" - if cmath.isnan(value): - assert xp.isnan(res[key]), msg + + repro_snippet = ph.format_snippet(f"{res!r}[{key!r}] = {value!r}") + try: + res[key] = value + + ph.assert_dtype("__setitem__", in_dtype=x.dtype, out_dtype=res.dtype, repr_name="x.dtype") + ph.assert_shape("__setitem__", out_shape=res.shape, expected=x.shape, repr_name="x.shape") + f_res = sh.fmt_idx("x", key) + if isinstance(value, get_args(Scalar)): + msg = f"{f_res}={res[key]!r}, but should be {value=} [__setitem__()]" + if cmath.isnan(value): + assert xp.isnan(res[key]), msg + else: + assert res[key] == value, msg else: - assert res[key] == value, msg - else: - ph.assert_array_elements("__setitem__", out=res[key], expected=value, out_repr=f_res) - unaffected_indices = set(sh.ndindex(res.shape)) - set(product(*axes_indices)) - for idx in unaffected_indices: - ph.assert_0d_equals( - "__setitem__", - x_repr=f"old {f_res}", - x_val=x[idx], - out_repr=f"modified {f_res}", - out_val=res[idx], - ) + ph.assert_array_elements("__setitem__", out=res[key], expected=value, out_repr=f_res) + unaffected_indices = set(sh.ndindex(res.shape)) - set(product(*axes_indices)) + for idx in unaffected_indices: + ph.assert_0d_equals( + "__setitem__", + x_repr=f"old {f_res}", + x_val=x[idx], + out_repr=f"modified {f_res}", + out_val=res[idx], + ) + except Exception as exc: + exc.add_note(repro_snippet) + raise @pytest.mark.unvectorized @@ -178,29 +184,34 @@ def test_getitem_masking(shape, data): x[key] return - out = x[key] + repro_snippet = ph.format_snippet(f"out = {x!r}[{key!r}]") + try: + out = x[key] - ph.assert_dtype("__getitem__", in_dtype=x.dtype, out_dtype=out.dtype) - if key.ndim == 0: - expected_shape = (1,) if key else (0,) - expected_shape += x.shape - else: - size = int(xp.sum(xp.astype(key, xp.uint8))) - expected_shape = (size,) + x.shape[key.ndim :] - ph.assert_shape("__getitem__", out_shape=out.shape, expected=expected_shape) - if not any(s == 0 for s in key.shape): - assume(key.ndim == x.ndim) # TODO: test key.ndim < x.ndim scenarios - out_indices = sh.ndindex(out.shape) - for x_idx in sh.ndindex(x.shape): - if key[x_idx]: - out_idx = next(out_indices) - ph.assert_0d_equals( - "__getitem__", - x_repr=f"x[{x_idx}]", - x_val=x[x_idx], - out_repr=f"out[{out_idx}]", - out_val=out[out_idx], - ) + ph.assert_dtype("__getitem__", in_dtype=x.dtype, out_dtype=out.dtype) + if key.ndim == 0: + expected_shape = (1,) if key else (0,) + expected_shape += x.shape + else: + size = int(xp.sum(xp.astype(key, xp.uint8))) + expected_shape = (size,) + x.shape[key.ndim :] + ph.assert_shape("__getitem__", out_shape=out.shape, expected=expected_shape) + if not any(s == 0 for s in key.shape): + assume(key.ndim == x.ndim) # TODO: test key.ndim < x.ndim scenarios + out_indices = sh.ndindex(out.shape) + for x_idx in sh.ndindex(x.shape): + if key[x_idx]: + out_idx = next(out_indices) + ph.assert_0d_equals( + "__getitem__", + x_repr=f"x[{x_idx}]", + x_val=x[x_idx], + out_repr=f"out[{out_idx}]", + out_val=out[out_idx], + ) + except Exception as exc: + exc.add_note(repro_snippet) + raise @pytest.mark.unvectorized @@ -213,38 +224,44 @@ def test_setitem_masking(shape, data): ) res = xp.asarray(x, copy=True) - res[key] = value - - ph.assert_dtype("__setitem__", in_dtype=x.dtype, out_dtype=res.dtype, repr_name="x.dtype") - ph.assert_shape("__setitem__", out_shape=res.shape, expected=x.shape, repr_name="x.dtype") - scalar_type = dh.get_scalar_type(x.dtype) - for idx in sh.ndindex(x.shape): - if key[idx]: - if isinstance(value, get_args(Scalar)): - ph.assert_scalar_equals( - "__setitem__", - type_=scalar_type, - idx=idx, - out=scalar_type(res[idx]), - expected=value, - repr_name="modified x", - ) + + repro_snippet = ph.format_snippet(f"{res}[{key!r}] = {value!r}") + try: + res[key] = value + + ph.assert_dtype("__setitem__", in_dtype=x.dtype, out_dtype=res.dtype, repr_name="x.dtype") + ph.assert_shape("__setitem__", out_shape=res.shape, expected=x.shape, repr_name="x.dtype") + scalar_type = dh.get_scalar_type(x.dtype) + for idx in sh.ndindex(x.shape): + if key[idx]: + if isinstance(value, get_args(Scalar)): + ph.assert_scalar_equals( + "__setitem__", + type_=scalar_type, + idx=idx, + out=scalar_type(res[idx]), + expected=value, + repr_name="modified x", + ) + else: + ph.assert_0d_equals( + "__setitem__", + x_repr="value", + x_val=value, + out_repr=f"modified x[{idx}]", + out_val=res[idx] + ) else: ph.assert_0d_equals( "__setitem__", - x_repr="value", - x_val=value, + x_repr=f"old x[{idx}]", + x_val=x[idx], out_repr=f"modified x[{idx}]", out_val=res[idx] ) - else: - ph.assert_0d_equals( - "__setitem__", - x_repr=f"old x[{idx}]", - x_val=x[idx], - out_repr=f"modified x[{idx}]", - out_val=res[idx] - ) + except Exception as exc: + exc.add_note(repro_snippet) + raise # ### Fancy indexing ### @@ -309,15 +326,20 @@ def _test_getitem_arrays_and_ints(shape, data, idx_max_dims): key.append(data.draw(st.integers(-shape[i], shape[i]-1))) key = tuple(key) - out = x[key] + repro_snippet = ph.format_snippet(f"out = {x!r}[{key!r}]") + try: + out = x[key] - arrays = [xp.asarray(k) for k in key] - bcast_shape = sh.broadcast_shapes(*[arr.shape for arr in arrays]) - bcast_key = [xp.broadcast_to(arr, bcast_shape) for arr in arrays] + arrays = [xp.asarray(k) for k in key] + bcast_shape = sh.broadcast_shapes(*[arr.shape for arr in arrays]) + bcast_key = [xp.broadcast_to(arr, bcast_shape) for arr in arrays] - for idx in sh.ndindex(bcast_shape): - tpl = tuple(k[idx] for k in bcast_key) - assert out[idx] == x[tpl], f"failing at {idx = } w/ {key = }" + for idx in sh.ndindex(bcast_shape): + tpl = tuple(k[idx] for k in bcast_key) + assert out[idx] == x[tpl], f"failing at {idx = } w/ {key = }" + except Exception as exc: + exc.add_note(repro_snippet) + raise def make_scalar_casting_param( From 0c73a6f1cec8b7585bc325ae244e3ea843b35471 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Fri, 21 Nov 2025 15:23:37 +0100 Subject: [PATCH 10/30] Update .git-blame-ignore-revs for whitespace heavy commits --- .git-blame-ignore-revs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs index 4541bd36..eaeedfd4 100644 --- a/.git-blame-ignore-revs +++ b/.git-blame-ignore-revs @@ -7,3 +7,8 @@ e807ffe526c7330691e8f39d31347dc2b3106de3 bd42e84d2e5aae26ade8d70384e74effd1de89cb f7e822883b7e24b5aa540e2413759a85128b42ef a37f348ba27b6818e92fda8aee2406c653c671ea +# gh-396 +ec5a3b4e185c262b0a5f5b1631b84a09f766d80e +9058908b58ce627467ac34e768098a25f5863d31 +c80e1823c2e738381ca02f27cea1e2b89dde0ac5 + From 5acac26b36ca4215a9a857a318fbec4ebf86f73f Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 1 Dec 2025 16:23:31 +0100 Subject: [PATCH 11/30] MAINT: python 3.10 compatible add_note cf https://github.com/data-apis/array-api-tests/pull/398 --- array_api_tests/test_array_object.py | 8 ++++---- array_api_tests/test_indexing_functions.py | 4 ++-- array_api_tests/test_utility_functions.py | 8 ++++---- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/array_api_tests/test_array_object.py b/array_api_tests/test_array_object.py index ba34716b..8337cd86 100644 --- a/array_api_tests/test_array_object.py +++ b/array_api_tests/test_array_object.py @@ -159,7 +159,7 @@ def test_setitem(shape, dtypes, data): out_val=res[idx], ) except Exception as exc: - exc.add_note(repro_snippet) + ph.add_note(exc, repro_snippet) raise @@ -210,7 +210,7 @@ def test_getitem_masking(shape, data): out_val=out[out_idx], ) except Exception as exc: - exc.add_note(repro_snippet) + ph.add_note(exc, repro_snippet) raise @@ -260,7 +260,7 @@ def test_setitem_masking(shape, data): out_val=res[idx] ) except Exception as exc: - exc.add_note(repro_snippet) + ph.add_note(exc, repro_snippet) raise @@ -338,7 +338,7 @@ def _test_getitem_arrays_and_ints(shape, data, idx_max_dims): tpl = tuple(k[idx] for k in bcast_key) assert out[idx] == x[tpl], f"failing at {idx = } w/ {key = }" except Exception as exc: - exc.add_note(repro_snippet) + ph.add_note(exc, repro_snippet) raise diff --git a/array_api_tests/test_indexing_functions.py b/array_api_tests/test_indexing_functions.py index 4996932d..64e4261f 100644 --- a/array_api_tests/test_indexing_functions.py +++ b/array_api_tests/test_indexing_functions.py @@ -70,7 +70,7 @@ def test_take(x, data): with pytest.raises(StopIteration): next(out_indices) except Exception as exc: - exc.add_note(repro_snippet) + ph.add_note(exc, repro_snippet) raise @pytest.mark.unvectorized @@ -135,5 +135,5 @@ def test_take_along_axis(x, data): for j in range(new_len): assert o_1d[j] == a_1d[i_1d[j]], f'{ii=}, {kk=}, {j=}' except Exception as exc: - exc.add_note(repro_snippet) + ph.add_note(exc, repro_snippet) raise diff --git a/array_api_tests/test_utility_functions.py b/array_api_tests/test_utility_functions.py index 9d136dcb..eefee250 100644 --- a/array_api_tests/test_utility_functions.py +++ b/array_api_tests/test_utility_functions.py @@ -38,7 +38,7 @@ def test_all(x, data): ph.assert_scalar_equals("all", type_=scalar_type, idx=out_idx, out=result, expected=expected, kw=kw) except Exception as exc: - exc.add_note(repro_snippet) + ph.add_note(exc, repro_snippet) raise @@ -71,7 +71,7 @@ def test_any(x, data): ph.assert_scalar_equals("any", type_=scalar_type, idx=out_idx, out=result, expected=expected, kw=kw) except Exception as exc: - exc.add_note(repro_snippet) + ph.add_note(exc, repro_snippet) raise @@ -111,7 +111,7 @@ def test_diff(x, data): l[n_axis] += 1 assert out[idx] == x[tuple(l)] - x[idx], f"diff failed with {idx = }" except Exception as exc: - exc.add_note(repro_snippet) + ph.add_note(exc, repro_snippet) raise @@ -158,6 +158,6 @@ def test_diff_append_prepend(x, data): for idx in sh.ndindex(out.shape): assert out[idx] == out_1[idx], f"{idx = }" except Exception as exc: - exc.add_note(repro_snippet) + ph.add_note(exc, repro_snippet) raise From 538c24b3463b59b3d4eedcc34a84d38c1a239052 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Tue, 2 Dec 2025 17:48:50 +0100 Subject: [PATCH 12/30] CI: bump versions of GH actions on CI --- .github/workflows/lint.yml | 4 ++-- .github/workflows/test.yml | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 9b49c09b..15bd4d0f 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -7,9 +7,9 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: "3.10" - name: Run pre-commit hook diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 7f406d81..48ef897f 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -11,11 +11,11 @@ jobs: python-version: ["3.10", "3.12", "3.13", "3.14"] steps: - name: Checkout array-api-tests - uses: actions/checkout@v1 + uses: actions/checkout@v6 with: submodules: 'true' - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v1 + uses: actions/setup-python@v6 with: python-version: ${{ matrix.python-version }} - name: Install dependencies From bdc84e8316046cb5bdc637067460057eef17d0f1 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Fri, 9 Jan 2026 15:09:07 +0100 Subject: [PATCH 13/30] add "repro snippets" to test_operators_and_elementwise_functions.py --- ...est_operators_and_elementwise_functions.py | 1605 ++++++++++------- 1 file changed, 970 insertions(+), 635 deletions(-) diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index 84bcaa28..f074bd8e 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -737,48 +737,63 @@ def test_abs(ctx, data): if x.dtype in dh.int_dtypes: assume(xp.all(x > dh.dtype_ranges[x.dtype].min)) - out = ctx.func(x) + repro_snippet = ph.format_snippet(f"{ctx.func_name}({x!r})") + try: + out = ctx.func(x) - if x.dtype in dh.complex_dtypes: - assert out.dtype == dh.dtype_components[x.dtype] - else: - ph.assert_dtype(ctx.func_name, in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_shape(ctx.func_name, out_shape=out.shape, expected=x.shape) - unary_assert_against_refimpl( - ctx.func_name, - x, - out, - abs, # type: ignore - res_stype=float if x.dtype in dh.complex_dtypes else None, - expr_template="abs({})={}", - # filter_=lambda s: ( - # s == float("infinity") or (cmath.isfinite(s) and not ph.is_neg_zero(s)) - # ), - ) + if x.dtype in dh.complex_dtypes: + assert out.dtype == dh.dtype_components[x.dtype] + else: + ph.assert_dtype(ctx.func_name, in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape(ctx.func_name, out_shape=out.shape, expected=x.shape) + unary_assert_against_refimpl( + ctx.func_name, + x, + out, + abs, # type: ignore + res_stype=float if x.dtype in dh.complex_dtypes else None, + expr_template="abs({})={}", + # filter_=lambda s: ( + # s == float("infinity") or (cmath.isfinite(s) and not ph.is_neg_zero(s)) + # ), + ) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_acos(x): - out = xp.acos(x) - ph.assert_dtype("acos", in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_shape("acos", out_shape=out.shape, expected=x.shape) - refimpl = cmath.acos if x.dtype in dh.complex_dtypes else math.acos - filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and -1 <= s <= 1 - unary_assert_against_refimpl( - "acos", x, out, refimpl, filter_=filter_ - ) + repro_snippet = ph.format_snippet(f"xp.acos({x!r})") + try: + out = xp.acos(x) + ph.assert_dtype("acos", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("acos", out_shape=out.shape, expected=x.shape) + refimpl = cmath.acos if x.dtype in dh.complex_dtypes else math.acos + filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and -1 <= s <= 1 + unary_assert_against_refimpl( + "acos", x, out, refimpl, filter_=filter_ + ) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_acosh(x): - out = xp.acosh(x) - ph.assert_dtype("acosh", in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_shape("acosh", out_shape=out.shape, expected=x.shape) - refimpl = cmath.acosh if x.dtype in dh.complex_dtypes else math.acosh - filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and s >= 1 - unary_assert_against_refimpl( - "acosh", x, out, refimpl, filter_=filter_ - ) + repro_snippet = ph.format_snippet(f"xp.acosh({x!r})") + try: + out = xp.acosh(x) + ph.assert_dtype("acosh", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("acosh", out_shape=out.shape, expected=x.shape) + refimpl = cmath.acosh if x.dtype in dh.complex_dtypes else math.acosh + filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and s >= 1 + unary_assert_against_refimpl( + "acosh", x, out, refimpl, filter_=filter_ + ) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @pytest.mark.parametrize("ctx,", make_binary_params("add", dh.numeric_dtypes)) @@ -787,71 +802,101 @@ def test_add(ctx, data): left = data.draw(ctx.left_strat, label=ctx.left_sym) right = data.draw(ctx.right_strat, label=ctx.right_sym) - with hh.reject_overflow(): - res = ctx.func(left, right) + repro_snippet = ph.format_snippet(f"{ctx.func_name}({left!r}, {right!r})") + try: + with hh.reject_overflow(): + res = ctx.func(left, right) - binary_param_assert_dtype(ctx, left, right, res) - binary_param_assert_shape(ctx, left, right, res) - binary_param_assert_against_refimpl(ctx, left, right, res, "+", operator.add) + binary_param_assert_dtype(ctx, left, right, res) + binary_param_assert_shape(ctx, left, right, res) + binary_param_assert_against_refimpl(ctx, left, right, res, "+", operator.add) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_asin(x): - out = xp.asin(x) - ph.assert_dtype("asin", in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_shape("asin", out_shape=out.shape, expected=x.shape) - refimpl = cmath.asin if x.dtype in dh.complex_dtypes else math.asin - filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and -1 <= s <= 1 - unary_assert_against_refimpl( - "asin", x, out, refimpl, filter_=filter_ - ) + repro_snippet = ph.format_snippet(f"xp.asin({x!r})") + try: + out = xp.asin(x) + ph.assert_dtype("asin", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("asin", out_shape=out.shape, expected=x.shape) + refimpl = cmath.asin if x.dtype in dh.complex_dtypes else math.asin + filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and -1 <= s <= 1 + unary_assert_against_refimpl( + "asin", x, out, refimpl, filter_=filter_ + ) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_asinh(x): - out = xp.asinh(x) - ph.assert_dtype("asinh", in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_shape("asinh", out_shape=out.shape, expected=x.shape) - refimpl = cmath.asinh if x.dtype in dh.complex_dtypes else math.asinh - unary_assert_against_refimpl("asinh", x, out, refimpl) + repro_snippet = ph.format_snippet(f"xp.asinh({x!r})") + try: + out = xp.asinh(x) + ph.assert_dtype("asinh", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("asinh", out_shape=out.shape, expected=x.shape) + refimpl = cmath.asinh if x.dtype in dh.complex_dtypes else math.asinh + unary_assert_against_refimpl("asinh", x, out, refimpl) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_atan(x): - out = xp.atan(x) - ph.assert_dtype("atan", in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_shape("atan", out_shape=out.shape, expected=x.shape) - refimpl = cmath.atan if x.dtype in dh.complex_dtypes else math.atan - unary_assert_against_refimpl("atan", x, out, refimpl) + repro_snippet = ph.format_snippet(f"xp.atan({x!r})") + try: + out = xp.atan(x) + ph.assert_dtype("atan", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("atan", out_shape=out.shape, expected=x.shape) + refimpl = cmath.atan if x.dtype in dh.complex_dtypes else math.atan + unary_assert_against_refimpl("atan", x, out, refimpl) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @given(*hh.two_mutual_arrays(dh.real_float_dtypes)) def test_atan2(x1, x2): - out = xp.atan2(x1, x2) - _assert_correctness_binary( - "atan", - cmath.atan2 if x1.dtype in dh.complex_dtypes else math.atan2, - in_dtypes=[x1.dtype, x2.dtype], - in_shapes=[x1.shape, x2.shape], - in_arrs=[x1, x2], - out=out, - ) + repro_snippet = ph.format_snippet(f"xp.atan2({x1!r}, {x2!r})") + try: + out = xp.atan2(x1, x2) + _assert_correctness_binary( + "atan", + cmath.atan2 if x1.dtype in dh.complex_dtypes else math.atan2, + in_dtypes=[x1.dtype, x2.dtype], + in_shapes=[x1.shape, x2.shape], + in_arrs=[x1, x2], + out=out, + ) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_atanh(x): - out = xp.atanh(x) - ph.assert_dtype("atanh", in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_shape("atanh", out_shape=out.shape, expected=x.shape) - refimpl = cmath.atanh if x.dtype in dh.complex_dtypes else math.atanh - filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and -1 < s < 1 - unary_assert_against_refimpl( - "atanh", - x, - out, - refimpl, - filter_=filter_, - ) + repro_snippet = ph.format_snippet(f"xp.atanh({x!r})") + try: + out = xp.atanh(x) + ph.assert_dtype("atanh", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("atanh", out_shape=out.shape, expected=x.shape) + refimpl = cmath.atanh if x.dtype in dh.complex_dtypes else math.atanh + filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and -1 < s < 1 + unary_assert_against_refimpl( + "atanh", + x, + out, + refimpl, + filter_=filter_, + ) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @pytest.mark.parametrize( @@ -862,15 +907,20 @@ def test_bitwise_and(ctx, data): left = data.draw(ctx.left_strat, label=ctx.left_sym) right = data.draw(ctx.right_strat, label=ctx.right_sym) - res = ctx.func(left, right) + repro_snippet = ph.format_snippet(f"{ctx.func_name}({left!r}, {right!r})") + try: + res = ctx.func(left, right) - binary_param_assert_dtype(ctx, left, right, res) - binary_param_assert_shape(ctx, left, right, res) - if left.dtype == xp.bool: - refimpl = operator.and_ - else: - refimpl = lambda l, r: mock_int_dtype(l & r, res.dtype) - binary_param_assert_against_refimpl(ctx, left, right, res, "&", refimpl) + binary_param_assert_dtype(ctx, left, right, res) + binary_param_assert_shape(ctx, left, right, res) + if left.dtype == xp.bool: + refimpl = operator.and_ + else: + refimpl = lambda l, r: mock_int_dtype(l & r, res.dtype) + binary_param_assert_against_refimpl(ctx, left, right, res, "&", refimpl) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @pytest.mark.parametrize( @@ -885,14 +935,19 @@ def test_bitwise_left_shift(ctx, data): else: assume(not xp.any(ah.isnegative(right))) - res = ctx.func(left, right) + repro_snippet = ph.format_snippet(f"{ctx.func_name}({left!r}, {right!r})") + try: + res = ctx.func(left, right) - binary_param_assert_dtype(ctx, left, right, res) - binary_param_assert_shape(ctx, left, right, res) - nbits = dh.dtype_nbits[res.dtype] - binary_param_assert_against_refimpl( - ctx, left, right, res, "<<", lambda l, r: l << r if r < nbits else 0 - ) + binary_param_assert_dtype(ctx, left, right, res) + binary_param_assert_shape(ctx, left, right, res) + nbits = dh.dtype_nbits[res.dtype] + binary_param_assert_against_refimpl( + ctx, left, right, res, "<<", lambda l, r: l << r if r < nbits else 0 + ) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @pytest.mark.parametrize( @@ -902,15 +957,20 @@ def test_bitwise_left_shift(ctx, data): def test_bitwise_invert(ctx, data): x = data.draw(ctx.strat, label="x") - out = ctx.func(x) + repro_snippet = ph.format_snippet(f"{ctx.func_name}({x!r})") + try: + out = ctx.func(x) - ph.assert_dtype(ctx.func_name, in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_shape(ctx.func_name, out_shape=out.shape, expected=x.shape) - if x.dtype == xp.bool: - refimpl = operator.not_ - else: - refimpl = lambda s: mock_int_dtype(~s, x.dtype) - unary_assert_against_refimpl(ctx.func_name, x, out, refimpl, expr_template="~{}={}") + ph.assert_dtype(ctx.func_name, in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape(ctx.func_name, out_shape=out.shape, expected=x.shape) + if x.dtype == xp.bool: + refimpl = operator.not_ + else: + refimpl = lambda s: mock_int_dtype(~s, x.dtype) + unary_assert_against_refimpl(ctx.func_name, x, out, refimpl, expr_template="~{}={}") + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @pytest.mark.parametrize( @@ -921,15 +981,20 @@ def test_bitwise_or(ctx, data): left = data.draw(ctx.left_strat, label=ctx.left_sym) right = data.draw(ctx.right_strat, label=ctx.right_sym) - res = ctx.func(left, right) + repro_snippet = ph.format_snippet(f"{ctx.func_name}({left!r}, {right!r})") + try: + res = ctx.func(left, right) - binary_param_assert_dtype(ctx, left, right, res) - binary_param_assert_shape(ctx, left, right, res) - if left.dtype == xp.bool: - refimpl = operator.or_ - else: - refimpl = lambda l, r: mock_int_dtype(l | r, res.dtype) - binary_param_assert_against_refimpl(ctx, left, right, res, "|", refimpl) + binary_param_assert_dtype(ctx, left, right, res) + binary_param_assert_shape(ctx, left, right, res) + if left.dtype == xp.bool: + refimpl = operator.or_ + else: + refimpl = lambda l, r: mock_int_dtype(l | r, res.dtype) + binary_param_assert_against_refimpl(ctx, left, right, res, "|", refimpl) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @pytest.mark.parametrize( @@ -944,13 +1009,18 @@ def test_bitwise_right_shift(ctx, data): else: assume(not xp.any(ah.isnegative(right))) - res = ctx.func(left, right) + repro_snippet = ph.format_snippet(f"{ctx.func_name}({left!r}, {right!r})") + try: + res = ctx.func(left, right) - binary_param_assert_dtype(ctx, left, right, res) - binary_param_assert_shape(ctx, left, right, res) - binary_param_assert_against_refimpl( - ctx, left, right, res, ">>", lambda l, r: mock_int_dtype(l >> r, res.dtype) - ) + binary_param_assert_dtype(ctx, left, right, res) + binary_param_assert_shape(ctx, left, right, res) + binary_param_assert_against_refimpl( + ctx, left, right, res, ">>", lambda l, r: mock_int_dtype(l >> r, res.dtype) + ) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @pytest.mark.parametrize( @@ -961,24 +1031,32 @@ def test_bitwise_xor(ctx, data): left = data.draw(ctx.left_strat, label=ctx.left_sym) right = data.draw(ctx.right_strat, label=ctx.right_sym) - res = ctx.func(left, right) - - binary_param_assert_dtype(ctx, left, right, res) - binary_param_assert_shape(ctx, left, right, res) - if left.dtype == xp.bool: - refimpl = operator.xor - else: - refimpl = lambda l, r: mock_int_dtype(l ^ r, res.dtype) - binary_param_assert_against_refimpl(ctx, left, right, res, "^", refimpl) + repro_snippet = ph.format_snippet(f"{ctx.func_name}({left!r}, {right!r})") + try: + res = ctx.func(left, right) + binary_param_assert_dtype(ctx, left, right, res) + binary_param_assert_shape(ctx, left, right, res) + if left.dtype == xp.bool: + refimpl = operator.xor + else: + refimpl = lambda l, r: mock_int_dtype(l ^ r, res.dtype) + binary_param_assert_against_refimpl(ctx, left, right, res, "^", refimpl) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @given(hh.arrays(dtype=hh.real_dtypes, shape=hh.shapes())) def test_ceil(x): - out = xp.ceil(x) - ph.assert_dtype("ceil", in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_shape("ceil", out_shape=out.shape, expected=x.shape) - unary_assert_against_refimpl("ceil", x, out, math.ceil, strict_check=True) - + repro_snippet = ph.format_snippet(f"xp.ceil({x!r})") + try: + out = xp.ceil(x) + ph.assert_dtype("ceil", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("ceil", out_shape=out.shape, expected=x.shape) + unary_assert_against_refimpl("ceil", x, out, math.ceil, strict_check=True) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @pytest.mark.min_version("2023.12") @given(x=hh.arrays(dtype=hh.real_dtypes, shape=hh.shapes()), data=st.data()) @@ -1009,141 +1087,163 @@ def test_clip(x, data): ("max", max, None)), label="kwargs") - out = xp.clip(x, **kw) - - # min and max do not participate in type promotion - ph.assert_dtype("clip", in_dtype=x.dtype, out_dtype=out.dtype) - - shapes = [x.shape] - if min is not None and not dh.is_scalar(min): - shapes.append(min.shape) - if max is not None and not dh.is_scalar(max): - shapes.append(max.shape) - expected_shape = sh.broadcast_shapes(*shapes) - ph.assert_shape("clip", out_shape=out.shape, expected=expected_shape) - - # This is based on right_scalar_assert_against_refimpl and - # binary_assert_against_refimpl. clip() is currently the only ternary - # elementwise function and the only function that supports arrays and - # scalars. However, where() (in test_searching_functions) is similar - # and if scalar support is added to it, we may want to factor out and - # reuse this logic. - - def refimpl(_x, _min, _max): - # Skip cases where _min and _max are integers whose values do not - # fit in the dtype of _x, since this behavior is unspecified. - if dh.is_int_dtype(x.dtype): - if _min is not None and _min not in dh.dtype_ranges[x.dtype]: - return None - if _max is not None and _max not in dh.dtype_ranges[x.dtype]: - return None - - # If min or max are float64 and x is float32, they will need to be - # downcast to float32. This could result in a round in the wrong - # direction meaning the resulting clipped value might not actually be - # between min and max. This behavior is unspecified, so skip any cases - # where x is within the rounding error of downcasting min or max. - if x.dtype == xp.float32: - if min is not None and not dh.is_scalar(min) and min.dtype == xp.float64 and math.isfinite(_min): - _min_float32 = float(xp.asarray(_min, dtype=xp.float32)) - if math.isinf(_min_float32): - return None - tol = abs(_min - _min_float32) - if math.isclose(_min, _min_float32, abs_tol=tol): - return None - if max is not None and not dh.is_scalar(max) and max.dtype == xp.float64 and math.isfinite(_max): - _max_float32 = float(xp.asarray(_max, dtype=xp.float32)) - if math.isinf(_max_float32): + repro_snippet = ph.format_snippet(f"xp.clip({x!r}, **kw) with {kw = }") + try: + out = xp.clip(x, **kw) + + # min and max do not participate in type promotion + ph.assert_dtype("clip", in_dtype=x.dtype, out_dtype=out.dtype) + + shapes = [x.shape] + if min is not None and not dh.is_scalar(min): + shapes.append(min.shape) + if max is not None and not dh.is_scalar(max): + shapes.append(max.shape) + expected_shape = sh.broadcast_shapes(*shapes) + ph.assert_shape("clip", out_shape=out.shape, expected=expected_shape) + + # This is based on right_scalar_assert_against_refimpl and + # binary_assert_against_refimpl. clip() is currently the only ternary + # elementwise function and the only function that supports arrays and + # scalars. However, where() (in test_searching_functions) is similar + # and if scalar support is added to it, we may want to factor out and + # reuse this logic. + + def refimpl(_x, _min, _max): + # Skip cases where _min and _max are integers whose values do not + # fit in the dtype of _x, since this behavior is unspecified. + if dh.is_int_dtype(x.dtype): + if _min is not None and _min not in dh.dtype_ranges[x.dtype]: return None - tol = abs(_max - _max_float32) - if math.isclose(_max, _max_float32, abs_tol=tol): + if _max is not None and _max not in dh.dtype_ranges[x.dtype]: return None - if (math.isnan(_x) - or (_min is not None and math.isnan(_min)) - or (_max is not None and math.isnan(_max))): - return math.nan - if _min is _max is None: - return _x - if _max is None: - return builtins.max(_x, _min) - if _min is None: - return builtins.min(_x, _max) - return builtins.min(builtins.max(_x, _min), _max) - - stype = dh.get_scalar_type(x.dtype) - min_shape = () if min is None or dh.is_scalar(min) else min.shape - max_shape = () if max is None or dh.is_scalar(max) else max.shape - - for x_idx, min_idx, max_idx, o_idx in sh.iter_indices( - x.shape, min_shape, max_shape, out.shape): - x_val = stype(x[x_idx]) - if min is None or dh.is_scalar(min): - min_val = min - else: - min_val = stype(min[min_idx]) - if max is None or dh.is_scalar(max): - max_val = max - else: - max_val = stype(max[max_idx]) - expected = refimpl(x_val, min_val, max_val) - if expected is None: - continue - out_val = stype(out[o_idx]) - if math.isnan(expected): - assert math.isnan(out_val), ( - f"out[{o_idx}]={out[o_idx]} but should be nan [clip()]\n" - f"x[{x_idx}]={x_val}, min[{min_idx}]={min_val}, max[{max_idx}]={max_val}" - ) - else: - if out.dtype == xp.float32: - # conversion to builtin float is prone to roundoff errors - close_enough = math.isclose(out_val, expected, rel_tol=EPS32) + # If min or max are float64 and x is float32, they will need to be + # downcast to float32. This could result in a round in the wrong + # direction meaning the resulting clipped value might not actually be + # between min and max. This behavior is unspecified, so skip any cases + # where x is within the rounding error of downcasting min or max. + if x.dtype == xp.float32: + if min is not None and not dh.is_scalar(min) and min.dtype == xp.float64 and math.isfinite(_min): + _min_float32 = float(xp.asarray(_min, dtype=xp.float32)) + if math.isinf(_min_float32): + return None + tol = abs(_min - _min_float32) + if math.isclose(_min, _min_float32, abs_tol=tol): + return None + if max is not None and not dh.is_scalar(max) and max.dtype == xp.float64 and math.isfinite(_max): + _max_float32 = float(xp.asarray(_max, dtype=xp.float32)) + if math.isinf(_max_float32): + return None + tol = abs(_max - _max_float32) + if math.isclose(_max, _max_float32, abs_tol=tol): + return None + + if (math.isnan(_x) + or (_min is not None and math.isnan(_min)) + or (_max is not None and math.isnan(_max))): + return math.nan + if _min is _max is None: + return _x + if _max is None: + return builtins.max(_x, _min) + if _min is None: + return builtins.min(_x, _max) + return builtins.min(builtins.max(_x, _min), _max) + + stype = dh.get_scalar_type(x.dtype) + min_shape = () if min is None or dh.is_scalar(min) else min.shape + max_shape = () if max is None or dh.is_scalar(max) else max.shape + + for x_idx, min_idx, max_idx, o_idx in sh.iter_indices( + x.shape, min_shape, max_shape, out.shape): + x_val = stype(x[x_idx]) + if min is None or dh.is_scalar(min): + min_val = min else: - close_enough = out_val == expected - - assert close_enough, ( - f"out[{o_idx}]={out[o_idx]} but should be {expected} [clip()]\n" - f"x[{x_idx}]={x_val}, min[{min_idx}]={min_val}, max[{max_idx}]={max_val}" - ) + min_val = stype(min[min_idx]) + if max is None or dh.is_scalar(max): + max_val = max + else: + max_val = stype(max[max_idx]) + expected = refimpl(x_val, min_val, max_val) + if expected is None: + continue + out_val = stype(out[o_idx]) + if math.isnan(expected): + assert math.isnan(out_val), ( + f"out[{o_idx}]={out[o_idx]} but should be nan [clip()]\n" + f"x[{x_idx}]={x_val}, min[{min_idx}]={min_val}, max[{max_idx}]={max_val}" + ) + else: + if out.dtype == xp.float32: + # conversion to builtin float is prone to roundoff errors + close_enough = math.isclose(out_val, expected, rel_tol=EPS32) + else: + close_enough = out_val == expected + + assert close_enough, ( + f"out[{o_idx}]={out[o_idx]} but should be {expected} [clip()]\n" + f"x[{x_idx}]={x_val}, min[{min_idx}]={min_val}, max[{max_idx}]={max_val}" + ) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @pytest.mark.min_version("2022.12") @pytest.mark.skipif(hh.complex_dtypes.is_empty, reason="no complex data types to draw from") @given(hh.arrays(dtype=hh.complex_dtypes, shape=hh.shapes())) def test_conj(x): - out = xp.conj(x) - ph.assert_dtype("conj", in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_shape("conj", out_shape=out.shape, expected=x.shape) - unary_assert_against_refimpl("conj", x, out, operator.methodcaller("conjugate")) + repro_snippet = ph.format_snippet(f"xp.conj({x!r})") + try: + out = xp.conj(x) + ph.assert_dtype("conj", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("conj", out_shape=out.shape, expected=x.shape) + unary_assert_against_refimpl("conj", x, out, operator.methodcaller("conjugate")) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @pytest.mark.min_version("2023.12") @given(*hh.two_mutual_arrays(dh.real_float_dtypes)) def test_copysign(x1, x2): - out = xp.copysign(x1, x2) - ph.assert_dtype("copysign", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype) - ph.assert_result_shape("copysign", in_shapes=[x1.shape, x2.shape], out_shape=out.shape) - binary_assert_against_refimpl("copysign", x1, x2, out, math.copysign) - + repro_snippet = ph.format_snippet(f"xp.copysign({x1!r}, {x2!r})") + try: + out = xp.copysign(x1, x2) + ph.assert_dtype("copysign", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype) + ph.assert_result_shape("copysign", in_shapes=[x1.shape, x2.shape], out_shape=out.shape) + binary_assert_against_refimpl("copysign", x1, x2, out, math.copysign) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_cos(x): - out = xp.cos(x) - ph.assert_dtype("cos", in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_shape("cos", out_shape=out.shape, expected=x.shape) - refimpl = cmath.cos if x.dtype in dh.complex_dtypes else math.cos - unary_assert_against_refimpl("cos", x, out, refimpl) - + repro_snippet = ph.format_snippet(f"xp.cos({x!r})") + try: + out = xp.cos(x) + ph.assert_dtype("cos", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("cos", out_shape=out.shape, expected=x.shape) + refimpl = cmath.cos if x.dtype in dh.complex_dtypes else math.cos + unary_assert_against_refimpl("cos", x, out, refimpl) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_cosh(x): - out = xp.cosh(x) - ph.assert_dtype("cosh", in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_shape("cosh", out_shape=out.shape, expected=x.shape) - refimpl = cmath.cosh if x.dtype in dh.complex_dtypes else math.cosh - unary_assert_against_refimpl("cosh", x, out, refimpl) - + repro_snippet = ph.format_snippet(f"xp.cosh({x!r})") + try: + out = xp.cosh(x) + ph.assert_dtype("cosh", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("cosh", out_shape=out.shape, expected=x.shape) + refimpl = cmath.cosh if x.dtype in dh.complex_dtypes else math.cosh + unary_assert_against_refimpl("cosh", x, out, refimpl) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @pytest.mark.parametrize("ctx", make_binary_params("divide", dh.all_float_dtypes)) @given(data=st.data()) @@ -1153,19 +1253,24 @@ def test_divide(ctx, data): if ctx.right_is_scalar: assume # TODO: assume what? - res = ctx.func(left, right) - - binary_param_assert_dtype(ctx, left, right, res) - binary_param_assert_shape(ctx, left, right, res) - binary_param_assert_against_refimpl( - ctx, - left, - right, - res, - "/", - operator.truediv, - filter_=lambda s: cmath.isfinite(s) and s != 0, - ) + repro_snippet = ph.format_snippet(f"{ctx.func_name}({left!r}, {right!r})") + try: + res = ctx.func(left, right) + + binary_param_assert_dtype(ctx, left, right, res) + binary_param_assert_shape(ctx, left, right, res) + binary_param_assert_against_refimpl( + ctx, + left, + right, + res, + "/", + operator.truediv, + filter_=lambda s: cmath.isfinite(s) and s != 0, + ) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @pytest.mark.parametrize("ctx", make_binary_params("equal", dh.all_dtypes)) @@ -1174,72 +1279,91 @@ def test_equal(ctx, data): left = data.draw(ctx.left_strat, label=ctx.left_sym) right = data.draw(ctx.right_strat, label=ctx.right_sym) - out = ctx.func(left, right) - - binary_param_assert_dtype(ctx, left, right, out, xp.bool) - binary_param_assert_shape(ctx, left, right, out) - if not ctx.right_is_scalar: - # We manually promote the dtypes as incorrect internal type promotion - # could lead to false positives. For example - # - # >>> xp.equal( - # ... xp.asarray(1.0, dtype=xp.float32), - # ... xp.asarray(1.00000001, dtype=xp.float64), - # ... ) - # - # would erroneously be True if float64 downcasted to float32. - promoted_dtype = dh.promotion_table[left.dtype, right.dtype] - left = xp.astype(left, promoted_dtype) - right = xp.astype(right, promoted_dtype) - binary_param_assert_against_refimpl( - ctx, left, right, out, "==", operator.eq, res_stype=bool - ) + repro_snippet = ph.format_snippet(f"{ctx.func_name}({left!r}, {right!r})") + try: + out = ctx.func(left, right) + + binary_param_assert_dtype(ctx, left, right, out, xp.bool) + binary_param_assert_shape(ctx, left, right, out) + if not ctx.right_is_scalar: + # We manually promote the dtypes as incorrect internal type promotion + # could lead to false positives. For example + # + # >>> xp.equal( + # ... xp.asarray(1.0, dtype=xp.float32), + # ... xp.asarray(1.00000001, dtype=xp.float64), + # ... ) + # + # would erroneously be True if float64 downcasted to float32. + promoted_dtype = dh.promotion_table[left.dtype, right.dtype] + left = xp.astype(left, promoted_dtype) + right = xp.astype(right, promoted_dtype) + binary_param_assert_against_refimpl( + ctx, left, right, out, "==", operator.eq, res_stype=bool + ) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_exp(x): - out = xp.exp(x) - ph.assert_dtype("exp", in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_shape("exp", out_shape=out.shape, expected=x.shape) - refimpl = cmath.exp if x.dtype in dh.complex_dtypes else math.exp - unary_assert_against_refimpl("exp", x, out, refimpl) - + repro_snippet = ph.format_snippet(f"xp.exp({x!r})") + try: + out = xp.exp(x) + ph.assert_dtype("exp", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("exp", out_shape=out.shape, expected=x.shape) + refimpl = cmath.exp if x.dtype in dh.complex_dtypes else math.exp + unary_assert_against_refimpl("exp", x, out, refimpl) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_expm1(x): - out = xp.expm1(x) - ph.assert_dtype("expm1", in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_shape("expm1", out_shape=out.shape, expected=x.shape) - if x.dtype in dh.complex_dtypes: - def refimpl(z): - # There's no cmath.expm1. Use - # - # exp(x+yi) - 1 - # = exp(x)exp(yi) - 1 - # = exp(x)(cos(y) + sin(y)i) - 1 - # = (exp(x) - 1)cos(y) + (cos(y) - 1) + exp(x)sin(y)i - # = expm1(x)cos(y) - 2sin(y/2)^2 + exp(x)sin(y)i - # - # where 1 - cos(y) = 2sin(y/2)^2 is used to avoid loss of - # significance near y = 0. - re, im = z.real, z.imag - return math.expm1(re)*math.cos(im) - 2*math.sin(im/2)**2 + 1j*math.exp(re)*math.sin(im) - else: - refimpl = math.expm1 - unary_assert_against_refimpl("expm1", x, out, refimpl) + repro_snippet = ph.format_snippet(f"xp.expm1({x!r})") + try: + out = xp.expm1(x) + ph.assert_dtype("expm1", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("expm1", out_shape=out.shape, expected=x.shape) + if x.dtype in dh.complex_dtypes: + def refimpl(z): + # There's no cmath.expm1. Use + # + # exp(x+yi) - 1 + # = exp(x)exp(yi) - 1 + # = exp(x)(cos(y) + sin(y)i) - 1 + # = (exp(x) - 1)cos(y) + (cos(y) - 1) + exp(x)sin(y)i + # = expm1(x)cos(y) - 2sin(y/2)^2 + exp(x)sin(y)i + # + # where 1 - cos(y) = 2sin(y/2)^2 is used to avoid loss of + # significance near y = 0. + re, im = z.real, z.imag + return math.expm1(re)*math.cos(im) - 2*math.sin(im/2)**2 + 1j*math.exp(re)*math.sin(im) + else: + refimpl = math.expm1 + unary_assert_against_refimpl("expm1", x, out, refimpl) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @given(hh.arrays(dtype=hh.real_dtypes, shape=hh.shapes())) def test_floor(x): - out = xp.floor(x) - ph.assert_dtype("floor", in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_shape("floor", out_shape=out.shape, expected=x.shape) - if x.dtype in dh.complex_dtypes: - def refimpl(z): - return complex(math.floor(z.real), math.floor(z.imag)) - else: - refimpl = math.floor - unary_assert_against_refimpl("floor", x, out, refimpl, strict_check=True) + repro_snippet = ph.format_snippet(f"xp.floor({x!r})") + try: + out = xp.floor(x) + ph.assert_dtype("floor", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("floor", out_shape=out.shape, expected=x.shape) + if x.dtype in dh.complex_dtypes: + def refimpl(z): + return complex(math.floor(z.real), math.floor(z.imag)) + else: + refimpl = math.floor + unary_assert_against_refimpl("floor", x, out, refimpl, strict_check=True) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @pytest.mark.parametrize("ctx", make_binary_params("floor_divide", dh.real_dtypes)) @@ -1254,11 +1378,16 @@ def test_floor_divide(ctx, data): else: assume(not xp.any(right == 0)) - res = ctx.func(left, right) + repro_snippet = ph.format_snippet(f"{ctx.func_name}({left!r}, {right!r})") + try: + res = ctx.func(left, right) - binary_param_assert_dtype(ctx, left, right, res) - binary_param_assert_shape(ctx, left, right, res) - binary_param_assert_against_refimpl(ctx, left, right, res, "//", operator.floordiv) + binary_param_assert_dtype(ctx, left, right, res) + binary_param_assert_shape(ctx, left, right, res) + binary_param_assert_against_refimpl(ctx, left, right, res, "//", operator.floordiv) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @pytest.mark.parametrize("ctx", make_binary_params("greater", dh.real_dtypes)) @@ -1267,18 +1396,23 @@ def test_greater(ctx, data): left = data.draw(ctx.left_strat, label=ctx.left_sym) right = data.draw(ctx.right_strat, label=ctx.right_sym) - out = ctx.func(left, right) - - binary_param_assert_dtype(ctx, left, right, out, xp.bool) - binary_param_assert_shape(ctx, left, right, out) - if not ctx.right_is_scalar: - # See test_equal note - promoted_dtype = dh.promotion_table[left.dtype, right.dtype] - left = xp.astype(left, promoted_dtype) - right = xp.astype(right, promoted_dtype) - binary_param_assert_against_refimpl( - ctx, left, right, out, ">", operator.gt, res_stype=bool - ) + repro_snippet = ph.format_snippet(f"{ctx.func_name}({left!r}, {right!r})") + try: + out = ctx.func(left, right) + + binary_param_assert_dtype(ctx, left, right, out, xp.bool) + binary_param_assert_shape(ctx, left, right, out) + if not ctx.right_is_scalar: + # See test_equal note + promoted_dtype = dh.promotion_table[left.dtype, right.dtype] + left = xp.astype(left, promoted_dtype) + right = xp.astype(right, promoted_dtype) + binary_param_assert_against_refimpl( + ctx, left, right, out, ">", operator.gt, res_stype=bool + ) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @pytest.mark.parametrize("ctx", make_binary_params("greater_equal", dh.real_dtypes)) @@ -1287,69 +1421,99 @@ def test_greater_equal(ctx, data): left = data.draw(ctx.left_strat, label=ctx.left_sym) right = data.draw(ctx.right_strat, label=ctx.right_sym) - out = ctx.func(left, right) - - binary_param_assert_dtype(ctx, left, right, out, xp.bool) - binary_param_assert_shape(ctx, left, right, out) - if not ctx.right_is_scalar: - # See test_equal note - promoted_dtype = dh.promotion_table[left.dtype, right.dtype] - left = xp.astype(left, promoted_dtype) - right = xp.astype(right, promoted_dtype) - binary_param_assert_against_refimpl( - ctx, left, right, out, ">=", operator.ge, res_stype=bool - ) + repro_snippet = ph.format_snippet(f"{ctx.func_name}({left!r}, {right!r})") + try: + out = ctx.func(left, right) + + binary_param_assert_dtype(ctx, left, right, out, xp.bool) + binary_param_assert_shape(ctx, left, right, out) + if not ctx.right_is_scalar: + # See test_equal note + promoted_dtype = dh.promotion_table[left.dtype, right.dtype] + left = xp.astype(left, promoted_dtype) + right = xp.astype(right, promoted_dtype) + binary_param_assert_against_refimpl( + ctx, left, right, out, ">=", operator.ge, res_stype=bool + ) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @pytest.mark.min_version("2023.12") @given(*hh.two_mutual_arrays(dh.real_float_dtypes)) def test_hypot(x1, x2): - out = xp.hypot(x1, x2) - _assert_correctness_binary( - "hypot", - math.hypot, - in_dtypes=[x1.dtype, x2.dtype], - in_shapes=[x1.shape, x2.shape], - in_arrs=[x1, x2], - out=out - ) + repro_snippet = ph.format_snippet(f"xp.hypot({x1!r}, {x2!r})") + try: + out = xp.hypot(x1, x2) + _assert_correctness_binary( + "hypot", + math.hypot, + in_dtypes=[x1.dtype, x2.dtype], + in_shapes=[x1.shape, x2.shape], + in_arrs=[x1, x2], + out=out + ) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @pytest.mark.min_version("2022.12") @pytest.mark.skipif(hh.complex_dtypes.is_empty, reason="no complex data types to draw from") @given(hh.arrays(dtype=hh.complex_dtypes, shape=hh.shapes())) def test_imag(x): - out = xp.imag(x) - ph.assert_dtype("imag", in_dtype=x.dtype, out_dtype=out.dtype, expected=dh.dtype_components[x.dtype]) - ph.assert_shape("imag", out_shape=out.shape, expected=x.shape) - unary_assert_against_refimpl("imag", x, out, operator.attrgetter("imag")) + repro_snippet = ph.format_snippet(f"xp.imag({x!r})") + try: + out = xp.imag(x) + ph.assert_dtype("imag", in_dtype=x.dtype, out_dtype=out.dtype, expected=dh.dtype_components[x.dtype]) + ph.assert_shape("imag", out_shape=out.shape, expected=x.shape) + unary_assert_against_refimpl("imag", x, out, operator.attrgetter("imag")) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @given(hh.arrays(dtype=hh.numeric_dtypes, shape=hh.shapes())) def test_isfinite(x): - out = xp.isfinite(x) - ph.assert_dtype("isfinite", in_dtype=x.dtype, out_dtype=out.dtype, expected=xp.bool) - ph.assert_shape("isfinite", out_shape=out.shape, expected=x.shape) - refimpl = cmath.isfinite if x.dtype in dh.complex_dtypes else math.isfinite - unary_assert_against_refimpl("isfinite", x, out, refimpl, res_stype=bool) + repro_snippet = ph.format_snippet(f"xp.isfinite({x!r})") + try: + out = xp.isfinite(x) + ph.assert_dtype("isfinite", in_dtype=x.dtype, out_dtype=out.dtype, expected=xp.bool) + ph.assert_shape("isfinite", out_shape=out.shape, expected=x.shape) + refimpl = cmath.isfinite if x.dtype in dh.complex_dtypes else math.isfinite + unary_assert_against_refimpl("isfinite", x, out, refimpl, res_stype=bool) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @given(hh.arrays(dtype=hh.numeric_dtypes, shape=hh.shapes())) def test_isinf(x): - out = xp.isinf(x) - ph.assert_dtype("isfinite", in_dtype=x.dtype, out_dtype=out.dtype, expected=xp.bool) - ph.assert_shape("isinf", out_shape=out.shape, expected=x.shape) - refimpl = cmath.isinf if x.dtype in dh.complex_dtypes else math.isinf - unary_assert_against_refimpl("isinf", x, out, refimpl, res_stype=bool) + repro_snippet = ph.format_snippet(f"xp.isinf({x!r})") + try: + out = xp.isinf(x) + ph.assert_dtype("isfinite", in_dtype=x.dtype, out_dtype=out.dtype, expected=xp.bool) + ph.assert_shape("isinf", out_shape=out.shape, expected=x.shape) + refimpl = cmath.isinf if x.dtype in dh.complex_dtypes else math.isinf + unary_assert_against_refimpl("isinf", x, out, refimpl, res_stype=bool) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @given(hh.arrays(dtype=hh.numeric_dtypes, shape=hh.shapes())) def test_isnan(x): - out = xp.isnan(x) - ph.assert_dtype("isnan", in_dtype=x.dtype, out_dtype=out.dtype, expected=xp.bool) - ph.assert_shape("isnan", out_shape=out.shape, expected=x.shape) - refimpl = cmath.isnan if x.dtype in dh.complex_dtypes else math.isnan - unary_assert_against_refimpl("isnan", x, out, refimpl, res_stype=bool) + repro_snippet = ph.format_snippet(f"xp.isnan({x!r})") + try: + out = xp.isnan(x) + ph.assert_dtype("isnan", in_dtype=x.dtype, out_dtype=out.dtype, expected=xp.bool) + ph.assert_shape("isnan", out_shape=out.shape, expected=x.shape) + refimpl = cmath.isnan if x.dtype in dh.complex_dtypes else math.isnan + unary_assert_against_refimpl("isnan", x, out, refimpl, res_stype=bool) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @pytest.mark.parametrize("ctx", make_binary_params("less", dh.real_dtypes)) @@ -1358,18 +1522,23 @@ def test_less(ctx, data): left = data.draw(ctx.left_strat, label=ctx.left_sym) right = data.draw(ctx.right_strat, label=ctx.right_sym) - out = ctx.func(left, right) - - binary_param_assert_dtype(ctx, left, right, out, xp.bool) - binary_param_assert_shape(ctx, left, right, out) - if not ctx.right_is_scalar: - # See test_equal note - promoted_dtype = dh.promotion_table[left.dtype, right.dtype] - left = xp.astype(left, promoted_dtype) - right = xp.astype(right, promoted_dtype) - binary_param_assert_against_refimpl( - ctx, left, right, out, "<", operator.lt, res_stype=bool - ) + repro_snippet = ph.format_snippet(f"{ctx.func_name}({left!r}, {right!r})") + try: + out = ctx.func(left, right) + + binary_param_assert_dtype(ctx, left, right, out, xp.bool) + binary_param_assert_shape(ctx, left, right, out) + if not ctx.right_is_scalar: + # See test_equal note + promoted_dtype = dh.promotion_table[left.dtype, right.dtype] + left = xp.astype(left, promoted_dtype) + right = xp.astype(right, promoted_dtype) + binary_param_assert_against_refimpl( + ctx, left, right, out, "<", operator.lt, res_stype=bool + ) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @pytest.mark.parametrize("ctx", make_binary_params("less_equal", dh.real_dtypes)) @@ -1378,81 +1547,106 @@ def test_less_equal(ctx, data): left = data.draw(ctx.left_strat, label=ctx.left_sym) right = data.draw(ctx.right_strat, label=ctx.right_sym) - out = ctx.func(left, right) - - binary_param_assert_dtype(ctx, left, right, out, xp.bool) - binary_param_assert_shape(ctx, left, right, out) - if not ctx.right_is_scalar: - # See test_equal note - promoted_dtype = dh.promotion_table[left.dtype, right.dtype] - left = xp.astype(left, promoted_dtype) - right = xp.astype(right, promoted_dtype) - binary_param_assert_against_refimpl( - ctx, left, right, out, "<=", operator.le, res_stype=bool - ) + repro_snippet = ph.format_snippet(f"{ctx.func_name}({left!r}, {right!r})") + try: + out = ctx.func(left, right) + + binary_param_assert_dtype(ctx, left, right, out, xp.bool) + binary_param_assert_shape(ctx, left, right, out) + if not ctx.right_is_scalar: + # See test_equal note + promoted_dtype = dh.promotion_table[left.dtype, right.dtype] + left = xp.astype(left, promoted_dtype) + right = xp.astype(right, promoted_dtype) + binary_param_assert_against_refimpl( + ctx, left, right, out, "<=", operator.le, res_stype=bool + ) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_log(x): - out = xp.log(x) - ph.assert_dtype("log", in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_shape("log", out_shape=out.shape, expected=x.shape) - refimpl = cmath.log if x.dtype in dh.complex_dtypes else math.log - filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and s > 0 - unary_assert_against_refimpl( - "log", x, out, refimpl, filter_=filter_ - ) + repro_snippet = ph.format_snippet(f"xp.log({x!r})") + try: + out = xp.log(x) + ph.assert_dtype("log", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("log", out_shape=out.shape, expected=x.shape) + refimpl = cmath.log if x.dtype in dh.complex_dtypes else math.log + filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and s > 0 + unary_assert_against_refimpl( + "log", x, out, refimpl, filter_=filter_ + ) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_log1p(x): - out = xp.log1p(x) - ph.assert_dtype("log1p", in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_shape("log1p", out_shape=out.shape, expected=x.shape) - # There isn't a cmath.log1p, and implementing one isn't straightforward - # (see - # https://stackoverflow.com/questions/78318212/unexpected-behaviour-of-log1p-numpy). - # For now, just use log(1+p) for complex inputs, which should hopefully be - # fine given the very loose numerical tolerances we use. If it isn't, we - # can try using something like a series expansion for small p. - if x.dtype in dh.complex_dtypes: - refimpl = lambda z: cmath.log(1+z) - else: - refimpl = math.log1p - filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and s > -1 - unary_assert_against_refimpl( - "log1p", x, out, refimpl, filter_=filter_ - ) + repro_snippet = ph.format_snippet(f"xp.log1p({x!r})") + try: + out = xp.log1p(x) + ph.assert_dtype("log1p", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("log1p", out_shape=out.shape, expected=x.shape) + # There isn't a cmath.log1p, and implementing one isn't straightforward + # (see + # https://stackoverflow.com/questions/78318212/unexpected-behaviour-of-log1p-numpy). + # For now, just use log(1+p) for complex inputs, which should hopefully be + # fine given the very loose numerical tolerances we use. If it isn't, we + # can try using something like a series expansion for small p. + if x.dtype in dh.complex_dtypes: + refimpl = lambda z: cmath.log(1+z) + else: + refimpl = math.log1p + filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and s > -1 + unary_assert_against_refimpl( + "log1p", x, out, refimpl, filter_=filter_ + ) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_log2(x): - out = xp.log2(x) - ph.assert_dtype("log2", in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_shape("log2", out_shape=out.shape, expected=x.shape) - if x.dtype in dh.complex_dtypes: - refimpl = lambda z: cmath.log(z)/math.log(2) - else: - refimpl = math.log2 - filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and s > 0 - unary_assert_against_refimpl( - "log2", x, out, refimpl, filter_=filter_ - ) + repro_snippet = ph.format_snippet(f"xp.log2({x!r})") + try: + out = xp.log2(x) + ph.assert_dtype("log2", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("log2", out_shape=out.shape, expected=x.shape) + if x.dtype in dh.complex_dtypes: + refimpl = lambda z: cmath.log(z)/math.log(2) + else: + refimpl = math.log2 + filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and s > 0 + unary_assert_against_refimpl( + "log2", x, out, refimpl, filter_=filter_ + ) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_log10(x): - out = xp.log10(x) - ph.assert_dtype("log10", in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_shape("log10", out_shape=out.shape, expected=x.shape) - if x.dtype in dh.complex_dtypes: - refimpl = lambda z: cmath.log(z)/math.log(10) - else: - refimpl = math.log10 - filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and s > 0 - unary_assert_against_refimpl( - "log10", x, out, refimpl, filter_=filter_ - ) + repro_snippet = ph.format_snippet(f"xp.log10({x!r})") + try: + out = xp.log10(x) + ph.assert_dtype("log10", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("log10", out_shape=out.shape, expected=x.shape) + if x.dtype in dh.complex_dtypes: + refimpl = lambda z: cmath.log(z)/math.log(10) + else: + refimpl = math.log10 + filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and s > 0 + unary_assert_against_refimpl( + "log10", x, out, refimpl, filter_=filter_ + ) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise def logaddexp_refimpl(l: float, r: float) -> float: @@ -1465,85 +1659,120 @@ def logaddexp_refimpl(l: float, r: float) -> float: @pytest.mark.min_version("2023.12") @given(*hh.two_mutual_arrays(dh.real_float_dtypes)) def test_logaddexp(x1, x2): - out = xp.logaddexp(x1, x2) - _assert_correctness_binary( - "logaddexp", - logaddexp_refimpl, - in_dtypes=[x1.dtype, x2.dtype], - in_shapes=[x1.shape, x2.shape], - in_arrs=[x1, x2], - out=out - ) + repro_snippet = ph.format_snippet(f"xp.logaddexp({x1!r}, {x2!r})") + try: + out = xp.logaddexp(x1, x2) + _assert_correctness_binary( + "logaddexp", + logaddexp_refimpl, + in_dtypes=[x1.dtype, x2.dtype], + in_shapes=[x1.shape, x2.shape], + in_arrs=[x1, x2], + out=out + ) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @given(hh.arrays(dtype=xp.bool, shape=hh.shapes())) def test_logical_not(x): - out = xp.logical_not(x) - ph.assert_dtype("logical_not", in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_shape("logical_not", out_shape=out.shape, expected=x.shape) - unary_assert_against_refimpl( - "logical_not", x, out, operator.not_, expr_template="(not {})={}" - ) + repro_snippet = ph.format_snippet(f"xp.logical_not({x!r})") + try: + out = xp.logical_not(x) + ph.assert_dtype("logical_not", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("logical_not", out_shape=out.shape, expected=x.shape) + unary_assert_against_refimpl( + "logical_not", x, out, operator.not_, expr_template="(not {})={}" + ) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @given(*hh.two_mutual_arrays([xp.bool])) def test_logical_and(x1, x2): - out = xp.logical_and(x1, x2) - _assert_correctness_binary( - "logical_and", - operator.and_, - in_dtypes=[x1.dtype, x2.dtype], - in_shapes=[x1.shape, x2.shape], - in_arrs=[x1, x2], - out=out, - expr_template="({} and {})={}" - ) + repro_snippet = ph.format_snippet(f"xp.logical_and({x1!r}, {x2!r})") + try: + out = xp.logical_and(x1, x2) + _assert_correctness_binary( + "logical_and", + operator.and_, + in_dtypes=[x1.dtype, x2.dtype], + in_shapes=[x1.shape, x2.shape], + in_arrs=[x1, x2], + out=out, + expr_template="({} and {})={}" + ) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @given(*hh.two_mutual_arrays([xp.bool])) def test_logical_or(x1, x2): - out = xp.logical_or(x1, x2) - _assert_correctness_binary( - "logical_or", - operator.or_, - in_dtypes=[x1.dtype, x2.dtype], - in_shapes=[x1.shape, x2.shape], - in_arrs=[x1, x2], - out=out, - expr_template="({} or {})={}" - ) + repro_snippet = ph.format_snippet(f"xp.logical_or({x1!r}, {x2!r})") + try: + out = xp.logical_or(x1, x2) + _assert_correctness_binary( + "logical_or", + operator.or_, + in_dtypes=[x1.dtype, x2.dtype], + in_shapes=[x1.shape, x2.shape], + in_arrs=[x1, x2], + out=out, + expr_template="({} or {})={}" + ) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @given(*hh.two_mutual_arrays([xp.bool])) def test_logical_xor(x1, x2): - out = xp.logical_xor(x1, x2) - _assert_correctness_binary( - "logical_xor", - operator.xor, - in_dtypes=[x1.dtype, x2.dtype], - in_shapes=[x1.shape, x2.shape], - in_arrs=[x1, x2], - out=out, - expr_template="({} ^ {})={}" - ) + repro_snippet = ph.format_snippet(f"xp.logical_xor({x1!r}, {x2!r})") + try: + out = xp.logical_xor(x1, x2) + _assert_correctness_binary( + "logical_xor", + operator.xor, + in_dtypes=[x1.dtype, x2.dtype], + in_shapes=[x1.shape, x2.shape], + in_arrs=[x1, x2], + out=out, + expr_template="({} ^ {})={}" + ) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @pytest.mark.min_version("2023.12") @given(*hh.two_mutual_arrays(dh.real_float_dtypes)) def test_maximum(x1, x2): - out = xp.maximum(x1, x2) - _assert_correctness_binary( - "maximum", max, [x1.dtype, x2.dtype], [x1.shape, x2.shape], (x1, x2), out, strict_check=True - ) + repro_snippet = ph.format_snippet(f"xp.maximum({x1!r}, {x2!r})") + try: + out = xp.maximum(x1, x2) + _assert_correctness_binary( + "maximum", max, [x1.dtype, x2.dtype], [x1.shape, x2.shape], (x1, x2), out, strict_check=True + ) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @pytest.mark.min_version("2023.12") @given(*hh.two_mutual_arrays(dh.real_float_dtypes)) def test_minimum(x1, x2): - out = xp.minimum(x1, x2) - _assert_correctness_binary( - "minimum", min, [x1.dtype, x2.dtype], [x1.shape, x2.shape], (x1, x2), out, strict_check=True - ) + repro_snippet = ph.format_snippet(f"xp.minumum({x1!r}, {x2!r})") + try: + out = xp.minimum(x1, x2) + _assert_correctness_binary( + "minimum", min, [x1.dtype, x2.dtype], [x1.shape, x2.shape], (x1, x2), out, strict_check=True + ) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @pytest.mark.parametrize("ctx", make_binary_params("multiply", dh.numeric_dtypes)) @@ -1552,11 +1781,16 @@ def test_multiply(ctx, data): left = data.draw(ctx.left_strat, label=ctx.left_sym) right = data.draw(ctx.right_strat, label=ctx.right_sym) - res = ctx.func(left, right) + repro_snippet = ph.format_snippet(f"{ctx.func_name}({left!r}, {right!r})") + try: + res = ctx.func(left, right) - binary_param_assert_dtype(ctx, left, right, res) - binary_param_assert_shape(ctx, left, right, res) - binary_param_assert_against_refimpl(ctx, left, right, res, "*", operator.mul) + binary_param_assert_dtype(ctx, left, right, res) + binary_param_assert_shape(ctx, left, right, res) + binary_param_assert_against_refimpl(ctx, left, right, res, "*", operator.mul) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise # TODO: clarify if uints are acceptable, adjust accordingly @@ -1568,14 +1802,18 @@ def test_negative(ctx, data): if x.dtype in dh.int_dtypes: assume(xp.all(x > dh.dtype_ranges[x.dtype].min)) - out = ctx.func(x) - - ph.assert_dtype(ctx.func_name, in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_shape(ctx.func_name, out_shape=out.shape, expected=x.shape) - unary_assert_against_refimpl( - ctx.func_name, x, out, operator.neg, expr_template="-({})={}" # type: ignore - ) + repro_snippet = ph.format_snippet(f"{ctx.func_name}({x!r}") + try: + out = ctx.func(x) + ph.assert_dtype(ctx.func_name, in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape(ctx.func_name, out_shape=out.shape, expected=x.shape) + unary_assert_against_refimpl( + ctx.func_name, x, out, operator.neg, expr_template="-({})={}" # type: ignore + ) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @pytest.mark.parametrize("ctx", make_binary_params("not_equal", dh.all_dtypes)) @given(data=st.data()) @@ -1583,18 +1821,23 @@ def test_not_equal(ctx, data): left = data.draw(ctx.left_strat, label=ctx.left_sym) right = data.draw(ctx.right_strat, label=ctx.right_sym) - out = ctx.func(left, right) - - binary_param_assert_dtype(ctx, left, right, out, xp.bool) - binary_param_assert_shape(ctx, left, right, out) - if not ctx.right_is_scalar: - # See test_equal note - promoted_dtype = dh.promotion_table[left.dtype, right.dtype] - left = xp.astype(left, promoted_dtype) - right = xp.astype(right, promoted_dtype) - binary_param_assert_against_refimpl( - ctx, left, right, out, "!=", operator.ne, res_stype=bool - ) + repro_snippet = ph.format_snippet(f"{ctx.func_name}({left!r}, {right!r})") + try: + out = ctx.func(left, right) + + binary_param_assert_dtype(ctx, left, right, out, xp.bool) + binary_param_assert_shape(ctx, left, right, out) + if not ctx.right_is_scalar: + # See test_equal note + promoted_dtype = dh.promotion_table[left.dtype, right.dtype] + left = xp.astype(left, promoted_dtype) + right = xp.astype(right, promoted_dtype) + binary_param_assert_against_refimpl( + ctx, left, right, out, "!=", operator.ne, res_stype=bool + ) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @pytest.mark.min_version("2024.12") @@ -1607,26 +1850,37 @@ def test_nextafter(shapes, dtype, data): x1 = data.draw(hh.arrays(dtype=dtype, shape=shapes[0]), label="x1") x2 = data.draw(hh.arrays(dtype=dtype, shape=shapes[0]), label="x2") - out = xp.nextafter(x1, x2) - _assert_correctness_binary( - "nextafter", - math.nextafter, - in_dtypes=[x1.dtype, x2.dtype], - in_shapes=[x1.shape, x2.shape], - in_arrs=[x1, x2], - out=out - ) + repro_snippet = ph.format_snippet(f"xp.nextafter({x1!r}, {x2!r})") + try: + out = xp.nextafter(x1, x2) + _assert_correctness_binary( + "nextafter", + math.nextafter, + in_dtypes=[x1.dtype, x2.dtype], + in_shapes=[x1.shape, x2.shape], + in_arrs=[x1, x2], + out=out + ) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise + @pytest.mark.parametrize("ctx", make_unary_params("positive", dh.numeric_dtypes)) @given(data=st.data()) def test_positive(ctx, data): x = data.draw(ctx.strat, label="x") - out = ctx.func(x) + repro_snippet = ph.format_snippet(f"{ctx.func_name}({x!r})") + try: + out = ctx.func(x) - ph.assert_dtype(ctx.func_name, in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_shape(ctx.func_name, out_shape=out.shape, expected=x.shape) - ph.assert_array_elements(ctx.func_name, out=out, expected=x) + ph.assert_dtype(ctx.func_name, in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape(ctx.func_name, out_shape=out.shape, expected=x.shape) + ph.assert_array_elements(ctx.func_name, out=out, expected=x) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @pytest.mark.parametrize("ctx", make_binary_params("pow", dh.numeric_dtypes)) @@ -1641,38 +1895,53 @@ def test_pow(ctx, data): if dh.is_int_dtype(right.dtype): assume(xp.all(right >= 0)) - with hh.reject_overflow(): - res = ctx.func(left, right) + repro_snippet = ph.format_snippet(f"{ctx.func_name}({left!r}, {right!r})") + try: + with hh.reject_overflow(): + res = ctx.func(left, right) - binary_param_assert_dtype(ctx, left, right, res) - binary_param_assert_shape(ctx, left, right, res) - # Values testing pow is too finicky + binary_param_assert_dtype(ctx, left, right, res) + binary_param_assert_shape(ctx, left, right, res) + # Values testing pow is too finicky + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @pytest.mark.min_version("2022.12") @pytest.mark.skipif(hh.complex_dtypes.is_empty, reason="no complex data types to draw from") @given(hh.arrays(dtype=hh.complex_dtypes, shape=hh.shapes())) def test_real(x): - out = xp.real(x) - ph.assert_dtype("real", in_dtype=x.dtype, out_dtype=out.dtype, expected=dh.dtype_components[x.dtype]) - ph.assert_shape("real", out_shape=out.shape, expected=x.shape) - unary_assert_against_refimpl("real", x, out, operator.attrgetter("real")) + repro_snippet = ph.format_snippet(f"xp.real({x!r})") + try: + out = xp.real(x) + ph.assert_dtype("real", in_dtype=x.dtype, out_dtype=out.dtype, expected=dh.dtype_components[x.dtype]) + ph.assert_shape("real", out_shape=out.shape, expected=x.shape) + unary_assert_against_refimpl("real", x, out, operator.attrgetter("real")) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @pytest.mark.min_version("2024.12") @given(hh.arrays(dtype=hh.floating_dtypes, shape=hh.shapes(), elements=finite_kw)) def test_reciprocal(x): - out = xp.reciprocal(x) - ph.assert_dtype("reciprocal", in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_shape("reciprocal", out_shape=out.shape, expected=x.shape) - refimpl = lambda x: 1.0 / x - unary_assert_against_refimpl( - "reciprocal", - x, - out, - refimpl, - strict_check=True, - ) + repro_snippet = ph.format_snippet(f"xp.reciprocal({x!r})") + try: + out = xp.reciprocal(x) + ph.assert_dtype("reciprocal", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("reciprocal", out_shape=out.shape, expected=x.shape) + refimpl = lambda x: 1.0 / x + unary_assert_against_refimpl( + "reciprocal", + x, + out, + refimpl, + strict_check=True, + ) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @pytest.mark.skip(reason="flaky") @@ -1686,88 +1955,128 @@ def test_remainder(ctx, data): else: assume(not xp.any(right == 0)) - res = ctx.func(left, right) + repro_snippet = ph.format_snippet(f"{ctx.func_name}({left!r}, {right!r})") + try: + res = ctx.func(left, right) - binary_param_assert_dtype(ctx, left, right, res) - binary_param_assert_shape(ctx, left, right, res) - binary_param_assert_against_refimpl(ctx, left, right, res, "%", operator.mod) + binary_param_assert_dtype(ctx, left, right, res) + binary_param_assert_shape(ctx, left, right, res) + binary_param_assert_against_refimpl(ctx, left, right, res, "%", operator.mod) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @given(hh.arrays(dtype=hh.numeric_dtypes, shape=hh.shapes())) def test_round(x): - out = xp.round(x) - ph.assert_dtype("round", in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_shape("round", out_shape=out.shape, expected=x.shape) - if x.dtype in dh.complex_dtypes: - refimpl = lambda z: complex(round(z.real), round(z.imag)) - else: - refimpl = round - unary_assert_against_refimpl("round", x, out, refimpl, strict_check=True) + repro_snippet = ph.format_snippet(f"xp.round({x!r})") + try: + out = xp.round(x) + ph.assert_dtype("round", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("round", out_shape=out.shape, expected=x.shape) + if x.dtype in dh.complex_dtypes: + refimpl = lambda z: complex(round(z.real), round(z.imag)) + else: + refimpl = round + unary_assert_against_refimpl("round", x, out, refimpl, strict_check=True) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @pytest.mark.min_version("2023.12") @given(hh.arrays(dtype=hh.real_floating_dtypes, shape=hh.shapes())) def test_signbit(x): - out = xp.signbit(x) - ph.assert_dtype("signbit", in_dtype=x.dtype, out_dtype=out.dtype, expected=xp.bool) - ph.assert_shape("signbit", out_shape=out.shape, expected=x.shape) - refimpl = lambda x: math.copysign(1.0, x) < 0 - unary_assert_against_refimpl("round", x, out, refimpl, strict_check=True) + repro_snippet = ph.format_snippet(f"xp.signbit({x!r})") + try: + out = xp.signbit(x) + ph.assert_dtype("signbit", in_dtype=x.dtype, out_dtype=out.dtype, expected=xp.bool) + ph.assert_shape("signbit", out_shape=out.shape, expected=x.shape) + refimpl = lambda x: math.copysign(1.0, x) < 0 + unary_assert_against_refimpl("round", x, out, refimpl, strict_check=True) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @given(hh.arrays(dtype=hh.numeric_dtypes, shape=hh.shapes(), elements=finite_kw)) def test_sign(x): - out = xp.sign(x) - ph.assert_dtype("sign", in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_shape("sign", out_shape=out.shape, expected=x.shape) - refimpl = lambda x: x / abs(x) if x != 0 else 0 - unary_assert_against_refimpl( - "sign", - x, - out, - refimpl, - strict_check=True, - ) + repro_snippet = ph.format_snippet(f"xp.sign({x!r})") + try: + out = xp.sign(x) + ph.assert_dtype("sign", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("sign", out_shape=out.shape, expected=x.shape) + refimpl = lambda x: x / abs(x) if x != 0 else 0 + unary_assert_against_refimpl( + "sign", + x, + out, + refimpl, + strict_check=True, + ) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_sin(x): - out = xp.sin(x) - ph.assert_dtype("sin", in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_shape("sin", out_shape=out.shape, expected=x.shape) - refimpl = cmath.sin if x.dtype in dh.complex_dtypes else math.sin - unary_assert_against_refimpl("sin", x, out, refimpl) + repro_snippet = ph.format_snippet(f"xp.sin({x!r})") + try: + out = xp.sin(x) + ph.assert_dtype("sin", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("sin", out_shape=out.shape, expected=x.shape) + refimpl = cmath.sin if x.dtype in dh.complex_dtypes else math.sin + unary_assert_against_refimpl("sin", x, out, refimpl) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_sinh(x): - out = xp.sinh(x) - ph.assert_dtype("sinh", in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_shape("sinh", out_shape=out.shape, expected=x.shape) - refimpl = cmath.sinh if x.dtype in dh.complex_dtypes else math.sinh - unary_assert_against_refimpl("sinh", x, out, refimpl) + repro_snippet = ph.format_snippet(f"xp.sinh({x!r})") + try: + out = xp.sinh(x) + ph.assert_dtype("sinh", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("sinh", out_shape=out.shape, expected=x.shape) + refimpl = cmath.sinh if x.dtype in dh.complex_dtypes else math.sinh + unary_assert_against_refimpl("sinh", x, out, refimpl) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @given(hh.arrays(dtype=hh.numeric_dtypes, shape=hh.shapes())) def test_square(x): - out = xp.square(x) - ph.assert_dtype("square", in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_shape("square", out_shape=out.shape, expected=x.shape) - unary_assert_against_refimpl( - "square", x, out, lambda s: s*s, expr_template="{}²={}" - ) + repro_snippet = ph.format_snippet(f"xp.square({x!r})") + try: + out = xp.square(x) + ph.assert_dtype("square", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("square", out_shape=out.shape, expected=x.shape) + unary_assert_against_refimpl( + "square", x, out, lambda s: s*s, expr_template="{}²={}" + ) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_sqrt(x): - out = xp.sqrt(x) - ph.assert_dtype("sqrt", in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_shape("sqrt", out_shape=out.shape, expected=x.shape) - refimpl = cmath.sqrt if x.dtype in dh.complex_dtypes else math.sqrt - filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and s >= 0 - unary_assert_against_refimpl( - "sqrt", x, out, refimpl, filter_=filter_ - ) + repro_snippet = ph.format_snippet(f"xp.sqrt({x!r})") + try: + out = xp.sqrt(x) + ph.assert_dtype("sqrt", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("sqrt", out_shape=out.shape, expected=x.shape) + refimpl = cmath.sqrt if x.dtype in dh.complex_dtypes else math.sqrt + filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and s >= 0 + unary_assert_against_refimpl( + "sqrt", x, out, refimpl, filter_=filter_ + ) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @pytest.mark.parametrize("ctx", make_binary_params("subtract", dh.numeric_dtypes)) @@ -1776,50 +2085,73 @@ def test_subtract(ctx, data): left = data.draw(ctx.left_strat, label=ctx.left_sym) right = data.draw(ctx.right_strat, label=ctx.right_sym) - with hh.reject_overflow(): - res = ctx.func(left, right) + repro_snippet = ph.format_snippet(f"{ctx.func_name}({left!r}, {right!r})") + try: + with hh.reject_overflow(): + res = ctx.func(left, right) - binary_param_assert_dtype(ctx, left, right, res) - binary_param_assert_shape(ctx, left, right, res) - binary_param_assert_against_refimpl(ctx, left, right, res, "-", operator.sub) + binary_param_assert_dtype(ctx, left, right, res) + binary_param_assert_shape(ctx, left, right, res) + binary_param_assert_against_refimpl(ctx, left, right, res, "-", operator.sub) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_tan(x): - out = xp.tan(x) - ph.assert_dtype("tan", in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_shape("tan", out_shape=out.shape, expected=x.shape) - refimpl = cmath.tan if x.dtype in dh.complex_dtypes else math.tan - unary_assert_against_refimpl("tan", x, out, refimpl) + repro_snippet = ph.format_snippet(f"xp.tan({x!r})") + try: + out = xp.tan(x) + ph.assert_dtype("tan", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("tan", out_shape=out.shape, expected=x.shape) + refimpl = cmath.tan if x.dtype in dh.complex_dtypes else math.tan + unary_assert_against_refimpl("tan", x, out, refimpl) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_tanh(x): - out = xp.tanh(x) - ph.assert_dtype("tanh", in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_shape("tanh", out_shape=out.shape, expected=x.shape) - refimpl = cmath.tanh if x.dtype in dh.complex_dtypes else math.tanh - unary_assert_against_refimpl("tanh", x, out, refimpl) - + repro_snippet = ph.format_snippet(f"xp.tanh({x!r})") + try: + out = xp.tanh(x) + ph.assert_dtype("tanh", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("tanh", out_shape=out.shape, expected=x.shape) + refimpl = cmath.tanh if x.dtype in dh.complex_dtypes else math.tanh + unary_assert_against_refimpl("tanh", x, out, refimpl) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @given(hh.arrays(dtype=hh.real_dtypes, shape=xps.array_shapes())) def test_trunc(x): - out = xp.trunc(x) - ph.assert_dtype("trunc", in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_shape("trunc", out_shape=out.shape, expected=x.shape) - unary_assert_against_refimpl("trunc", x, out, math.trunc, strict_check=True) - + repro_snippet = ph.format_snippet(f"xp.trunc({x!r})") + try: + out = xp.trunc(x) + ph.assert_dtype("trunc", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("trunc", out_shape=out.shape, expected=x.shape) + unary_assert_against_refimpl("trunc", x, out, math.trunc, strict_check=True) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise def _check_binary_with_scalars(func_data, x1x2): x1, x2 = x1x2 func_name, refimpl, kwds, expected_dtype = func_data func = getattr(xp, func_name) - out = func(x1, x2) - in_dtypes, in_shapes, (x1a, x2a) = _convert_scalars_helper(x1, x2) - _assert_correctness_binary( - func_name, refimpl, in_dtypes, in_shapes, (x1a, x2a), out, expected_dtype, **kwds - ) + repro_snippet = ph.format_snippet(f"xp.{func_name}({x1!r}, {x2!r})") + try: + out = func(x1, x2) + in_dtypes, in_shapes, (x1a, x2a) = _convert_scalars_helper(x1, x2) + _assert_correctness_binary( + func_name, refimpl, in_dtypes, in_shapes, (x1a, x2a), out, expected_dtype, **kwds + ) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise def _filter_zero(x): return x != 0 if dh.is_scalar(x) else (not xp.any(x == 0)) @@ -1940,16 +2272,19 @@ def test_where_with_scalars(x1x2, data): condition = data.draw(hh.arrays(shape=shape, dtype=xp.bool)) - out = xp.where(condition, x1, x2) - - assert out.dtype == dtype, f"where: got {out.dtype = } for {dtype=}, {x1=} and {x2=}" - assert out.shape == shape, f"where: got {out.shape = } for {shape=}, {x1=} and {x2=}" - - # value test - for idx in sh.ndindex(shape): - if condition[idx]: - assert out[idx] == x1_arr[idx] - else: - assert out[idx] == x2_arr[idx] + repro_snippet = ph.format_snippet(f"xp.where({condition!r}, {x1!r}, {x2!r})") + try: + out = xp.where(condition, x1, x2) + assert out.dtype == dtype, f"where: got {out.dtype = } for {dtype=}, {x1=} and {x2=}" + assert out.shape == shape, f"where: got {out.shape = } for {shape=}, {x1=} and {x2=}" + # value test + for idx in sh.ndindex(shape): + if condition[idx]: + assert out[idx] == x1_arr[idx] + else: + assert out[idx] == x2_arr[idx] + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise From 2835f031b37b2bc2b6b93f9046bd401716380cc8 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Fri, 9 Jan 2026 16:02:36 +0100 Subject: [PATCH 14/30] MAINT: ignore a whitespace heavy trivial commit --- .git-blame-ignore-revs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs index eaeedfd4..44aa89dc 100644 --- a/.git-blame-ignore-revs +++ b/.git-blame-ignore-revs @@ -11,4 +11,6 @@ a37f348ba27b6818e92fda8aee2406c653c671ea ec5a3b4e185c262b0a5f5b1631b84a09f766d80e 9058908b58ce627467ac34e768098a25f5863d31 c80e1823c2e738381ca02f27cea1e2b89dde0ac5 +# gh-402 +bdc84e8316046cb5bdc637067460057eef17d0f1 From e8d11653896ec3fc255b9a1c1cd40c0ac181a6b5 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sat, 10 Jan 2026 21:30:50 +0100 Subject: [PATCH 15/30] ENH: expand testing of meshgrid: draw `indexing`, check shapes --- array_api_tests/test_creation_functions.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/array_api_tests/test_creation_functions.py b/array_api_tests/test_creation_functions.py index c55b2da4..568dd7ce 100644 --- a/array_api_tests/test_creation_functions.py +++ b/array_api_tests/test_creation_functions.py @@ -538,8 +538,12 @@ def test_linspace(num, dtype, endpoint, data): raise -@given(dtype=hh.numeric_dtypes, data=st.data()) -def test_meshgrid(dtype, data): +@given( + dtype=hh.numeric_dtypes, + kw=hh.kwargs(indexing=st.sampled_from(["xy", "ij"])), + data=st.data() +) +def test_meshgrid(dtype, kw, data): # The number and size of generated arrays is arbitrarily limited to prevent # meshgrid() running out of memory. shapes = data.draw( @@ -557,11 +561,17 @@ def test_meshgrid(dtype, data): # sanity check # assert math.prod(math.prod(x.shape) for x in arrays) <= hh.MAX_ARRAY_SIZE - repro_snippet = ph.format_snippet(f"xp.meshgrid(*arrays) with {arrays = }") + tgt_shape = [a.shape[0] for a in arrays] + if len(tgt_shape) > 1 and kw.get('indexing', 'xy') == 'xy': + tgt_shape[0], tgt_shape[1] = tgt_shape[1], tgt_shape[0] + tgt_shape = tuple(tgt_shape) + + repro_snippet = ph.format_snippet(f"xp.meshgrid(*arrays, **kw) with {arrays = } and {kw = }") try: - out = xp.meshgrid(*arrays) + out = xp.meshgrid(*arrays, **kw) for i, x in enumerate(out): ph.assert_dtype("meshgrid", in_dtype=dtype, out_dtype=x.dtype, repr_name=f"out[{i}].dtype") + ph.assert_shape("meshgrid", out_shape=x.shape, expected=tgt_shape) except Exception as exc: ph.add_note(exc, repro_snippet) raise From f24e54862e477445682e0008320ce664ff33d4f3 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 12 Jan 2026 14:07:32 +0100 Subject: [PATCH 16/30] TST: test_{r,}fftfreq dtype argument --- array_api_tests/test_fft.py | 43 +++++++++++++++++++++++++++++-------- 1 file changed, 34 insertions(+), 9 deletions(-) diff --git a/array_api_tests/test_fft.py b/array_api_tests/test_fft.py index 358a8eef..900a3d57 100644 --- a/array_api_tests/test_fft.py +++ b/array_api_tests/test_fft.py @@ -279,18 +279,43 @@ def test_ihfft(x, data): ph.assert_shape("ihfft", out_shape=out.shape, expected=expected_shape) -@given(n=st.integers(1, 100), kw=hh.kwargs(d=st.floats(0.1, 5))) +@given( + n=st.integers(1, 100), + kw=hh.kwargs(d=st.floats(0.1, 5), dtype=hh.real_floating_dtypes), +) def test_fftfreq(n, kw): - out = xp.fft.fftfreq(n, **kw) - ph.assert_shape("fftfreq", out_shape=out.shape, expected=(n,), kw={"n": n}) - + repro_snippet = ph.format_snippet(f"xp.fft.fftfreq({n!r}, **kw) with {kw = }") + try: + out = xp.fft.fftfreq(n, **kw) + ph.assert_shape("fftfreq", out_shape=out.shape, expected=(n,), kw={"n": n}) + + dt = kw.get("dtype", None) + if dt is None: + dt = xp.__array_namespace_info__().default_dtypes()["real floating"] + assert out.dtype == dt + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise -@given(n=st.integers(1, 100), kw=hh.kwargs(d=st.floats(0.1, 5))) +@given( + n=st.integers(1, 100), + kw=hh.kwargs(d=st.floats(0.1, 5), dtype=hh.real_floating_dtypes) +) def test_rfftfreq(n, kw): - out = xp.fft.rfftfreq(n, **kw) - ph.assert_shape( - "rfftfreq", out_shape=out.shape, expected=(n // 2 + 1,), kw={"n": n} - ) + repro_snippet = ph.format_snippet(f"xp.fft.rfftfreq({n!r}, **kw) with {kw = }") + try: + out = xp.fft.rfftfreq(n, **kw) + ph.assert_shape( + "rfftfreq", out_shape=out.shape, expected=(n // 2 + 1,), kw={"n": n} + ) + + dt = kw.get("dtype", None) + if dt is None: + dt = xp.__array_namespace_info__().default_dtypes()["real floating"] + assert out.dtype == dt + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @pytest.mark.parametrize("func_name", ["fftshift", "ifftshift"]) From 63317c68dc45089cceb2f181f3dbbc2ab012bc3b Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 15 Jan 2026 20:01:34 +0000 Subject: [PATCH 17/30] ENH: parse complex special cases from stubs of unary functions "stub" docstrings include "special cases": > `if x_i is +infinity, sqrt(x_i) is +infinity` etc `test_special_cases` parses these statements from docstrings and makes them into tests cases. So far, parsing only worked for real value cases, and failed for complex-valued cases: > For complex floating-point operands, let a = real(x_i), b = imag(x_i), and > `If a is either +0 or -0 and b is +0, the result is +0 + 0j.` These stanzas simply generate "case for {func} is not machine-readable" `UserWarning`s. Quite a wall of them. Therefore, we update parsing and testing code to take these complex-valued cases into accout. For now, we only consider unary functions. The effect is: $ ARRAY_API_TESTS_MODULE=array_api_compat.torch pytest array_api_tests/test_special_cases.py::test_unary generates - "128 passed, 177 warnings in 0.78s" on master - "11 failed, 241 passed, 49 warnings in 1.82s" on this branch So that there are new failures (from new complex-valued cases) but we 128 less warnings. --- array_api_tests/test_special_cases.py | 289 +++++++++++++++++++++++++- 1 file changed, 281 insertions(+), 8 deletions(-) diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index bf05a262..0db3eb29 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -15,6 +15,7 @@ import inspect import math import operator +import os import re from dataclasses import dataclass, field from decimal import ROUND_HALF_EVEN, Decimal @@ -99,7 +100,7 @@ def or_(i: float) -> bool: def make_and(cond1: UnaryCheck, cond2: UnaryCheck) -> UnaryCheck: def and_(i: float) -> bool: - return cond1(i) or cond2(i) + return cond1(i) and cond2(i) return and_ @@ -492,6 +493,179 @@ def check_result(result: float) -> bool: return check_result, expr +def parse_complex_value(value_str: str) -> complex: + """ + Parses a complex value string to return a complex number, e.g. + + >>> parse_complex_value('+0 + 0j') + 0j + >>> parse_complex_value('NaN + NaN j') + (nan+nanj) + >>> parse_complex_value('0 + NaN j') + nanj + >>> parse_complex_value('+0 + πj/2') + 1.5707963267948966j + >>> parse_complex_value('+infinity + 3πj/4') + (inf+2.356194490192345j) + + Handles formats: "A + Bj", "A + B j", "A + πj/N", "A + Nπj/M" + """ + m = r_complex_value.match(value_str) + if m is None: + raise ParseError(value_str) + + # Parse real part with its sign + real_sign = m.group(1) if m.group(1) else "+" + real_val_str = m.group(2) + real_val = parse_value(real_sign + real_val_str) + + # Parse imaginary part with its sign + imag_sign = m.group(3) + # Group 4 is πj form (e.g., "πj/2"), group 5 is plain form (e.g., "NaN") + if m.group(4): # πj form + imag_val_str_raw = m.group(4) + # Remove 'j' to get coefficient: "πj/2" -> "π/2" + imag_val_str = imag_val_str_raw.replace('j', '') + else: # plain form + imag_val_str_raw = m.group(5) + # Strip trailing 'j' if present: "0j" -> "0" + imag_val_str = imag_val_str_raw[:-1] if imag_val_str_raw.endswith('j') else imag_val_str_raw + + imag_val = parse_value(imag_sign + imag_val_str) + + return complex(real_val, imag_val) + + +def make_strict_eq_complex(v: complex) -> Callable[[complex], bool]: + """ + Creates a checker for complex values that respects sign of zero and NaN. + """ + real_check = make_strict_eq(v.real) + imag_check = make_strict_eq(v.imag) + + def strict_eq_complex(z: complex) -> bool: + return real_check(z.real) and imag_check(z.imag) + + return strict_eq_complex + + +def parse_complex_cond( + a_cond_str: str, b_cond_str: str +) -> Tuple[Callable[[complex], bool], str, FromDtypeFunc]: + """ + Parses complex condition strings for real (a) and imaginary (b) parts. + + Returns: + - cond: Function that checks if a complex number meets the condition + - expr: String expression for the condition + - from_dtype: Strategy generator for complex numbers meeting the condition + """ + # Parse conditions for real and imaginary parts separately + a_cond, a_expr_template, a_from_dtype = parse_cond(a_cond_str) + b_cond, b_expr_template, b_from_dtype = parse_cond(b_cond_str) + + # Create compound condition + def complex_cond(z: complex) -> bool: + return a_cond(z.real) and b_cond(z.imag) + + # Create expression + a_expr = a_expr_template.replace("{}", "real(x_i)") + b_expr = b_expr_template.replace("{}", "imag(x_i)") + expr = f"{a_expr} and {b_expr}" + + # Create strategy that generates complex numbers + def complex_from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[complex]: + assert len(kw) == 0 # sanity check + # For complex dtype, we need to get the corresponding float dtype + # complex64 -> float32, complex128 -> float64 + if hasattr(dtype, 'name'): + if 'complex64' in str(dtype): + float_dtype = xp.float32 + elif 'complex128' in str(dtype): + float_dtype = xp.float64 + else: + # Fallback to float64 + float_dtype = xp.float64 + else: + float_dtype = xp.float64 + + real_strat = a_from_dtype(float_dtype) + imag_strat = b_from_dtype(float_dtype) + return st.builds(complex, real_strat, imag_strat) + + return complex_cond, expr, complex_from_dtype + + +def _check_component_with_tolerance(actual: float, expected: float, allow_any_sign: bool) -> bool: + """ + Helper to check if actual matches expected, with optional sign flexibility and tolerance. + """ + if allow_any_sign and not math.isnan(expected): + return abs(actual) == abs(expected) or math.isclose(abs(actual), abs(expected), abs_tol=0.01) + elif not math.isnan(expected): + check_fn = make_strict_eq(expected) if expected == 0 or math.isinf(expected) else make_rough_eq(expected) + return check_fn(actual) + else: + return math.isnan(actual) + + +def parse_complex_result(result_str: str) -> Tuple[Callable[[complex], bool], str]: + """ + Parses a complex result string to return a checker and expression. + + Handles cases like: + - "``+0 + 0j``" - exact complex value + - "``0 + NaN j`` (sign of the real component is unspecified)" + - "``+0 + πj/2``" - with π expressions (uses approximate equality) + """ + # Check for unspecified sign notes + unspecified_real_sign = "sign of the real component is unspecified" in result_str + unspecified_imag_sign = "sign of the imaginary component is unspecified" in result_str + + # Extract the complex value from backticks - need to handle spaces in complex values + # Pattern: ``...`` where ... can contain spaces (for complex values like "0 + NaN j") + m = re.search(r"``([^`]+)``", result_str) + if m: + value_str = m.group(1) + # Check if the value contains π expressions (for approximate comparison) + has_pi = 'π' in value_str + + try: + expected = parse_complex_value(value_str) + except ParseError: + raise ParseError(result_str) + + # Create checker based on whether signs are unspecified and whether π is involved + if has_pi: + # Use approximate equality for both real and imaginary parts if they involve π + def check_result(z: complex) -> bool: + real_match = _check_component_with_tolerance(z.real, expected.real, unspecified_real_sign) + imag_match = _check_component_with_tolerance(z.imag, expected.imag, unspecified_imag_sign) + return real_match and imag_match + elif unspecified_real_sign and not math.isnan(expected.real): + # Allow any sign for real part + def check_result(z: complex) -> bool: + imag_check = make_strict_eq(expected.imag) + return abs(z.real) == abs(expected.real) and imag_check(z.imag) + elif unspecified_imag_sign and not math.isnan(expected.imag): + # Allow any sign for imaginary part + def check_result(z: complex) -> bool: + real_check = make_strict_eq(expected.real) + return real_check(z.real) and abs(z.imag) == abs(expected.imag) + elif unspecified_real_sign and unspecified_imag_sign: + # Allow any sign for both parts + def check_result(z: complex) -> bool: + return abs(z.real) == abs(expected.real) and abs(z.imag) == abs(expected.imag) + else: + # Exact match including signs + check_result = make_strict_eq_complex(expected) + + expr = value_str + return check_result, expr + else: + raise ParseError(result_str) + + class Case(Protocol): cond_expr: str result_expr: str @@ -535,6 +709,7 @@ class UnaryCase(Case): cond: UnaryCheck check_result: UnaryResultCheck raw_case: Optional[str] = field(default=None) + is_complex: bool = field(default=False) r_unary_case = re.compile("If ``x_i`` is (.+), the result is (.+)") @@ -549,6 +724,16 @@ class UnaryCase(Case): "If ``x_i`` is ``NaN`` and the sign bit of ``x_i`` is ``(.+)``, " "the result is ``(.+)``" ) +# Regex patterns for complex special cases +r_complex_marker = re.compile( + r"For complex floating-point operands, let ``a = real\(x_i\)``, ``b = imag\(x_i\)``" +) +r_complex_case = re.compile(r"If ``a`` is (.+) and ``b`` is (.+), the result is (.+)") +# Matches complex values like "+0 + 0j", "NaN + NaN j", "infinity + NaN j", "πj/2", "3πj/4" +# Two formats: 1) πj/N expressions where j is part of the coefficient, 2) plain values followed by j +r_complex_value = re.compile( + r"([+-]?)([^\s]+)\s*([+-])\s*(?:(\d*πj(?:/\d+)?)|([^\s]+))\s*j?" +) def integers_from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]: @@ -630,7 +815,15 @@ def check_result(i: float, result: float) -> bool: return check_result -def parse_unary_case_block(case_block: str, func_name: str) -> List[UnaryCase]: +def make_complex_unary_check_result(check_fn: Callable[[complex], bool]) -> UnaryResultCheck: + """Wraps a complex check function for use in UnaryCase.""" + def check_result(in_value, out_value): + # in_value is complex, out_value is complex + return check_fn(out_value) + return check_result + + +def parse_unary_case_block(case_block: str, func_name: str, record_list: Optional[List[str]] = None) -> List[UnaryCase]: """ Parses a Sphinx-formatted docstring of a unary function to return a list of codified unary cases, e.g. @@ -677,8 +870,52 @@ def parse_unary_case_block(case_block: str, func_name: str) -> List[UnaryCase]: """ cases = [] + # Check if the case block contains complex cases by looking for the marker + in_complex_section = r_complex_marker.search(case_block) is not None + for case_m in r_case.finditer(case_block): case_str = case_m.group(1) + + # Record this special case if a record list is provided + if record_list is not None: + record_list.append(f"{func_name}: {case_str}.") + + + # Try to parse complex cases if we're in the complex section + if in_complex_section and (m := r_complex_case.search(case_str)): + try: + a_cond_str = m.group(1) + b_cond_str = m.group(2) + result_str = m.group(3) + + # Skip cases with complex expressions like "cis(b)" + if "cis" in result_str or "*" in result_str: + warn(f"case for {func_name} not machine-readable: '{case_str}'") + continue + + # Parse the complex condition and result + complex_cond, cond_expr, complex_from_dtype = parse_complex_cond( + a_cond_str, b_cond_str + ) + _check_result, result_expr = parse_complex_result(result_str) + + check_result = make_complex_unary_check_result(_check_result) + + case = UnaryCase( + cond_expr=cond_expr, + cond=complex_cond, + cond_from_dtype=complex_from_dtype, + result_expr=result_expr, + check_result=check_result, + raw_case=case_str, + is_complex=True, + ) + cases.append(case) + except ParseError as e: + warn(f"case for {func_name} not machine-readable: '{e.value}'") + continue + + # Parse regular (real-valued) cases if r_already_int_case.search(case_str): cases.append(already_int_case) elif r_even_round_halves_case.search(case_str): @@ -1103,7 +1340,7 @@ def cond(i1: float, i2: float) -> bool: r_redundant_case = re.compile("result.+determined by the rule already stated above") -def parse_binary_case_block(case_block: str, func_name: str) -> List[BinaryCase]: +def parse_binary_case_block(case_block: str, func_name: str, record_list: Optional[List[str]] = None) -> List[BinaryCase]: """ Parses a Sphinx-formatted docstring of a binary function to return a list of codified binary cases, e.g. @@ -1145,6 +1382,11 @@ def parse_binary_case_block(case_block: str, func_name: str) -> List[BinaryCase] cases = [] for case_m in r_case.finditer(case_block): case_str = case_m.group(1) + + # Record this special case if a record list is provided + if record_list is not None: + record_list.append(f"{func_name}: {case_str}.") + if r_redundant_case.search(case_str): continue if r_binary_case.match(case_str): @@ -1162,6 +1404,7 @@ def parse_binary_case_block(case_block: str, func_name: str) -> List[BinaryCase] unary_params = [] binary_params = [] iop_params = [] +special_case_records = [] # List of "func_name: case_str" for all special cases func_to_op: Dict[str, str] = {v: k for k, v in dh.op_to_func.items()} for stub in category_to_funcs["elementwise"]: func_name = stub.__name__ @@ -1186,7 +1429,7 @@ def parse_binary_case_block(case_block: str, func_name: str) -> List[BinaryCase] warn(f"{func=} has no parameters") continue if param_names[0] == "x": - if cases := parse_unary_case_block(case_block, func_name): + if cases := parse_unary_case_block(case_block, func_name, special_case_records): name_to_func = {func_name: func} if func_name in func_to_op.keys(): op_name = func_to_op[func_name] @@ -1204,7 +1447,7 @@ def parse_binary_case_block(case_block: str, func_name: str) -> List[BinaryCase] warn(f"{func=} has one parameter '{param_names[0]}' which is not named 'x'") continue if param_names[0] == "x1" and param_names[1] == "x2": - if cases := parse_binary_case_block(case_block, func_name): + if cases := parse_binary_case_block(case_block, func_name, special_case_records): name_to_func = {func_name: func} if func_name in func_to_op.keys(): op_name = func_to_op[func_name] @@ -1249,6 +1492,22 @@ def parse_binary_case_block(case_block: str, func_name: str) -> List[BinaryCase] assert len(iop_params) != 0 +@pytest.fixture(scope="session", autouse=True) +def emit_special_case_records(): + """Emit all special case records at the start of test session.""" + # This runs once at the beginning of the test session + if os.environ.get('ARRAY_API_TESTS_SPECIAL_CASES_VERBOSE') == '1': + print("\n" + "="*80) + print("SPECIAL CASE RECORDS") + print("="*80) + for record in special_case_records: + print(record) + print("="*80) + print(f"Total special cases: {len(special_case_records)}") + print("="*80 + "\n") + yield # Tests run after this point + + @pytest.mark.parametrize("func_name, func, case", unary_params) def test_unary(func_name, func, case): with catch_warnings(): @@ -1257,10 +1516,24 @@ def test_unary(func_name, func, case): # drawing multiple examples like a normal test, or just hard-coding a # single example test case without using hypothesis. filterwarnings('ignore', category=NonInteractiveExampleWarning) - in_value = case.cond_from_dtype(xp.float64).example() - x = xp.asarray(in_value, dtype=xp.float64) + + # Use the is_complex flag to determine the appropriate dtype + if case.is_complex: + dtype = xp.complex128 + in_value = case.cond_from_dtype(dtype).example() + else: + dtype = xp.float64 + in_value = case.cond_from_dtype(dtype).example() + + # Create array and compute result based on dtype + x = xp.asarray(in_value, dtype=dtype) out = func(x) - out_value = float(out) + + if case.is_complex: + out_value = complex(out) + else: + out_value = float(out) + assert case.check_result(in_value, out_value), ( f"out={out_value}, but should be {case.result_expr} [{func_name}()]\n" ) From eae7d2ab485fb5e83364e21eb44567ee0e5b8c29 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Fri, 16 Jan 2026 20:22:19 +0100 Subject: [PATCH 18/30] TST: update array-api-strict-skips.txt for the complex cases --- array-api-strict-skips.txt | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/array-api-strict-skips.txt b/array-api-strict-skips.txt index afc1b845..10a55fdd 100644 --- a/array-api-strict-skips.txt +++ b/array-api-strict-skips.txt @@ -32,3 +32,18 @@ array_api_tests/test_data_type_functions.py::test_finfo array_api_tests/test_data_type_functions.py::test_finfo_dtype array_api_tests/test_data_type_functions.py::test_iinfo array_api_tests/test_data_type_functions.py::test_iinfo_dtype + + +# complex special cases which failed "forever" +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is +infinity and imag(x_i) is +0) -> +infinity + 0j] +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is -infinity and imag(x_i) is +infinity) -> -1 + 0j] +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is +infinity and imag(x_i) is +infinity) -> infinity + NaN j] +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is -infinity and imag(x_i) is NaN) -> -1 + 0j] +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is +infinity and imag(x_i) is NaN) -> infinity + NaN j] +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is NaN and imag(x_i) is +0) -> NaN + 0j] +array_api_tests/test_special_cases.py::test_unary[expm1((real(x_i) is +0 or real(x_i) == -0) and imag(x_i) is +0) -> 0 + 0j] + +array_api_tests/test_special_cases.py::test_unary[sign((real(x_i) is -0 or real(x_i) == +0) and (imag(x_i) is -0 or imag(x_i) == +0)) -> 0 + 0j] +array_api_tests/test_special_cases.py::test_unary[tanh(real(x_i) is +infinity and isfinite(imag(x_i)) and imag(x_i) > 0) -> 1 + 0j] + +array_api_tests/test_special_cases.py::test_unary[sqrt(real(x_i) is +infinity and isfinite(imag(x_i)) and imag(x_i) > 0) -> +0 + infinity j] From add6964cbdccae68e4ae2b2452811bfb4bbcf33b Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sat, 17 Jan 2026 11:12:00 +0100 Subject: [PATCH 19/30] MAINT: cleaner float<->complex mappings --- array_api_tests/dtype_helpers.py | 33 +++++++++++++++++++++++++++ array_api_tests/test_special_cases.py | 13 ++--------- 2 files changed, 35 insertions(+), 11 deletions(-) diff --git a/array_api_tests/dtype_helpers.py b/array_api_tests/dtype_helpers.py index f7fa306b..2fe3c1b2 100644 --- a/array_api_tests/dtype_helpers.py +++ b/array_api_tests/dtype_helpers.py @@ -199,6 +199,39 @@ def is_scalar(x): return isinstance(x, (int, float, complex, bool)) +def complex_dtype_for(dtyp): + """Complex dtype for a float or complex.""" + if dtyp in complex_dtypes: + return dtyp + if dtyp not in real_float_dtypes: + raise ValueError(f"no complex dtype to match {dtyp}") + + real_name = dtype_to_name[dtyp] + complex_name = {"float32": "complex64", "float64": "complex128"}[real_name] + + complex_dtype = _name_to_dtype.get(complex_name, None) + if complex_dtype is None: + raise ValueError(f"no complex dtype to match {dtyp}") + return complex_dtype + + +def real_dtype_for(dtyp): + """Real float dtype for a float or complex.""" + if dtyp in real_float_dtypes: + return dtyp + if dtyp not in complex_dtypes: + raise ValueError(f"no real float dtype to match {dtyp}") + + complex_name = dtype_to_name[dtyp] + real_name = {"complex64": "float32", "complex128": "float64"}[complex_name] + + real_dtype = _name_to_dtype.get(real_name, None) + if real_dtype is None: + raise ValueError(f"no real dtype to match {dtyp}") + return real_dtype + + + def _make_dtype_mapping_from_names(mapping: Dict[str, Any]) -> EqualityMapping: dtype_value_pairs = [] for name, value in mapping.items(): diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index 0db3eb29..e05fd02d 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -578,17 +578,8 @@ def complex_from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[complex]: assert len(kw) == 0 # sanity check # For complex dtype, we need to get the corresponding float dtype # complex64 -> float32, complex128 -> float64 - if hasattr(dtype, 'name'): - if 'complex64' in str(dtype): - float_dtype = xp.float32 - elif 'complex128' in str(dtype): - float_dtype = xp.float64 - else: - # Fallback to float64 - float_dtype = xp.float64 - else: - float_dtype = xp.float64 - + float_dtype = dh.real_dtype_for(dtype) + real_strat = a_from_dtype(float_dtype) imag_strat = b_from_dtype(float_dtype) return st.builds(complex, real_strat, imag_strat) From 29733fb962a9a9e11238fc8e2210365873a1f44d Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 19 Jan 2026 12:05:41 +0000 Subject: [PATCH 20/30] =?UTF-8?q?ENH:=20Implement=20=C2=B1=20symbol=20hand?= =?UTF-8?q?ling=20in=20complex=20value=20parsing?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Update r_complex_value regex to match ± in sign positions - Normalize ± to + in parse_complex_value() function - Detect ± symbols in parse_complex_result() to set unspecified sign flags - Results: 246 passed (+4), 45 warnings (-4) Handle spaces after ± symbol in complex value parsing Update regex to allow optional whitespace between sign and value: ([±+-]?)\s* This now handles cases like "± 0 + 0j" in addition to "±0 + 0j" --- array_api_tests/test_special_cases.py | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index e05fd02d..a919fe08 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -515,12 +515,18 @@ def parse_complex_value(value_str: str) -> complex: raise ParseError(value_str) # Parse real part with its sign + # Normalize ± to + (we choose positive arbitrarily since sign is unspecified) real_sign = m.group(1) if m.group(1) else "+" + if '±' in real_sign: + real_sign = '+' real_val_str = m.group(2) real_val = parse_value(real_sign + real_val_str) # Parse imaginary part with its sign + # Normalize ± to + for imaginary part as well imag_sign = m.group(3) + if '±' in imag_sign: + imag_sign = '+' # Group 4 is πj form (e.g., "πj/2"), group 5 is plain form (e.g., "NaN") if m.group(4): # πj form imag_val_str_raw = m.group(4) @@ -609,7 +615,7 @@ def parse_complex_result(result_str: str) -> Tuple[Callable[[complex], bool], st - "``0 + NaN j`` (sign of the real component is unspecified)" - "``+0 + πj/2``" - with π expressions (uses approximate equality) """ - # Check for unspecified sign notes + # Check for unspecified sign notes (text-based detection) unspecified_real_sign = "sign of the real component is unspecified" in result_str unspecified_imag_sign = "sign of the imaginary component is unspecified" in result_str @@ -618,6 +624,20 @@ def parse_complex_result(result_str: str) -> Tuple[Callable[[complex], bool], st m = re.search(r"``([^`]+)``", result_str) if m: value_str = m.group(1) + + # Check for ± symbols in the value string (symbol-based detection) + # This works in addition to the text-based detection above + if '±' in value_str: + # Parse the value to determine which component has ± + m_val = r_complex_value.match(value_str) + if m_val: + # Check if real part has ± + if m_val.group(1) and '±' in m_val.group(1): + unspecified_real_sign = True + # Check if imaginary part has ± + if m_val.group(3) and '±' in m_val.group(3): + unspecified_imag_sign = True + # Check if the value contains π expressions (for approximate comparison) has_pi = 'π' in value_str @@ -722,8 +742,9 @@ class UnaryCase(Case): r_complex_case = re.compile(r"If ``a`` is (.+) and ``b`` is (.+), the result is (.+)") # Matches complex values like "+0 + 0j", "NaN + NaN j", "infinity + NaN j", "πj/2", "3πj/4" # Two formats: 1) πj/N expressions where j is part of the coefficient, 2) plain values followed by j +# Also handles ± symbol for unspecified signs (with or without spaces after the sign) r_complex_value = re.compile( - r"([+-]?)([^\s]+)\s*([+-])\s*(?:(\d*πj(?:/\d+)?)|([^\s]+))\s*j?" + r"([±+-]?)\s*([^\s]+)\s*([±+-])\s*(?:(\d*πj(?:/\d+)?)|([^\s]+))\s*j?" ) From 45e566d781acb9fead36d7bc639d522fb958002e Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 19 Jan 2026 13:34:57 +0100 Subject: [PATCH 21/30] TST: update array-api-strict-skips.txt with an acosh case --- array-api-strict-skips.txt | 3 +++ 1 file changed, 3 insertions(+) diff --git a/array-api-strict-skips.txt b/array-api-strict-skips.txt index 10a55fdd..c2a2f902 100644 --- a/array-api-strict-skips.txt +++ b/array-api-strict-skips.txt @@ -46,4 +46,7 @@ array_api_tests/test_special_cases.py::test_unary[expm1((real(x_i) is +0 or real array_api_tests/test_special_cases.py::test_unary[sign((real(x_i) is -0 or real(x_i) == +0) and (imag(x_i) is -0 or imag(x_i) == +0)) -> 0 + 0j] array_api_tests/test_special_cases.py::test_unary[tanh(real(x_i) is +infinity and isfinite(imag(x_i)) and imag(x_i) > 0) -> 1 + 0j] +# this acosh failure is only seen with python==3.10 and numpy==2.2.6, and not e.g. python 3.12 & numpy 2.4.1 +array_api_tests/test_special_cases.py::test_unary[acosh(real(x_i) is +0 and imag(x_i) is NaN) -> NaN \xb1 \u03c0j/2] + array_api_tests/test_special_cases.py::test_unary[sqrt(real(x_i) is +infinity and isfinite(imag(x_i)) and imag(x_i) > 0) -> +0 + infinity j] From 48879e517cf5f2ea14b483d1fd2811050587ed8b Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sat, 17 Jan 2026 20:11:21 +0100 Subject: [PATCH 22/30] MAINT: clean up scalars/array_and_py_scalar strategies Get rid of ad hoc mM and positive arguments. Use the same `min_value`, `max_value` kwargs for scalars and arrays of varying types/dtypes. --- array_api_tests/hypothesis_helpers.py | 32 +++++++++++-------- ...est_operators_and_elementwise_functions.py | 2 +- 2 files changed, 19 insertions(+), 15 deletions(-) diff --git a/array_api_tests/hypothesis_helpers.py b/array_api_tests/hypothesis_helpers.py index e1df108c..0da2da8d 100644 --- a/array_api_tests/hypothesis_helpers.py +++ b/array_api_tests/hypothesis_helpers.py @@ -457,13 +457,12 @@ def scalars(draw, dtypes, finite=False, **kwds): dtypes should be one of the shared_* dtypes strategies. """ dtype = draw(dtypes) - mM = kwds.pop('mM', None) if dh.is_int_dtype(dtype): - if mM is None: - m, M = dh.dtype_ranges[dtype] - else: - m, M = mM - return draw(integers(m, M)) + m, M = dh.dtype_ranges[dtype] + min_value = kwds.get('min_value', m) + max_value = kwds.get('max_value', M) + + return draw(integers(min_value, max_value)) elif dtype == bool_dtype: return draw(booleans()) elif dtype == float64: @@ -593,20 +592,25 @@ def two_mutual_arrays( @composite -def array_and_py_scalar(draw, dtypes, mM=None, positive=False): +def array_and_py_scalar(draw, dtypes, **kwds): """Draw a pair: (array, scalar) or (scalar, array).""" dtype = draw(sampled_from(dtypes)) - scalar_var = draw(scalars(just(dtype), finite=True, mM=mM)) - if positive: - assume (scalar_var > 0) + scalar_var = draw(scalars(just(dtype), finite=True, **kwds)) elements={} if dtype in dh.real_float_dtypes: - elements = {'allow_nan': False, 'allow_infinity': False, - 'min_value': 1.0 / (2<<5), 'max_value': 2<<5} - if positive: - elements = {'min_value': 0} + elements = { + 'allow_nan': False, + 'allow_infinity': False, + 'min_value': kwds.get('min_value', 1.0 / (2<<5)), + 'max_value': kwds.get('max_value', 2<<5) + } + elif dtype in dh.int_dtypes: + elements = { + 'min_value': kwds.get('min_value', None), + 'max_value': kwds.get('max_value', None) + } array_var = draw(arrays(dtype, shape=shapes(min_dims=1), elements=elements)) if draw(booleans()): diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index f074bd8e..c19fe3cd 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -2246,7 +2246,7 @@ def test_binary_with_scalars_bitwise(func_data, x1x2): ], ids=lambda func_data: func_data[0] # use names for test IDs ) -@given(x1x2=hh.array_and_py_scalar([xp.int32], positive=True, mM=(1, 3))) +@given(x1x2=hh.array_and_py_scalar([xp.int32], min_value=1, max_value=3)) def test_binary_with_scalars_bitwise_shifts(func_data, x1x2): func_name, refimpl, kwargs, expected = func_data # repack the refimpl From 05e89f345cb06a6937819fc9421ee5bf318182da Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sat, 17 Jan 2026 20:46:18 +0100 Subject: [PATCH 23/30] ENH: test func(float_array, int_scalar) Previously, we only tested the matching type/dtype combinations, i.e. - float_array, float_scalar - int_array, int_scalar - bool_array, bool_scalar --- array_api_tests/hypothesis_helpers.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/array_api_tests/hypothesis_helpers.py b/array_api_tests/hypothesis_helpers.py index 0da2da8d..25d68324 100644 --- a/array_api_tests/hypothesis_helpers.py +++ b/array_api_tests/hypothesis_helpers.py @@ -596,8 +596,15 @@ def array_and_py_scalar(draw, dtypes, **kwds): """Draw a pair: (array, scalar) or (scalar, array).""" dtype = draw(sampled_from(dtypes)) - scalar_var = draw(scalars(just(dtype), finite=True, **kwds)) + # draw the scalar: for float arrays, draw a float or an int + if dtype in dh.real_float_dtypes: + scalar_strategy = sampled_from([xp.int32, dtype]) + else: + scalar_strategy = just(dtype) + scalar_var = draw(scalars(scalar_strategy, finite=True, **kwds)) + # draw the array. + # XXX artificially limit the range of values for floats, otherwise value testing is flaky elements={} if dtype in dh.real_float_dtypes: elements = { From 54e58620b5d5ec7689bace11196a5668a20f5912 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Fri, 21 Nov 2025 14:05:46 +0100 Subject: [PATCH 24/30] ENH: test expand_dims with tuple axis cf data-apis#760 for discussion We test here that expand_dims with multiple axes is equivalent to expanding axes one by one---the key is that the axes to add need to be pre-sorted. --- .../test_manipulation_functions.py | 36 +++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py index 3af7b959..6954e5f9 100644 --- a/array_api_tests/test_manipulation_functions.py +++ b/array_api_tests/test_manipulation_functions.py @@ -156,6 +156,42 @@ def test_expand_dims(x, axis): raise +@given( + x=hh.arrays(dtype=hh.all_dtypes, shape=shared_shapes(max_dims=4)), + axes=shared_shapes().flatmap( + lambda s: st.lists( + st.integers(2*(-len(s)-1), 2*len(s)), + min_size=0 if len(s)==0 else 1, + max_size=len(s) + ).map(tuple) + ) +) +def test_expand_dims_tuples(x, axes): + # normalize the axes + y_ndim = x.ndim + len(axes) + n_axes = tuple(ax + y_ndim if ax < 0 else ax for ax in axes) + unique_axes = set(n_axes) + + if any(ax < 0 or ax >= y_ndim for ax in n_axes) or len(n_axes) != len(unique_axes): + with pytest.raises((IndexError, ValueError)): + xp.expand_dims(x, axis=axes) + return + + repro_snippet = ph.format_snippet(f"xp.expand_dims({x!r}, axis={axes!r})") + try: + y = xp.expand_dims(x, axis=axes) + + ye = x + for ax in sorted(n_axes): + ye = xp.expand_dims(ye, axis=ax) + assert y.shape == ye.shape + # TODO value tests; check that y.shape is 1s and items from x.shape, in order + + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise + + @pytest.mark.min_version("2023.12") @given(x=hh.arrays(dtype=hh.all_dtypes, shape=hh.shapes(min_dims=1)), data=st.data()) def test_moveaxis(x, data): From d597b5c7d4ee0d752c399d683372c5b293da5923 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sat, 7 Feb 2026 15:27:12 +0100 Subject: [PATCH 25/30] MAINT: move expand_dims tests to a class --- .../test_manipulation_functions.py | 127 +++++++++--------- 1 file changed, 64 insertions(+), 63 deletions(-) diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py index 6954e5f9..a5e969a5 100644 --- a/array_api_tests/test_manipulation_functions.py +++ b/array_api_tests/test_manipulation_functions.py @@ -122,74 +122,75 @@ def test_concat(dtypes, base_shape, data): raise -@pytest.mark.unvectorized -@given( - x=hh.arrays(dtype=hh.all_dtypes, shape=shared_shapes()), - axis=shared_shapes().flatmap( - # Generate both valid and invalid axis - lambda s: st.integers(2 * (-len(s) - 1), 2 * len(s)) - ), -) -def test_expand_dims(x, axis): - if axis < -x.ndim - 1 or axis > x.ndim: - with pytest.raises(IndexError): - xp.expand_dims(x, axis=axis) - return - - repro_snippet = ph.format_snippet(f"xp.expand_dims({x!r}, axis={axis!r})") - try: - out = xp.expand_dims(x, axis=axis) +class TestExpandDims: + @pytest.mark.unvectorized + @given( + x=hh.arrays(dtype=hh.all_dtypes, shape=shared_shapes()), + axis=shared_shapes().flatmap( + # Generate both valid and invalid axis + lambda s: st.integers(2 * (-len(s) - 1), 2 * len(s)) + ), + ) + def test_expand_dims(self, x, axis): + if axis < -x.ndim - 1 or axis > x.ndim: + with pytest.raises(IndexError): + xp.expand_dims(x, axis=axis) + return - ph.assert_dtype("expand_dims", in_dtype=x.dtype, out_dtype=out.dtype) + repro_snippet = ph.format_snippet(f"xp.expand_dims({x!r}, axis={axis!r})") + try: + out = xp.expand_dims(x, axis=axis) - shape = [side for side in x.shape] - index = axis if axis >= 0 else x.ndim + axis + 1 - shape.insert(index, 1) - shape = tuple(shape) - ph.assert_result_shape("expand_dims", in_shapes=[x.shape], out_shape=out.shape, expected=shape) + ph.assert_dtype("expand_dims", in_dtype=x.dtype, out_dtype=out.dtype) - assert_array_ndindex( - "expand_dims", x, x_indices=sh.ndindex(x.shape), out=out, out_indices=sh.ndindex(out.shape) + shape = [side for side in x.shape] + index = axis if axis >= 0 else x.ndim + axis + 1 + shape.insert(index, 1) + shape = tuple(shape) + ph.assert_result_shape("expand_dims", in_shapes=[x.shape], out_shape=out.shape, expected=shape) + + assert_array_ndindex( + "expand_dims", x, x_indices=sh.ndindex(x.shape), out=out, out_indices=sh.ndindex(out.shape) + ) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise + + + @given( + x=hh.arrays(dtype=hh.all_dtypes, shape=shared_shapes(max_dims=4)), + axes=shared_shapes().flatmap( + lambda s: st.lists( + st.integers(2*(-len(s)-1), 2*len(s)), + min_size=0 if len(s)==0 else 1, + max_size=len(s) + ).map(tuple) ) - except Exception as exc: - ph.add_note(exc, repro_snippet) - raise - - -@given( - x=hh.arrays(dtype=hh.all_dtypes, shape=shared_shapes(max_dims=4)), - axes=shared_shapes().flatmap( - lambda s: st.lists( - st.integers(2*(-len(s)-1), 2*len(s)), - min_size=0 if len(s)==0 else 1, - max_size=len(s) - ).map(tuple) ) -) -def test_expand_dims_tuples(x, axes): - # normalize the axes - y_ndim = x.ndim + len(axes) - n_axes = tuple(ax + y_ndim if ax < 0 else ax for ax in axes) - unique_axes = set(n_axes) - - if any(ax < 0 or ax >= y_ndim for ax in n_axes) or len(n_axes) != len(unique_axes): - with pytest.raises((IndexError, ValueError)): - xp.expand_dims(x, axis=axes) - return - - repro_snippet = ph.format_snippet(f"xp.expand_dims({x!r}, axis={axes!r})") - try: - y = xp.expand_dims(x, axis=axes) - - ye = x - for ax in sorted(n_axes): - ye = xp.expand_dims(ye, axis=ax) - assert y.shape == ye.shape - # TODO value tests; check that y.shape is 1s and items from x.shape, in order - - except Exception as exc: - ph.add_note(exc, repro_snippet) - raise + def test_expand_dims_tuples(self, x, axes): + # normalize the axes + y_ndim = x.ndim + len(axes) + n_axes = tuple(ax + y_ndim if ax < 0 else ax for ax in axes) + unique_axes = set(n_axes) + + if any(ax < 0 or ax >= y_ndim for ax in n_axes) or len(n_axes) != len(unique_axes): + with pytest.raises((IndexError, ValueError)): + xp.expand_dims(x, axis=axes) + return + + repro_snippet = ph.format_snippet(f"xp.expand_dims({x!r}, axis={axes!r})") + try: + y = xp.expand_dims(x, axis=axes) + + ye = x + for ax in sorted(n_axes): + ye = xp.expand_dims(ye, axis=ax) + assert y.shape == ye.shape + # TODO value tests; check that y.shape is 1s and items from x.shape, in order + + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @pytest.mark.min_version("2023.12") From 779e1c3fc34d0578747d973fcc4c37790ee1061d Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 7 Feb 2026 14:34:28 +0000 Subject: [PATCH 26/30] Initial plan From 7a9935dcdf069404afeecefa5f3236f3f175e55d Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 7 Feb 2026 14:39:04 +0000 Subject: [PATCH 27/30] Fix unvectorized marker for class methods Co-authored-by: ev-br <2133832+ev-br@users.noreply.github.com> --- conftest.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/conftest.py b/conftest.py index dc50c9ae..f77aa81e 100644 --- a/conftest.py +++ b/conftest.py @@ -1,6 +1,7 @@ from functools import lru_cache from pathlib import Path import argparse +import inspect import warnings import os @@ -253,9 +254,20 @@ def pytest_collection_modifyitems(config, items): # reduce max generated Hypothesis example for unvectorized tests if any(m.name == "unvectorized" for m in markers): # TODO: limit generated examples when settings already applied - if not hasattr(item.obj, "_hypothesis_internal_settings_applied"): + # For class methods, we need to access the underlying function + test_func = item.obj + if inspect.ismethod(test_func): + test_func = test_func.__func__ + + if not hasattr(test_func, "_hypothesis_internal_settings_applied"): try: - item.obj = settings(max_examples=unvectorized_max_examples)(item.obj) + decorated_func = settings(max_examples=unvectorized_max_examples)(test_func) + # For class methods, replace the function in the class + if inspect.ismethod(item.obj): + # Get the class and method name + setattr(item.obj.__self__.__class__, item.obj.__name__, decorated_func) + else: + item.obj = decorated_func except InvalidArgument as e: warnings.warn( f"Tried decorating {item.name} with settings() but got " From 2c4cbbab0a62342a74dc7dcb9b205abc53e7496a Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 7 Feb 2026 14:39:04 +0000 Subject: [PATCH 28/30] Fix unvectorized marker for class methods Co-authored-by: ev-br <2133832+ev-br@users.noreply.github.com> --- conftest.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/conftest.py b/conftest.py index dc50c9ae..f77aa81e 100644 --- a/conftest.py +++ b/conftest.py @@ -1,6 +1,7 @@ from functools import lru_cache from pathlib import Path import argparse +import inspect import warnings import os @@ -253,9 +254,20 @@ def pytest_collection_modifyitems(config, items): # reduce max generated Hypothesis example for unvectorized tests if any(m.name == "unvectorized" for m in markers): # TODO: limit generated examples when settings already applied - if not hasattr(item.obj, "_hypothesis_internal_settings_applied"): + # For class methods, we need to access the underlying function + test_func = item.obj + if inspect.ismethod(test_func): + test_func = test_func.__func__ + + if not hasattr(test_func, "_hypothesis_internal_settings_applied"): try: - item.obj = settings(max_examples=unvectorized_max_examples)(item.obj) + decorated_func = settings(max_examples=unvectorized_max_examples)(test_func) + # For class methods, replace the function in the class + if inspect.ismethod(item.obj): + # Get the class and method name + setattr(item.obj.__self__.__class__, item.obj.__name__, decorated_func) + else: + item.obj = decorated_func except InvalidArgument as e: warnings.warn( f"Tried decorating {item.name} with settings() but got " From 417384059f42d2f66a40fc156a6f9cdb1e633537 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 7 Feb 2026 14:52:41 +0000 Subject: [PATCH 29/30] Use pytest-provided item.cls instead of manual class access Co-authored-by: ev-br <2133832+ev-br@users.noreply.github.com> --- conftest.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/conftest.py b/conftest.py index f77aa81e..0cbad101 100644 --- a/conftest.py +++ b/conftest.py @@ -262,10 +262,10 @@ def pytest_collection_modifyitems(config, items): if not hasattr(test_func, "_hypothesis_internal_settings_applied"): try: decorated_func = settings(max_examples=unvectorized_max_examples)(test_func) - # For class methods, replace the function in the class + # For class methods, use pytest's item.cls to access the class if inspect.ismethod(item.obj): - # Get the class and method name - setattr(item.obj.__self__.__class__, item.obj.__name__, decorated_func) + # Use pytest-provided item.cls instead of manually accessing the class + setattr(item.cls, item.obj.__name__, decorated_func) else: item.obj = decorated_func except InvalidArgument as e: From 4cc1072ae54c381660ef45b7ba8c15e207d67407 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sat, 7 Feb 2026 16:02:26 +0100 Subject: [PATCH 30/30] . --- conftest.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/conftest.py b/conftest.py index 0cbad101..7b3499af 100644 --- a/conftest.py +++ b/conftest.py @@ -258,13 +258,11 @@ def pytest_collection_modifyitems(config, items): test_func = item.obj if inspect.ismethod(test_func): test_func = test_func.__func__ - + if not hasattr(test_func, "_hypothesis_internal_settings_applied"): try: decorated_func = settings(max_examples=unvectorized_max_examples)(test_func) - # For class methods, use pytest's item.cls to access the class if inspect.ismethod(item.obj): - # Use pytest-provided item.cls instead of manually accessing the class setattr(item.cls, item.obj.__name__, decorated_func) else: item.obj = decorated_func