Skip to content

Commit fb9ee9e

Browse files
authored
BUG: create_diagonal: remove delegation for ndim >= 2 (#544)
1 parent 1132a27 commit fb9ee9e

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

src/array_api_extra/_delegation.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -226,13 +226,14 @@ def create_diagonal(
226226
if is_torch_namespace(xp):
227227
return xp.diag_embed(x, offset=offset, dim1=-2, dim2=-1)
228228

229-
if (is_dask_namespace(xp) or is_cupy_namespace(xp)) and x.ndim < 2:
229+
if (
230+
is_dask_namespace(xp)
231+
or is_cupy_namespace(xp)
232+
or is_numpy_namespace(xp)
233+
or is_jax_namespace(xp)
234+
) and (x.ndim < 2):
230235
return xp.diag(x, k=offset)
231236

232-
if (is_jax_namespace(xp) or is_numpy_namespace(xp)) and x.ndim < 3:
233-
batch_dim, n = eager_shape(x)[:-1], eager_shape(x, -1)[0] + abs(offset)
234-
return xp.reshape(xp.diag(x, k=offset), (*batch_dim, n, n))
235-
236237
return _funcs.create_diagonal(x, offset=offset, xp=xp)
237238

238239

tests/test_funcs.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,6 @@
5555
lazy_xp_function(setdiff1d, jax_jit=False)
5656
lazy_xp_function(sinc)
5757

58-
NestedFloatList = list[float] | list["NestedFloatList"]
59-
6058

6159
class TestApplyWhere:
6260
@staticmethod
@@ -711,6 +709,7 @@ def test_0d_raises(self, xp: ModuleType):
711709
(0, 1),
712710
(1, 0),
713711
(0, 0),
712+
(2, 3),
714713
(4, 2, 1),
715714
(1, 1, 7),
716715
(0, 0, 1),

0 commit comments

Comments
 (0)