diff --git a/pytensor/link/jax/dispatch/subtensor.py b/pytensor/link/jax/dispatch/subtensor.py index 1c659be29b..3658717e51 100644 --- a/pytensor/link/jax/dispatch/subtensor.py +++ b/pytensor/link/jax/dispatch/subtensor.py @@ -31,11 +31,18 @@ """ +@jax_funcify.register(AdvancedSubtensor1) +def jax_funcify_AdvancedSubtensor1(op, node, **kwargs): + def advanced_subtensor1(x, ilist): + return x[ilist] + + return advanced_subtensor1 + + @jax_funcify.register(Subtensor) @jax_funcify.register(AdvancedSubtensor) -@jax_funcify.register(AdvancedSubtensor1) def jax_funcify_Subtensor(op, node, **kwargs): - idx_list = getattr(op, "idx_list", None) + idx_list = op.idx_list def subtensor(x, *ilists): indices = indices_from_subtensor(ilists, idx_list) @@ -47,10 +54,24 @@ def subtensor(x, *ilists): return subtensor -@jax_funcify.register(IncSubtensor) @jax_funcify.register(AdvancedIncSubtensor1) +def jax_funcify_AdvancedIncSubtensor1(op, node, **kwargs): + if getattr(op, "set_instead_of_inc", False): + + def jax_fn(x, y, ilist): + return x.at[ilist].set(y) + + else: + + def jax_fn(x, y, ilist): + return x.at[ilist].add(y) + + return jax_fn + + +@jax_funcify.register(IncSubtensor) def jax_funcify_IncSubtensor(op, node, **kwargs): - idx_list = getattr(op, "idx_list", None) + idx_list = op.idx_list if getattr(op, "set_instead_of_inc", False): @@ -77,6 +98,8 @@ def incsubtensor(x, y, *ilist, jax_fn=jax_fn, idx_list=idx_list): @jax_funcify.register(AdvancedIncSubtensor) def jax_funcify_AdvancedIncSubtensor(op, node, **kwargs): + idx_list = op.idx_list + if getattr(op, "set_instead_of_inc", False): def jax_fn(x, indices, y): @@ -87,8 +110,11 @@ def jax_fn(x, indices, y): def jax_fn(x, indices, y): return x.at[indices].add(y) - def advancedincsubtensor(x, y, *ilist, jax_fn=jax_fn): - return jax_fn(x, ilist, y) + def advancedincsubtensor(x, y, *ilist, jax_fn=jax_fn, idx_list=idx_list): + indices = indices_from_subtensor(ilist, idx_list) + if len(indices) == 1: + indices = indices[0] + return jax_fn(x, indices, y) return advancedincsubtensor diff --git a/pytensor/link/numba/dispatch/subtensor.py b/pytensor/link/numba/dispatch/subtensor.py index 51787daf41..3d4bc1f185 100644 --- a/pytensor/link/numba/dispatch/subtensor.py +++ b/pytensor/link/numba/dispatch/subtensor.py @@ -20,7 +20,6 @@ ) from pytensor.link.numba.dispatch.compile_ops import numba_deepcopy from pytensor.tensor import TensorType -from pytensor.tensor.rewriting.subtensor import is_full_slice from pytensor.tensor.subtensor import ( AdvancedIncSubtensor, AdvancedIncSubtensor1, @@ -29,7 +28,7 @@ IncSubtensor, Subtensor, ) -from pytensor.tensor.type_other import MakeSlice, NoneTypeT, SliceType +from pytensor.tensor.type_other import MakeSlice def slice_new(self, start, stop, step): @@ -239,28 +238,32 @@ def {function_name}({", ".join(input_names)}): @register_funcify_and_cache_key(AdvancedIncSubtensor) def numba_funcify_AdvancedSubtensor(op, node, **kwargs): if isinstance(op, AdvancedSubtensor): - _x, _y, idxs = node.inputs[0], None, node.inputs[1:] + tensor_inputs = node.inputs[1:] else: - _x, _y, *idxs = node.inputs - - basic_idxs = [ - idx - for idx in idxs - if ( - isinstance(idx.type, NoneTypeT) - or (isinstance(idx.type, SliceType) and not is_full_slice(idx)) - ) - ] - adv_idxs = [ - { - "axis": i, - "dtype": idx.type.dtype, - "bcast": idx.type.broadcastable, - "ndim": idx.type.ndim, - } - for i, idx in enumerate(idxs) - if isinstance(idx.type, TensorType) - ] + tensor_inputs = node.inputs[2:] + + # Reconstruct indexing information from idx_list and tensor inputs + basic_idxs = [] + adv_idxs = [] + input_idx = 0 + + for i, entry in enumerate(op.idx_list): + if isinstance(entry, slice): + # Basic slice index + basic_idxs.append(entry) + elif isinstance(entry, Type): + # Advanced tensor index + if input_idx < len(tensor_inputs): + idx_input = tensor_inputs[input_idx] + adv_idxs.append( + { + "axis": i, + "dtype": idx_input.type.dtype, + "bcast": idx_input.type.broadcastable, + "ndim": idx_input.type.ndim, + } + ) + input_idx += 1 # Special implementation for consecutive integer vector indices if ( diff --git a/pytensor/link/pytorch/dispatch/subtensor.py b/pytensor/link/pytorch/dispatch/subtensor.py index 26b4fd0f7f..9a5e4b2ce1 100644 --- a/pytensor/link/pytorch/dispatch/subtensor.py +++ b/pytensor/link/pytorch/dispatch/subtensor.py @@ -9,7 +9,7 @@ Subtensor, indices_from_subtensor, ) -from pytensor.tensor.type_other import MakeSlice, SliceType +from pytensor.tensor.type_other import MakeSlice def check_negative_steps(indices): @@ -63,7 +63,10 @@ def makeslice(start, stop, step): @pytorch_funcify.register(AdvancedSubtensor1) @pytorch_funcify.register(AdvancedSubtensor) def pytorch_funcify_AdvSubtensor(op, node, **kwargs): - def advsubtensor(x, *indices): + idx_list = op.idx_list + + def advsubtensor(x, *flattened_indices): + indices = indices_from_subtensor(flattened_indices, idx_list) check_negative_steps(indices) return x[indices] @@ -102,12 +105,14 @@ def inc_subtensor(x, y, *flattened_indices): @pytorch_funcify.register(AdvancedIncSubtensor) @pytorch_funcify.register(AdvancedIncSubtensor1) def pytorch_funcify_AdvancedIncSubtensor(op, node, **kwargs): + idx_list = op.idx_list inplace = op.inplace ignore_duplicates = getattr(op, "ignore_duplicates", False) if op.set_instead_of_inc: - def adv_set_subtensor(x, y, *indices): + def adv_set_subtensor(x, y, *flattened_indices): + indices = indices_from_subtensor(flattened_indices, idx_list) check_negative_steps(indices) if isinstance(op, AdvancedIncSubtensor1): op._check_runtime_broadcasting(node, x, y, indices) @@ -120,7 +125,8 @@ def adv_set_subtensor(x, y, *indices): elif ignore_duplicates: - def adv_inc_subtensor_no_duplicates(x, y, *indices): + def adv_inc_subtensor_no_duplicates(x, y, *flattened_indices): + indices = indices_from_subtensor(flattened_indices, idx_list) check_negative_steps(indices) if isinstance(op, AdvancedIncSubtensor1): op._check_runtime_broadcasting(node, x, y, indices) @@ -132,13 +138,18 @@ def adv_inc_subtensor_no_duplicates(x, y, *indices): return adv_inc_subtensor_no_duplicates else: - if any(isinstance(idx.type, SliceType) for idx in node.inputs[2:]): + # Check if we have slice indexing in idx_list + has_slice_indexing = ( + any(isinstance(entry, slice) for entry in idx_list) if idx_list else False + ) + if has_slice_indexing: raise NotImplementedError( "IncSubtensor with potential duplicates indexes and slice indexing not implemented in PyTorch" ) - def adv_inc_subtensor(x, y, *indices): - # Not needed because slices aren't supported + def adv_inc_subtensor(x, y, *flattened_indices): + indices = indices_from_subtensor(flattened_indices, idx_list) + # Not needed because slices aren't supported in this path # check_negative_steps(indices) if not inplace: x = x.clone() diff --git a/pytensor/sparse/basic.py b/pytensor/sparse/basic.py index 60ac79f149..75a0d4d855 100644 --- a/pytensor/sparse/basic.py +++ b/pytensor/sparse/basic.py @@ -1317,8 +1317,6 @@ def perform(self, node, inputs, outputs): z[0] = y def grad(self, inputs, gout): - from pytensor.sparse.math import sp_sum - (x, s) = inputs (gz,) = gout return [col_scale(gz, s), sp_sum(x * gz, axis=0)] @@ -1368,8 +1366,6 @@ def perform(self, node, inputs, outputs): z[0] = scipy.sparse.csc_matrix((y_data, indices, indptr), (M, N)) def grad(self, inputs, gout): - from pytensor.sparse.math import sp_sum - (x, s) = inputs (gz,) = gout return [row_scale(gz, s), sp_sum(x * gz, axis=1)] @@ -1435,6 +1431,126 @@ def row_scale(x, s): return col_scale(x.T, s).T +class SpSum(Op): + """ + + WARNING: judgement call... + We are not using the structured in the comparison or hashing + because it doesn't change the perform method therefore, we + *do* want Sums with different structured values to be merged + by the merge optimization and this requires them to compare equal. + """ + + __props__ = ("axis",) + + def __init__(self, axis=None, sparse_grad=True): + super().__init__() + self.axis = axis + self.structured = sparse_grad + if self.axis not in (None, 0, 1): + raise ValueError("Illegal value for self.axis.") + + def make_node(self, x): + x = as_sparse_variable(x) + assert x.format in ("csr", "csc") + + if self.axis is not None: + out_shape = (None,) + else: + out_shape = () + + z = TensorType(dtype=x.dtype, shape=out_shape)() + return Apply(self, [x], [z]) + + def perform(self, node, inputs, outputs): + (x,) = inputs + (z,) = outputs + if self.axis is None: + z[0] = np.asarray(x.sum()) + else: + z[0] = np.asarray(x.sum(self.axis)).ravel() + + def grad(self, inputs, gout): + (x,) = inputs + (gz,) = gout + if x.dtype not in continuous_dtypes: + return [x.zeros_like(dtype=config.floatX)] + if self.structured: + if self.axis is None: + r = gz * sp_ones_like(x) + elif self.axis == 0: + r = col_scale(sp_ones_like(x), gz) + elif self.axis == 1: + r = row_scale(sp_ones_like(x), gz) + else: + raise ValueError("Illegal value for self.axis.") + else: + o_format = x.format + x = dense_from_sparse(x) + if _is_sparse_variable(gz): + gz = dense_from_sparse(gz) + if self.axis is None: + r = ptb.second(x, gz) + else: + ones = ptb.ones_like(x) + if self.axis == 0: + r = specify_broadcastable(gz.dimshuffle("x", 0), 0) * ones + elif self.axis == 1: + r = specify_broadcastable(gz.dimshuffle(0, "x"), 1) * ones + else: + raise ValueError("Illegal value for self.axis.") + r = SparseFromDense(o_format)(r) + return [r] + + def infer_shape(self, fgraph, node, shapes): + r = None + if self.axis is None: + r = [()] + elif self.axis == 0: + r = [(shapes[0][1],)] + else: + r = [(shapes[0][0],)] + return r + + def __str__(self): + return f"{self.__class__.__name__}{{axis={self.axis}}}" + + +def sp_sum(x, axis=None, sparse_grad=False): + """ + Calculate the sum of a sparse matrix along the specified axis. + + It operates a reduction along the specified axis. When `axis` is `None`, + it is applied along all axes. + + Parameters + ---------- + x + Sparse matrix. + axis + Axis along which the sum is applied. Integer or `None`. + sparse_grad : bool + `True` to have a structured grad. + + Returns + ------- + object + The sum of `x` in a dense format. + + Notes + ----- + The grad implementation is controlled with the `sparse_grad` parameter. + `True` will provide a structured grad and `False` will provide a regular + grad. For both choices, the grad returns a sparse matrix having the same + format as `x`. + + This op does not return a sparse matrix, but a dense tensor matrix. + + """ + + return SpSum(axis, sparse_grad)(x) + + class Diag(Op): """Extract the diagonal of a square sparse matrix as a dense vector. diff --git a/pytensor/sparse/math.py b/pytensor/sparse/math.py index 972de80d89..3dffa12265 100644 --- a/pytensor/sparse/math.py +++ b/pytensor/sparse/math.py @@ -8,11 +8,11 @@ import pytensor.sparse.basic as psb import pytensor.tensor.basic as ptb import pytensor.tensor.math as ptm -from pytensor import config from pytensor.gradient import grad_not_implemented from pytensor.graph import Apply, Op from pytensor.link.c.op import COp -from pytensor.tensor import TensorType, Variable, specify_broadcastable, tensor +from pytensor.sparse.basic import sp_sum +from pytensor.tensor import TensorType, Variable, tensor from pytensor.tensor.type import complex_dtypes @@ -295,126 +295,6 @@ def conjugate(x): structured_conjugate = conjugate -class SpSum(Op): - """ - - WARNING: judgement call... - We are not using the structured in the comparison or hashing - because it doesn't change the perform method therefore, we - *do* want Sums with different structured values to be merged - by the merge optimization and this requires them to compare equal. - """ - - __props__ = ("axis",) - - def __init__(self, axis=None, sparse_grad=True): - super().__init__() - self.axis = axis - self.structured = sparse_grad - if self.axis not in (None, 0, 1): - raise ValueError("Illegal value for self.axis.") - - def make_node(self, x): - x = psb.as_sparse_variable(x) - assert x.format in ("csr", "csc") - - if self.axis is not None: - out_shape = (None,) - else: - out_shape = () - - z = TensorType(dtype=x.dtype, shape=out_shape)() - return Apply(self, [x], [z]) - - def perform(self, node, inputs, outputs): - (x,) = inputs - (z,) = outputs - if self.axis is None: - z[0] = np.asarray(x.sum()) - else: - z[0] = np.asarray(x.sum(self.axis)).ravel() - - def grad(self, inputs, gout): - (x,) = inputs - (gz,) = gout - if x.dtype not in psb.continuous_dtypes: - return [x.zeros_like(dtype=config.floatX)] - if self.structured: - if self.axis is None: - r = gz * psb.sp_ones_like(x) - elif self.axis == 0: - r = psb.col_scale(psb.sp_ones_like(x), gz) - elif self.axis == 1: - r = psb.row_scale(psb.sp_ones_like(x), gz) - else: - raise ValueError("Illegal value for self.axis.") - else: - o_format = x.format - x = psb.dense_from_sparse(x) - if psb._is_sparse_variable(gz): - gz = psb.dense_from_sparse(gz) - if self.axis is None: - r = ptb.second(x, gz) - else: - ones = ptb.ones_like(x) - if self.axis == 0: - r = specify_broadcastable(gz.dimshuffle("x", 0), 0) * ones - elif self.axis == 1: - r = specify_broadcastable(gz.dimshuffle(0, "x"), 1) * ones - else: - raise ValueError("Illegal value for self.axis.") - r = psb.SparseFromDense(o_format)(r) - return [r] - - def infer_shape(self, fgraph, node, shapes): - r = None - if self.axis is None: - r = [()] - elif self.axis == 0: - r = [(shapes[0][1],)] - else: - r = [(shapes[0][0],)] - return r - - def __str__(self): - return f"{self.__class__.__name__}{{axis={self.axis}}}" - - -def sp_sum(x, axis=None, sparse_grad=False): - """ - Calculate the sum of a sparse matrix along the specified axis. - - It operates a reduction along the specified axis. When `axis` is `None`, - it is applied along all axes. - - Parameters - ---------- - x - Sparse matrix. - axis - Axis along which the sum is applied. Integer or `None`. - sparse_grad : bool - `True` to have a structured grad. - - Returns - ------- - object - The sum of `x` in a dense format. - - Notes - ----- - The grad implementation is controlled with the `sparse_grad` parameter. - `True` will provide a structured grad and `False` will provide a regular - grad. For both choices, the grad returns a sparse matrix having the same - format as `x`. - - This op does not return a sparse matrix, but a dense tensor matrix. - - """ - - return SpSum(axis, sparse_grad)(x) - - class AddSS(Op): # add(sparse, sparse). # see the doc of add() for more detail. diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index e789659474..9bb31482c4 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -1818,6 +1818,33 @@ def do_constant_folding(self, fgraph, node): return True +@_vectorize_node.register(Alloc) +def vectorize_alloc(op: Alloc, node: Apply, batch_val, *batch_shapes): + # batch_shapes are usually not batched (they are scalars for the shape) + # batch_val is the value being allocated. + + # If shapes are batched, we fall back (complex case) + if any( + b_shp.type.ndim > shp.type.ndim + for b_shp, shp in zip(batch_shapes, node.inputs[1:], strict=True) + ): + return vectorize_node_fallback(op, node, batch_val, *batch_shapes) + + # If value is batched, we need to prepend batch dims to the output shape + val = node.inputs[0] + batch_ndim = batch_val.type.ndim - val.type.ndim + + if batch_ndim == 0: + return op.make_node(batch_val, *batch_shapes) + + # We need the size of the batch dimensions + # batch_val has shape (B1, B2, ..., val_dims...) + batch_dims = [batch_val.shape[i] for i in range(batch_ndim)] + + new_shapes = batch_dims + list(batch_shapes) + return op.make_node(batch_val, *new_shapes) + + alloc = Alloc() pprint.assign(alloc, printing.FunctionPrinter(["alloc"])) diff --git a/pytensor/tensor/rewriting/subtensor.py b/pytensor/tensor/rewriting/subtensor.py index fbe97b9a68..b031d30ae6 100644 --- a/pytensor/tensor/rewriting/subtensor.py +++ b/pytensor/tensor/rewriting/subtensor.py @@ -14,6 +14,7 @@ in2out, node_rewriter, ) +from pytensor.graph.type import Type from pytensor.raise_op import Assert from pytensor.scalar import Add, ScalarConstant, ScalarType from pytensor.scalar import constant as scalar_constant @@ -212,6 +213,20 @@ def get_advsubtensor_axis(indices): return axis +def reconstruct_indices(idx_list, tensor_inputs): + """Reconstruct indices from idx_list and tensor inputs.""" + indices = [] + input_idx = 0 + for entry in idx_list: + if isinstance(entry, slice): + indices.append(entry) + elif isinstance(entry, Type): + if input_idx < len(tensor_inputs): + indices.append(tensor_inputs[input_idx]) + input_idx += 1 + return indices + + @register_specialize @node_rewriter([AdvancedSubtensor]) def local_replace_AdvancedSubtensor(fgraph, node): @@ -228,7 +243,10 @@ def local_replace_AdvancedSubtensor(fgraph, node): return indexed_var = node.inputs[0] - indices = node.inputs[1:] + tensor_inputs = node.inputs[1:] + + # Reconstruct indices from idx_list and tensor inputs + indices = reconstruct_indices(node.op.idx_list, tensor_inputs) axis = get_advsubtensor_axis(indices) @@ -255,7 +273,10 @@ def local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1(fgraph, node): res = node.inputs[0] val = node.inputs[1] - indices = node.inputs[2:] + tensor_inputs = node.inputs[2:] + + # Reconstruct indices from idx_list and tensor inputs + indices = reconstruct_indices(node.op.idx_list, tensor_inputs) axis = get_advsubtensor_axis(indices) @@ -1090,6 +1111,7 @@ def local_inplace_AdvancedIncSubtensor1(fgraph, node): def local_inplace_AdvancedIncSubtensor(fgraph, node): if isinstance(node.op, AdvancedIncSubtensor) and not node.op.inplace: new_op = type(node.op)( + node.op.idx_list, inplace=True, set_instead_of_inc=node.op.set_instead_of_inc, ignore_duplicates=node.op.ignore_duplicates, @@ -1354,6 +1376,7 @@ def local_useless_inc_subtensor_alloc(fgraph, node): z_broad[k] and not same_shape(xi, y, dim_x=k, dim_y=k) and shape_of[y][k] != 1 + and shape_of[xi][k] == 1 ) ] @@ -1751,9 +1774,14 @@ def ravel_multidimensional_bool_idx(fgraph, node): x[eye(3, dtype=bool)].set(y) -> x.ravel()[eye(3).ravel()].set(y).reshape(x.shape) """ if isinstance(node.op, AdvancedSubtensor): - x, *idxs = node.inputs + x = node.inputs[0] + tensor_inputs = node.inputs[1:] else: - x, y, *idxs = node.inputs + x, y = node.inputs[0], node.inputs[1] + tensor_inputs = node.inputs[2:] + + # Reconstruct indices from idx_list and tensor inputs + idxs = reconstruct_indices(node.op.idx_list, tensor_inputs) if any( ( @@ -1791,12 +1819,41 @@ def ravel_multidimensional_bool_idx(fgraph, node): new_idxs[bool_idx_pos] = raveled_bool_idx if isinstance(node.op, AdvancedSubtensor): - new_out = node.op(raveled_x, *new_idxs) + # Create new AdvancedSubtensor with updated idx_list + new_idx_list = list(node.op.idx_list) + new_tensor_inputs = list(tensor_inputs) + + # Update the idx_list and tensor_inputs for the raveled boolean index + input_idx = 0 + for i, entry in enumerate(node.op.idx_list): + if isinstance(entry, Type): + if input_idx == bool_idx_pos: + new_tensor_inputs[input_idx] = raveled_bool_idx + input_idx += 1 + + new_out = AdvancedSubtensor(new_idx_list)(raveled_x, *new_tensor_inputs) else: + # Create new AdvancedIncSubtensor with updated idx_list + new_idx_list = list(node.op.idx_list) + new_tensor_inputs = list(tensor_inputs) + + # Update the tensor_inputs for the raveled boolean index + input_idx = 0 + for i, entry in enumerate(node.op.idx_list): + if isinstance(entry, Type): + if input_idx == bool_idx_pos: + new_tensor_inputs[input_idx] = raveled_bool_idx + input_idx += 1 + # The dimensions of y that correspond to the boolean indices # must already be raveled in the original graph, so we don't need to do anything to it - new_out = node.op(raveled_x, y, *new_idxs) - # But we must reshape the output to math the original shape + new_out = AdvancedIncSubtensor( + new_idx_list, + inplace=node.op.inplace, + set_instead_of_inc=node.op.set_instead_of_inc, + ignore_duplicates=node.op.ignore_duplicates, + )(raveled_x, y, *new_tensor_inputs) + # But we must reshape the output to match the original shape new_out = new_out.reshape(x_shape) return [copy_stack_trace(node.outputs[0], new_out)] diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py index d7fc1bedbc..4467300e66 100644 --- a/pytensor/tensor/subtensor.py +++ b/pytensor/tensor/subtensor.py @@ -1,3 +1,4 @@ +import copy import logging import sys import warnings @@ -63,7 +64,6 @@ from pytensor.tensor.type_other import ( MakeSlice, NoneConst, - NoneSliceConst, NoneTypeT, SliceConstant, SliceType, @@ -706,7 +706,7 @@ def helper(entry): return ret -def index_vars_to_types(entry, slice_ok=True): +def index_vars_to_types(entry, slice_ok=True, allow_advanced=False): r"""Change references to `Variable`s into references to `Type`s. The `Subtensor.idx_list` field is unique to each `Subtensor` instance. It @@ -717,12 +717,13 @@ def index_vars_to_types(entry, slice_ok=True): when would that happen? """ - if ( - isinstance(entry, np.ndarray | Variable) - and hasattr(entry, "dtype") - and entry.dtype == "bool" - ): - raise AdvancedIndexingError("Invalid index type or slice for Subtensor") + if not allow_advanced: + if ( + isinstance(entry, np.ndarray | Variable) + and hasattr(entry, "dtype") + and entry.dtype == "bool" + ): + raise AdvancedIndexingError("Invalid index type or slice for Subtensor") if isinstance(entry, Variable) and ( entry.type in invalid_scal_types or entry.type in invalid_tensor_types @@ -742,13 +743,37 @@ def index_vars_to_types(entry, slice_ok=True): return ps.get_scalar_type(entry.type.dtype) elif isinstance(entry, Type) and entry in tensor_types and all(entry.broadcastable): return ps.get_scalar_type(entry.dtype) + elif ( + allow_advanced + and isinstance(entry, Variable) + and isinstance(entry.type, TensorType) + ): + return entry.type + elif allow_advanced and isinstance(entry, TensorType): + return entry + elif ( + allow_advanced + and isinstance(entry, Variable) + and isinstance(entry.type, SliceType) + ): + return entry.type + elif allow_advanced and isinstance(entry, SliceType): + return entry + elif ( + allow_advanced + and isinstance(entry, Variable) + and isinstance(entry.type, SliceType) + ): + return entry.type + elif allow_advanced and isinstance(entry, SliceType): + return entry elif slice_ok and isinstance(entry, slice): a = entry.start b = entry.stop c = entry.step if a is not None: - slice_a = index_vars_to_types(a, False) + slice_a = index_vars_to_types(a, False, allow_advanced) else: slice_a = None @@ -756,18 +781,18 @@ def index_vars_to_types(entry, slice_ok=True): # The special "maxsize" case is probably not needed here, # as slices containing maxsize are not generated by # __getslice__ anymore. - slice_b = index_vars_to_types(b, False) + slice_b = index_vars_to_types(b, False, allow_advanced) else: slice_b = None if c is not None: - slice_c = index_vars_to_types(c, False) + slice_c = index_vars_to_types(c, False, allow_advanced) else: slice_c = None return slice(slice_a, slice_b, slice_c) elif isinstance(entry, int | np.integer): - raise TypeError() + return entry else: raise AdvancedIndexingError("Invalid index type or slice for Subtensor") @@ -1564,7 +1589,10 @@ def inc_subtensor( ilist = x.owner.inputs[1] if ignore_duplicates: the_op = AdvancedIncSubtensor( - inplace, set_instead_of_inc=set_instead_of_inc, ignore_duplicates=True + [ilist], + inplace, + set_instead_of_inc=set_instead_of_inc, + ignore_duplicates=True, ) else: the_op = AdvancedIncSubtensor1( @@ -1575,6 +1603,7 @@ def inc_subtensor( real_x = x.owner.inputs[0] ilist = x.owner.inputs[1:] the_op = AdvancedIncSubtensor( + x.owner.op.idx_list, inplace, set_instead_of_inc=set_instead_of_inc, ignore_duplicates=ignore_duplicates, @@ -2576,78 +2605,151 @@ def check_advanced_indexing_dimensions(input, idx_list): class AdvancedSubtensor(Op): """Implements NumPy's advanced indexing.""" - __props__ = () + __props__ = ("idx_list",) + + def __init__(self, idx_list): + """ + Initialize AdvancedSubtensor with index list. - def make_node(self, x, *indices): + Parameters + ---------- + idx_list : tuple + List of indices where slices are stored as-is, + and numerical indices are replaced by their types. + """ + self.idx_list = tuple( + index_vars_to_types(idx, allow_advanced=True) for idx in idx_list + ) + # Store expected number of tensor inputs for validation + self.expected_inputs_len = len( + get_slice_elements(self.idx_list, lambda entry: isinstance(entry, Type)) + ) + + def __hash__(self): + msg = [] + for entry in self.idx_list: + if isinstance(entry, slice): + msg += [(entry.start, entry.stop, entry.step)] + else: + msg += [entry] + + idx_list = tuple(msg) + return hash((type(self), idx_list)) + + def make_node(self, x, *inputs): + """ + Parameters + ---------- + x + The tensor to take a subtensor of. + inputs + A list of pytensor Scalars and Tensors (numerical indices only). + + """ x = as_tensor_variable(x) - indices = tuple(map(as_index_variable, indices)) + processed_inputs = [] + for a in inputs: + if isinstance(a, Variable) and isinstance(a.type, SliceType): + processed_inputs.append(a) + else: + processed_inputs.append(as_tensor_variable(a)) + inputs = tuple(processed_inputs) + idx_list = list(self.idx_list) + if len(idx_list) > x.type.ndim: + raise IndexError("too many indices for array") + + # Validate input count matches expected from idx_list + if len(inputs) != self.expected_inputs_len: + raise ValueError( + f"Expected {self.expected_inputs_len} inputs but got {len(inputs)}" + ) + + # Build explicit_indices for shape inference explicit_indices = [] - new_axes = [] - for idx in indices: - if isinstance(idx.type, TensorType) and idx.dtype == "bool": - if idx.type.ndim == 0: - raise NotImplementedError( - "Indexing with scalar booleans not supported" - ) + input_idx = 0 - # Check static shape aligned - axis = len(explicit_indices) - len(new_axes) - indexed_shape = x.type.shape[axis : axis + idx.type.ndim] - for j, (indexed_length, indexer_length) in enumerate( - zip(indexed_shape, idx.type.shape) - ): - if ( - indexed_length is not None - and indexer_length is not None - and indexed_length != indexer_length - ): - raise IndexError( - f"boolean index did not match indexed tensor along axis {axis + j};" - f"size of axis is {indexed_length} but size of corresponding boolean axis is {indexer_length}" + for i, entry in enumerate(idx_list): + if isinstance(entry, slice): + # Reconstruct slice with actual values from inputs + if entry.start is not None and isinstance(entry.start, Type): + start_val = inputs[input_idx] + input_idx += 1 + else: + start_val = entry.start + + if entry.stop is not None and isinstance(entry.stop, Type): + stop_val = inputs[input_idx] + input_idx += 1 + else: + stop_val = entry.stop + + if entry.step is not None and isinstance(entry.step, Type): + step_val = inputs[input_idx] + input_idx += 1 + else: + step_val = entry.step + + explicit_indices.append(slice(start_val, stop_val, step_val)) + elif isinstance(entry, Type): + # This is a numerical index + inp = inputs[input_idx] + input_idx += 1 + + # Handle boolean indices + if hasattr(inp, "dtype") and inp.dtype == "bool": + if inp.type.ndim == 0: + raise NotImplementedError( + "Indexing with scalar booleans not supported" ) - # Convert boolean indices to integer with nonzero, to reason about static shape next - if isinstance(idx, Constant): - nonzero_indices = [tensor_constant(i) for i in idx.data.nonzero()] + + # Check static shape aligned + axis = len(explicit_indices) + indexed_shape = x.type.shape[axis : axis + inp.type.ndim] + for j, (indexed_length, indexer_length) in enumerate( + zip(indexed_shape, inp.type.shape) + ): + if ( + indexed_length is not None + and indexer_length is not None + and indexed_length != indexer_length + ): + raise IndexError( + f"boolean index did not match indexed tensor along axis {axis + j};" + f"size of axis is {indexed_length} but size of corresponding boolean axis is {indexer_length}" + ) + # Convert boolean indices to integer with nonzero + if isinstance(inp, Constant): + nonzero_indices = [ + tensor_constant(i) for i in inp.data.nonzero() + ] + else: + nonzero_indices = inp.nonzero() + explicit_indices.extend(nonzero_indices) else: - # Note: Sometimes we could infer a shape error by reasoning about the largest possible size of nonzero - # and seeing that other integer indices cannot possible match it - nonzero_indices = idx.nonzero() - explicit_indices.extend(nonzero_indices) + # Regular numerical index + explicit_indices.append(inp) else: - if isinstance(idx.type, NoneTypeT): - new_axes.append(len(explicit_indices)) - explicit_indices.append(idx) + raise ValueError(f"Invalid entry in idx_list: {entry}") - if (len(explicit_indices) - len(new_axes)) > x.type.ndim: + if len(explicit_indices) > x.type.ndim: raise IndexError( - f"too many indices for array: tensor is {x.type.ndim}-dimensional, but {len(explicit_indices) - len(new_axes)} were indexed" + f"too many indices for array: tensor is {x.type.ndim}-dimensional, but {len(explicit_indices)} were indexed" ) - # Perform basic and advanced indexing shape inference separately + # Perform basic and advanced indexing shape inference separately (no newaxis) basic_group_shape = [] advanced_indices = [] adv_group_axis = None last_adv_group_axis = None - expanded_x_shape = tuple( - np.insert(np.array(x.type.shape, dtype=object), 1, new_axes) - ) for i, (idx, dim_length) in enumerate( - zip_longest(explicit_indices, expanded_x_shape, fillvalue=NoneSliceConst) + zip_longest(explicit_indices, x.type.shape, fillvalue=slice(None)) ): - if isinstance(idx.type, NoneTypeT): - basic_group_shape.append(1) # New-axis - elif isinstance(idx.type, SliceType): - if isinstance(idx, Constant): - basic_group_shape.append(slice_static_length(idx.data, dim_length)) - elif idx.owner is not None and isinstance(idx.owner.op, MakeSlice): - basic_group_shape.append( - slice_static_length(slice(*idx.owner.inputs), dim_length) - ) - else: - # Symbolic root slice (owner is None), or slice operation we don't understand - basic_group_shape.append(None) - else: # TensorType + if isinstance(idx, slice): + basic_group_shape.append(slice_static_length(idx, dim_length)) + elif isinstance(idx, Variable) and isinstance(idx.type, SliceType): + basic_group_shape.append(None) + else: # TensorType (advanced index) # Keep track of advanced group axis if adv_group_axis is None: # First time we see an advanced index @@ -2682,7 +2784,7 @@ def make_node(self, x, *indices): return Apply( self, - [x, *indices], + [x, *inputs], [tensor(dtype=x.type.dtype, shape=tuple(indexed_shape))], ) @@ -2698,19 +2800,61 @@ def is_bool_index(idx): or getattr(idx, "dtype", None) == "bool" ) - indices = node.inputs[1:] + # Reconstruct the full indices from idx_list and inputs (newaxis handled by __getitem__) + inputs = node.inputs[1:] + + full_indices = [] + input_idx = 0 + + for entry in self.idx_list: + if isinstance(entry, slice): + # Reconstruct slice from idx_list and inputs + if entry.start is not None and isinstance(entry.start, Type): + start_val = inputs[input_idx] + input_idx += 1 + else: + start_val = entry.start + + if entry.stop is not None and isinstance(entry.stop, Type): + stop_val = inputs[input_idx] + input_idx += 1 + else: + stop_val = entry.stop + + if entry.step is not None and isinstance(entry.step, Type): + step_val = inputs[input_idx] + input_idx += 1 + else: + step_val = entry.step + + full_indices.append(slice(start_val, stop_val, step_val)) + elif isinstance(entry, Type): + # This is a numerical index - get from inputs + if input_idx < len(inputs): + full_indices.append(inputs[input_idx]) + input_idx += 1 + else: + raise ValueError("Mismatch between idx_list and inputs") + index_shapes = [] - for idx, ishape in zip(indices, ishapes[1:], strict=True): - # Mixed bool indexes are converted to nonzero entries - shape0_op = Shape_i(0) - if is_bool_index(idx): - index_shapes.extend((shape0_op(nz_dim),) for nz_dim in nonzero(idx)) - # The `ishapes` entries for `SliceType`s will be None, and - # we need to give `indexed_result_shape` the actual slices. - elif isinstance(getattr(idx, "type", None), SliceType): + for idx in full_indices: + if isinstance(idx, slice): index_shapes.append(idx) + elif isinstance(idx, Variable) and isinstance(idx.type, SliceType): + index_shapes.append(idx) + elif hasattr(idx, "type"): + # Mixed bool indexes are converted to nonzero entries + shape0_op = Shape_i(0) + if is_bool_index(idx): + index_shapes.extend((shape0_op(nz_dim),) for nz_dim in nonzero(idx)) + else: + # Get ishape for this input + input_shape_idx = ( + inputs.index(idx) + 1 + ) # +1 because ishapes[0] is x + index_shapes.append(ishapes[input_shape_idx]) else: - index_shapes.append(ishape) + index_shapes.append(idx) res_shape = list( indexed_result_shape(ishapes[0], index_shapes, indices_are_shapes=True) @@ -2721,7 +2865,7 @@ def is_bool_index(idx): # We must compute the Op to find its shape res_shape[i] = Shape_i(i)(node.out) - adv_indices = [idx for idx in indices if not is_basic_idx(idx)] + adv_indices = [idx for idx in full_indices if not is_basic_idx(idx)] bool_indices = [idx for idx in adv_indices if is_bool_index(idx)] # Special logic when the only advanced index group is of bool type. @@ -2732,7 +2876,7 @@ def is_bool_index(idx): # Because there are no more advanced index groups, there is exactly # one output dim per index variable up to the bool group. # Note: Scalar integer indexing counts as advanced indexing. - start_dim = indices.index(bool_index) + start_dim = full_indices.index(bool_index) res_shape[start_dim] = bool_index.sum() assert node.outputs[0].ndim == len(res_shape) @@ -2740,14 +2884,75 @@ def is_bool_index(idx): def perform(self, node, inputs, out_): (out,) = out_ - check_advanced_indexing_dimensions(inputs[0], inputs[1:]) - rval = inputs[0].__getitem__(tuple(inputs[1:])) + + # Reconstruct the full tuple of indices from idx_list and inputs (newaxis handled by __getitem__) + x = inputs[0] + tensor_inputs = inputs[1:] + + full_indices = [] + input_idx = 0 + + for entry in self.idx_list: + if isinstance(entry, slice): + # Reconstruct slice from idx_list and inputs + if entry.start is not None and isinstance(entry.start, Type): + start_val = tensor_inputs[input_idx] + input_idx += 1 + else: + start_val = entry.start + + if entry.stop is not None and isinstance(entry.stop, Type): + stop_val = tensor_inputs[input_idx] + input_idx += 1 + else: + stop_val = entry.stop + + if entry.step is not None and isinstance(entry.step, Type): + step_val = tensor_inputs[input_idx] + input_idx += 1 + else: + step_val = entry.step + + full_indices.append(slice(start_val, stop_val, step_val)) + elif isinstance(entry, Type): + # This is a numerical index - get from inputs + if input_idx < len(tensor_inputs): + full_indices.append(tensor_inputs[input_idx]) + input_idx += 1 + else: + raise ValueError("Mismatch between idx_list and inputs") + + check_advanced_indexing_dimensions(x, full_indices) + + # Handle runtime broadcasting for broadcastable dimensions + broadcastable = node.inputs[0].type.broadcastable + new_full_indices = [] + for i, idx in enumerate(full_indices): + if i < len(broadcastable) and broadcastable[i] and x.shape[i] == 1: + if isinstance(idx, np.ndarray | list | tuple): + # Replace with zeros of same shape to preserve output shape + if isinstance(idx, np.ndarray): + new_full_indices.append(np.zeros_like(idx)) + else: + arr = np.array(idx) + new_full_indices.append(np.zeros_like(arr)) + elif isinstance(idx, int | np.integer): + new_full_indices.append(0) + else: + # Slice or other + new_full_indices.append(idx) + else: + new_full_indices.append(idx) + + rval = x.__getitem__(tuple(new_full_indices)) # When there are no arrays, we are not actually doing advanced # indexing, so __getitem__ will not return a copy. # Since no view_map is set, we need to copy the returned value - if not any( - isinstance(v.type, TensorType) and v.ndim > 0 for v in node.inputs[1:] - ): + has_tensor_indices = any( + isinstance(entry, Type) and not getattr(entry, "broadcastable", (False,))[0] + for entry in self.idx_list + ) + if not has_tensor_indices: rval = rval.copy() out[0] = rval @@ -2785,7 +2990,7 @@ def non_consecutive_adv_indexing(node: Apply) -> bool: This function checks if the advanced indexing is non-consecutive, in which case the advanced index dimensions are placed on the left of the - output array, regardless of their opriginal position. + output array, regardless of their original position. See: https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing @@ -2800,11 +3005,27 @@ def non_consecutive_adv_indexing(node: Apply) -> bool: bool True if the advanced indexing is non-consecutive, False otherwise. """ - _, *idxs = node.inputs - return _non_consecutive_adv_indexing(idxs) + # Reconstruct the full indices from idx_list and inputs to check consecutivity (newaxis handled by __getitem__) + op = node.op + tensor_inputs = node.inputs[1:] + + full_indices = [] + input_idx = 0 + for entry in op.idx_list: + if isinstance(entry, slice): + full_indices.append(slice(None)) # Represent as basic slice + elif isinstance(entry, Type): + # This is a numerical index - get from inputs + if input_idx < len(tensor_inputs): + full_indices.append(tensor_inputs[input_idx]) + input_idx += 1 -advanced_subtensor = AdvancedSubtensor() + return _non_consecutive_adv_indexing(full_indices) + + +# Note: This is now a factory function since AdvancedSubtensor needs idx_list +# The old global instance approach won't work anymore @_vectorize_node.register(AdvancedSubtensor) @@ -2824,36 +3045,68 @@ def vectorize_advanced_subtensor(op: AdvancedSubtensor, node, *batch_inputs): # which would put the indexed results to the left of the batch dimensions! # TODO: Not all cases must be handled by Blockwise, but the logic is complex - # Blockwise doesn't accept None or Slices types so we raise informative error here - # TODO: Implement these internally, so Blockwise is always a safe fallback - if any(not isinstance(idx, TensorVariable) for idx in idxs): - raise NotImplementedError( - "Vectorized AdvancedSubtensor with batched indexes or non-consecutive advanced indexing " - "and slices or newaxis is currently not supported." - ) - else: - return vectorize_node_fallback(op, node, batch_x, *batch_idxs) + # With the new interface, all inputs are tensors, so Blockwise can handle them + return vectorize_node_fallback(op, node, batch_x, *batch_idxs) # Otherwise we just need to add None slices for every new batch dim x_batch_ndim = batch_x.type.ndim - x.type.ndim empty_slices = (slice(None),) * x_batch_ndim - return op.make_node(batch_x, *empty_slices, *batch_idxs) + new_idx_list = empty_slices + op.idx_list + return AdvancedSubtensor(new_idx_list).make_node(batch_x, *batch_idxs) class AdvancedIncSubtensor(Op): """Increments a subtensor using advanced indexing.""" - __props__ = ("inplace", "set_instead_of_inc", "ignore_duplicates") + __props__ = ("inplace", "set_instead_of_inc", "ignore_duplicates", "idx_list") def __init__( - self, inplace=False, set_instead_of_inc=False, ignore_duplicates=False + self, + idx_list=None, + inplace=False, + set_instead_of_inc=False, + ignore_duplicates=False, ): + if idx_list is not None: + self.idx_list = tuple( + index_vars_to_types(idx, allow_advanced=True) for idx in idx_list + ) + # Store expected number of tensor inputs for validation + self.expected_inputs_len = len( + get_slice_elements(self.idx_list, lambda entry: isinstance(entry, Type)) + ) + else: + self.idx_list = None + self.expected_inputs_len = None + self.set_instead_of_inc = set_instead_of_inc self.inplace = inplace if inplace: self.destroy_map = {0: [0]} self.ignore_duplicates = ignore_duplicates + def __hash__(self): + if self.idx_list is None: + idx_list = None + else: + msg = [] + for entry in self.idx_list: + if isinstance(entry, slice): + msg += [(entry.start, entry.stop, entry.step)] + else: + msg += [entry] + idx_list = tuple(msg) + + return hash( + ( + type(self), + idx_list, + self.inplace, + self.set_instead_of_inc, + self.ignore_duplicates, + ) + ) + def __str__(self): return ( "AdvancedSetSubtensor" @@ -2865,6 +3118,22 @@ def make_node(self, x, y, *inputs): x = as_tensor_variable(x) y = as_tensor_variable(y) + if self.idx_list is None: + # Infer idx_list from inputs + # This handles the case where AdvancedIncSubtensor is initialized without idx_list + # and used as a factory. + idx_list = [inp.type for inp in inputs] + new_op = copy.copy(self) + new_op.idx_list = tuple(idx_list) + new_op.expected_inputs_len = len(inputs) + return new_op.make_node(x, y, *inputs) + + # Validate that we have the right number of tensor inputs for our idx_list + if len(inputs) != self.expected_inputs_len: + raise ValueError( + f"Expected {self.expected_inputs_len} tensor inputs but got {len(inputs)}" + ) + new_inputs = [] for inp in inputs: if isinstance(inp, list | tuple): @@ -2877,9 +3146,43 @@ def make_node(self, x, y, *inputs): ) def perform(self, node, inputs, out_): - x, y, *indices = inputs + x, y, *tensor_inputs = inputs - check_advanced_indexing_dimensions(x, indices) + # Reconstruct the full tuple of indices from idx_list and inputs (newaxis handled by __getitem__) + full_indices = [] + input_idx = 0 + + for entry in self.idx_list: + if isinstance(entry, slice): + # Reconstruct slice from idx_list and inputs + if entry.start is not None and isinstance(entry.start, Type): + start_val = tensor_inputs[input_idx] + input_idx += 1 + else: + start_val = entry.start + + if entry.stop is not None and isinstance(entry.stop, Type): + stop_val = tensor_inputs[input_idx] + input_idx += 1 + else: + stop_val = entry.stop + + if entry.step is not None and isinstance(entry.step, Type): + step_val = tensor_inputs[input_idx] + input_idx += 1 + else: + step_val = entry.step + + full_indices.append(slice(start_val, stop_val, step_val)) + elif isinstance(entry, Type): + # This is a numerical index - get from inputs + if input_idx < len(tensor_inputs): + full_indices.append(tensor_inputs[input_idx]) + input_idx += 1 + else: + raise ValueError("Mismatch between idx_list and inputs") + + check_advanced_indexing_dimensions(x, full_indices) (out,) = out_ if not self.inplace: @@ -2888,11 +3191,11 @@ def perform(self, node, inputs, out_): out[0] = x if self.set_instead_of_inc: - out[0][tuple(indices)] = y + out[0][tuple(full_indices)] = y elif self.ignore_duplicates: - out[0][tuple(indices)] += y + out[0][tuple(full_indices)] += y else: - np.add.at(out[0], tuple(indices), y) + np.add.at(out[0], tuple(full_indices), y) def infer_shape(self, fgraph, node, ishapes): return [ishapes[0]] @@ -2922,10 +3225,14 @@ def grad(self, inpt, output_gradients): raise NotImplementedError("No support for complex grad yet") else: if self.set_instead_of_inc: - gx = advanced_set_subtensor(outgrad, y.zeros_like(), *idxs) + gx = ( + AdvancedIncSubtensor(self.idx_list, set_instead_of_inc=True) + .make_node(outgrad, y.zeros_like(), *idxs) + .outputs[0] + ) else: gx = outgrad - gy = advanced_subtensor(outgrad, *idxs) + gy = AdvancedSubtensor(self.idx_list).make_node(outgrad, *idxs).outputs[0] # Make sure to sum gy over the dimensions of y that have been # added or broadcasted gy = _sum_grad_over_bcasted_dims(y, gy) @@ -2945,7 +3252,7 @@ def non_consecutive_adv_indexing(node: Apply) -> bool: This function checks if the advanced indexing is non-consecutive, in which case the advanced index dimensions are placed on the left of the - output array, regardless of their opriginal position. + output array, regardless of their original position. See: https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing @@ -2960,16 +3267,153 @@ def non_consecutive_adv_indexing(node: Apply) -> bool: bool True if the advanced indexing is non-consecutive, False otherwise. """ - _, _, *idxs = node.inputs - return _non_consecutive_adv_indexing(idxs) + # Reconstruct the full indices from idx_list and inputs to check consecutivity (newaxis handled by __getitem__) + op = node.op + tensor_inputs = node.inputs[2:] # Skip x and y + full_indices = [] + input_idx = 0 + + for entry in op.idx_list: + if isinstance(entry, slice): + full_indices.append(slice(None)) # Represent as basic slice + elif isinstance(entry, Type): + # This is a numerical index - get from inputs + if input_idx < len(tensor_inputs): + full_indices.append(tensor_inputs[input_idx]) + input_idx += 1 + + return _non_consecutive_adv_indexing(full_indices) -advanced_inc_subtensor = AdvancedIncSubtensor() -advanced_set_subtensor = AdvancedIncSubtensor(set_instead_of_inc=True) -advanced_inc_subtensor_nodup = AdvancedIncSubtensor(ignore_duplicates=True) -advanced_set_subtensor_nodup = AdvancedIncSubtensor( - set_instead_of_inc=True, ignore_duplicates=True -) + +def advanced_subtensor(x, *args): + """Create an AdvancedSubtensor operation. + + This function converts the arguments to work with the new AdvancedSubtensor + interface that separates slice structure from variable inputs. + + Note: newaxis (None) should be handled by __getitem__ using dimshuffle + before calling this function. + """ + # Convert args using as_index_variable (like original AdvancedSubtensor did) + processed_args = tuple(map(as_index_variable, args)) + + # Now create idx_list and extract inputs + idx_list = [] + input_vars = [] + + for arg in processed_args: + if isinstance(arg.type, SliceType): + # Handle SliceType - extract components and structure + if isinstance(arg, Constant): + # Constant slice + idx_list.append(arg.data) + elif arg.owner and isinstance(arg.owner.op, MakeSlice): + # Variable slice - extract components + start, stop, step = arg.owner.inputs + + # Convert components to types for idx_list + start_type = ( + index_vars_to_types(start, False) + if not isinstance(start.type, NoneTypeT) + else None + ) + stop_type = ( + index_vars_to_types(stop, False) + if not isinstance(stop.type, NoneTypeT) + else None + ) + step_type = ( + index_vars_to_types(step, False) + if not isinstance(step.type, NoneTypeT) + else None + ) + + idx_list.append(slice(start_type, stop_type, step_type)) + + # Add variable components to inputs + if not isinstance(start.type, NoneTypeT): + input_vars.append(start) + if not isinstance(stop.type, NoneTypeT): + input_vars.append(stop) + if not isinstance(step.type, NoneTypeT): + input_vars.append(step) + else: + # Generic SliceType variable + idx_list.append(arg.type) + input_vars.append(arg) + else: + # Tensor index (should not be NoneType since newaxis handled in __getitem__) + idx_list.append(index_vars_to_types(arg, allow_advanced=True)) + input_vars.append(arg) + + return AdvancedSubtensor(idx_list)(x, *input_vars) + + +def advanced_inc_subtensor(x, y, *args, **kwargs): + """Create an AdvancedIncSubtensor operation for incrementing. + + Note: newaxis (None) should be handled by __getitem__ using dimshuffle + before calling this function. + """ + # Convert args using as_index_variable (like original AdvancedIncSubtensor would) + processed_args = tuple(map(as_index_variable, args)) + + # Now create idx_list and extract inputs + idx_list = [] + input_vars = [] + + for arg in processed_args: + if isinstance(arg.type, SliceType): + # Handle SliceType - extract components and structure + if isinstance(arg, Constant): + # Constant slice + idx_list.append(arg.data) + elif arg.owner and isinstance(arg.owner.op, MakeSlice): + # Variable slice - extract components + start, stop, step = arg.owner.inputs + + # Convert components to types for idx_list + start_type = ( + index_vars_to_types(start, False) + if not isinstance(start.type, NoneTypeT) + else None + ) + stop_type = ( + index_vars_to_types(stop, False) + if not isinstance(stop.type, NoneTypeT) + else None + ) + step_type = ( + index_vars_to_types(step, False) + if not isinstance(step.type, NoneTypeT) + else None + ) + + idx_list.append(slice(start_type, stop_type, step_type)) + + # Add variable components to inputs + if not isinstance(start.type, NoneTypeT): + input_vars.append(start) + if not isinstance(stop.type, NoneTypeT): + input_vars.append(stop) + if not isinstance(step.type, NoneTypeT): + input_vars.append(step) + else: + # Generic SliceType variable + idx_list.append(arg.type) + input_vars.append(arg) + else: + # Tensor index (should not be NoneType since newaxis handled in __getitem__) + idx_list.append(index_vars_to_types(arg, allow_advanced=True)) + input_vars.append(arg) + + return AdvancedIncSubtensor(idx_list, **kwargs)(x, y, *input_vars) + + +def advanced_set_subtensor(x, y, *args, **kwargs): + """Create an AdvancedIncSubtensor operation for setting.""" + return advanced_inc_subtensor(x, y, *args, set_instead_of_inc=True, **kwargs) def take(a, indices, axis=None, mode="raise"): @@ -3169,3 +3613,108 @@ def flip( "slice_at_axis", "take", ] + + +@_vectorize_node.register(AdvancedIncSubtensor) +def vectorize_advanced_inc_subtensor(op: AdvancedIncSubtensor, node, *batch_inputs): + x, y, *idxs = node.inputs + batch_x, batch_y, *batch_idxs = batch_inputs + + x_is_batched = x.type.ndim < batch_x.type.ndim + idxs_are_batched = any( + batch_idx.type.ndim > idx.type.ndim + for batch_idx, idx in zip(batch_idxs, idxs, strict=True) + if isinstance(batch_idx, TensorVariable) + ) + + if idxs_are_batched or (x_is_batched and op.non_consecutive_adv_indexing(node)): + # Fallback to Blockwise if idxs are batched or if we have non contiguous advanced indexing + # which would put the indexed results to the left of the batch dimensions! + return vectorize_node_fallback(op, node, batch_x, batch_y, *batch_idxs) + # If y is batched more than x, we need to broadcast x to match y's batch dims + x_batch_ndim = batch_x.type.ndim - x.type.ndim + y_batch_ndim = batch_y.type.ndim - y.type.ndim + + # Ensure x has at least as many batch dims as y + if y_batch_ndim > x_batch_ndim: + diff = y_batch_ndim - x_batch_ndim + new_dims = (["x"] * diff) + list(range(batch_x.type.ndim)) + batch_x = batch_x.dimshuffle(new_dims) + x_batch_ndim = y_batch_ndim + + # Ensure x is broadcasted to match y's batch shape + # We use Alloc to broadcast batch_x to the required shape + if y_batch_ndim > 0: + # Optimization: check if broadcasting is needed + # This is hard to do symbolically without adding nodes. + # But we can check broadcastable flags. + + # Let's just use Alloc to be safe. + # batch_x might have shape (1, 1, 458). y has (1, 1000, ...). + # We want (1, 1000, 458). + # We can use alloc(batch_x, y_batch_shape[0], y_batch_shape[1], ..., *x.shape) + + # We need to unpack y_batch_shape. + # Since we don't know y_batch_ndim statically (it's int), we can't unpack easily in python arg list if it was variable. + # But y_batch_ndim is computed from types, so it is known at graph construction time. + + # Actually, we can use pt.broadcast_to if available, or just alloc. + # alloc takes *shape. + + # Let's collect shape tensors. + out_shape = [batch_y.shape[i] for i in range(y_batch_ndim)] + out_shape.extend(batch_x.shape[x_batch_ndim + i] for i in range(x.type.ndim)) + + batch_x = alloc(batch_x, *out_shape) + + # Otherwise we just need to add None slices for every new batch dim + empty_slices = (slice(None),) * x_batch_ndim + new_idx_list = empty_slices + op.idx_list + return AdvancedIncSubtensor( + new_idx_list, + inplace=op.inplace, + set_instead_of_inc=op.set_instead_of_inc, + ignore_duplicates=op.ignore_duplicates, + ).make_node(batch_x, batch_y, *batch_idxs) + + +@_vectorize_node.register(AdvancedIncSubtensor1) +def vectorize_advanced_inc_subtensor1(op: AdvancedIncSubtensor1, node, *batch_inputs): + x, y, idx = node.inputs + batch_x, batch_y, batch_idx = batch_inputs + + # x_is_batched = x.type.ndim < batch_x.type.ndim + idx_is_batched = idx.type.ndim < batch_idx.type.ndim + + if idx_is_batched: + return vectorize_node_fallback(op, node, batch_x, batch_y, batch_idx) + + # AdvancedIncSubtensor1 only supports indexing the first dimension. + # If x is batched, we can use AdvancedIncSubtensor which supports indexing any dimension. + x_batch_ndim = batch_x.type.ndim - x.type.ndim + y_batch_ndim = batch_y.type.ndim - y.type.ndim + + # Ensure x has at least as many batch dims as y + if y_batch_ndim > x_batch_ndim: + diff = y_batch_ndim - x_batch_ndim + new_dims = (["x"] * diff) + list(range(batch_x.type.ndim)) + batch_x = batch_x.dimshuffle(new_dims) + x_batch_ndim = y_batch_ndim + + # Ensure x is broadcasted to match y's batch shape + if y_batch_ndim > 0: + out_shape = [batch_y.shape[i] for i in range(y_batch_ndim)] + out_shape.extend(batch_x.shape[x_batch_ndim + i] for i in range(x.type.ndim)) + + batch_x = alloc(batch_x, *out_shape) + + empty_slices = (slice(None),) * x_batch_ndim + + # AdvancedIncSubtensor1 takes a single index tensor + new_idx_list = (*empty_slices, batch_idx.type) + + return AdvancedIncSubtensor( + new_idx_list, + inplace=op.inplace, + set_instead_of_inc=op.set_instead_of_inc, + ).make_node(batch_x, batch_y, batch_idx) diff --git a/pytensor/tensor/variable.py b/pytensor/tensor/variable.py index 31e08fd39b..a1d6fd1fc6 100644 --- a/pytensor/tensor/variable.py +++ b/pytensor/tensor/variable.py @@ -438,6 +438,65 @@ def trunc(self): def astype(self, dtype): return pt.basic.cast(self, dtype) + def _getitem_with_newaxis(self, args): + """Handle newaxis (None) for both basic and advanced indexing. + + `np.newaxis` (i.e. `None`) in NumPy indexing mean "add a new + broadcastable dimension at this location". Since PyTensor adds + new broadcastable dimensions via the `DimShuffle` `Op`, the + following code uses said `Op` to add one of the new axes and + then uses recursion to apply any other indices and add any + remaining new axes. + """ + counter = 0 + pattern = [] + new_args = [] + for arg in args: + if arg is np.newaxis or arg is NoneConst: + pattern.append("x") + new_args.append(slice(None)) + else: + # Check for boolean index which consumes multiple dimensions + consumed_dims = 1 + try: + val = pt.subtensor.as_index_variable(arg) + if ( + hasattr(val, "type") + and isinstance(val.type, TensorType) + and val.type.dtype == "bool" + ): + consumed_dims = val.type.ndim + except Exception: + pass + + pattern.extend(range(counter, counter + consumed_dims)) + counter += consumed_dims + new_args.append(arg) + + pattern.extend(range(counter, self.ndim)) + + view = self.dimshuffle(pattern) + + # Check if we can return the view directly if all new_args are full slices + # We can't do arg == slice(None, None, None) as in + # Python 2.7, this call __lt__ if we have a slice + # with some symbolic variable. + full_slices = True + for arg in new_args: + if not ( + isinstance(arg, slice) + and (arg.start is None or arg.start is NoneConst) + and (arg.stop is None or arg.stop is NoneConst) + and (arg.step is None or arg.step is NoneConst) + ): + full_slices = False + break + + if full_slices: + return view + else: + return view.__getitem__(tuple(new_args)) + def __getitem__(self, args): def includes_bool(args_el): if isinstance(args_el, np.bool_ | bool) or ( @@ -539,55 +598,18 @@ def is_empty_array(val): else: advanced = True - if advanced: + # Handle newaxis (None) for both basic and advanced indexing + if np.newaxis in args or NoneConst in args: + return self._getitem_with_newaxis(args) + elif advanced: return pt.subtensor.advanced_subtensor(self, *args) else: - if np.newaxis in args or NoneConst in args: - # `np.newaxis` (i.e. `None`) in NumPy indexing mean "add a new - # broadcastable dimension at this location". Since PyTensor adds - # new broadcastable dimensions via the `DimShuffle` `Op`, the - # following code uses said `Op` to add one of the new axes and - # then uses recursion to apply any other indices and add any - # remaining new axes. - - counter = 0 - pattern = [] - new_args = [] - for arg in args: - if arg is np.newaxis or arg is NoneConst: - pattern.append("x") - new_args.append(slice(None, None, None)) - else: - pattern.append(counter) - counter += 1 - new_args.append(arg) - - pattern.extend(list(range(counter, self.ndim))) - - view = self.dimshuffle(pattern) - full_slices = True - for arg in new_args: - # We can't do arg == slice(None, None, None) as in - # Python 2.7, this call __lt__ if we have a slice - # with some symbolic variable. - if not ( - isinstance(arg, slice) - and (arg.start is None or arg.start is NoneConst) - and (arg.stop is None or arg.stop is NoneConst) - and (arg.step is None or arg.step is NoneConst) - ): - full_slices = False - if full_slices: - return view - else: - return view.__getitem__(tuple(new_args)) - else: - return pt.subtensor.Subtensor(args)( - self, - *pt.subtensor.get_slice_elements( - args, lambda entry: isinstance(entry, Variable) - ), - ) + return pt.subtensor.Subtensor(args)( + self, + *pt.subtensor.get_slice_elements( + args, lambda entry: isinstance(entry, Variable) + ), + ) def __setitem__(self, key, value): raise TypeError( diff --git a/tests/tensor/test_subtensor.py b/tests/tensor/test_subtensor.py index d8dadf0009..fa5a73805b 100644 --- a/tests/tensor/test_subtensor.py +++ b/tests/tensor/test_subtensor.py @@ -11,11 +11,10 @@ import pytensor import pytensor.scalar as scal import pytensor.tensor.basic as ptb -from pytensor import function -from pytensor.compile import DeepCopyOp, shared +from pytensor import config, function, shared +from pytensor.compile import DeepCopyOp from pytensor.compile.io import In from pytensor.compile.mode import Mode -from pytensor.configdefaults import config from pytensor.gradient import grad from pytensor.graph import Constant from pytensor.graph.basic import equal_computations @@ -622,7 +621,7 @@ def test_slice_symbol(self): (3, DimShuffle, np.index_exp[..., [0, 2, 3]]), (1, DimShuffle, np.index_exp[np.newaxis, ...]), ( - 1, + 3, AdvancedSubtensor, np.index_exp[..., np.newaxis, [1, 2]], ), @@ -2946,8 +2945,8 @@ def test_index_vars_to_types(): with pytest.raises(AdvancedIndexingError): index_vars_to_types(x) - with pytest.raises(TypeError): - index_vars_to_types(1) + # Integers are now allowed + assert index_vars_to_types(1) == 1 res = index_vars_to_types(iscalar) assert isinstance(res, scal.ScalarType) @@ -3055,7 +3054,6 @@ def core_fn(x, start): (11, 7, 5, 3), (2,), False, - marks=pytest.mark.xfail(raises=NotImplementedError), ), ( (lambda x, idx: x[:, idx, idx, :]), @@ -3071,7 +3069,6 @@ def core_fn(x, start): (11, 7, 5, 3, 5), (2,), True, - marks=pytest.mark.xfail(raises=NotImplementedError), ), # Core x, batched idx ((lambda x, idx: x[idx]), "(t1),(idx)->(tx)", (7,), (11, 2), True), @@ -3084,7 +3081,6 @@ def core_fn(x, start): (11, 7, 5, 3), (11, 2), True, - marks=pytest.mark.xfail(raises=NotImplementedError), ), ], )