From 9ff8eeaaa2169cdb2b284a060bc348d75c398705 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 28 Sep 2025 13:19:14 +0000 Subject: [PATCH 1/7] Initial plan From 3cfbd0d7e0af57dfcb4fa94d59241aee4b417366 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 28 Sep 2025 13:38:46 +0000 Subject: [PATCH 2/7] Implement core refactoring of AdvancedSubtensor and AdvancedIncSubtensor Co-authored-by: ricardoV94 <28983449+ricardoV94@users.noreply.github.com> --- pytensor/tensor/subtensor.py | 398 +++++++++++++++++++++++++++-------- 1 file changed, 305 insertions(+), 93 deletions(-) diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py index d7fc1bedbc..09b4287660 100644 --- a/pytensor/tensor/subtensor.py +++ b/pytensor/tensor/subtensor.py @@ -2576,48 +2576,98 @@ def check_advanced_indexing_dimensions(input, idx_list): class AdvancedSubtensor(Op): """Implements NumPy's advanced indexing.""" - __props__ = () + __props__ = ("idx_list",) + + def __init__(self, idx_list): + """ + Initialize AdvancedSubtensor with index list. + + Parameters + ---------- + idx_list : tuple + List of indices where slices and newaxis are stored as-is, + and numerical indices are replaced by their types. + """ + self.idx_list = tuple(map(index_vars_to_types, idx_list)) + + def make_node(self, x, *inputs): + """ + Parameters + ---------- + x + The tensor to take a subtensor of. + inputs + A list of pytensor Scalars and Tensors (numerical indices only). - def make_node(self, x, *indices): + """ x = as_tensor_variable(x) - indices = tuple(map(as_index_variable, indices)) + inputs = tuple(as_tensor_variable(a) for a in inputs) + + idx_list = list(self.idx_list) + if len(idx_list) > x.type.ndim: + raise IndexError("too many indices for array") + # Get input types from idx_list - only process numerical indices + input_types = [] + input_idx = 0 explicit_indices = [] new_axes = [] - for idx in indices: - if isinstance(idx.type, TensorType) and idx.dtype == "bool": - if idx.type.ndim == 0: - raise NotImplementedError( - "Indexing with scalar booleans not supported" - ) + + for i, entry in enumerate(idx_list): + if isinstance(entry, slice): + # Slices are stored in idx_list, not passed as inputs + explicit_indices.append(entry) + elif entry is np.newaxis: + # Newaxis stored in idx_list, not passed as inputs + new_axes.append(len(explicit_indices)) + explicit_indices.append(entry) + elif isinstance(entry, Type): + # This is a numerical index - should have corresponding input + if input_idx >= len(inputs): + raise ValueError(f"Missing input for index {i}") + inp = inputs[input_idx] + + # Handle boolean indices + if inp.dtype == "bool": + if inp.type.ndim == 0: + raise NotImplementedError( + "Indexing with scalar booleans not supported" + ) - # Check static shape aligned - axis = len(explicit_indices) - len(new_axes) - indexed_shape = x.type.shape[axis : axis + idx.type.ndim] - for j, (indexed_length, indexer_length) in enumerate( - zip(indexed_shape, idx.type.shape) - ): - if ( - indexed_length is not None - and indexer_length is not None - and indexed_length != indexer_length + # Check static shape aligned + axis = len(explicit_indices) - len(new_axes) + indexed_shape = x.type.shape[axis : axis + inp.type.ndim] + for j, (indexed_length, indexer_length) in enumerate( + zip(indexed_shape, inp.type.shape) ): - raise IndexError( - f"boolean index did not match indexed tensor along axis {axis + j};" - f"size of axis is {indexed_length} but size of corresponding boolean axis is {indexer_length}" - ) - # Convert boolean indices to integer with nonzero, to reason about static shape next - if isinstance(idx, Constant): - nonzero_indices = [tensor_constant(i) for i in idx.data.nonzero()] + if ( + indexed_length is not None + and indexer_length is not None + and indexed_length != indexer_length + ): + raise IndexError( + f"boolean index did not match indexed tensor along axis {axis + j};" + f"size of axis is {indexed_length} but size of corresponding boolean axis is {indexer_length}" + ) + # Convert boolean indices to integer with nonzero, to reason about static shape next + if isinstance(inp, Constant): + nonzero_indices = [tensor_constant(i) for i in inp.data.nonzero()] + else: + # Note: Sometimes we could infer a shape error by reasoning about the largest possible size of nonzero + # and seeing that other integer indices cannot possible match it + nonzero_indices = inp.nonzero() + explicit_indices.extend(nonzero_indices) else: - # Note: Sometimes we could infer a shape error by reasoning about the largest possible size of nonzero - # and seeing that other integer indices cannot possible match it - nonzero_indices = idx.nonzero() - explicit_indices.extend(nonzero_indices) + # Regular numerical index + explicit_indices.append(inp) + + input_types.append(entry) + input_idx += 1 else: - if isinstance(idx.type, NoneTypeT): - new_axes.append(len(explicit_indices)) - explicit_indices.append(idx) + raise ValueError(f"Invalid entry in idx_list: {entry}") + + if input_idx != len(inputs): + raise ValueError(f"Expected {input_idx} inputs but got {len(inputs)}") if (len(explicit_indices) - len(new_axes)) > x.type.ndim: raise IndexError( @@ -2633,21 +2683,13 @@ def make_node(self, x, *indices): np.insert(np.array(x.type.shape, dtype=object), 1, new_axes) ) for i, (idx, dim_length) in enumerate( - zip_longest(explicit_indices, expanded_x_shape, fillvalue=NoneSliceConst) + zip_longest(explicit_indices, expanded_x_shape, fillvalue=slice(None)) ): - if isinstance(idx.type, NoneTypeT): + if idx is np.newaxis: basic_group_shape.append(1) # New-axis - elif isinstance(idx.type, SliceType): - if isinstance(idx, Constant): - basic_group_shape.append(slice_static_length(idx.data, dim_length)) - elif idx.owner is not None and isinstance(idx.owner.op, MakeSlice): - basic_group_shape.append( - slice_static_length(slice(*idx.owner.inputs), dim_length) - ) - else: - # Symbolic root slice (owner is None), or slice operation we don't understand - basic_group_shape.append(None) - else: # TensorType + elif isinstance(idx, slice): + basic_group_shape.append(slice_static_length(idx, dim_length)) + else: # TensorType (advanced index) # Keep track of advanced group axis if adv_group_axis is None: # First time we see an advanced index @@ -2682,7 +2724,7 @@ def make_node(self, x, *indices): return Apply( self, - [x, *indices], + [x, *inputs], [tensor(dtype=x.type.dtype, shape=tuple(indexed_shape))], ) @@ -2698,19 +2740,41 @@ def is_bool_index(idx): or getattr(idx, "dtype", None) == "bool" ) + # Reconstruct full index list from idx_list and inputs indices = node.inputs[1:] + full_indices = [] + input_idx = 0 + + for entry in self.idx_list: + if isinstance(entry, slice): + full_indices.append(entry) + elif entry is np.newaxis: + full_indices.append(entry) + elif isinstance(entry, Type): + # This is a numerical index - get from inputs + if input_idx < len(indices): + full_indices.append(indices[input_idx]) + input_idx += 1 + else: + raise ValueError("Mismatch between idx_list and inputs") + index_shapes = [] - for idx, ishape in zip(indices, ishapes[1:], strict=True): - # Mixed bool indexes are converted to nonzero entries - shape0_op = Shape_i(0) - if is_bool_index(idx): - index_shapes.extend((shape0_op(nz_dim),) for nz_dim in nonzero(idx)) - # The `ishapes` entries for `SliceType`s will be None, and - # we need to give `indexed_result_shape` the actual slices. - elif isinstance(getattr(idx, "type", None), SliceType): + for idx in full_indices: + if isinstance(idx, slice): index_shapes.append(idx) + elif idx is np.newaxis: + index_shapes.append(idx) + elif hasattr(idx, 'type'): + # Mixed bool indexes are converted to nonzero entries + shape0_op = Shape_i(0) + if is_bool_index(idx): + index_shapes.extend((shape0_op(nz_dim),) for nz_dim in nonzero(idx)) + else: + # Get ishape for this input + input_shape_idx = indices.index(idx) + 1 # +1 because ishapes[0] is x + index_shapes.append(ishapes[input_shape_idx]) else: - index_shapes.append(ishape) + index_shapes.append(idx) res_shape = list( indexed_result_shape(ishapes[0], index_shapes, indices_are_shapes=True) @@ -2740,14 +2804,37 @@ def is_bool_index(idx): def perform(self, node, inputs, out_): (out,) = out_ - check_advanced_indexing_dimensions(inputs[0], inputs[1:]) - rval = inputs[0].__getitem__(tuple(inputs[1:])) + + # Reconstruct the full tuple of indices from idx_list and inputs + x = inputs[0] + tensor_inputs = inputs[1:] + + full_indices = [] + input_idx = 0 + + for entry in self.idx_list: + if isinstance(entry, slice): + full_indices.append(entry) + elif entry is np.newaxis: + full_indices.append(np.newaxis) + elif isinstance(entry, Type): + # This is a numerical index - get from inputs + if input_idx < len(tensor_inputs): + full_indices.append(tensor_inputs[input_idx]) + input_idx += 1 + else: + raise ValueError("Mismatch between idx_list and inputs") + + check_advanced_indexing_dimensions(x, full_indices) + rval = x.__getitem__(tuple(full_indices)) # When there are no arrays, we are not actually doing advanced # indexing, so __getitem__ will not return a copy. # Since no view_map is set, we need to copy the returned value - if not any( - isinstance(v.type, TensorType) and v.ndim > 0 for v in node.inputs[1:] - ): + has_tensor_indices = any( + isinstance(entry, Type) and not getattr(entry, 'broadcastable', (False,))[0] + for entry in self.idx_list + ) + if not has_tensor_indices: rval = rval.copy() out[0] = rval @@ -2785,7 +2872,7 @@ def non_consecutive_adv_indexing(node: Apply) -> bool: This function checks if the advanced indexing is non-consecutive, in which case the advanced index dimensions are placed on the left of the - output array, regardless of their opriginal position. + output array, regardless of their original position. See: https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing @@ -2800,11 +2887,29 @@ def non_consecutive_adv_indexing(node: Apply) -> bool: bool True if the advanced indexing is non-consecutive, False otherwise. """ - _, *idxs = node.inputs - return _non_consecutive_adv_indexing(idxs) + # Reconstruct the full indices from idx_list and inputs to check consecutivity + op = node.op + tensor_inputs = node.inputs[1:] + + full_indices = [] + input_idx = 0 + + for entry in op.idx_list: + if isinstance(entry, slice): + full_indices.append(slice(None)) # Represent as basic slice + elif entry is np.newaxis: + full_indices.append(np.newaxis) + elif isinstance(entry, Type): + # This is a numerical index - get from inputs + if input_idx < len(tensor_inputs): + full_indices.append(tensor_inputs[input_idx]) + input_idx += 1 + + return _non_consecutive_adv_indexing(full_indices) -advanced_subtensor = AdvancedSubtensor() +# Note: This is now a factory function since AdvancedSubtensor needs idx_list +# The old global instance approach won't work anymore @_vectorize_node.register(AdvancedSubtensor) @@ -2824,30 +2929,25 @@ def vectorize_advanced_subtensor(op: AdvancedSubtensor, node, *batch_inputs): # which would put the indexed results to the left of the batch dimensions! # TODO: Not all cases must be handled by Blockwise, but the logic is complex - # Blockwise doesn't accept None or Slices types so we raise informative error here - # TODO: Implement these internally, so Blockwise is always a safe fallback - if any(not isinstance(idx, TensorVariable) for idx in idxs): - raise NotImplementedError( - "Vectorized AdvancedSubtensor with batched indexes or non-consecutive advanced indexing " - "and slices or newaxis is currently not supported." - ) - else: - return vectorize_node_fallback(op, node, batch_x, *batch_idxs) + # With the new interface, all inputs are tensors, so Blockwise can handle them + return vectorize_node_fallback(op, node, batch_x, *batch_idxs) # Otherwise we just need to add None slices for every new batch dim x_batch_ndim = batch_x.type.ndim - x.type.ndim empty_slices = (slice(None),) * x_batch_ndim - return op.make_node(batch_x, *empty_slices, *batch_idxs) + new_idx_list = empty_slices + op.idx_list + return AdvancedSubtensor(new_idx_list).make_node(batch_x, *batch_idxs) class AdvancedIncSubtensor(Op): """Increments a subtensor using advanced indexing.""" - __props__ = ("inplace", "set_instead_of_inc", "ignore_duplicates") + __props__ = ("inplace", "set_instead_of_inc", "ignore_duplicates", "idx_list") def __init__( - self, inplace=False, set_instead_of_inc=False, ignore_duplicates=False + self, idx_list, inplace=False, set_instead_of_inc=False, ignore_duplicates=False ): + self.idx_list = tuple(map(index_vars_to_types, idx_list)) self.set_instead_of_inc = set_instead_of_inc self.inplace = inplace if inplace: @@ -2865,6 +2965,11 @@ def make_node(self, x, y, *inputs): x = as_tensor_variable(x) y = as_tensor_variable(y) + # Validate that we have the right number of tensor inputs for our idx_list + expected_tensor_inputs = sum(1 for entry in self.idx_list if isinstance(entry, Type)) + if len(inputs) != expected_tensor_inputs: + raise ValueError(f"Expected {expected_tensor_inputs} tensor inputs but got {len(inputs)}") + new_inputs = [] for inp in inputs: if isinstance(inp, list | tuple): @@ -2877,9 +2982,26 @@ def make_node(self, x, y, *inputs): ) def perform(self, node, inputs, out_): - x, y, *indices = inputs + x, y, *tensor_inputs = inputs - check_advanced_indexing_dimensions(x, indices) + # Reconstruct the full tuple of indices from idx_list and inputs + full_indices = [] + input_idx = 0 + + for entry in self.idx_list: + if isinstance(entry, slice): + full_indices.append(entry) + elif entry is np.newaxis: + full_indices.append(np.newaxis) + elif isinstance(entry, Type): + # This is a numerical index - get from inputs + if input_idx < len(tensor_inputs): + full_indices.append(tensor_inputs[input_idx]) + input_idx += 1 + else: + raise ValueError("Mismatch between idx_list and inputs") + + check_advanced_indexing_dimensions(x, full_indices) (out,) = out_ if not self.inplace: @@ -2888,11 +3010,11 @@ def perform(self, node, inputs, out_): out[0] = x if self.set_instead_of_inc: - out[0][tuple(indices)] = y + out[0][tuple(full_indices)] = y elif self.ignore_duplicates: - out[0][tuple(indices)] += y + out[0][tuple(full_indices)] += y else: - np.add.at(out[0], tuple(indices), y) + np.add.at(out[0], tuple(full_indices), y) def infer_shape(self, fgraph, node, ishapes): return [ishapes[0]] @@ -2922,10 +3044,12 @@ def grad(self, inpt, output_gradients): raise NotImplementedError("No support for complex grad yet") else: if self.set_instead_of_inc: - gx = advanced_set_subtensor(outgrad, y.zeros_like(), *idxs) + gx = AdvancedIncSubtensor(self.idx_list, set_instead_of_inc=True).make_node( + outgrad, y.zeros_like(), *idxs + ).outputs[0] else: gx = outgrad - gy = advanced_subtensor(outgrad, *idxs) + gy = AdvancedSubtensor(self.idx_list).make_node(outgrad, *idxs).outputs[0] # Make sure to sum gy over the dimensions of y that have been # added or broadcasted gy = _sum_grad_over_bcasted_dims(y, gy) @@ -2945,7 +3069,7 @@ def non_consecutive_adv_indexing(node: Apply) -> bool: This function checks if the advanced indexing is non-consecutive, in which case the advanced index dimensions are placed on the left of the - output array, regardless of their opriginal position. + output array, regardless of their original position. See: https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing @@ -2960,16 +3084,104 @@ def non_consecutive_adv_indexing(node: Apply) -> bool: bool True if the advanced indexing is non-consecutive, False otherwise. """ - _, _, *idxs = node.inputs - return _non_consecutive_adv_indexing(idxs) + # Reconstruct the full indices from idx_list and inputs to check consecutivity + op = node.op + tensor_inputs = node.inputs[2:] # Skip x and y + + full_indices = [] + input_idx = 0 + + for entry in op.idx_list: + if isinstance(entry, slice): + full_indices.append(slice(None)) # Represent as basic slice + elif entry is np.newaxis: + full_indices.append(np.newaxis) + elif isinstance(entry, Type): + # This is a numerical index - get from inputs + if input_idx < len(tensor_inputs): + full_indices.append(tensor_inputs[input_idx]) + input_idx += 1 + + return _non_consecutive_adv_indexing(full_indices) + + +def advanced_subtensor(x, *args): + """Create an AdvancedSubtensor operation. + + This function processes the arguments to separate numerical indices from + slice/newaxis information and creates the appropriate AdvancedSubtensor op. + """ + # Process args to extract idx_list and numerical inputs + idx_list = [] + numerical_inputs = [] + + for arg in args: + if arg is None: + idx_list.append(np.newaxis) + elif isinstance(arg, slice): + idx_list.append(arg) + elif isinstance(arg, Variable) and isinstance(arg.type, SliceType): + # Convert SliceType variable back to slice - this should be a constant + if isinstance(arg, Constant): + idx_list.append(arg.data) + elif arg.owner and isinstance(arg.owner.op, MakeSlice): + # Convert MakeSlice back to slice + start, stop, step = arg.owner.inputs + start_val = start.data if isinstance(start, Constant) else start + stop_val = stop.data if isinstance(stop, Constant) else stop + step_val = step.data if isinstance(step, Constant) else step + idx_list.append(slice(start_val, stop_val, step_val)) + else: + # This is a symbolic slice that we need to handle + # For now, convert to a generic slice - this may need more work + idx_list.append(slice(None)) + elif isinstance(arg, Variable) and isinstance(arg.type, NoneTypeT): + idx_list.append(np.newaxis) + else: + # This is a numerical index (tensor, scalar, etc.) + idx_list.append(index_vars_to_types(as_tensor_variable(arg))) + numerical_inputs.append(as_tensor_variable(arg)) + + return AdvancedSubtensor(idx_list).make_node(x, *numerical_inputs).outputs[0] + + +def advanced_inc_subtensor(x, y, *args, **kwargs): + """Create an AdvancedIncSubtensor operation for incrementing.""" + # Process args to extract idx_list and numerical inputs + idx_list = [] + numerical_inputs = [] + + for arg in args: + if arg is None: + idx_list.append(np.newaxis) + elif isinstance(arg, slice): + idx_list.append(arg) + elif isinstance(arg, Variable) and isinstance(arg.type, SliceType): + # Convert SliceType variable back to slice + if isinstance(arg, Constant): + idx_list.append(arg.data) + elif arg.owner and isinstance(arg.owner.op, MakeSlice): + # Convert MakeSlice back to slice + start, stop, step = arg.owner.inputs + start_val = start.data if isinstance(start, Constant) else start + stop_val = stop.data if isinstance(stop, Constant) else stop + step_val = step.data if isinstance(step, Constant) else step + idx_list.append(slice(start_val, stop_val, step_val)) + else: + idx_list.append(slice(None)) + elif isinstance(arg, Variable) and isinstance(arg.type, NoneTypeT): + idx_list.append(np.newaxis) + else: + # This is a numerical index + idx_list.append(index_vars_to_types(as_tensor_variable(arg))) + numerical_inputs.append(as_tensor_variable(arg)) + + return AdvancedIncSubtensor(idx_list, **kwargs).make_node(x, y, *numerical_inputs).outputs[0] -advanced_inc_subtensor = AdvancedIncSubtensor() -advanced_set_subtensor = AdvancedIncSubtensor(set_instead_of_inc=True) -advanced_inc_subtensor_nodup = AdvancedIncSubtensor(ignore_duplicates=True) -advanced_set_subtensor_nodup = AdvancedIncSubtensor( - set_instead_of_inc=True, ignore_duplicates=True -) +def advanced_set_subtensor(x, y, *args, **kwargs): + """Create an AdvancedIncSubtensor operation for setting.""" + return advanced_inc_subtensor(x, y, *args, set_instead_of_inc=True, **kwargs) def take(a, indices, axis=None, mode="raise"): From c18b322f32f938eaf1c486fc2aa85501b9106560 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 28 Sep 2025 13:46:17 +0000 Subject: [PATCH 3/7] Complete refactoring with improved factory functions and proper slice handling Co-authored-by: ricardoV94 <28983449+ricardoV94@users.noreply.github.com> --- pytensor/tensor/subtensor.py | 264 ++++++++++++++++++++++++----------- 1 file changed, 185 insertions(+), 79 deletions(-) diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py index 09b4287660..d462aad9ba 100644 --- a/pytensor/tensor/subtensor.py +++ b/pytensor/tensor/subtensor.py @@ -2604,28 +2604,48 @@ def make_node(self, x, *inputs): inputs = tuple(as_tensor_variable(a) for a in inputs) idx_list = list(self.idx_list) - if len(idx_list) > x.type.ndim: + if (len([entry for entry in idx_list if entry is not np.newaxis]) > x.type.ndim): raise IndexError("too many indices for array") - # Get input types from idx_list - only process numerical indices - input_types = [] - input_idx = 0 + # Validate input count matches expected from idx_list + expected_inputs = get_slice_elements(idx_list, lambda entry: isinstance(entry, Type)) + if len(inputs) != len(expected_inputs): + raise ValueError(f"Expected {len(expected_inputs)} inputs but got {len(inputs)}") + + # Build explicit_indices for shape inference explicit_indices = [] new_axes = [] + input_idx = 0 for i, entry in enumerate(idx_list): - if isinstance(entry, slice): - # Slices are stored in idx_list, not passed as inputs - explicit_indices.append(entry) - elif entry is np.newaxis: - # Newaxis stored in idx_list, not passed as inputs + if entry is np.newaxis: new_axes.append(len(explicit_indices)) - explicit_indices.append(entry) + explicit_indices.append(np.newaxis) + elif isinstance(entry, slice): + # Reconstruct slice with actual values from inputs + if entry.start is not None and isinstance(entry.start, Type): + start_val = inputs[input_idx] + input_idx += 1 + else: + start_val = entry.start + + if entry.stop is not None and isinstance(entry.stop, Type): + stop_val = inputs[input_idx] + input_idx += 1 + else: + stop_val = entry.stop + + if entry.step is not None and isinstance(entry.step, Type): + step_val = inputs[input_idx] + input_idx += 1 + else: + step_val = entry.step + + explicit_indices.append(slice(start_val, stop_val, step_val)) elif isinstance(entry, Type): - # This is a numerical index - should have corresponding input - if input_idx >= len(inputs): - raise ValueError(f"Missing input for index {i}") + # This is a numerical index inp = inputs[input_idx] + input_idx += 1 # Handle boolean indices if inp.dtype == "bool": @@ -2649,26 +2669,18 @@ def make_node(self, x, *inputs): f"boolean index did not match indexed tensor along axis {axis + j};" f"size of axis is {indexed_length} but size of corresponding boolean axis is {indexer_length}" ) - # Convert boolean indices to integer with nonzero, to reason about static shape next + # Convert boolean indices to integer with nonzero if isinstance(inp, Constant): nonzero_indices = [tensor_constant(i) for i in inp.data.nonzero()] else: - # Note: Sometimes we could infer a shape error by reasoning about the largest possible size of nonzero - # and seeing that other integer indices cannot possible match it nonzero_indices = inp.nonzero() explicit_indices.extend(nonzero_indices) else: # Regular numerical index explicit_indices.append(inp) - - input_types.append(entry) - input_idx += 1 else: raise ValueError(f"Invalid entry in idx_list: {entry}") - if input_idx != len(inputs): - raise ValueError(f"Expected {input_idx} inputs but got {len(inputs)}") - if (len(explicit_indices) - len(new_axes)) > x.type.ndim: raise IndexError( f"too many indices for array: tensor is {x.type.ndim}-dimensional, but {len(explicit_indices) - len(new_axes)} were indexed" @@ -2740,20 +2752,40 @@ def is_bool_index(idx): or getattr(idx, "dtype", None) == "bool" ) - # Reconstruct full index list from idx_list and inputs - indices = node.inputs[1:] + # Reconstruct the full indices from idx_list and inputs (like perform method) + inputs = node.inputs[1:] + full_indices = [] input_idx = 0 for entry in self.idx_list: - if isinstance(entry, slice): - full_indices.append(entry) - elif entry is np.newaxis: - full_indices.append(entry) + if entry is np.newaxis: + full_indices.append(np.newaxis) + elif isinstance(entry, slice): + # Reconstruct slice from idx_list and inputs + if entry.start is not None and isinstance(entry.start, Type): + start_val = inputs[input_idx] + input_idx += 1 + else: + start_val = entry.start + + if entry.stop is not None and isinstance(entry.stop, Type): + stop_val = inputs[input_idx] + input_idx += 1 + else: + stop_val = entry.stop + + if entry.step is not None and isinstance(entry.step, Type): + step_val = inputs[input_idx] + input_idx += 1 + else: + step_val = entry.step + + full_indices.append(slice(start_val, stop_val, step_val)) elif isinstance(entry, Type): # This is a numerical index - get from inputs - if input_idx < len(indices): - full_indices.append(indices[input_idx]) + if input_idx < len(inputs): + full_indices.append(inputs[input_idx]) input_idx += 1 else: raise ValueError("Mismatch between idx_list and inputs") @@ -2771,7 +2803,7 @@ def is_bool_index(idx): index_shapes.extend((shape0_op(nz_dim),) for nz_dim in nonzero(idx)) else: # Get ishape for this input - input_shape_idx = indices.index(idx) + 1 # +1 because ishapes[0] is x + input_shape_idx = inputs.index(idx) + 1 # +1 because ishapes[0] is x index_shapes.append(ishapes[input_shape_idx]) else: index_shapes.append(idx) @@ -2813,10 +2845,29 @@ def perform(self, node, inputs, out_): input_idx = 0 for entry in self.idx_list: - if isinstance(entry, slice): - full_indices.append(entry) - elif entry is np.newaxis: + if entry is np.newaxis: full_indices.append(np.newaxis) + elif isinstance(entry, slice): + # Reconstruct slice from idx_list and inputs + if entry.start is not None and isinstance(entry.start, Type): + start_val = tensor_inputs[input_idx] + input_idx += 1 + else: + start_val = entry.start + + if entry.stop is not None and isinstance(entry.stop, Type): + stop_val = tensor_inputs[input_idx] + input_idx += 1 + else: + stop_val = entry.stop + + if entry.step is not None and isinstance(entry.step, Type): + step_val = tensor_inputs[input_idx] + input_idx += 1 + else: + step_val = entry.step + + full_indices.append(slice(start_val, stop_val, step_val)) elif isinstance(entry, Type): # This is a numerical index - get from inputs if input_idx < len(tensor_inputs): @@ -2989,10 +3040,29 @@ def perform(self, node, inputs, out_): input_idx = 0 for entry in self.idx_list: - if isinstance(entry, slice): - full_indices.append(entry) - elif entry is np.newaxis: + if entry is np.newaxis: full_indices.append(np.newaxis) + elif isinstance(entry, slice): + # Reconstruct slice from idx_list and inputs + if entry.start is not None and isinstance(entry.start, Type): + start_val = tensor_inputs[input_idx] + input_idx += 1 + else: + start_val = entry.start + + if entry.stop is not None and isinstance(entry.stop, Type): + stop_val = tensor_inputs[input_idx] + input_idx += 1 + else: + stop_val = entry.stop + + if entry.step is not None and isinstance(entry.step, Type): + step_val = tensor_inputs[input_idx] + input_idx += 1 + else: + step_val = entry.step + + full_indices.append(slice(start_val, stop_val, step_val)) elif isinstance(entry, Type): # This is a numerical index - get from inputs if input_idx < len(tensor_inputs): @@ -3108,75 +3178,111 @@ def non_consecutive_adv_indexing(node: Apply) -> bool: def advanced_subtensor(x, *args): """Create an AdvancedSubtensor operation. - This function processes the arguments to separate numerical indices from - slice/newaxis information and creates the appropriate AdvancedSubtensor op. + This function converts the arguments to work with the new AdvancedSubtensor + interface that separates slice structure from variable inputs. """ - # Process args to extract idx_list and numerical inputs - idx_list = [] - numerical_inputs = [] - + # Convert raw args to proper form first + processed_args = [] for arg in args: if arg is None: - idx_list.append(np.newaxis) + processed_args.append(NoneConst.clone()) elif isinstance(arg, slice): - idx_list.append(arg) - elif isinstance(arg, Variable) and isinstance(arg.type, SliceType): - # Convert SliceType variable back to slice - this should be a constant + processed_args.append(make_slice(arg)) + else: + processed_args.append(as_tensor_variable(arg)) + + # Now create idx_list and extract inputs + idx_list = [] + input_vars = [] + + for arg in processed_args: + if isinstance(arg.type, NoneTypeT): + idx_list.append(np.newaxis) + elif isinstance(arg.type, SliceType): + # Handle SliceType - extract components and structure if isinstance(arg, Constant): + # Constant slice idx_list.append(arg.data) elif arg.owner and isinstance(arg.owner.op, MakeSlice): - # Convert MakeSlice back to slice + # Variable slice - extract components start, stop, step = arg.owner.inputs - start_val = start.data if isinstance(start, Constant) else start - stop_val = stop.data if isinstance(stop, Constant) else stop - step_val = step.data if isinstance(step, Constant) else step - idx_list.append(slice(start_val, stop_val, step_val)) + + # Convert components to types for idx_list + start_type = index_vars_to_types(start, False) if not isinstance(start.type, NoneTypeT) else None + stop_type = index_vars_to_types(stop, False) if not isinstance(stop.type, NoneTypeT) else None + step_type = index_vars_to_types(step, False) if not isinstance(step.type, NoneTypeT) else None + + idx_list.append(slice(start_type, stop_type, step_type)) + + # Add variable components to inputs + if not isinstance(start.type, NoneTypeT): + input_vars.append(start) + if not isinstance(stop.type, NoneTypeT): + input_vars.append(stop) + if not isinstance(step.type, NoneTypeT): + input_vars.append(step) else: - # This is a symbolic slice that we need to handle - # For now, convert to a generic slice - this may need more work + # Other slice case idx_list.append(slice(None)) - elif isinstance(arg, Variable) and isinstance(arg.type, NoneTypeT): - idx_list.append(np.newaxis) else: - # This is a numerical index (tensor, scalar, etc.) - idx_list.append(index_vars_to_types(as_tensor_variable(arg))) - numerical_inputs.append(as_tensor_variable(arg)) + # Tensor index + idx_list.append(index_vars_to_types(arg)) + input_vars.append(arg) - return AdvancedSubtensor(idx_list).make_node(x, *numerical_inputs).outputs[0] + return AdvancedSubtensor(idx_list).make_node(x, *input_vars).outputs[0] def advanced_inc_subtensor(x, y, *args, **kwargs): """Create an AdvancedIncSubtensor operation for incrementing.""" - # Process args to extract idx_list and numerical inputs - idx_list = [] - numerical_inputs = [] - + # Convert raw args to proper form first + processed_args = [] for arg in args: if arg is None: - idx_list.append(np.newaxis) + processed_args.append(NoneConst.clone()) elif isinstance(arg, slice): - idx_list.append(arg) - elif isinstance(arg, Variable) and isinstance(arg.type, SliceType): - # Convert SliceType variable back to slice + processed_args.append(make_slice(arg)) + else: + processed_args.append(as_tensor_variable(arg)) + + # Now create idx_list and extract inputs + idx_list = [] + input_vars = [] + + for arg in processed_args: + if isinstance(arg.type, NoneTypeT): + idx_list.append(np.newaxis) + elif isinstance(arg.type, SliceType): + # Handle SliceType - extract components and structure if isinstance(arg, Constant): + # Constant slice idx_list.append(arg.data) elif arg.owner and isinstance(arg.owner.op, MakeSlice): - # Convert MakeSlice back to slice + # Variable slice - extract components start, stop, step = arg.owner.inputs - start_val = start.data if isinstance(start, Constant) else start - stop_val = stop.data if isinstance(stop, Constant) else stop - step_val = step.data if isinstance(step, Constant) else step - idx_list.append(slice(start_val, stop_val, step_val)) + + # Convert components to types for idx_list + start_type = index_vars_to_types(start, False) if not isinstance(start.type, NoneTypeT) else None + stop_type = index_vars_to_types(stop, False) if not isinstance(stop.type, NoneTypeT) else None + step_type = index_vars_to_types(step, False) if not isinstance(step.type, NoneTypeT) else None + + idx_list.append(slice(start_type, stop_type, step_type)) + + # Add variable components to inputs + if not isinstance(start.type, NoneTypeT): + input_vars.append(start) + if not isinstance(stop.type, NoneTypeT): + input_vars.append(stop) + if not isinstance(step.type, NoneTypeT): + input_vars.append(step) else: + # Other slice case idx_list.append(slice(None)) - elif isinstance(arg, Variable) and isinstance(arg.type, NoneTypeT): - idx_list.append(np.newaxis) else: - # This is a numerical index - idx_list.append(index_vars_to_types(as_tensor_variable(arg))) - numerical_inputs.append(as_tensor_variable(arg)) + # Tensor index + idx_list.append(index_vars_to_types(arg)) + input_vars.append(arg) - return AdvancedIncSubtensor(idx_list, **kwargs).make_node(x, y, *numerical_inputs).outputs[0] + return AdvancedIncSubtensor(idx_list, **kwargs).make_node(x, y, *input_vars).outputs[0] def advanced_set_subtensor(x, y, *args, **kwargs): From 737b8cb276749e8e6c5180b4ef9a05f71a3dd050 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 28 Sep 2025 13:50:35 +0000 Subject: [PATCH 4/7] Final fix: use as_index_variable consistently with original implementation Co-authored-by: ricardoV94 <28983449+ricardoV94@users.noreply.github.com> --- pytensor/tensor/subtensor.py | 22 ++++------------------ 1 file changed, 4 insertions(+), 18 deletions(-) diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py index d462aad9ba..1f0edb2417 100644 --- a/pytensor/tensor/subtensor.py +++ b/pytensor/tensor/subtensor.py @@ -3181,15 +3181,8 @@ def advanced_subtensor(x, *args): This function converts the arguments to work with the new AdvancedSubtensor interface that separates slice structure from variable inputs. """ - # Convert raw args to proper form first - processed_args = [] - for arg in args: - if arg is None: - processed_args.append(NoneConst.clone()) - elif isinstance(arg, slice): - processed_args.append(make_slice(arg)) - else: - processed_args.append(as_tensor_variable(arg)) + # Convert args using as_index_variable (like original AdvancedSubtensor did) + processed_args = tuple(map(as_index_variable, args)) # Now create idx_list and extract inputs idx_list = [] @@ -3234,15 +3227,8 @@ def advanced_subtensor(x, *args): def advanced_inc_subtensor(x, y, *args, **kwargs): """Create an AdvancedIncSubtensor operation for incrementing.""" - # Convert raw args to proper form first - processed_args = [] - for arg in args: - if arg is None: - processed_args.append(NoneConst.clone()) - elif isinstance(arg, slice): - processed_args.append(make_slice(arg)) - else: - processed_args.append(as_tensor_variable(arg)) + # Convert args using as_index_variable (like original AdvancedIncSubtensor would) + processed_args = tuple(map(as_index_variable, args)) # Now create idx_list and extract inputs idx_list = [] From a3634dda25ff43a0e811d5f2aa1d85dac8b7754f Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 28 Sep 2025 17:16:46 +0000 Subject: [PATCH 5/7] Refactor newaxis handling: move to __getitem__ level, unify with Subtensor approach Co-authored-by: ricardoV94 <28983449+ricardoV94@users.noreply.github.com> --- pytensor/tensor/subtensor.py | 74 +++++++++++----------------- pytensor/tensor/variable.py | 94 ++++++++++++++++++------------------ 2 files changed, 75 insertions(+), 93 deletions(-) diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py index 1f0edb2417..0da34b6fd0 100644 --- a/pytensor/tensor/subtensor.py +++ b/pytensor/tensor/subtensor.py @@ -2612,16 +2612,12 @@ def make_node(self, x, *inputs): if len(inputs) != len(expected_inputs): raise ValueError(f"Expected {len(expected_inputs)} inputs but got {len(inputs)}") - # Build explicit_indices for shape inference + # Build explicit_indices for shape inference (newaxis handled by __getitem__) explicit_indices = [] - new_axes = [] input_idx = 0 for i, entry in enumerate(idx_list): - if entry is np.newaxis: - new_axes.append(len(explicit_indices)) - explicit_indices.append(np.newaxis) - elif isinstance(entry, slice): + if isinstance(entry, slice): # Reconstruct slice with actual values from inputs if entry.start is not None and isinstance(entry.start, Type): start_val = inputs[input_idx] @@ -2655,7 +2651,7 @@ def make_node(self, x, *inputs): ) # Check static shape aligned - axis = len(explicit_indices) - len(new_axes) + axis = len(explicit_indices) indexed_shape = x.type.shape[axis : axis + inp.type.ndim] for j, (indexed_length, indexer_length) in enumerate( zip(indexed_shape, inp.type.shape) @@ -2681,25 +2677,20 @@ def make_node(self, x, *inputs): else: raise ValueError(f"Invalid entry in idx_list: {entry}") - if (len(explicit_indices) - len(new_axes)) > x.type.ndim: + if len(explicit_indices) > x.type.ndim: raise IndexError( - f"too many indices for array: tensor is {x.type.ndim}-dimensional, but {len(explicit_indices) - len(new_axes)} were indexed" + f"too many indices for array: tensor is {x.type.ndim}-dimensional, but {len(explicit_indices)} were indexed" ) - # Perform basic and advanced indexing shape inference separately + # Perform basic and advanced indexing shape inference separately (no newaxis) basic_group_shape = [] advanced_indices = [] adv_group_axis = None last_adv_group_axis = None - expanded_x_shape = tuple( - np.insert(np.array(x.type.shape, dtype=object), 1, new_axes) - ) for i, (idx, dim_length) in enumerate( - zip_longest(explicit_indices, expanded_x_shape, fillvalue=slice(None)) + zip_longest(explicit_indices, x.type.shape, fillvalue=slice(None)) ): - if idx is np.newaxis: - basic_group_shape.append(1) # New-axis - elif isinstance(idx, slice): + if isinstance(idx, slice): basic_group_shape.append(slice_static_length(idx, dim_length)) else: # TensorType (advanced index) # Keep track of advanced group axis @@ -2752,16 +2743,14 @@ def is_bool_index(idx): or getattr(idx, "dtype", None) == "bool" ) - # Reconstruct the full indices from idx_list and inputs (like perform method) + # Reconstruct the full indices from idx_list and inputs (newaxis handled by __getitem__) inputs = node.inputs[1:] full_indices = [] input_idx = 0 for entry in self.idx_list: - if entry is np.newaxis: - full_indices.append(np.newaxis) - elif isinstance(entry, slice): + if isinstance(entry, slice): # Reconstruct slice from idx_list and inputs if entry.start is not None and isinstance(entry.start, Type): start_val = inputs[input_idx] @@ -2794,8 +2783,6 @@ def is_bool_index(idx): for idx in full_indices: if isinstance(idx, slice): index_shapes.append(idx) - elif idx is np.newaxis: - index_shapes.append(idx) elif hasattr(idx, 'type'): # Mixed bool indexes are converted to nonzero entries shape0_op = Shape_i(0) @@ -2837,7 +2824,7 @@ def is_bool_index(idx): def perform(self, node, inputs, out_): (out,) = out_ - # Reconstruct the full tuple of indices from idx_list and inputs + # Reconstruct the full tuple of indices from idx_list and inputs (newaxis handled by __getitem__) x = inputs[0] tensor_inputs = inputs[1:] @@ -2845,9 +2832,7 @@ def perform(self, node, inputs, out_): input_idx = 0 for entry in self.idx_list: - if entry is np.newaxis: - full_indices.append(np.newaxis) - elif isinstance(entry, slice): + if isinstance(entry, slice): # Reconstruct slice from idx_list and inputs if entry.start is not None and isinstance(entry.start, Type): start_val = tensor_inputs[input_idx] @@ -2938,7 +2923,7 @@ def non_consecutive_adv_indexing(node: Apply) -> bool: bool True if the advanced indexing is non-consecutive, False otherwise. """ - # Reconstruct the full indices from idx_list and inputs to check consecutivity + # Reconstruct the full indices from idx_list and inputs to check consecutivity (newaxis handled by __getitem__) op = node.op tensor_inputs = node.inputs[1:] @@ -2948,8 +2933,6 @@ def non_consecutive_adv_indexing(node: Apply) -> bool: for entry in op.idx_list: if isinstance(entry, slice): full_indices.append(slice(None)) # Represent as basic slice - elif entry is np.newaxis: - full_indices.append(np.newaxis) elif isinstance(entry, Type): # This is a numerical index - get from inputs if input_idx < len(tensor_inputs): @@ -3035,14 +3018,12 @@ def make_node(self, x, y, *inputs): def perform(self, node, inputs, out_): x, y, *tensor_inputs = inputs - # Reconstruct the full tuple of indices from idx_list and inputs + # Reconstruct the full tuple of indices from idx_list and inputs (newaxis handled by __getitem__) full_indices = [] input_idx = 0 for entry in self.idx_list: - if entry is np.newaxis: - full_indices.append(np.newaxis) - elif isinstance(entry, slice): + if isinstance(entry, slice): # Reconstruct slice from idx_list and inputs if entry.start is not None and isinstance(entry.start, Type): start_val = tensor_inputs[input_idx] @@ -3154,7 +3135,7 @@ def non_consecutive_adv_indexing(node: Apply) -> bool: bool True if the advanced indexing is non-consecutive, False otherwise. """ - # Reconstruct the full indices from idx_list and inputs to check consecutivity + # Reconstruct the full indices from idx_list and inputs to check consecutivity (newaxis handled by __getitem__) op = node.op tensor_inputs = node.inputs[2:] # Skip x and y @@ -3164,8 +3145,6 @@ def non_consecutive_adv_indexing(node: Apply) -> bool: for entry in op.idx_list: if isinstance(entry, slice): full_indices.append(slice(None)) # Represent as basic slice - elif entry is np.newaxis: - full_indices.append(np.newaxis) elif isinstance(entry, Type): # This is a numerical index - get from inputs if input_idx < len(tensor_inputs): @@ -3180,6 +3159,9 @@ def advanced_subtensor(x, *args): This function converts the arguments to work with the new AdvancedSubtensor interface that separates slice structure from variable inputs. + + Note: newaxis (None) should be handled by __getitem__ using dimshuffle + before calling this function. """ # Convert args using as_index_variable (like original AdvancedSubtensor did) processed_args = tuple(map(as_index_variable, args)) @@ -3189,9 +3171,7 @@ def advanced_subtensor(x, *args): input_vars = [] for arg in processed_args: - if isinstance(arg.type, NoneTypeT): - idx_list.append(np.newaxis) - elif isinstance(arg.type, SliceType): + if isinstance(arg.type, SliceType): # Handle SliceType - extract components and structure if isinstance(arg, Constant): # Constant slice @@ -3218,7 +3198,7 @@ def advanced_subtensor(x, *args): # Other slice case idx_list.append(slice(None)) else: - # Tensor index + # Tensor index (should not be NoneType since newaxis handled in __getitem__) idx_list.append(index_vars_to_types(arg)) input_vars.append(arg) @@ -3226,7 +3206,11 @@ def advanced_subtensor(x, *args): def advanced_inc_subtensor(x, y, *args, **kwargs): - """Create an AdvancedIncSubtensor operation for incrementing.""" + """Create an AdvancedIncSubtensor operation for incrementing. + + Note: newaxis (None) should be handled by __getitem__ using dimshuffle + before calling this function. + """ # Convert args using as_index_variable (like original AdvancedIncSubtensor would) processed_args = tuple(map(as_index_variable, args)) @@ -3235,9 +3219,7 @@ def advanced_inc_subtensor(x, y, *args, **kwargs): input_vars = [] for arg in processed_args: - if isinstance(arg.type, NoneTypeT): - idx_list.append(np.newaxis) - elif isinstance(arg.type, SliceType): + if isinstance(arg.type, SliceType): # Handle SliceType - extract components and structure if isinstance(arg, Constant): # Constant slice @@ -3264,7 +3246,7 @@ def advanced_inc_subtensor(x, y, *args, **kwargs): # Other slice case idx_list.append(slice(None)) else: - # Tensor index + # Tensor index (should not be NoneType since newaxis handled in __getitem__) idx_list.append(index_vars_to_types(arg)) input_vars.append(arg) diff --git a/pytensor/tensor/variable.py b/pytensor/tensor/variable.py index 31e08fd39b..33f0ed3a81 100644 --- a/pytensor/tensor/variable.py +++ b/pytensor/tensor/variable.py @@ -539,55 +539,55 @@ def is_empty_array(val): else: advanced = True - if advanced: - return pt.subtensor.advanced_subtensor(self, *args) - else: - if np.newaxis in args or NoneConst in args: - # `np.newaxis` (i.e. `None`) in NumPy indexing mean "add a new - # broadcastable dimension at this location". Since PyTensor adds - # new broadcastable dimensions via the `DimShuffle` `Op`, the - # following code uses said `Op` to add one of the new axes and - # then uses recursion to apply any other indices and add any - # remaining new axes. - - counter = 0 - pattern = [] - new_args = [] - for arg in args: - if arg is np.newaxis or arg is NoneConst: - pattern.append("x") - new_args.append(slice(None, None, None)) - else: - pattern.append(counter) - counter += 1 - new_args.append(arg) - - pattern.extend(list(range(counter, self.ndim))) - - view = self.dimshuffle(pattern) - full_slices = True - for arg in new_args: - # We can't do arg == slice(None, None, None) as in - # Python 2.7, this call __lt__ if we have a slice - # with some symbolic variable. - if not ( - isinstance(arg, slice) - and (arg.start is None or arg.start is NoneConst) - and (arg.stop is None or arg.stop is NoneConst) - and (arg.step is None or arg.step is NoneConst) - ): - full_slices = False - if full_slices: - return view + # Handle newaxis (None) for both basic and advanced indexing + if np.newaxis in args or NoneConst in args: + # `np.newaxis` (i.e. `None`) in NumPy indexing mean "add a new + # broadcastable dimension at this location". Since PyTensor adds + # new broadcastable dimensions via the `DimShuffle` `Op`, the + # following code uses said `Op` to add one of the new axes and + # then uses recursion to apply any other indices and add any + # remaining new axes. + + counter = 0 + pattern = [] + new_args = [] + for arg in args: + if arg is np.newaxis or arg is NoneConst: + pattern.append("x") + new_args.append(slice(None, None, None)) else: - return view.__getitem__(tuple(new_args)) + pattern.append(counter) + counter += 1 + new_args.append(arg) + + pattern.extend(list(range(counter, self.ndim))) + + view = self.dimshuffle(pattern) + full_slices = True + for arg in new_args: + # We can't do arg == slice(None, None, None) as in + # Python 2.7, this call __lt__ if we have a slice + # with some symbolic variable. + if not ( + isinstance(arg, slice) + and (arg.start is None or arg.start is NoneConst) + and (arg.stop is None or arg.stop is NoneConst) + and (arg.step is None or arg.step is NoneConst) + ): + full_slices = False + if full_slices: + return view else: - return pt.subtensor.Subtensor(args)( - self, - *pt.subtensor.get_slice_elements( - args, lambda entry: isinstance(entry, Variable) - ), - ) + return view.__getitem__(tuple(new_args)) + elif advanced: + return pt.subtensor.advanced_subtensor(self, *args) + else: + return pt.subtensor.Subtensor(args)( + self, + *pt.subtensor.get_slice_elements( + args, lambda entry: isinstance(entry, Variable) + ), + ) def __setitem__(self, key, value): raise TypeError( From 53adf9ad156a1fed8c1b1b427bbe111b9b65673e Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 29 Sep 2025 14:45:10 +0000 Subject: [PATCH 6/7] Update dispatch functions and rewrite rules for new AdvancedSubtensor interface, store expected_inputs_len Co-authored-by: ricardoV94 <28983449+ricardoV94@users.noreply.github.com> --- pytensor/link/jax/dispatch/subtensor.py | 9 ++- pytensor/link/numba/dispatch/subtensor.py | 44 ++++++------ pytensor/link/pytorch/dispatch/subtensor.py | 21 ++++-- pytensor/tensor/rewriting/subtensor.py | 78 +++++++++++++++++++-- pytensor/tensor/subtensor.py | 20 +++--- 5 files changed, 127 insertions(+), 45 deletions(-) diff --git a/pytensor/link/jax/dispatch/subtensor.py b/pytensor/link/jax/dispatch/subtensor.py index 1c659be29b..cd8f78575a 100644 --- a/pytensor/link/jax/dispatch/subtensor.py +++ b/pytensor/link/jax/dispatch/subtensor.py @@ -77,6 +77,8 @@ def incsubtensor(x, y, *ilist, jax_fn=jax_fn, idx_list=idx_list): @jax_funcify.register(AdvancedIncSubtensor) def jax_funcify_AdvancedIncSubtensor(op, node, **kwargs): + idx_list = getattr(op, "idx_list", None) + if getattr(op, "set_instead_of_inc", False): def jax_fn(x, indices, y): @@ -87,8 +89,11 @@ def jax_fn(x, indices, y): def jax_fn(x, indices, y): return x.at[indices].add(y) - def advancedincsubtensor(x, y, *ilist, jax_fn=jax_fn): - return jax_fn(x, ilist, y) + def advancedincsubtensor(x, y, *ilist, jax_fn=jax_fn, idx_list=idx_list): + indices = indices_from_subtensor(ilist, idx_list) + if len(indices) == 1: + indices = indices[0] + return jax_fn(x, indices, y) return advancedincsubtensor diff --git a/pytensor/link/numba/dispatch/subtensor.py b/pytensor/link/numba/dispatch/subtensor.py index 51787daf41..7e7353f60e 100644 --- a/pytensor/link/numba/dispatch/subtensor.py +++ b/pytensor/link/numba/dispatch/subtensor.py @@ -239,28 +239,30 @@ def {function_name}({", ".join(input_names)}): @register_funcify_and_cache_key(AdvancedIncSubtensor) def numba_funcify_AdvancedSubtensor(op, node, **kwargs): if isinstance(op, AdvancedSubtensor): - _x, _y, idxs = node.inputs[0], None, node.inputs[1:] + x, y, tensor_inputs = node.inputs[0], None, node.inputs[1:] else: - _x, _y, *idxs = node.inputs - - basic_idxs = [ - idx - for idx in idxs - if ( - isinstance(idx.type, NoneTypeT) - or (isinstance(idx.type, SliceType) and not is_full_slice(idx)) - ) - ] - adv_idxs = [ - { - "axis": i, - "dtype": idx.type.dtype, - "bcast": idx.type.broadcastable, - "ndim": idx.type.ndim, - } - for i, idx in enumerate(idxs) - if isinstance(idx.type, TensorType) - ] + x, y, *tensor_inputs = node.inputs + + # Reconstruct indexing information from idx_list and tensor inputs + basic_idxs = [] + adv_idxs = [] + input_idx = 0 + + for i, entry in enumerate(op.idx_list): + if isinstance(entry, slice): + # Basic slice index + basic_idxs.append(entry) + elif isinstance(entry, Type): + # Advanced tensor index + if input_idx < len(tensor_inputs): + idx_input = tensor_inputs[input_idx] + adv_idxs.append({ + "axis": i, + "dtype": idx_input.type.dtype, + "bcast": idx_input.type.broadcastable, + "ndim": idx_input.type.ndim, + }) + input_idx += 1 # Special implementation for consecutive integer vector indices if ( diff --git a/pytensor/link/pytorch/dispatch/subtensor.py b/pytensor/link/pytorch/dispatch/subtensor.py index 26b4fd0f7f..786ec46fe4 100644 --- a/pytensor/link/pytorch/dispatch/subtensor.py +++ b/pytensor/link/pytorch/dispatch/subtensor.py @@ -63,7 +63,10 @@ def makeslice(start, stop, step): @pytorch_funcify.register(AdvancedSubtensor1) @pytorch_funcify.register(AdvancedSubtensor) def pytorch_funcify_AdvSubtensor(op, node, **kwargs): - def advsubtensor(x, *indices): + idx_list = getattr(op, "idx_list", None) + + def advsubtensor(x, *flattened_indices): + indices = indices_from_subtensor(flattened_indices, idx_list) check_negative_steps(indices) return x[indices] @@ -102,12 +105,14 @@ def inc_subtensor(x, y, *flattened_indices): @pytorch_funcify.register(AdvancedIncSubtensor) @pytorch_funcify.register(AdvancedIncSubtensor1) def pytorch_funcify_AdvancedIncSubtensor(op, node, **kwargs): + idx_list = getattr(op, "idx_list", None) inplace = op.inplace ignore_duplicates = getattr(op, "ignore_duplicates", False) if op.set_instead_of_inc: - def adv_set_subtensor(x, y, *indices): + def adv_set_subtensor(x, y, *flattened_indices): + indices = indices_from_subtensor(flattened_indices, idx_list) check_negative_steps(indices) if isinstance(op, AdvancedIncSubtensor1): op._check_runtime_broadcasting(node, x, y, indices) @@ -120,7 +125,8 @@ def adv_set_subtensor(x, y, *indices): elif ignore_duplicates: - def adv_inc_subtensor_no_duplicates(x, y, *indices): + def adv_inc_subtensor_no_duplicates(x, y, *flattened_indices): + indices = indices_from_subtensor(flattened_indices, idx_list) check_negative_steps(indices) if isinstance(op, AdvancedIncSubtensor1): op._check_runtime_broadcasting(node, x, y, indices) @@ -132,13 +138,16 @@ def adv_inc_subtensor_no_duplicates(x, y, *indices): return adv_inc_subtensor_no_duplicates else: - if any(isinstance(idx.type, SliceType) for idx in node.inputs[2:]): + # Check if we have slice indexing in idx_list + has_slice_indexing = any(isinstance(entry, slice) for entry in idx_list) if idx_list else False + if has_slice_indexing: raise NotImplementedError( "IncSubtensor with potential duplicates indexes and slice indexing not implemented in PyTorch" ) - def adv_inc_subtensor(x, y, *indices): - # Not needed because slices aren't supported + def adv_inc_subtensor(x, y, *flattened_indices): + indices = indices_from_subtensor(flattened_indices, idx_list) + # Not needed because slices aren't supported in this path # check_negative_steps(indices) if not inplace: x = x.clone() diff --git a/pytensor/tensor/rewriting/subtensor.py b/pytensor/tensor/rewriting/subtensor.py index fbe97b9a68..599e3497d3 100644 --- a/pytensor/tensor/rewriting/subtensor.py +++ b/pytensor/tensor/rewriting/subtensor.py @@ -228,7 +228,18 @@ def local_replace_AdvancedSubtensor(fgraph, node): return indexed_var = node.inputs[0] - indices = node.inputs[1:] + tensor_inputs = node.inputs[1:] + + # Reconstruct indices from idx_list and tensor inputs + indices = [] + input_idx = 0 + for entry in node.op.idx_list: + if isinstance(entry, slice): + indices.append(entry) + elif isinstance(entry, Type): + if input_idx < len(tensor_inputs): + indices.append(tensor_inputs[input_idx]) + input_idx += 1 axis = get_advsubtensor_axis(indices) @@ -255,7 +266,18 @@ def local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1(fgraph, node): res = node.inputs[0] val = node.inputs[1] - indices = node.inputs[2:] + tensor_inputs = node.inputs[2:] + + # Reconstruct indices from idx_list and tensor inputs + indices = [] + input_idx = 0 + for entry in node.op.idx_list: + if isinstance(entry, slice): + indices.append(entry) + elif isinstance(entry, Type): + if input_idx < len(tensor_inputs): + indices.append(tensor_inputs[input_idx]) + input_idx += 1 axis = get_advsubtensor_axis(indices) @@ -1751,9 +1773,22 @@ def ravel_multidimensional_bool_idx(fgraph, node): x[eye(3, dtype=bool)].set(y) -> x.ravel()[eye(3).ravel()].set(y).reshape(x.shape) """ if isinstance(node.op, AdvancedSubtensor): - x, *idxs = node.inputs + x = node.inputs[0] + tensor_inputs = node.inputs[1:] else: - x, y, *idxs = node.inputs + x, y = node.inputs[0], node.inputs[1] + tensor_inputs = node.inputs[2:] + + # Reconstruct indices from idx_list and tensor inputs + idxs = [] + input_idx = 0 + for entry in node.op.idx_list: + if isinstance(entry, slice): + idxs.append(entry) + elif isinstance(entry, Type): + if input_idx < len(tensor_inputs): + idxs.append(tensor_inputs[input_idx]) + input_idx += 1 if any( ( @@ -1791,12 +1826,41 @@ def ravel_multidimensional_bool_idx(fgraph, node): new_idxs[bool_idx_pos] = raveled_bool_idx if isinstance(node.op, AdvancedSubtensor): - new_out = node.op(raveled_x, *new_idxs) + # Create new AdvancedSubtensor with updated idx_list + new_idx_list = list(node.op.idx_list) + new_tensor_inputs = list(tensor_inputs) + + # Update the idx_list and tensor_inputs for the raveled boolean index + input_idx = 0 + for i, entry in enumerate(node.op.idx_list): + if isinstance(entry, Type): + if input_idx == bool_idx_pos: + new_tensor_inputs[input_idx] = raveled_bool_idx + input_idx += 1 + + new_out = AdvancedSubtensor(new_idx_list)(raveled_x, *new_tensor_inputs) else: + # Create new AdvancedIncSubtensor with updated idx_list + new_idx_list = list(node.op.idx_list) + new_tensor_inputs = list(tensor_inputs) + + # Update the tensor_inputs for the raveled boolean index + input_idx = 0 + for i, entry in enumerate(node.op.idx_list): + if isinstance(entry, Type): + if input_idx == bool_idx_pos: + new_tensor_inputs[input_idx] = raveled_bool_idx + input_idx += 1 + # The dimensions of y that correspond to the boolean indices # must already be raveled in the original graph, so we don't need to do anything to it - new_out = node.op(raveled_x, y, *new_idxs) - # But we must reshape the output to math the original shape + new_out = AdvancedIncSubtensor( + new_idx_list, + inplace=node.op.inplace, + set_instead_of_inc=node.op.set_instead_of_inc, + ignore_duplicates=node.op.ignore_duplicates + )(raveled_x, y, *new_tensor_inputs) + # But we must reshape the output to match the original shape new_out = new_out.reshape(x_shape) return [copy_stack_trace(node.outputs[0], new_out)] diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py index 0da34b6fd0..eeda92bccf 100644 --- a/pytensor/tensor/subtensor.py +++ b/pytensor/tensor/subtensor.py @@ -2585,10 +2585,12 @@ def __init__(self, idx_list): Parameters ---------- idx_list : tuple - List of indices where slices and newaxis are stored as-is, + List of indices where slices are stored as-is, and numerical indices are replaced by their types. """ self.idx_list = tuple(map(index_vars_to_types, idx_list)) + # Store expected number of tensor inputs for validation + self.expected_inputs_len = len(get_slice_elements(self.idx_list, lambda entry: isinstance(entry, Type))) def make_node(self, x, *inputs): """ @@ -2604,15 +2606,14 @@ def make_node(self, x, *inputs): inputs = tuple(as_tensor_variable(a) for a in inputs) idx_list = list(self.idx_list) - if (len([entry for entry in idx_list if entry is not np.newaxis]) > x.type.ndim): + if len(idx_list) > x.type.ndim: raise IndexError("too many indices for array") # Validate input count matches expected from idx_list - expected_inputs = get_slice_elements(idx_list, lambda entry: isinstance(entry, Type)) - if len(inputs) != len(expected_inputs): - raise ValueError(f"Expected {len(expected_inputs)} inputs but got {len(inputs)}") + if len(inputs) != self.expected_inputs_len: + raise ValueError(f"Expected {self.expected_inputs_len} inputs but got {len(inputs)}") - # Build explicit_indices for shape inference (newaxis handled by __getitem__) + # Build explicit_indices for shape inference explicit_indices = [] input_idx = 0 @@ -2982,6 +2983,8 @@ def __init__( self, idx_list, inplace=False, set_instead_of_inc=False, ignore_duplicates=False ): self.idx_list = tuple(map(index_vars_to_types, idx_list)) + # Store expected number of tensor inputs for validation + self.expected_inputs_len = len(get_slice_elements(self.idx_list, lambda entry: isinstance(entry, Type))) self.set_instead_of_inc = set_instead_of_inc self.inplace = inplace if inplace: @@ -3000,9 +3003,8 @@ def make_node(self, x, y, *inputs): y = as_tensor_variable(y) # Validate that we have the right number of tensor inputs for our idx_list - expected_tensor_inputs = sum(1 for entry in self.idx_list if isinstance(entry, Type)) - if len(inputs) != expected_tensor_inputs: - raise ValueError(f"Expected {expected_tensor_inputs} tensor inputs but got {len(inputs)}") + if len(inputs) != self.expected_inputs_len: + raise ValueError(f"Expected {self.expected_inputs_len} tensor inputs but got {len(inputs)}") new_inputs = [] for inp in inputs: From a4a305c35cbb983ae5659367aa2089508f9c89fe Mon Sep 17 00:00:00 2001 From: Jaan Erik Pihel Date: Thu, 4 Dec 2025 12:24:07 +0200 Subject: [PATCH 7/7] Finish Copilot code --- pytensor/link/jax/dispatch/subtensor.py | 33 +- pytensor/link/numba/dispatch/subtensor.py | 23 +- pytensor/link/pytorch/dispatch/subtensor.py | 12 +- pytensor/sparse/basic.py | 124 +++++- pytensor/sparse/math.py | 124 +----- pytensor/tensor/basic.py | 27 ++ pytensor/tensor/rewriting/subtensor.py | 63 ++- pytensor/tensor/subtensor.py | 455 +++++++++++++++----- pytensor/tensor/variable.py | 98 +++-- tests/tensor/test_subtensor.py | 14 +- 10 files changed, 646 insertions(+), 327 deletions(-) diff --git a/pytensor/link/jax/dispatch/subtensor.py b/pytensor/link/jax/dispatch/subtensor.py index cd8f78575a..3658717e51 100644 --- a/pytensor/link/jax/dispatch/subtensor.py +++ b/pytensor/link/jax/dispatch/subtensor.py @@ -31,11 +31,18 @@ """ +@jax_funcify.register(AdvancedSubtensor1) +def jax_funcify_AdvancedSubtensor1(op, node, **kwargs): + def advanced_subtensor1(x, ilist): + return x[ilist] + + return advanced_subtensor1 + + @jax_funcify.register(Subtensor) @jax_funcify.register(AdvancedSubtensor) -@jax_funcify.register(AdvancedSubtensor1) def jax_funcify_Subtensor(op, node, **kwargs): - idx_list = getattr(op, "idx_list", None) + idx_list = op.idx_list def subtensor(x, *ilists): indices = indices_from_subtensor(ilists, idx_list) @@ -47,10 +54,24 @@ def subtensor(x, *ilists): return subtensor -@jax_funcify.register(IncSubtensor) @jax_funcify.register(AdvancedIncSubtensor1) +def jax_funcify_AdvancedIncSubtensor1(op, node, **kwargs): + if getattr(op, "set_instead_of_inc", False): + + def jax_fn(x, y, ilist): + return x.at[ilist].set(y) + + else: + + def jax_fn(x, y, ilist): + return x.at[ilist].add(y) + + return jax_fn + + +@jax_funcify.register(IncSubtensor) def jax_funcify_IncSubtensor(op, node, **kwargs): - idx_list = getattr(op, "idx_list", None) + idx_list = op.idx_list if getattr(op, "set_instead_of_inc", False): @@ -77,8 +98,8 @@ def incsubtensor(x, y, *ilist, jax_fn=jax_fn, idx_list=idx_list): @jax_funcify.register(AdvancedIncSubtensor) def jax_funcify_AdvancedIncSubtensor(op, node, **kwargs): - idx_list = getattr(op, "idx_list", None) - + idx_list = op.idx_list + if getattr(op, "set_instead_of_inc", False): def jax_fn(x, indices, y): diff --git a/pytensor/link/numba/dispatch/subtensor.py b/pytensor/link/numba/dispatch/subtensor.py index 7e7353f60e..3d4bc1f185 100644 --- a/pytensor/link/numba/dispatch/subtensor.py +++ b/pytensor/link/numba/dispatch/subtensor.py @@ -20,7 +20,6 @@ ) from pytensor.link.numba.dispatch.compile_ops import numba_deepcopy from pytensor.tensor import TensorType -from pytensor.tensor.rewriting.subtensor import is_full_slice from pytensor.tensor.subtensor import ( AdvancedIncSubtensor, AdvancedIncSubtensor1, @@ -29,7 +28,7 @@ IncSubtensor, Subtensor, ) -from pytensor.tensor.type_other import MakeSlice, NoneTypeT, SliceType +from pytensor.tensor.type_other import MakeSlice def slice_new(self, start, stop, step): @@ -239,15 +238,15 @@ def {function_name}({", ".join(input_names)}): @register_funcify_and_cache_key(AdvancedIncSubtensor) def numba_funcify_AdvancedSubtensor(op, node, **kwargs): if isinstance(op, AdvancedSubtensor): - x, y, tensor_inputs = node.inputs[0], None, node.inputs[1:] + tensor_inputs = node.inputs[1:] else: - x, y, *tensor_inputs = node.inputs + tensor_inputs = node.inputs[2:] # Reconstruct indexing information from idx_list and tensor inputs basic_idxs = [] adv_idxs = [] input_idx = 0 - + for i, entry in enumerate(op.idx_list): if isinstance(entry, slice): # Basic slice index @@ -256,12 +255,14 @@ def numba_funcify_AdvancedSubtensor(op, node, **kwargs): # Advanced tensor index if input_idx < len(tensor_inputs): idx_input = tensor_inputs[input_idx] - adv_idxs.append({ - "axis": i, - "dtype": idx_input.type.dtype, - "bcast": idx_input.type.broadcastable, - "ndim": idx_input.type.ndim, - }) + adv_idxs.append( + { + "axis": i, + "dtype": idx_input.type.dtype, + "bcast": idx_input.type.broadcastable, + "ndim": idx_input.type.ndim, + } + ) input_idx += 1 # Special implementation for consecutive integer vector indices diff --git a/pytensor/link/pytorch/dispatch/subtensor.py b/pytensor/link/pytorch/dispatch/subtensor.py index 786ec46fe4..9a5e4b2ce1 100644 --- a/pytensor/link/pytorch/dispatch/subtensor.py +++ b/pytensor/link/pytorch/dispatch/subtensor.py @@ -9,7 +9,7 @@ Subtensor, indices_from_subtensor, ) -from pytensor.tensor.type_other import MakeSlice, SliceType +from pytensor.tensor.type_other import MakeSlice def check_negative_steps(indices): @@ -63,8 +63,8 @@ def makeslice(start, stop, step): @pytorch_funcify.register(AdvancedSubtensor1) @pytorch_funcify.register(AdvancedSubtensor) def pytorch_funcify_AdvSubtensor(op, node, **kwargs): - idx_list = getattr(op, "idx_list", None) - + idx_list = op.idx_list + def advsubtensor(x, *flattened_indices): indices = indices_from_subtensor(flattened_indices, idx_list) check_negative_steps(indices) @@ -105,7 +105,7 @@ def inc_subtensor(x, y, *flattened_indices): @pytorch_funcify.register(AdvancedIncSubtensor) @pytorch_funcify.register(AdvancedIncSubtensor1) def pytorch_funcify_AdvancedIncSubtensor(op, node, **kwargs): - idx_list = getattr(op, "idx_list", None) + idx_list = op.idx_list inplace = op.inplace ignore_duplicates = getattr(op, "ignore_duplicates", False) @@ -139,7 +139,9 @@ def adv_inc_subtensor_no_duplicates(x, y, *flattened_indices): else: # Check if we have slice indexing in idx_list - has_slice_indexing = any(isinstance(entry, slice) for entry in idx_list) if idx_list else False + has_slice_indexing = ( + any(isinstance(entry, slice) for entry in idx_list) if idx_list else False + ) if has_slice_indexing: raise NotImplementedError( "IncSubtensor with potential duplicates indexes and slice indexing not implemented in PyTorch" diff --git a/pytensor/sparse/basic.py b/pytensor/sparse/basic.py index 60ac79f149..75a0d4d855 100644 --- a/pytensor/sparse/basic.py +++ b/pytensor/sparse/basic.py @@ -1317,8 +1317,6 @@ def perform(self, node, inputs, outputs): z[0] = y def grad(self, inputs, gout): - from pytensor.sparse.math import sp_sum - (x, s) = inputs (gz,) = gout return [col_scale(gz, s), sp_sum(x * gz, axis=0)] @@ -1368,8 +1366,6 @@ def perform(self, node, inputs, outputs): z[0] = scipy.sparse.csc_matrix((y_data, indices, indptr), (M, N)) def grad(self, inputs, gout): - from pytensor.sparse.math import sp_sum - (x, s) = inputs (gz,) = gout return [row_scale(gz, s), sp_sum(x * gz, axis=1)] @@ -1435,6 +1431,126 @@ def row_scale(x, s): return col_scale(x.T, s).T +class SpSum(Op): + """ + + WARNING: judgement call... + We are not using the structured in the comparison or hashing + because it doesn't change the perform method therefore, we + *do* want Sums with different structured values to be merged + by the merge optimization and this requires them to compare equal. + """ + + __props__ = ("axis",) + + def __init__(self, axis=None, sparse_grad=True): + super().__init__() + self.axis = axis + self.structured = sparse_grad + if self.axis not in (None, 0, 1): + raise ValueError("Illegal value for self.axis.") + + def make_node(self, x): + x = as_sparse_variable(x) + assert x.format in ("csr", "csc") + + if self.axis is not None: + out_shape = (None,) + else: + out_shape = () + + z = TensorType(dtype=x.dtype, shape=out_shape)() + return Apply(self, [x], [z]) + + def perform(self, node, inputs, outputs): + (x,) = inputs + (z,) = outputs + if self.axis is None: + z[0] = np.asarray(x.sum()) + else: + z[0] = np.asarray(x.sum(self.axis)).ravel() + + def grad(self, inputs, gout): + (x,) = inputs + (gz,) = gout + if x.dtype not in continuous_dtypes: + return [x.zeros_like(dtype=config.floatX)] + if self.structured: + if self.axis is None: + r = gz * sp_ones_like(x) + elif self.axis == 0: + r = col_scale(sp_ones_like(x), gz) + elif self.axis == 1: + r = row_scale(sp_ones_like(x), gz) + else: + raise ValueError("Illegal value for self.axis.") + else: + o_format = x.format + x = dense_from_sparse(x) + if _is_sparse_variable(gz): + gz = dense_from_sparse(gz) + if self.axis is None: + r = ptb.second(x, gz) + else: + ones = ptb.ones_like(x) + if self.axis == 0: + r = specify_broadcastable(gz.dimshuffle("x", 0), 0) * ones + elif self.axis == 1: + r = specify_broadcastable(gz.dimshuffle(0, "x"), 1) * ones + else: + raise ValueError("Illegal value for self.axis.") + r = SparseFromDense(o_format)(r) + return [r] + + def infer_shape(self, fgraph, node, shapes): + r = None + if self.axis is None: + r = [()] + elif self.axis == 0: + r = [(shapes[0][1],)] + else: + r = [(shapes[0][0],)] + return r + + def __str__(self): + return f"{self.__class__.__name__}{{axis={self.axis}}}" + + +def sp_sum(x, axis=None, sparse_grad=False): + """ + Calculate the sum of a sparse matrix along the specified axis. + + It operates a reduction along the specified axis. When `axis` is `None`, + it is applied along all axes. + + Parameters + ---------- + x + Sparse matrix. + axis + Axis along which the sum is applied. Integer or `None`. + sparse_grad : bool + `True` to have a structured grad. + + Returns + ------- + object + The sum of `x` in a dense format. + + Notes + ----- + The grad implementation is controlled with the `sparse_grad` parameter. + `True` will provide a structured grad and `False` will provide a regular + grad. For both choices, the grad returns a sparse matrix having the same + format as `x`. + + This op does not return a sparse matrix, but a dense tensor matrix. + + """ + + return SpSum(axis, sparse_grad)(x) + + class Diag(Op): """Extract the diagonal of a square sparse matrix as a dense vector. diff --git a/pytensor/sparse/math.py b/pytensor/sparse/math.py index 972de80d89..3dffa12265 100644 --- a/pytensor/sparse/math.py +++ b/pytensor/sparse/math.py @@ -8,11 +8,11 @@ import pytensor.sparse.basic as psb import pytensor.tensor.basic as ptb import pytensor.tensor.math as ptm -from pytensor import config from pytensor.gradient import grad_not_implemented from pytensor.graph import Apply, Op from pytensor.link.c.op import COp -from pytensor.tensor import TensorType, Variable, specify_broadcastable, tensor +from pytensor.sparse.basic import sp_sum +from pytensor.tensor import TensorType, Variable, tensor from pytensor.tensor.type import complex_dtypes @@ -295,126 +295,6 @@ def conjugate(x): structured_conjugate = conjugate -class SpSum(Op): - """ - - WARNING: judgement call... - We are not using the structured in the comparison or hashing - because it doesn't change the perform method therefore, we - *do* want Sums with different structured values to be merged - by the merge optimization and this requires them to compare equal. - """ - - __props__ = ("axis",) - - def __init__(self, axis=None, sparse_grad=True): - super().__init__() - self.axis = axis - self.structured = sparse_grad - if self.axis not in (None, 0, 1): - raise ValueError("Illegal value for self.axis.") - - def make_node(self, x): - x = psb.as_sparse_variable(x) - assert x.format in ("csr", "csc") - - if self.axis is not None: - out_shape = (None,) - else: - out_shape = () - - z = TensorType(dtype=x.dtype, shape=out_shape)() - return Apply(self, [x], [z]) - - def perform(self, node, inputs, outputs): - (x,) = inputs - (z,) = outputs - if self.axis is None: - z[0] = np.asarray(x.sum()) - else: - z[0] = np.asarray(x.sum(self.axis)).ravel() - - def grad(self, inputs, gout): - (x,) = inputs - (gz,) = gout - if x.dtype not in psb.continuous_dtypes: - return [x.zeros_like(dtype=config.floatX)] - if self.structured: - if self.axis is None: - r = gz * psb.sp_ones_like(x) - elif self.axis == 0: - r = psb.col_scale(psb.sp_ones_like(x), gz) - elif self.axis == 1: - r = psb.row_scale(psb.sp_ones_like(x), gz) - else: - raise ValueError("Illegal value for self.axis.") - else: - o_format = x.format - x = psb.dense_from_sparse(x) - if psb._is_sparse_variable(gz): - gz = psb.dense_from_sparse(gz) - if self.axis is None: - r = ptb.second(x, gz) - else: - ones = ptb.ones_like(x) - if self.axis == 0: - r = specify_broadcastable(gz.dimshuffle("x", 0), 0) * ones - elif self.axis == 1: - r = specify_broadcastable(gz.dimshuffle(0, "x"), 1) * ones - else: - raise ValueError("Illegal value for self.axis.") - r = psb.SparseFromDense(o_format)(r) - return [r] - - def infer_shape(self, fgraph, node, shapes): - r = None - if self.axis is None: - r = [()] - elif self.axis == 0: - r = [(shapes[0][1],)] - else: - r = [(shapes[0][0],)] - return r - - def __str__(self): - return f"{self.__class__.__name__}{{axis={self.axis}}}" - - -def sp_sum(x, axis=None, sparse_grad=False): - """ - Calculate the sum of a sparse matrix along the specified axis. - - It operates a reduction along the specified axis. When `axis` is `None`, - it is applied along all axes. - - Parameters - ---------- - x - Sparse matrix. - axis - Axis along which the sum is applied. Integer or `None`. - sparse_grad : bool - `True` to have a structured grad. - - Returns - ------- - object - The sum of `x` in a dense format. - - Notes - ----- - The grad implementation is controlled with the `sparse_grad` parameter. - `True` will provide a structured grad and `False` will provide a regular - grad. For both choices, the grad returns a sparse matrix having the same - format as `x`. - - This op does not return a sparse matrix, but a dense tensor matrix. - - """ - - return SpSum(axis, sparse_grad)(x) - - class AddSS(Op): # add(sparse, sparse). # see the doc of add() for more detail. diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index e789659474..9bb31482c4 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -1818,6 +1818,33 @@ def do_constant_folding(self, fgraph, node): return True +@_vectorize_node.register(Alloc) +def vectorize_alloc(op: Alloc, node: Apply, batch_val, *batch_shapes): + # batch_shapes are usually not batched (they are scalars for the shape) + # batch_val is the value being allocated. + + # If shapes are batched, we fall back (complex case) + if any( + b_shp.type.ndim > shp.type.ndim + for b_shp, shp in zip(batch_shapes, node.inputs[1:], strict=True) + ): + return vectorize_node_fallback(op, node, batch_val, *batch_shapes) + + # If value is batched, we need to prepend batch dims to the output shape + val = node.inputs[0] + batch_ndim = batch_val.type.ndim - val.type.ndim + + if batch_ndim == 0: + return op.make_node(batch_val, *batch_shapes) + + # We need the size of the batch dimensions + # batch_val has shape (B1, B2, ..., val_dims...) + batch_dims = [batch_val.shape[i] for i in range(batch_ndim)] + + new_shapes = batch_dims + list(batch_shapes) + return op.make_node(batch_val, *new_shapes) + + alloc = Alloc() pprint.assign(alloc, printing.FunctionPrinter(["alloc"])) diff --git a/pytensor/tensor/rewriting/subtensor.py b/pytensor/tensor/rewriting/subtensor.py index 599e3497d3..b031d30ae6 100644 --- a/pytensor/tensor/rewriting/subtensor.py +++ b/pytensor/tensor/rewriting/subtensor.py @@ -14,6 +14,7 @@ in2out, node_rewriter, ) +from pytensor.graph.type import Type from pytensor.raise_op import Assert from pytensor.scalar import Add, ScalarConstant, ScalarType from pytensor.scalar import constant as scalar_constant @@ -212,6 +213,20 @@ def get_advsubtensor_axis(indices): return axis +def reconstruct_indices(idx_list, tensor_inputs): + """Reconstruct indices from idx_list and tensor inputs.""" + indices = [] + input_idx = 0 + for entry in idx_list: + if isinstance(entry, slice): + indices.append(entry) + elif isinstance(entry, Type): + if input_idx < len(tensor_inputs): + indices.append(tensor_inputs[input_idx]) + input_idx += 1 + return indices + + @register_specialize @node_rewriter([AdvancedSubtensor]) def local_replace_AdvancedSubtensor(fgraph, node): @@ -229,17 +244,9 @@ def local_replace_AdvancedSubtensor(fgraph, node): indexed_var = node.inputs[0] tensor_inputs = node.inputs[1:] - + # Reconstruct indices from idx_list and tensor inputs - indices = [] - input_idx = 0 - for entry in node.op.idx_list: - if isinstance(entry, slice): - indices.append(entry) - elif isinstance(entry, Type): - if input_idx < len(tensor_inputs): - indices.append(tensor_inputs[input_idx]) - input_idx += 1 + indices = reconstruct_indices(node.op.idx_list, tensor_inputs) axis = get_advsubtensor_axis(indices) @@ -267,17 +274,9 @@ def local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1(fgraph, node): res = node.inputs[0] val = node.inputs[1] tensor_inputs = node.inputs[2:] - + # Reconstruct indices from idx_list and tensor inputs - indices = [] - input_idx = 0 - for entry in node.op.idx_list: - if isinstance(entry, slice): - indices.append(entry) - elif isinstance(entry, Type): - if input_idx < len(tensor_inputs): - indices.append(tensor_inputs[input_idx]) - input_idx += 1 + indices = reconstruct_indices(node.op.idx_list, tensor_inputs) axis = get_advsubtensor_axis(indices) @@ -1112,6 +1111,7 @@ def local_inplace_AdvancedIncSubtensor1(fgraph, node): def local_inplace_AdvancedIncSubtensor(fgraph, node): if isinstance(node.op, AdvancedIncSubtensor) and not node.op.inplace: new_op = type(node.op)( + node.op.idx_list, inplace=True, set_instead_of_inc=node.op.set_instead_of_inc, ignore_duplicates=node.op.ignore_duplicates, @@ -1376,6 +1376,7 @@ def local_useless_inc_subtensor_alloc(fgraph, node): z_broad[k] and not same_shape(xi, y, dim_x=k, dim_y=k) and shape_of[y][k] != 1 + and shape_of[xi][k] == 1 ) ] @@ -1778,17 +1779,9 @@ def ravel_multidimensional_bool_idx(fgraph, node): else: x, y = node.inputs[0], node.inputs[1] tensor_inputs = node.inputs[2:] - + # Reconstruct indices from idx_list and tensor inputs - idxs = [] - input_idx = 0 - for entry in node.op.idx_list: - if isinstance(entry, slice): - idxs.append(entry) - elif isinstance(entry, Type): - if input_idx < len(tensor_inputs): - idxs.append(tensor_inputs[input_idx]) - input_idx += 1 + idxs = reconstruct_indices(node.op.idx_list, tensor_inputs) if any( ( @@ -1829,7 +1822,7 @@ def ravel_multidimensional_bool_idx(fgraph, node): # Create new AdvancedSubtensor with updated idx_list new_idx_list = list(node.op.idx_list) new_tensor_inputs = list(tensor_inputs) - + # Update the idx_list and tensor_inputs for the raveled boolean index input_idx = 0 for i, entry in enumerate(node.op.idx_list): @@ -1837,13 +1830,13 @@ def ravel_multidimensional_bool_idx(fgraph, node): if input_idx == bool_idx_pos: new_tensor_inputs[input_idx] = raveled_bool_idx input_idx += 1 - + new_out = AdvancedSubtensor(new_idx_list)(raveled_x, *new_tensor_inputs) else: # Create new AdvancedIncSubtensor with updated idx_list new_idx_list = list(node.op.idx_list) new_tensor_inputs = list(tensor_inputs) - + # Update the tensor_inputs for the raveled boolean index input_idx = 0 for i, entry in enumerate(node.op.idx_list): @@ -1851,14 +1844,14 @@ def ravel_multidimensional_bool_idx(fgraph, node): if input_idx == bool_idx_pos: new_tensor_inputs[input_idx] = raveled_bool_idx input_idx += 1 - + # The dimensions of y that correspond to the boolean indices # must already be raveled in the original graph, so we don't need to do anything to it new_out = AdvancedIncSubtensor( new_idx_list, inplace=node.op.inplace, set_instead_of_inc=node.op.set_instead_of_inc, - ignore_duplicates=node.op.ignore_duplicates + ignore_duplicates=node.op.ignore_duplicates, )(raveled_x, y, *new_tensor_inputs) # But we must reshape the output to match the original shape new_out = new_out.reshape(x_shape) diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py index eeda92bccf..4467300e66 100644 --- a/pytensor/tensor/subtensor.py +++ b/pytensor/tensor/subtensor.py @@ -1,3 +1,4 @@ +import copy import logging import sys import warnings @@ -63,7 +64,6 @@ from pytensor.tensor.type_other import ( MakeSlice, NoneConst, - NoneSliceConst, NoneTypeT, SliceConstant, SliceType, @@ -706,7 +706,7 @@ def helper(entry): return ret -def index_vars_to_types(entry, slice_ok=True): +def index_vars_to_types(entry, slice_ok=True, allow_advanced=False): r"""Change references to `Variable`s into references to `Type`s. The `Subtensor.idx_list` field is unique to each `Subtensor` instance. It @@ -717,12 +717,13 @@ def index_vars_to_types(entry, slice_ok=True): when would that happen? """ - if ( - isinstance(entry, np.ndarray | Variable) - and hasattr(entry, "dtype") - and entry.dtype == "bool" - ): - raise AdvancedIndexingError("Invalid index type or slice for Subtensor") + if not allow_advanced: + if ( + isinstance(entry, np.ndarray | Variable) + and hasattr(entry, "dtype") + and entry.dtype == "bool" + ): + raise AdvancedIndexingError("Invalid index type or slice for Subtensor") if isinstance(entry, Variable) and ( entry.type in invalid_scal_types or entry.type in invalid_tensor_types @@ -742,13 +743,37 @@ def index_vars_to_types(entry, slice_ok=True): return ps.get_scalar_type(entry.type.dtype) elif isinstance(entry, Type) and entry in tensor_types and all(entry.broadcastable): return ps.get_scalar_type(entry.dtype) + elif ( + allow_advanced + and isinstance(entry, Variable) + and isinstance(entry.type, TensorType) + ): + return entry.type + elif allow_advanced and isinstance(entry, TensorType): + return entry + elif ( + allow_advanced + and isinstance(entry, Variable) + and isinstance(entry.type, SliceType) + ): + return entry.type + elif allow_advanced and isinstance(entry, SliceType): + return entry + elif ( + allow_advanced + and isinstance(entry, Variable) + and isinstance(entry.type, SliceType) + ): + return entry.type + elif allow_advanced and isinstance(entry, SliceType): + return entry elif slice_ok and isinstance(entry, slice): a = entry.start b = entry.stop c = entry.step if a is not None: - slice_a = index_vars_to_types(a, False) + slice_a = index_vars_to_types(a, False, allow_advanced) else: slice_a = None @@ -756,18 +781,18 @@ def index_vars_to_types(entry, slice_ok=True): # The special "maxsize" case is probably not needed here, # as slices containing maxsize are not generated by # __getslice__ anymore. - slice_b = index_vars_to_types(b, False) + slice_b = index_vars_to_types(b, False, allow_advanced) else: slice_b = None if c is not None: - slice_c = index_vars_to_types(c, False) + slice_c = index_vars_to_types(c, False, allow_advanced) else: slice_c = None return slice(slice_a, slice_b, slice_c) elif isinstance(entry, int | np.integer): - raise TypeError() + return entry else: raise AdvancedIndexingError("Invalid index type or slice for Subtensor") @@ -1564,7 +1589,10 @@ def inc_subtensor( ilist = x.owner.inputs[1] if ignore_duplicates: the_op = AdvancedIncSubtensor( - inplace, set_instead_of_inc=set_instead_of_inc, ignore_duplicates=True + [ilist], + inplace, + set_instead_of_inc=set_instead_of_inc, + ignore_duplicates=True, ) else: the_op = AdvancedIncSubtensor1( @@ -1575,6 +1603,7 @@ def inc_subtensor( real_x = x.owner.inputs[0] ilist = x.owner.inputs[1:] the_op = AdvancedIncSubtensor( + x.owner.op.idx_list, inplace, set_instead_of_inc=set_instead_of_inc, ignore_duplicates=ignore_duplicates, @@ -2581,16 +2610,31 @@ class AdvancedSubtensor(Op): def __init__(self, idx_list): """ Initialize AdvancedSubtensor with index list. - + Parameters ---------- idx_list : tuple List of indices where slices are stored as-is, and numerical indices are replaced by their types. """ - self.idx_list = tuple(map(index_vars_to_types, idx_list)) + self.idx_list = tuple( + index_vars_to_types(idx, allow_advanced=True) for idx in idx_list + ) # Store expected number of tensor inputs for validation - self.expected_inputs_len = len(get_slice_elements(self.idx_list, lambda entry: isinstance(entry, Type))) + self.expected_inputs_len = len( + get_slice_elements(self.idx_list, lambda entry: isinstance(entry, Type)) + ) + + def __hash__(self): + msg = [] + for entry in self.idx_list: + if isinstance(entry, slice): + msg += [(entry.start, entry.stop, entry.step)] + else: + msg += [entry] + + idx_list = tuple(msg) + return hash((type(self), idx_list)) def make_node(self, x, *inputs): """ @@ -2603,7 +2647,13 @@ def make_node(self, x, *inputs): """ x = as_tensor_variable(x) - inputs = tuple(as_tensor_variable(a) for a in inputs) + processed_inputs = [] + for a in inputs: + if isinstance(a, Variable) and isinstance(a.type, SliceType): + processed_inputs.append(a) + else: + processed_inputs.append(as_tensor_variable(a)) + inputs = tuple(processed_inputs) idx_list = list(self.idx_list) if len(idx_list) > x.type.ndim: @@ -2611,12 +2661,14 @@ def make_node(self, x, *inputs): # Validate input count matches expected from idx_list if len(inputs) != self.expected_inputs_len: - raise ValueError(f"Expected {self.expected_inputs_len} inputs but got {len(inputs)}") + raise ValueError( + f"Expected {self.expected_inputs_len} inputs but got {len(inputs)}" + ) # Build explicit_indices for shape inference explicit_indices = [] input_idx = 0 - + for i, entry in enumerate(idx_list): if isinstance(entry, slice): # Reconstruct slice with actual values from inputs @@ -2625,27 +2677,27 @@ def make_node(self, x, *inputs): input_idx += 1 else: start_val = entry.start - + if entry.stop is not None and isinstance(entry.stop, Type): stop_val = inputs[input_idx] input_idx += 1 else: stop_val = entry.stop - + if entry.step is not None and isinstance(entry.step, Type): step_val = inputs[input_idx] input_idx += 1 else: step_val = entry.step - + explicit_indices.append(slice(start_val, stop_val, step_val)) elif isinstance(entry, Type): # This is a numerical index inp = inputs[input_idx] input_idx += 1 - + # Handle boolean indices - if inp.dtype == "bool": + if hasattr(inp, "dtype") and inp.dtype == "bool": if inp.type.ndim == 0: raise NotImplementedError( "Indexing with scalar booleans not supported" @@ -2668,7 +2720,9 @@ def make_node(self, x, *inputs): ) # Convert boolean indices to integer with nonzero if isinstance(inp, Constant): - nonzero_indices = [tensor_constant(i) for i in inp.data.nonzero()] + nonzero_indices = [ + tensor_constant(i) for i in inp.data.nonzero() + ] else: nonzero_indices = inp.nonzero() explicit_indices.extend(nonzero_indices) @@ -2693,6 +2747,8 @@ def make_node(self, x, *inputs): ): if isinstance(idx, slice): basic_group_shape.append(slice_static_length(idx, dim_length)) + elif isinstance(idx, Variable) and isinstance(idx.type, SliceType): + basic_group_shape.append(None) else: # TensorType (advanced index) # Keep track of advanced group axis if adv_group_axis is None: @@ -2746,10 +2802,10 @@ def is_bool_index(idx): # Reconstruct the full indices from idx_list and inputs (newaxis handled by __getitem__) inputs = node.inputs[1:] - + full_indices = [] input_idx = 0 - + for entry in self.idx_list: if isinstance(entry, slice): # Reconstruct slice from idx_list and inputs @@ -2758,19 +2814,19 @@ def is_bool_index(idx): input_idx += 1 else: start_val = entry.start - + if entry.stop is not None and isinstance(entry.stop, Type): stop_val = inputs[input_idx] input_idx += 1 else: stop_val = entry.stop - + if entry.step is not None and isinstance(entry.step, Type): step_val = inputs[input_idx] input_idx += 1 else: step_val = entry.step - + full_indices.append(slice(start_val, stop_val, step_val)) elif isinstance(entry, Type): # This is a numerical index - get from inputs @@ -2779,19 +2835,23 @@ def is_bool_index(idx): input_idx += 1 else: raise ValueError("Mismatch between idx_list and inputs") - + index_shapes = [] for idx in full_indices: if isinstance(idx, slice): index_shapes.append(idx) - elif hasattr(idx, 'type'): + elif isinstance(idx, Variable) and isinstance(idx.type, SliceType): + index_shapes.append(idx) + elif hasattr(idx, "type"): # Mixed bool indexes are converted to nonzero entries shape0_op = Shape_i(0) if is_bool_index(idx): index_shapes.extend((shape0_op(nz_dim),) for nz_dim in nonzero(idx)) else: # Get ishape for this input - input_shape_idx = inputs.index(idx) + 1 # +1 because ishapes[0] is x + input_shape_idx = ( + inputs.index(idx) + 1 + ) # +1 because ishapes[0] is x index_shapes.append(ishapes[input_shape_idx]) else: index_shapes.append(idx) @@ -2805,7 +2865,7 @@ def is_bool_index(idx): # We must compute the Op to find its shape res_shape[i] = Shape_i(i)(node.out) - adv_indices = [idx for idx in indices if not is_basic_idx(idx)] + adv_indices = [idx for idx in full_indices if not is_basic_idx(idx)] bool_indices = [idx for idx in adv_indices if is_bool_index(idx)] # Special logic when the only advanced index group is of bool type. @@ -2816,7 +2876,7 @@ def is_bool_index(idx): # Because there are no more advanced index groups, there is exactly # one output dim per index variable up to the bool group. # Note: Scalar integer indexing counts as advanced indexing. - start_dim = indices.index(bool_index) + start_dim = full_indices.index(bool_index) res_shape[start_dim] = bool_index.sum() assert node.outputs[0].ndim == len(res_shape) @@ -2824,14 +2884,14 @@ def is_bool_index(idx): def perform(self, node, inputs, out_): (out,) = out_ - + # Reconstruct the full tuple of indices from idx_list and inputs (newaxis handled by __getitem__) x = inputs[0] tensor_inputs = inputs[1:] - + full_indices = [] input_idx = 0 - + for entry in self.idx_list: if isinstance(entry, slice): # Reconstruct slice from idx_list and inputs @@ -2840,19 +2900,19 @@ def perform(self, node, inputs, out_): input_idx += 1 else: start_val = entry.start - + if entry.stop is not None and isinstance(entry.stop, Type): stop_val = tensor_inputs[input_idx] input_idx += 1 else: stop_val = entry.stop - + if entry.step is not None and isinstance(entry.step, Type): step_val = tensor_inputs[input_idx] input_idx += 1 else: step_val = entry.step - + full_indices.append(slice(start_val, stop_val, step_val)) elif isinstance(entry, Type): # This is a numerical index - get from inputs @@ -2861,14 +2921,35 @@ def perform(self, node, inputs, out_): input_idx += 1 else: raise ValueError("Mismatch between idx_list and inputs") - + check_advanced_indexing_dimensions(x, full_indices) - rval = x.__getitem__(tuple(full_indices)) + + # Handle runtime broadcasting for broadcastable dimensions + broadcastable = node.inputs[0].type.broadcastable + new_full_indices = [] + for i, idx in enumerate(full_indices): + if i < len(broadcastable) and broadcastable[i] and x.shape[i] == 1: + if isinstance(idx, np.ndarray | list | tuple): + # Replace with zeros of same shape to preserve output shape + if isinstance(idx, np.ndarray): + new_full_indices.append(np.zeros_like(idx)) + else: + arr = np.array(idx) + new_full_indices.append(np.zeros_like(arr)) + elif isinstance(idx, int | np.integer): + new_full_indices.append(0) + else: + # Slice or other + new_full_indices.append(idx) + else: + new_full_indices.append(idx) + + rval = x.__getitem__(tuple(new_full_indices)) # When there are no arrays, we are not actually doing advanced # indexing, so __getitem__ will not return a copy. # Since no view_map is set, we need to copy the returned value has_tensor_indices = any( - isinstance(entry, Type) and not getattr(entry, 'broadcastable', (False,))[0] + isinstance(entry, Type) and not getattr(entry, "broadcastable", (False,))[0] for entry in self.idx_list ) if not has_tensor_indices: @@ -2927,10 +3008,10 @@ def non_consecutive_adv_indexing(node: Apply) -> bool: # Reconstruct the full indices from idx_list and inputs to check consecutivity (newaxis handled by __getitem__) op = node.op tensor_inputs = node.inputs[1:] - + full_indices = [] input_idx = 0 - + for entry in op.idx_list: if isinstance(entry, slice): full_indices.append(slice(None)) # Represent as basic slice @@ -2939,7 +3020,7 @@ def non_consecutive_adv_indexing(node: Apply) -> bool: if input_idx < len(tensor_inputs): full_indices.append(tensor_inputs[input_idx]) input_idx += 1 - + return _non_consecutive_adv_indexing(full_indices) @@ -2980,17 +3061,52 @@ class AdvancedIncSubtensor(Op): __props__ = ("inplace", "set_instead_of_inc", "ignore_duplicates", "idx_list") def __init__( - self, idx_list, inplace=False, set_instead_of_inc=False, ignore_duplicates=False + self, + idx_list=None, + inplace=False, + set_instead_of_inc=False, + ignore_duplicates=False, ): - self.idx_list = tuple(map(index_vars_to_types, idx_list)) - # Store expected number of tensor inputs for validation - self.expected_inputs_len = len(get_slice_elements(self.idx_list, lambda entry: isinstance(entry, Type))) + if idx_list is not None: + self.idx_list = tuple( + index_vars_to_types(idx, allow_advanced=True) for idx in idx_list + ) + # Store expected number of tensor inputs for validation + self.expected_inputs_len = len( + get_slice_elements(self.idx_list, lambda entry: isinstance(entry, Type)) + ) + else: + self.idx_list = None + self.expected_inputs_len = None + self.set_instead_of_inc = set_instead_of_inc self.inplace = inplace if inplace: self.destroy_map = {0: [0]} self.ignore_duplicates = ignore_duplicates + def __hash__(self): + if self.idx_list is None: + idx_list = None + else: + msg = [] + for entry in self.idx_list: + if isinstance(entry, slice): + msg += [(entry.start, entry.stop, entry.step)] + else: + msg += [entry] + idx_list = tuple(msg) + + return hash( + ( + type(self), + idx_list, + self.inplace, + self.set_instead_of_inc, + self.ignore_duplicates, + ) + ) + def __str__(self): return ( "AdvancedSetSubtensor" @@ -3002,9 +3118,21 @@ def make_node(self, x, y, *inputs): x = as_tensor_variable(x) y = as_tensor_variable(y) + if self.idx_list is None: + # Infer idx_list from inputs + # This handles the case where AdvancedIncSubtensor is initialized without idx_list + # and used as a factory. + idx_list = [inp.type for inp in inputs] + new_op = copy.copy(self) + new_op.idx_list = tuple(idx_list) + new_op.expected_inputs_len = len(inputs) + return new_op.make_node(x, y, *inputs) + # Validate that we have the right number of tensor inputs for our idx_list if len(inputs) != self.expected_inputs_len: - raise ValueError(f"Expected {self.expected_inputs_len} tensor inputs but got {len(inputs)}") + raise ValueError( + f"Expected {self.expected_inputs_len} tensor inputs but got {len(inputs)}" + ) new_inputs = [] for inp in inputs: @@ -3023,7 +3151,7 @@ def perform(self, node, inputs, out_): # Reconstruct the full tuple of indices from idx_list and inputs (newaxis handled by __getitem__) full_indices = [] input_idx = 0 - + for entry in self.idx_list: if isinstance(entry, slice): # Reconstruct slice from idx_list and inputs @@ -3032,19 +3160,19 @@ def perform(self, node, inputs, out_): input_idx += 1 else: start_val = entry.start - + if entry.stop is not None and isinstance(entry.stop, Type): stop_val = tensor_inputs[input_idx] input_idx += 1 else: stop_val = entry.stop - + if entry.step is not None and isinstance(entry.step, Type): step_val = tensor_inputs[input_idx] input_idx += 1 else: step_val = entry.step - + full_indices.append(slice(start_val, stop_val, step_val)) elif isinstance(entry, Type): # This is a numerical index - get from inputs @@ -3053,7 +3181,7 @@ def perform(self, node, inputs, out_): input_idx += 1 else: raise ValueError("Mismatch between idx_list and inputs") - + check_advanced_indexing_dimensions(x, full_indices) (out,) = out_ @@ -3097,9 +3225,11 @@ def grad(self, inpt, output_gradients): raise NotImplementedError("No support for complex grad yet") else: if self.set_instead_of_inc: - gx = AdvancedIncSubtensor(self.idx_list, set_instead_of_inc=True).make_node( - outgrad, y.zeros_like(), *idxs - ).outputs[0] + gx = ( + AdvancedIncSubtensor(self.idx_list, set_instead_of_inc=True) + .make_node(outgrad, y.zeros_like(), *idxs) + .outputs[0] + ) else: gx = outgrad gy = AdvancedSubtensor(self.idx_list).make_node(outgrad, *idxs).outputs[0] @@ -3140,10 +3270,10 @@ def non_consecutive_adv_indexing(node: Apply) -> bool: # Reconstruct the full indices from idx_list and inputs to check consecutivity (newaxis handled by __getitem__) op = node.op tensor_inputs = node.inputs[2:] # Skip x and y - + full_indices = [] input_idx = 0 - + for entry in op.idx_list: if isinstance(entry, slice): full_indices.append(slice(None)) # Represent as basic slice @@ -3152,107 +3282,133 @@ def non_consecutive_adv_indexing(node: Apply) -> bool: if input_idx < len(tensor_inputs): full_indices.append(tensor_inputs[input_idx]) input_idx += 1 - + return _non_consecutive_adv_indexing(full_indices) def advanced_subtensor(x, *args): """Create an AdvancedSubtensor operation. - - This function converts the arguments to work with the new AdvancedSubtensor + + This function converts the arguments to work with the new AdvancedSubtensor interface that separates slice structure from variable inputs. - + Note: newaxis (None) should be handled by __getitem__ using dimshuffle before calling this function. """ # Convert args using as_index_variable (like original AdvancedSubtensor did) processed_args = tuple(map(as_index_variable, args)) - + # Now create idx_list and extract inputs idx_list = [] input_vars = [] - + for arg in processed_args: if isinstance(arg.type, SliceType): - # Handle SliceType - extract components and structure + # Handle SliceType - extract components and structure if isinstance(arg, Constant): # Constant slice idx_list.append(arg.data) elif arg.owner and isinstance(arg.owner.op, MakeSlice): # Variable slice - extract components start, stop, step = arg.owner.inputs - + # Convert components to types for idx_list - start_type = index_vars_to_types(start, False) if not isinstance(start.type, NoneTypeT) else None - stop_type = index_vars_to_types(stop, False) if not isinstance(stop.type, NoneTypeT) else None - step_type = index_vars_to_types(step, False) if not isinstance(step.type, NoneTypeT) else None - + start_type = ( + index_vars_to_types(start, False) + if not isinstance(start.type, NoneTypeT) + else None + ) + stop_type = ( + index_vars_to_types(stop, False) + if not isinstance(stop.type, NoneTypeT) + else None + ) + step_type = ( + index_vars_to_types(step, False) + if not isinstance(step.type, NoneTypeT) + else None + ) + idx_list.append(slice(start_type, stop_type, step_type)) - + # Add variable components to inputs if not isinstance(start.type, NoneTypeT): input_vars.append(start) if not isinstance(stop.type, NoneTypeT): - input_vars.append(stop) + input_vars.append(stop) if not isinstance(step.type, NoneTypeT): input_vars.append(step) else: - # Other slice case - idx_list.append(slice(None)) + # Generic SliceType variable + idx_list.append(arg.type) + input_vars.append(arg) else: # Tensor index (should not be NoneType since newaxis handled in __getitem__) - idx_list.append(index_vars_to_types(arg)) + idx_list.append(index_vars_to_types(arg, allow_advanced=True)) input_vars.append(arg) - - return AdvancedSubtensor(idx_list).make_node(x, *input_vars).outputs[0] + + return AdvancedSubtensor(idx_list)(x, *input_vars) def advanced_inc_subtensor(x, y, *args, **kwargs): """Create an AdvancedIncSubtensor operation for incrementing. - + Note: newaxis (None) should be handled by __getitem__ using dimshuffle before calling this function. """ # Convert args using as_index_variable (like original AdvancedIncSubtensor would) processed_args = tuple(map(as_index_variable, args)) - + # Now create idx_list and extract inputs idx_list = [] input_vars = [] - + for arg in processed_args: if isinstance(arg.type, SliceType): - # Handle SliceType - extract components and structure + # Handle SliceType - extract components and structure if isinstance(arg, Constant): # Constant slice idx_list.append(arg.data) elif arg.owner and isinstance(arg.owner.op, MakeSlice): # Variable slice - extract components start, stop, step = arg.owner.inputs - + # Convert components to types for idx_list - start_type = index_vars_to_types(start, False) if not isinstance(start.type, NoneTypeT) else None - stop_type = index_vars_to_types(stop, False) if not isinstance(stop.type, NoneTypeT) else None - step_type = index_vars_to_types(step, False) if not isinstance(step.type, NoneTypeT) else None - + start_type = ( + index_vars_to_types(start, False) + if not isinstance(start.type, NoneTypeT) + else None + ) + stop_type = ( + index_vars_to_types(stop, False) + if not isinstance(stop.type, NoneTypeT) + else None + ) + step_type = ( + index_vars_to_types(step, False) + if not isinstance(step.type, NoneTypeT) + else None + ) + idx_list.append(slice(start_type, stop_type, step_type)) - + # Add variable components to inputs if not isinstance(start.type, NoneTypeT): input_vars.append(start) if not isinstance(stop.type, NoneTypeT): - input_vars.append(stop) + input_vars.append(stop) if not isinstance(step.type, NoneTypeT): input_vars.append(step) else: - # Other slice case - idx_list.append(slice(None)) + # Generic SliceType variable + idx_list.append(arg.type) + input_vars.append(arg) else: # Tensor index (should not be NoneType since newaxis handled in __getitem__) - idx_list.append(index_vars_to_types(arg)) + idx_list.append(index_vars_to_types(arg, allow_advanced=True)) input_vars.append(arg) - - return AdvancedIncSubtensor(idx_list, **kwargs).make_node(x, y, *input_vars).outputs[0] + + return AdvancedIncSubtensor(idx_list, **kwargs)(x, y, *input_vars) def advanced_set_subtensor(x, y, *args, **kwargs): @@ -3457,3 +3613,108 @@ def flip( "slice_at_axis", "take", ] + + +@_vectorize_node.register(AdvancedIncSubtensor) +def vectorize_advanced_inc_subtensor(op: AdvancedIncSubtensor, node, *batch_inputs): + x, y, *idxs = node.inputs + batch_x, batch_y, *batch_idxs = batch_inputs + + x_is_batched = x.type.ndim < batch_x.type.ndim + idxs_are_batched = any( + batch_idx.type.ndim > idx.type.ndim + for batch_idx, idx in zip(batch_idxs, idxs, strict=True) + if isinstance(batch_idx, TensorVariable) + ) + + if idxs_are_batched or (x_is_batched and op.non_consecutive_adv_indexing(node)): + # Fallback to Blockwise if idxs are batched or if we have non contiguous advanced indexing + # which would put the indexed results to the left of the batch dimensions! + return vectorize_node_fallback(op, node, batch_x, batch_y, *batch_idxs) + # If y is batched more than x, we need to broadcast x to match y's batch dims + x_batch_ndim = batch_x.type.ndim - x.type.ndim + y_batch_ndim = batch_y.type.ndim - y.type.ndim + + # Ensure x has at least as many batch dims as y + if y_batch_ndim > x_batch_ndim: + diff = y_batch_ndim - x_batch_ndim + new_dims = (["x"] * diff) + list(range(batch_x.type.ndim)) + batch_x = batch_x.dimshuffle(new_dims) + x_batch_ndim = y_batch_ndim + + # Ensure x is broadcasted to match y's batch shape + # We use Alloc to broadcast batch_x to the required shape + if y_batch_ndim > 0: + # Optimization: check if broadcasting is needed + # This is hard to do symbolically without adding nodes. + # But we can check broadcastable flags. + + # Let's just use Alloc to be safe. + # batch_x might have shape (1, 1, 458). y has (1, 1000, ...). + # We want (1, 1000, 458). + # We can use alloc(batch_x, y_batch_shape[0], y_batch_shape[1], ..., *x.shape) + + # We need to unpack y_batch_shape. + # Since we don't know y_batch_ndim statically (it's int), we can't unpack easily in python arg list if it was variable. + # But y_batch_ndim is computed from types, so it is known at graph construction time. + + # Actually, we can use pt.broadcast_to if available, or just alloc. + # alloc takes *shape. + + # Let's collect shape tensors. + out_shape = [batch_y.shape[i] for i in range(y_batch_ndim)] + out_shape.extend(batch_x.shape[x_batch_ndim + i] for i in range(x.type.ndim)) + + batch_x = alloc(batch_x, *out_shape) + + # Otherwise we just need to add None slices for every new batch dim + empty_slices = (slice(None),) * x_batch_ndim + new_idx_list = empty_slices + op.idx_list + return AdvancedIncSubtensor( + new_idx_list, + inplace=op.inplace, + set_instead_of_inc=op.set_instead_of_inc, + ignore_duplicates=op.ignore_duplicates, + ).make_node(batch_x, batch_y, *batch_idxs) + + +@_vectorize_node.register(AdvancedIncSubtensor1) +def vectorize_advanced_inc_subtensor1(op: AdvancedIncSubtensor1, node, *batch_inputs): + x, y, idx = node.inputs + batch_x, batch_y, batch_idx = batch_inputs + + # x_is_batched = x.type.ndim < batch_x.type.ndim + idx_is_batched = idx.type.ndim < batch_idx.type.ndim + + if idx_is_batched: + return vectorize_node_fallback(op, node, batch_x, batch_y, batch_idx) + + # AdvancedIncSubtensor1 only supports indexing the first dimension. + # If x is batched, we can use AdvancedIncSubtensor which supports indexing any dimension. + x_batch_ndim = batch_x.type.ndim - x.type.ndim + y_batch_ndim = batch_y.type.ndim - y.type.ndim + + # Ensure x has at least as many batch dims as y + if y_batch_ndim > x_batch_ndim: + diff = y_batch_ndim - x_batch_ndim + new_dims = (["x"] * diff) + list(range(batch_x.type.ndim)) + batch_x = batch_x.dimshuffle(new_dims) + x_batch_ndim = y_batch_ndim + + # Ensure x is broadcasted to match y's batch shape + if y_batch_ndim > 0: + out_shape = [batch_y.shape[i] for i in range(y_batch_ndim)] + out_shape.extend(batch_x.shape[x_batch_ndim + i] for i in range(x.type.ndim)) + + batch_x = alloc(batch_x, *out_shape) + + empty_slices = (slice(None),) * x_batch_ndim + + # AdvancedIncSubtensor1 takes a single index tensor + new_idx_list = (*empty_slices, batch_idx.type) + + return AdvancedIncSubtensor( + new_idx_list, + inplace=op.inplace, + set_instead_of_inc=op.set_instead_of_inc, + ).make_node(batch_x, batch_y, batch_idx) diff --git a/pytensor/tensor/variable.py b/pytensor/tensor/variable.py index 33f0ed3a81..a1d6fd1fc6 100644 --- a/pytensor/tensor/variable.py +++ b/pytensor/tensor/variable.py @@ -438,6 +438,65 @@ def trunc(self): def astype(self, dtype): return pt.basic.cast(self, dtype) + def _getitem_with_newaxis(self, args): + """Handle newaxis (None) for both basic and advanced indexing. + + `np.newaxis` (i.e. `None`) in NumPy indexing mean "add a new + broadcastable dimension at this location". Since PyTensor adds + new broadcastable dimensions via the `DimShuffle` `Op`, the + following code uses said `Op` to add one of the new axes and + then uses recursion to apply any other indices and add any + remaining new axes. + """ + counter = 0 + pattern = [] + new_args = [] + for arg in args: + if arg is np.newaxis or arg is NoneConst: + pattern.append("x") + new_args.append(slice(None)) + else: + # Check for boolean index which consumes multiple dimensions + consumed_dims = 1 + try: + val = pt.subtensor.as_index_variable(arg) + if ( + hasattr(val, "type") + and isinstance(val.type, TensorType) + and val.type.dtype == "bool" + ): + consumed_dims = val.type.ndim + except Exception: + pass + + pattern.extend(range(counter, counter + consumed_dims)) + counter += consumed_dims + new_args.append(arg) + + pattern.extend(range(counter, self.ndim)) + + view = self.dimshuffle(pattern) + + # Check if we can return the view directly if all new_args are full slices + # We can't do arg == slice(None, None, None) as in + # Python 2.7, this call __lt__ if we have a slice + # with some symbolic variable. + full_slices = True + for arg in new_args: + if not ( + isinstance(arg, slice) + and (arg.start is None or arg.start is NoneConst) + and (arg.stop is None or arg.stop is NoneConst) + and (arg.step is None or arg.step is NoneConst) + ): + full_slices = False + break + + if full_slices: + return view + else: + return view.__getitem__(tuple(new_args)) + def __getitem__(self, args): def includes_bool(args_el): if isinstance(args_el, np.bool_ | bool) or ( @@ -541,44 +600,7 @@ def is_empty_array(val): # Handle newaxis (None) for both basic and advanced indexing if np.newaxis in args or NoneConst in args: - # `np.newaxis` (i.e. `None`) in NumPy indexing mean "add a new - # broadcastable dimension at this location". Since PyTensor adds - # new broadcastable dimensions via the `DimShuffle` `Op`, the - # following code uses said `Op` to add one of the new axes and - # then uses recursion to apply any other indices and add any - # remaining new axes. - - counter = 0 - pattern = [] - new_args = [] - for arg in args: - if arg is np.newaxis or arg is NoneConst: - pattern.append("x") - new_args.append(slice(None, None, None)) - else: - pattern.append(counter) - counter += 1 - new_args.append(arg) - - pattern.extend(list(range(counter, self.ndim))) - - view = self.dimshuffle(pattern) - full_slices = True - for arg in new_args: - # We can't do arg == slice(None, None, None) as in - # Python 2.7, this call __lt__ if we have a slice - # with some symbolic variable. - if not ( - isinstance(arg, slice) - and (arg.start is None or arg.start is NoneConst) - and (arg.stop is None or arg.stop is NoneConst) - and (arg.step is None or arg.step is NoneConst) - ): - full_slices = False - if full_slices: - return view - else: - return view.__getitem__(tuple(new_args)) + return self._getitem_with_newaxis(args) elif advanced: return pt.subtensor.advanced_subtensor(self, *args) else: diff --git a/tests/tensor/test_subtensor.py b/tests/tensor/test_subtensor.py index d8dadf0009..fa5a73805b 100644 --- a/tests/tensor/test_subtensor.py +++ b/tests/tensor/test_subtensor.py @@ -11,11 +11,10 @@ import pytensor import pytensor.scalar as scal import pytensor.tensor.basic as ptb -from pytensor import function -from pytensor.compile import DeepCopyOp, shared +from pytensor import config, function, shared +from pytensor.compile import DeepCopyOp from pytensor.compile.io import In from pytensor.compile.mode import Mode -from pytensor.configdefaults import config from pytensor.gradient import grad from pytensor.graph import Constant from pytensor.graph.basic import equal_computations @@ -622,7 +621,7 @@ def test_slice_symbol(self): (3, DimShuffle, np.index_exp[..., [0, 2, 3]]), (1, DimShuffle, np.index_exp[np.newaxis, ...]), ( - 1, + 3, AdvancedSubtensor, np.index_exp[..., np.newaxis, [1, 2]], ), @@ -2946,8 +2945,8 @@ def test_index_vars_to_types(): with pytest.raises(AdvancedIndexingError): index_vars_to_types(x) - with pytest.raises(TypeError): - index_vars_to_types(1) + # Integers are now allowed + assert index_vars_to_types(1) == 1 res = index_vars_to_types(iscalar) assert isinstance(res, scal.ScalarType) @@ -3055,7 +3054,6 @@ def core_fn(x, start): (11, 7, 5, 3), (2,), False, - marks=pytest.mark.xfail(raises=NotImplementedError), ), ( (lambda x, idx: x[:, idx, idx, :]), @@ -3071,7 +3069,6 @@ def core_fn(x, start): (11, 7, 5, 3, 5), (2,), True, - marks=pytest.mark.xfail(raises=NotImplementedError), ), # Core x, batched idx ((lambda x, idx: x[idx]), "(t1),(idx)->(tx)", (7,), (11, 2), True), @@ -3084,7 +3081,6 @@ def core_fn(x, start): (11, 7, 5, 3), (11, 2), True, - marks=pytest.mark.xfail(raises=NotImplementedError), ), ], )