Skip to content

Commit 34cc8eb

Browse files
committed
Implement AdvancedIncSubtensorExplicit
1 parent a5fb911 commit 34cc8eb

File tree

2 files changed

+317
-9
lines changed

2 files changed

+317
-9
lines changed

pytensor/tensor/subtensor.py

Lines changed: 249 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2738,10 +2738,42 @@ def is_bool_index(idx):
27382738
assert node.outputs[0].ndim == len(res_shape)
27392739
return [res_shape]
27402740

2741+
def _broadcast_indices(self, x, indices):
2742+
new_indices = []
2743+
x_dim = 0
2744+
for idx in indices:
2745+
if idx is None:
2746+
new_indices.append(idx)
2747+
continue
2748+
if isinstance(idx, slice):
2749+
x_dim += 1
2750+
new_indices.append(idx)
2751+
continue
2752+
2753+
# Check for boolean
2754+
if hasattr(idx, "dtype") and (idx.dtype == bool or idx.dtype == np.bool_):
2755+
x_dim += idx.ndim
2756+
new_indices.append(idx)
2757+
continue
2758+
2759+
# Integer array
2760+
if x_dim < x.ndim and x.shape[x_dim] == 1:
2761+
# Broadcast: replace with zeros
2762+
new_indices.append(np.zeros_like(idx))
2763+
else:
2764+
new_indices.append(idx)
2765+
x_dim += 1
2766+
return tuple(new_indices)
2767+
27412768
def perform(self, node, inputs, out_):
27422769
(out,) = out_
2743-
check_advanced_indexing_dimensions(inputs[0], inputs[1:])
2744-
rval = inputs[0].__getitem__(tuple(inputs[1:]))
2770+
x = inputs[0]
2771+
indices = inputs[1:]
2772+
2773+
indices = self._broadcast_indices(x, indices)
2774+
2775+
check_advanced_indexing_dimensions(x, indices)
2776+
rval = x.__getitem__(tuple(indices))
27452777
# When there are no arrays, we are not actually doing advanced
27462778
# indexing, so __getitem__ will not return a copy.
27472779
# Since no view_map is set, we need to copy the returned value
@@ -2807,6 +2839,97 @@ def non_consecutive_adv_indexing(node: Apply) -> bool:
28072839
advanced_subtensor = AdvancedSubtensor()
28082840

28092841

