Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
254 changes: 249 additions & 5 deletions pytensor/tensor/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2738,10 +2738,42 @@ def is_bool_index(idx):
assert node.outputs[0].ndim == len(res_shape)
return [res_shape]

def _broadcast_indices(self, x, indices):
new_indices = []
x_dim = 0
for idx in indices:
if idx is None:
new_indices.append(idx)
continue
if isinstance(idx, slice):
x_dim += 1
new_indices.append(idx)
continue

# Check for boolean
if hasattr(idx, "dtype") and (idx.dtype == bool or idx.dtype == np.bool_):
x_dim += idx.ndim
new_indices.append(idx)
continue

# Integer array
if x_dim < x.ndim and x.shape[x_dim] == 1:
# Broadcast: replace with zeros
new_indices.append(np.zeros_like(idx))
else:
new_indices.append(idx)
x_dim += 1
return tuple(new_indices)

def perform(self, node, inputs, out_):
(out,) = out_
check_advanced_indexing_dimensions(inputs[0], inputs[1:])
rval = inputs[0].__getitem__(tuple(inputs[1:]))
x = inputs[0]
indices = inputs[1:]

indices = self._broadcast_indices(x, indices)

check_advanced_indexing_dimensions(x, indices)
rval = x.__getitem__(tuple(indices))
# When there are no arrays, we are not actually doing advanced
# indexing, so __getitem__ will not return a copy.
# Since no view_map is set, we need to copy the returned value
Expand Down Expand Up @@ -2807,6 +2839,97 @@ def non_consecutive_adv_indexing(node: Apply) -> bool:
advanced_subtensor = AdvancedSubtensor()


class BatchedSliceType(Type):
def filter(self, x, strict=False, allow_downcast=None):
return x

def __str__(self):
return "BatchedSliceType"


batched_slice_type = BatchedSliceType()


class BatchedSlice(Op):
__props__ = ()

def make_node(self, start, stop, step):
return Apply(self, [start, stop, step], [batched_slice_type()])

def perform(self, node, inp, out_):
raise NotImplementedError("BatchedSlice is a placeholder")


@_vectorize_node.register(MakeSlice)
def vectorize_make_slice(op, node, *batched_inputs):
is_batched = False
for orig, batched in zip(node.inputs, batched_inputs):
if hasattr(batched.type, "ndim") and hasattr(orig.type, "ndim"):
if batched.type.ndim > orig.type.ndim:
is_batched = True
break

if is_batched:
return BatchedSlice().make_node(*batched_inputs)
return op.make_node(*batched_inputs)


class AdvancedIncSubtensorExplicit(Op):
__props__ = ("structure", "set_instead_of_inc", "inplace", "ignore_duplicates")

def __init__(
self,
structure,
set_instead_of_inc=False,
inplace=False,
ignore_duplicates=False,
):
self.structure = structure
self.set_instead_of_inc = set_instead_of_inc
self.inplace = inplace
self.ignore_duplicates = ignore_duplicates

def make_node(self, x, y, *inputs):
return Apply(self, [x, y, *inputs], [x.type()])

def perform(self, node, inputs, out_):
x, y, *flat_indices = inputs

indices = []
idx_ptr = 0
for kind in self.structure:
if kind == "slice":
(
start_val,
start_none,
stop_val,
stop_none,
step_val,
step_none,
) = flat_indices[idx_ptr : idx_ptr + 6]
start = None if start_none else start_val
stop = None if stop_none else stop_val
step = None if step_none else step_val
indices.append(slice(start, stop, step))
idx_ptr += 6
else:
indices.append(flat_indices[idx_ptr])
idx_ptr += 1

(out,) = out_
if not self.inplace:
out[0] = x.copy()
else:
out[0] = x

if self.set_instead_of_inc:
out[0][tuple(indices)] = y
elif self.ignore_duplicates:
out[0][tuple(indices)] += y
else:
np.add.at(out[0], tuple(indices), y)


