Skip to content

Commit de6aca8

Browse files
committed
Infer n_splits in split helper
1 parent 3a7f214 commit de6aca8

File tree

5 files changed

+13
-9
lines changed

5 files changed

+13
-9
lines changed

pytensor/tensor/basic.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2176,9 +2176,13 @@ def matrix_transpose(x: "TensorLike") -> TensorVariable:
21762176
return swapaxes(x, -1, -2)
21772177

21782178

2179-
def split(x, splits_size, n_splits, axis=0):
2180-
the_split = Split(n_splits)
2181-
return the_split(x, axis, splits_size)
2179+
def split(x, splits_size, *, n_splits=None, axis=0):
2180+
if n_splits is None:
2181+
if isinstance(splits_size, Variable):
2182+
n_splits = get_vector_length(splits_size)
2183+
else:
2184+
n_splits = len(splits_size)
2185+
return Split(n_splits)(x, axis, splits_size)
21822186

21832187

21842188
class Split(COp):

pytensor/tensor/fourier.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,8 @@ def infer_shape(self, fgraph, node, in_shapes):
116116
l = len(shape_a)
117117
shape_a = stack(shape_a)
118118
out_shape = concatenate((shape_a[0:axis], [n], shape_a[axis + 1 :]))
119-
n_splits = [1] * l
120-
out_shape = split(out_shape, n_splits, l)
119+
splits = [1] * l
120+
out_shape = split(out_shape, splits, n_splits=l)
121121
out_shape = [a[0] for a in out_shape]
122122
return [out_shape]
123123

tests/link/mlx/test_core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,15 +148,15 @@ def test_empty_dynamic_shape():
148148
def test_split_const_axis_const_splits_compiled():
149149
x = pt.vector("x")
150150
splits = [2, 3]
151-
outs = pt.split(x, splits, len(splits), axis=0)
151+
outs = pt.split(x, splits, n_splits=len(splits), axis=0)
152152
compare_mlx_and_py([x], outs, [np.arange(5, dtype="float32")])
153153

154154

155155
def test_split_dynamic_axis_const_splits():
156156
x = pt.matrix("x")
157157
axis = pt.scalar("axis", dtype="int64")
158158
splits = [1, 2, 3]
159-
outs = pt.split(x, splits, len(splits), axis=axis)
159+
outs = pt.split(x, splits, n_splits=len(splits), axis=axis)
160160

161161
test_input = np.arange(12).astype(config.floatX).reshape(2, 6)
162162

tests/link/numba/test_tensor_basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ def test_Join(vals, axis):
209209
def test_Split(n_splits, axis, values, sizes):
210210
values, values_test = values
211211
sizes, sizes_test = sizes
212-
g = pt.split(values, sizes, n_splits, axis=axis)
212+
g = pt.split(values, sizes, n_splits=n_splits, axis=axis)
213213
assert len(g) == n_splits
214214
if n_splits == 0:
215215
return

tests/link/pytorch/test_basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -518,7 +518,7 @@ def test_ScalarLoop_Elemwise_multi_carries():
518518
def test_Split(n_splits, axis, values, sizes):
519519
i = pt.tensor("i", shape=values.shape, dtype=config.floatX)
520520
s = pt.vector("s", dtype="int64")
521-
g = pt.split(i, s, n_splits, axis=axis)
521+
g = pt.split(i, s, n_splits=n_splits, axis=axis)
522522
assert len(g) == n_splits
523523
if n_splits == 0:
524524
return

0 commit comments

Comments
 (0)