2842+
class BatchedSliceType(Type):
2843+
def filter(self, x, strict=False, allow_downcast=None):
2844+
return x
2845+
2846+
def __str__(self):
2847+
return "BatchedSliceType"
2848+
2849+
2850+
batched_slice_type = BatchedSliceType()
2851+
2852+
2853+
class BatchedSlice(Op):
2854+
__props__ = ()
2855+
2856+
def make_node(self, start, stop, step):
2857+
return Apply(self, [start, stop, step], [batched_slice_type()])
2858+
2859+
def perform(self, node, inp, out_):
2860+
raise NotImplementedError("BatchedSlice is a placeholder")
2861+
2862+
2863+
@_vectorize_node.register(MakeSlice)
2864+
def vectorize_make_slice(op, node, *batched_inputs):
2865+
is_batched = False
2866+
for orig, batched in zip(node.inputs, batched_inputs):
2867+
if hasattr(batched.type, "ndim") and hasattr(orig.type, "ndim"):
2868+
if batched.type.ndim > orig.type.ndim:
2869+
is_batched = True
2870+
break
2871+
2872+
if is_batched:
2873+
return BatchedSlice().make_node(*batched_inputs)
2874+
return op.make_node(*batched_inputs)
2875+
2876+
2877+
class AdvancedIncSubtensorExplicit(Op):
2878+
__props__ = ("structure", "set_instead_of_inc", "inplace", "ignore_duplicates")
2879+
2880+
def __init__(
2881+
self,
2882+
structure,
2883+
set_instead_of_inc=False,
2884+
inplace=False,
2885+
ignore_duplicates=False,
2886+
):
2887+
self.structure = structure
2888+
self.set_instead_of_inc = set_instead_of_inc
2889+
self.inplace = inplace
2890+
self.ignore_duplicates = ignore_duplicates
2891+
2892+
def make_node(self, x, y, *inputs):
2893+
return Apply(self, [x, y, *inputs], [x.type()])
2894+
2895+
def perform(self, node, inputs, out_):
2896+
x, y, *flat_indices = inputs
2897+
2898+
indices = []
2899+
idx_ptr = 0
2900+
for kind in self.structure:
2901+
if kind == "slice":
2902+
(
2903+
start_val,
2904+
start_none,
2905+
stop_val,
2906+
stop_none,
2907+
step_val,
2908+
step_none,
2909+
) = flat_indices[idx_ptr : idx_ptr + 6]
2910+
start = None if start_none else start_val
2911+
stop = None if stop_none else stop_val
2912+
step = None if step_none else step_val
2913+
indices.append(slice(start, stop, step))
2914+
idx_ptr += 6
2915+
else:
2916+
indices.append(flat_indices[idx_ptr])
2917+
idx_ptr += 1
2918+
2919+
(out,) = out_
2920+
if not self.inplace:
2921+
out[0] = x.copy()
2922+
else:
2923+
out[0] = x
2924+
2925+
if self.set_instead_of_inc:
2926+
out[0][tuple(indices)] = y
2927+
elif self.ignore_duplicates:
2928+
out[0][tuple(indices)] += y
2929+
else:
2930+
np.add.at(out[0], tuple(indices), y)
2931+
2932+
28102933
@_vectorize_node.register(AdvancedSubtensor)
28112934
def vectorize_advanced_subtensor(op: AdvancedSubtensor, node, *batch_inputs):
28122935
x, *idxs = node.inputs
@@ -2967,9 +3090,130 @@ def non_consecutive_adv_indexing(node: Apply) -> bool:
29673090
advanced_inc_subtensor = AdvancedIncSubtensor()
29683091
advanced_set_subtensor = AdvancedIncSubtensor(set_instead_of_inc=True)
29693092
advanced_inc_subtensor_nodup = AdvancedIncSubtensor(ignore_duplicates=True)
2970-
advanced_set_subtensor_nodup = AdvancedIncSubtensor(
2971-
set_instead_of_inc=True, ignore_duplicates=True
2972-
)
3093+
3094+
3095+
@_vectorize_node.register(AdvancedIncSubtensor)
3096+
def advanced_inc_subtensor_vectorize_node(op, node, *batched_inputs):
3097+
x = batched_inputs[0]
3098+
y = batched_inputs[1]
3099+
indices = batched_inputs[2:]
3100+
3101+
# Check if we have batched slices (BatchedSliceType)
3102+
has_batched_slices = any(isinstance(idx.type, BatchedSliceType) for idx in indices)
3103+
3104+
if has_batched_slices:
3105+
from pytensor.tensor.blockwise import Blockwise, safe_signature
3106+
3107+
structure = []
3108+
flat_inputs = []
3109+
3110+
# We need to construct inputs for AdvancedIncSubtensorExplicit
3111+
# x and y are passed as is (Blockwise handles them)
3112+
3113+
for idx in indices:
3114+
if isinstance(idx.type, BatchedSliceType):
3115+
structure.append("slice")
3116+
# Unwrap BatchedSlice
3117+
# idx is a Variable output of BatchedSlice node
3118+
bs_node = idx.owner
3119+
assert isinstance(bs_node.op, BatchedSlice)
3120+
3121+
for comp in bs_node.inputs:
3122+
if isinstance(comp.type, NoneTypeT):
3123+
# Pass dummy and True flag
3124+
flat_inputs.append(tensor_constant(0, dtype="int8"))
3125+
flat_inputs.append(tensor_constant(True, dtype="bool"))
3126+
else:
3127+
# Pass component and False flag
3128+
flat_inputs.append(comp)
3129+
flat_inputs.append(tensor_constant(False, dtype="bool"))
3130+
else:
3131+
structure.append("tensor")
3132+
flat_inputs.append(idx)
3133+
3134+
core_op = AdvancedIncSubtensorExplicit(
3135+
structure=tuple(structure),
3136+
set_instead_of_inc=op.set_instead_of_inc,
3137+
inplace=op.inplace,
3138+
ignore_duplicates=op.ignore_duplicates,
3139+
)
3140+
3141+
# Signature
3142+
# x: (n, m, ...), y: (n, m, ...), indices... -> (n, m, ...)
3143+
3144+
x_core_ndim = node.inputs[0].ndim
3145+
y_core_ndim = node.inputs[1].ndim
3146+
3147+
input_core_ndims = [x_core_ndim, y_core_ndim]
3148+
3149+
# For indices
3150+
for i, idx in enumerate(indices):
3151+
if isinstance(idx.type, BatchedSliceType):
3152+
# 6 components, all scalar (0-d) for slice parameters
3153+
input_core_ndims.extend([0] * 6)
3154+
else:
3155+
# Tensor index
3156+
# Core ndim is the ndim of the original index
3157+
input_core_ndims.append(node.inputs[2 + i].ndim)
3158+
3159+
output_core_ndims = [node.outputs[0].ndim]
3160+
3161+
signature = safe_signature(input_core_ndims, output_core_ndims)
3162+
3163+
return Blockwise(core_op, signature=signature).make_node(x, y, *flat_inputs)
3164+
3165+
x_orig = node.inputs[0]
3166+
x_batch_ndim = x.ndim - x_orig.ndim
3167+
3168+
y_orig = node.inputs[1]
3169+
y_batch_ndim = y.ndim - y_orig.ndim
3170+
3171+
batch_ndim = max(x_batch_ndim, y_batch_ndim)
3172+
3173+
if batch_ndim == 0:
3174+
return op.make_node(x, y, *indices)
3175+
3176+
if x_batch_ndim < batch_ndim:
3177+
# Broadcast x to match batch dimensions
3178+
# We assume the batch dimensions are the first batch_ndim dimensions
3179+
# and that y has them (since y_batch_ndim >= batch_ndim)
3180+
from pytensor.tensor import concatenate
3181+
3182+
batch_shape = y.shape[:batch_ndim]
3183+
full_shape = concatenate([batch_shape, x.shape])
3184+
x = alloc(x, *full_shape)
3185+
3186+
# Check if any index is batched
3187+
any_batched_index = False
3188+
for i, idx in enumerate(indices):
3189+
orig_idx = node.inputs[2 + i]
3190+
if (
3191+
hasattr(idx, "ndim")
3192+
and hasattr(orig_idx, "ndim")
3193+
and idx.ndim > orig_idx.ndim
3194+
):
3195+
any_batched_index = True
3196+
break
3197+
3198+
if not any_batched_index:
3199+
# Simple case: prepend slice(None) for each batch dim
3200+
sl = make_slice(None, None, None)
3201+
new_indices = [sl] * batch_ndim + list(indices)
3202+
return op.make_node(x, y, *new_indices)
3203+
3204+
from pytensor.tensor import arange
3205+
3206+
batch_indices = []
3207+
for d in range(batch_ndim):
3208+
dim_len = x.shape[d]
3209+
idx = arange(dim_len)
3210+
pattern = ["x"] * batch_ndim
3211+
pattern[d] = 0
3212+
idx = idx.dimshuffle(pattern)
3213+
batch_indices.append(idx)
3214+
3215+
new_indices = batch_indices + list(indices)
3216+
return op.make_node(x, y, *new_indices)
29733217