@_vectorize_node.register(AdvancedSubtensor)
def vectorize_advanced_subtensor(op: AdvancedSubtensor, node, *batch_inputs):
x, *idxs = node.inputs
Expand Down Expand Up @@ -2967,9 +3090,130 @@ def non_consecutive_adv_indexing(node: Apply) -> bool:
advanced_inc_subtensor = AdvancedIncSubtensor()
advanced_set_subtensor = AdvancedIncSubtensor(set_instead_of_inc=True)
advanced_inc_subtensor_nodup = AdvancedIncSubtensor(ignore_duplicates=True)
advanced_set_subtensor_nodup = AdvancedIncSubtensor(
set_instead_of_inc=True, ignore_duplicates=True
)


@_vectorize_node.register(AdvancedIncSubtensor)
def advanced_inc_subtensor_vectorize_node(op, node, *batched_inputs):
x = batched_inputs[0]
y = batched_inputs[1]
indices = batched_inputs[2:]

# Check if we have batched slices (BatchedSliceType)
has_batched_slices = any(isinstance(idx.type, BatchedSliceType) for idx in indices)

if has_batched_slices:
from pytensor.tensor.blockwise import Blockwise, safe_signature

structure = []
flat_inputs = []

# We need to construct inputs for AdvancedIncSubtensorExplicit
# x and y are passed as is (Blockwise handles them)

for idx in indices:
if isinstance(idx.type, BatchedSliceType):
structure.append("slice")
# Unwrap BatchedSlice
# idx is a Variable output of BatchedSlice node
bs_node = idx.owner
assert isinstance(bs_node.op, BatchedSlice)

for comp in bs_node.inputs:
if isinstance(comp.type, NoneTypeT):
# Pass dummy and True flag
flat_inputs.append(tensor_constant(0, dtype="int8"))
flat_inputs.append(tensor_constant(True, dtype="bool"))
else:
# Pass component and False flag
flat_inputs.append(comp)
flat_inputs.append(tensor_constant(False, dtype="bool"))
else:
structure.append("tensor")
flat_inputs.append(idx)

core_op = AdvancedIncSubtensorExplicit(
structure=tuple(structure),
set_instead_of_inc=op.set_instead_of_inc,
inplace=op.inplace,
ignore_duplicates=op.ignore_duplicates,
)

# Signature
# x: (n, m, ...), y: (n, m, ...), indices... -> (n, m, ...)

x_core_ndim = node.inputs[0].ndim
y_core_ndim = node.inputs[1].ndim

input_core_ndims = [x_core_ndim, y_core_ndim]

# For indices
for i, idx in enumerate(indices):
if isinstance(idx.type, BatchedSliceType):
# 6 components, all scalar (0-d) for slice parameters
input_core_ndims.extend([0] * 6)
else:
# Tensor index
# Core ndim is the ndim of the original index
input_core_ndims.append(node.inputs[2 + i].ndim)

output_core_ndims = [node.outputs[0].ndim]

signature = safe_signature(input_core_ndims, output_core_ndims)

return Blockwise(core_op, signature=signature).make_node(x, y, *flat_inputs)

x_orig = node.inputs[0]
x_batch_ndim = x.ndim - x_orig.ndim

y_orig = node.inputs[1]
y_batch_ndim = y.ndim - y_orig.ndim

batch_ndim = max(x_batch_ndim, y_batch_ndim)

if batch_ndim == 0:
return op.make_node(x, y, *indices)

if x_batch_ndim < batch_ndim:
# Broadcast x to match batch dimensions
# We assume the batch dimensions are the first batch_ndim dimensions
# and that y has them (since y_batch_ndim >= batch_ndim)
from pytensor.tensor import concatenate

batch_shape = y.shape[:batch_ndim]
full_shape = concatenate([batch_shape, x.shape])
x = alloc(x, *full_shape)

# Check if any index is batched
any_batched_index = False
for i, idx in enumerate(indices):
orig_idx = node.inputs[2 + i]
if (
hasattr(idx, "ndim")
and hasattr(orig_idx, "ndim")
and idx.ndim > orig_idx.ndim
):
any_batched_index = True
break

