From 34cc8eb0c752cdd71d4be82346c751b4ff7e73ac Mon Sep 17 00:00:00 2001 From: Jaan Erik Pihel Date: Thu, 27 Nov 2025 11:12:32 +0200 Subject: [PATCH] Implement AdvancedIncSubtensorExplicit --- pytensor/tensor/subtensor.py | 254 ++++++++++++++++++++++++++++++++- tests/tensor/test_subtensor.py | 72 +++++++++- 2 files changed, 317 insertions(+), 9 deletions(-) diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py index d7fc1bedbc..2c1e869ff7 100644 --- a/pytensor/tensor/subtensor.py +++ b/pytensor/tensor/subtensor.py @@ -2738,10 +2738,42 @@ def is_bool_index(idx): assert node.outputs[0].ndim == len(res_shape) return [res_shape] + def _broadcast_indices(self, x, indices): + new_indices = [] + x_dim = 0 + for idx in indices: + if idx is None: + new_indices.append(idx) + continue + if isinstance(idx, slice): + x_dim += 1 + new_indices.append(idx) + continue + + # Check for boolean + if hasattr(idx, "dtype") and (idx.dtype == bool or idx.dtype == np.bool_): + x_dim += idx.ndim + new_indices.append(idx) + continue + + # Integer array + if x_dim < x.ndim and x.shape[x_dim] == 1: + # Broadcast: replace with zeros + new_indices.append(np.zeros_like(idx)) + else: + new_indices.append(idx) + x_dim += 1 + return tuple(new_indices) + def perform(self, node, inputs, out_): (out,) = out_ - check_advanced_indexing_dimensions(inputs[0], inputs[1:]) - rval = inputs[0].__getitem__(tuple(inputs[1:])) + x = inputs[0] + indices = inputs[1:] + + indices = self._broadcast_indices(x, indices) + + check_advanced_indexing_dimensions(x, indices) + rval = x.__getitem__(tuple(indices)) # When there are no arrays, we are not actually doing advanced # indexing, so __getitem__ will not return a copy. # Since no view_map is set, we need to copy the returned value @@ -2807,6 +2839,97 @@ def non_consecutive_adv_indexing(node: Apply) -> bool: advanced_subtensor = AdvancedSubtensor() +class BatchedSliceType(Type): + def filter(self, x, strict=False, allow_downcast=None): + return x + + def __str__(self): + return "BatchedSliceType" + + +batched_slice_type = BatchedSliceType() + + +class BatchedSlice(Op): + __props__ = () + + def make_node(self, start, stop, step): + return Apply(self, [start, stop, step], [batched_slice_type()]) + + def perform(self, node, inp, out_): + raise NotImplementedError("BatchedSlice is a placeholder") + + +@_vectorize_node.register(MakeSlice) +def vectorize_make_slice(op, node, *batched_inputs): + is_batched = False + for orig, batched in zip(node.inputs, batched_inputs): + if hasattr(batched.type, "ndim") and hasattr(orig.type, "ndim"): + if batched.type.ndim > orig.type.ndim: + is_batched = True + break + + if is_batched: + return BatchedSlice().make_node(*batched_inputs) + return op.make_node(*batched_inputs) + + +class AdvancedIncSubtensorExplicit(Op): + __props__ = ("structure", "set_instead_of_inc", "inplace", "ignore_duplicates") + + def __init__( + self, + structure, + set_instead_of_inc=False, + inplace=False, + ignore_duplicates=False, + ): + self.structure = structure + self.set_instead_of_inc = set_instead_of_inc + self.inplace = inplace + self.ignore_duplicates = ignore_duplicates + + def make_node(self, x, y, *inputs): + return Apply(self, [x, y, *inputs], [x.type()]) + + def perform(self, node, inputs, out_): + x, y, *flat_indices = inputs + + indices = [] + idx_ptr = 0 + for kind in self.structure: + if kind == "slice": + ( + start_val, + start_none, + stop_val, + stop_none, + step_val, + step_none, + ) = flat_indices[idx_ptr : idx_ptr + 6] + start = None if start_none else start_val + stop = None if stop_none else stop_val + step = None if step_none else step_val + indices.append(slice(start, stop, step)) + idx_ptr += 6 + else: + indices.append(flat_indices[idx_ptr]) + idx_ptr += 1 + + (out,) = out_ + if not self.inplace: + out[0] = x.copy() + else: + out[0] = x + + if self.set_instead_of_inc: + out[0][tuple(indices)] = y + elif self.ignore_duplicates: + out[0][tuple(indices)] += y + else: + np.add.at(out[0], tuple(indices), y) + + @_vectorize_node.register(AdvancedSubtensor) def vectorize_advanced_subtensor(op: AdvancedSubtensor, node, *batch_inputs): x, *idxs = node.inputs @@ -2967,9 +3090,130 @@ def non_consecutive_adv_indexing(node: Apply) -> bool: advanced_inc_subtensor = AdvancedIncSubtensor() advanced_set_subtensor = AdvancedIncSubtensor(set_instead_of_inc=True) advanced_inc_subtensor_nodup = AdvancedIncSubtensor(ignore_duplicates=True) -advanced_set_subtensor_nodup = AdvancedIncSubtensor( - set_instead_of_inc=True, ignore_duplicates=True -) + + +@_vectorize_node.register(AdvancedIncSubtensor) +def advanced_inc_subtensor_vectorize_node(op, node, *batched_inputs): + x = batched_inputs[0] + y = batched_inputs[1] + indices = batched_inputs[2:] + + # Check if we have batched slices (BatchedSliceType) + has_batched_slices = any(isinstance(idx.type, BatchedSliceType) for idx in indices) + + if has_batched_slices: + from pytensor.tensor.blockwise import Blockwise, safe_signature + + structure = [] + flat_inputs = [] + + # We need to construct inputs for AdvancedIncSubtensorExplicit + # x and y are passed as is (Blockwise handles them) + + for idx in indices: + if isinstance(idx.type, BatchedSliceType): + structure.append("slice") + # Unwrap BatchedSlice + # idx is a Variable output of BatchedSlice node + bs_node = idx.owner + assert isinstance(bs_node.op, BatchedSlice) + + for comp in bs_node.inputs: + if isinstance(comp.type, NoneTypeT): + # Pass dummy and True flag + flat_inputs.append(tensor_constant(0, dtype="int8")) + flat_inputs.append(tensor_constant(True, dtype="bool")) + else: + # Pass component and False flag + flat_inputs.append(comp) + flat_inputs.append(tensor_constant(False, dtype="bool")) + else: + structure.append("tensor") + flat_inputs.append(idx) + + core_op = AdvancedIncSubtensorExplicit( + structure=tuple(structure), + set_instead_of_inc=op.set_instead_of_inc, + inplace=op.inplace, + ignore_duplicates=op.ignore_duplicates, + ) + + # Signature + # x: (n, m, ...), y: (n, m, ...), indices... -> (n, m, ...) + + x_core_ndim = node.inputs[0].ndim + y_core_ndim = node.inputs[1].ndim + + input_core_ndims = [x_core_ndim, y_core_ndim] + + # For indices + for i, idx in enumerate(indices): + if isinstance(idx.type, BatchedSliceType): + # 6 components, all scalar (0-d) for slice parameters + input_core_ndims.extend([0] * 6) + else: + # Tensor index + # Core ndim is the ndim of the original index + input_core_ndims.append(node.inputs[2 + i].ndim) + + output_core_ndims = [node.outputs[0].ndim] + + signature = safe_signature(input_core_ndims, output_core_ndims) + + return Blockwise(core_op, signature=signature).make_node(x, y, *flat_inputs) + + x_orig = node.inputs[0] + x_batch_ndim = x.ndim - x_orig.ndim + + y_orig = node.inputs[1] + y_batch_ndim = y.ndim - y_orig.ndim + + batch_ndim = max(x_batch_ndim, y_batch_ndim) + + if batch_ndim == 0: + return op.make_node(x, y, *indices) + + if x_batch_ndim < batch_ndim: + # Broadcast x to match batch dimensions + # We assume the batch dimensions are the first batch_ndim dimensions + # and that y has them (since y_batch_ndim >= batch_ndim) + from pytensor.tensor import concatenate + + batch_shape = y.shape[:batch_ndim] + full_shape = concatenate([batch_shape, x.shape]) + x = alloc(x, *full_shape) + + # Check if any index is batched + any_batched_index = False + for i, idx in enumerate(indices): + orig_idx = node.inputs[2 + i] + if ( + hasattr(idx, "ndim") + and hasattr(orig_idx, "ndim") + and idx.ndim > orig_idx.ndim + ): + any_batched_index = True + break + + if not any_batched_index: + # Simple case: prepend slice(None) for each batch dim + sl = make_slice(None, None, None) + new_indices = [sl] * batch_ndim + list(indices) + return op.make_node(x, y, *new_indices) + + from pytensor.tensor import arange + + batch_indices = [] + for d in range(batch_ndim): + dim_len = x.shape[d] + idx = arange(dim_len) + pattern = ["x"] * batch_ndim + pattern[d] = 0 + idx = idx.dimshuffle(pattern) + batch_indices.append(idx) + + new_indices = batch_indices + list(indices) + return op.make_node(x, y, *new_indices) def take(a, indices, axis=None, mode="raise"): diff --git a/tests/tensor/test_subtensor.py b/tests/tensor/test_subtensor.py index d8dadf0009..fa11232a0d 100644 --- a/tests/tensor/test_subtensor.py +++ b/tests/tensor/test_subtensor.py @@ -17,7 +17,7 @@ from pytensor.compile.mode import Mode from pytensor.configdefaults import config from pytensor.gradient import grad -from pytensor.graph import Constant +from pytensor.graph import Constant, vectorize_graph from pytensor.graph.basic import equal_computations from pytensor.graph.op import get_test_value from pytensor.graph.rewriting.utils import is_same_graph @@ -3047,15 +3047,12 @@ def core_fn(x, start): (2,), False, ), - # (this is currently failing because PyTensor tries to vectorize the slice(None) operation, - # due to the exact same None constant being used there and in the np.newaxis) pytest.param( (lambda x, idx: x[:, idx, None]), "(7,5,3),(2)->(7,2,1,3)", (11, 7, 5, 3), (2,), False, - marks=pytest.mark.xfail(raises=NotImplementedError), ), ( (lambda x, idx: x[:, idx, idx, :]), @@ -3218,3 +3215,70 @@ def test_advanced_incsubtensor1(self, func, static_shape, gc, benchmark): ) fn.vm.allow_gc = gc benchmark(fn, x_values) + + +class TestAdvancedIncSubtensorVectorization: + def test_vectorize_advanced_inc_subtensor_slice(self): + # Regression test for vectorization of AdvancedIncSubtensor with slice inputs + x = matrix("x") + y = vector("y") + idx = as_tensor([0, 1]) + + # x[0:2, idx] = y + # This uses AdvancedIncSubtensor because of the vector index mixed with slice + z = set_subtensor(x[0:2, idx], y) + + # Vectorize over a batch dimension + # batched_x: (B, N, M) + # batched_y: (B, 2) + batched_x = tensor3("bx") + batched_y = matrix("by") + + out_batched = vectorize_graph(z, replace={x: batched_x, y: batched_y}) + + f = function([batched_x, batched_y], out_batched) + + bx_val = np.zeros((2, 5, 5), dtype=config.floatX) + by_val = np.ones((2, 2), dtype=config.floatX) + + res = f(bx_val, by_val) + + # Verify result + # For each batch b: + # res[b, 0:2, [0, 1]] should be 1 + assert np.all(res[:, 0:2, [0, 1]] == 1) + assert np.all(res[:, 2:, :] == 0) + + def test_vectorize_advanced_inc_subtensor_batched_slice(self): + # Regression test for vectorization of AdvancedIncSubtensor with batched slice parameters + x = matrix("x") + s = lscalar("s") + # x[s:, [0, 0]] = 0 + out = set_subtensor(x[s:, [0, 0]], 0) + + # Vectorize s -> z (vector) + z = lvector("z") + + out_batched = vectorize_graph(out, replace={s: z}) + + f = function([x, z], out_batched) + + x_val = np.arange(12).reshape((4, 3)).astype(config.floatX) + z_val = np.array([1, 2], dtype="int64") + + # For z=1: x[1:, [0,0]] = 0. Rows 1,2,3. Cols 0. + # For z=2: x[2:, [0,0]] = 0. Rows 2,3. Cols 0. + + res = f(x_val, z_val) + + # res shape: (2, 4, 3) + assert res.shape == (2, 4, 3) + + expected_0 = x_val.copy() + expected_0[1:, [0, 0]] = 0 + + expected_1 = x_val.copy() + expected_1[2:, [0, 0]] = 0 + + assert np.allclose(res[0], expected_0) + assert np.allclose(res[1], expected_1)