29743218

29753219
def take(a, indices, axis=None, mode="raise"):

tests/tensor/test_subtensor.py

Lines changed: 68 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from pytensor.compile.mode import Mode
1818
from pytensor.configdefaults import config
1919
from pytensor.gradient import grad
20-
from pytensor.graph import Constant
20+
from pytensor.graph import Constant, vectorize_graph
2121
from pytensor.graph.basic import equal_computations
2222
from pytensor.graph.op import get_test_value
2323
from pytensor.graph.rewriting.utils import is_same_graph
@@ -3047,15 +3047,12 @@ def core_fn(x, start):
30473047
(2,),
30483048
False,
30493049
),
3050-
# (this is currently failing because PyTensor tries to vectorize the slice(None) operation,
3051-
# due to the exact same None constant being used there and in the np.newaxis)
30523050
pytest.param(
30533051
(lambda x, idx: x[:, idx, None]),
30543052
"(7,5,3),(2)->(7,2,1,3)",
30553053
(11, 7, 5, 3),
30563054
(2,),
30573055
False,
3058-
marks=pytest.mark.xfail(raises=NotImplementedError),
30593056
),
30603057
(
30613058
(lambda x, idx: x[:, idx, idx, :]),
@@ -3218,3 +3215,70 @@ def test_advanced_incsubtensor1(self, func, static_shape, gc, benchmark):
32183215
)
32193216
fn.vm.allow_gc = gc
32203217
benchmark(fn, x_values)
3218+
3219+
3220+
class TestAdvancedIncSubtensorVectorization:
3221+
def test_vectorize_advanced_inc_subtensor_slice(self):
3222+
# Regression test for vectorization of AdvancedIncSubtensor with slice inputs
3223+
x = matrix("x")
3224+
y = vector("y")
3225+
idx = as_tensor([0, 1])
3226+
3227+
# x[0:2, idx] = y
3228+
# This uses AdvancedIncSubtensor because of the vector index mixed with slice
3229+
z = set_subtensor(x[0:2, idx], y)
3230+
3231+
# Vectorize over a batch dimension
3232+
# batched_x: (B, N, M)
3233+
# batched_y: (B, 2)
3234+
batched_x = tensor3("bx")
3235+
batched_y = matrix("by")
3236+
3237+
out_batched = vectorize_graph(z, replace={x: batched_x, y: batched_y})
3238+
3239+
f = function([batched_x, batched_y], out_batched)
3240+
3241+
bx_val = np.zeros((2, 5, 5), dtype=config.floatX)
3242+
by_val = np.ones((2, 2), dtype=config.floatX)
3243+
3244+
res = f(bx_val, by_val)
3245+
3246+
# Verify result
3247+
# For each batch b:
3248+
# res[b, 0:2, [0, 1]] should be 1
3249+
assert np.all(res[:, 0:2, [0, 1]] == 1)
3250+
assert np.all(res[:, 2:, :] == 0)
3251+
3252+
def test_vectorize_advanced_inc_subtensor_batched_slice(self):
3253+
# Regression test for vectorization of AdvancedIncSubtensor with batched slice parameters
3254+
x = matrix("x")
3255+
s = lscalar("s")
3256+
# x[s:, [0, 0]] = 0
3257+
out = set_subtensor(x[s:, [0, 0]], 0)
3258+
3259+
# Vectorize s -> z (vector)
3260+
z = lvector("z")
3261+
3262+
out_batched = vectorize_graph(out, replace={s: z})
3263+
3264+
f = function([x, z], out_batched)
3265+
3266+
x_val = np.arange(12).reshape((4, 3)).astype(config.floatX)
3267+
z_val = np.array([1, 2], dtype="int64")
3268+
3269+
# For z=1: x[1:, [0,0]] = 0. Rows 1,2,3. Cols 0.
3270+
# For z=2: x[2:, [0,0]] = 0. Rows 2,3. Cols 0.
3271+
3272+
res = f(x_val, z_val)
3273+
3274+
# res shape: (2, 4, 3)
3275+
assert res.shape == (2, 4, 3)
3276+
3277+
expected_0 = x_val.copy()
3278+
expected_0[1:, [0, 0]] = 0
3279+
3280+
expected_1 = x_val.copy()
3281+
expected_1[2:, [0, 0]] = 0
3282+
3283+
assert np.allclose(res[0], expected_0)
3284+
assert np.allclose(res[1], expected_1)

0 commit comments

Comments
 (0)