if not any_batched_index:
# Simple case: prepend slice(None) for each batch dim
sl = make_slice(None, None, None)
new_indices = [sl] * batch_ndim + list(indices)
return op.make_node(x, y, *new_indices)

from pytensor.tensor import arange

batch_indices = []
for d in range(batch_ndim):
dim_len = x.shape[d]
idx = arange(dim_len)
pattern = ["x"] * batch_ndim
pattern[d] = 0
idx = idx.dimshuffle(pattern)
batch_indices.append(idx)

new_indices = batch_indices + list(indices)
return op.make_node(x, y, *new_indices)


def take(a, indices, axis=None, mode="raise"):
Expand Down
72 changes: 68 additions & 4 deletions tests/tensor/test_subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from pytensor.compile.mode import Mode
from pytensor.configdefaults import config
from pytensor.gradient import grad
from pytensor.graph import Constant
from pytensor.graph import Constant, vectorize_graph
from pytensor.graph.basic import equal_computations
from pytensor.graph.op import get_test_value
from pytensor.graph.rewriting.utils import is_same_graph
Expand Down Expand Up @@ -3047,15 +3047,12 @@ def core_fn(x, start):
(2,),
False,
),
# (this is currently failing because PyTensor tries to vectorize the slice(None) operation,
# due to the exact same None constant being used there and in the np.newaxis)
pytest.param(
(lambda x, idx: x[:, idx, None]),
"(7,5,3),(2)->(7,2,1,3)",
(11, 7, 5, 3),
(2,),
False,
marks=pytest.mark.xfail(raises=NotImplementedError),
),
(
(lambda x, idx: x[:, idx, idx, :]),
Expand Down Expand Up @@ -3218,3 +3215,70 @@ def test_advanced_incsubtensor1(self, func, static_shape, gc, benchmark):
)
fn.vm.allow_gc = gc
benchmark(fn, x_values)


class TestAdvancedIncSubtensorVectorization:
def test_vectorize_advanced_inc_subtensor_slice(self):
# Regression test for vectorization of AdvancedIncSubtensor with slice inputs
x = matrix("x")
y = vector("y")
idx = as_tensor([0, 1])

# x[0:2, idx] = y
# This uses AdvancedIncSubtensor because of the vector index mixed with slice
z = set_subtensor(x[0:2, idx], y)

# Vectorize over a batch dimension
# batched_x: (B, N, M)
# batched_y: (B, 2)
batched_x = tensor3("bx")
batched_y = matrix("by")

out_batched = vectorize_graph(z, replace={x: batched_x, y: batched_y})

f = function([batched_x, batched_y], out_batched)

bx_val = np.zeros((2, 5, 5), dtype=config.floatX)
by_val = np.ones((2, 2), dtype=config.floatX)

res = f(bx_val, by_val)

# Verify result
# For each batch b:
# res[b, 0:2, [0, 1]] should be 1
assert np.all(res[:, 0:2, [0, 1]] == 1)
assert np.all(res[:, 2:, :] == 0)

def test_vectorize_advanced_inc_subtensor_batched_slice(self):
# Regression test for vectorization of AdvancedIncSubtensor with batched slice parameters
x = matrix("x")
s = lscalar("s")
# x[s:, [0, 0]] = 0
out = set_subtensor(x[s:, [0, 0]], 0)

# Vectorize s -> z (vector)
z = lvector("z")

out_batched = vectorize_graph(out, replace={s: z})

f = function([x, z], out_batched)

x_val = np.arange(12).reshape((4, 3)).astype(config.floatX)
z_val = np.array([1, 2], dtype="int64")

# For z=1: x[1:, [0,0]] = 0. Rows 1,2,3. Cols 0.
# For z=2: x[2:, [0,0]] = 0. Rows 2,3. Cols 0.

res = f(x_val, z_val)

# res shape: (2, 4, 3)
assert res.shape == (2, 4, 3)

expected_0 = x_val.copy()
expected_0[1:, [0, 0]] = 0

expected_1 = x_val.copy()
expected_1[2:, [0, 0]] = 0

assert np.allclose(res[0], expected_0)
assert np.allclose(res[1], expected_1)