Skip to content

Commit a3634dd

Browse files
CopilotricardoV94
authored andcommitted
Refactor newaxis handling: move to __getitem__ level, unify with Subtensor approach
Co-authored-by: ricardoV94 <28983449+ricardoV94@users.noreply.github.com>
1 parent 737b8cb commit a3634dd

File tree

2 files changed

+75
-93
lines changed

2 files changed

+75
-93
lines changed

pytensor/tensor/subtensor.py

Lines changed: 28 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -2612,16 +2612,12 @@ def make_node(self, x, *inputs):
26122612
if len(inputs) != len(expected_inputs):
26132613
raise ValueError(f"Expected {len(expected_inputs)} inputs but got {len(inputs)}")
26142614

2615-
# Build explicit_indices for shape inference
2615+
# Build explicit_indices for shape inference (newaxis handled by __getitem__)
26162616
explicit_indices = []
2617-
new_axes = []
26182617
input_idx = 0
26192618

26202619
for i, entry in enumerate(idx_list):
2621-
if entry is np.newaxis:
2622-
new_axes.append(len(explicit_indices))
2623-
explicit_indices.append(np.newaxis)
2624-
elif isinstance(entry, slice):
2620+
if isinstance(entry, slice):
26252621
# Reconstruct slice with actual values from inputs
26262622
if entry.start is not None and isinstance(entry.start, Type):
26272623
start_val = inputs[input_idx]
@@ -2655,7 +2651,7 @@ def make_node(self, x, *inputs):
26552651
)
26562652

26572653
# Check static shape aligned
2658-
axis = len(explicit_indices) - len(new_axes)
2654+
axis = len(explicit_indices)
26592655
indexed_shape = x.type.shape[axis : axis + inp.type.ndim]
26602656
for j, (indexed_length, indexer_length) in enumerate(
26612657
zip(indexed_shape, inp.type.shape)
@@ -2681,25 +2677,20 @@ def make_node(self, x, *inputs):
26812677
else:
26822678
raise ValueError(f"Invalid entry in idx_list: {entry}")
26832679

2684-
if (len(explicit_indices) - len(new_axes)) > x.type.ndim:
2680+
if len(explicit_indices) > x.type.ndim:
26852681
raise IndexError(
2686-
f"too many indices for array: tensor is {x.type.ndim}-dimensional, but {len(explicit_indices) - len(new_axes)} were indexed"
2682+
f"too many indices for array: tensor is {x.type.ndim}-dimensional, but {len(explicit_indices)} were indexed"
26872683
)
26882684

2689-
# Perform basic and advanced indexing shape inference separately
2685+
# Perform basic and advanced indexing shape inference separately (no newaxis)
26902686
basic_group_shape = []
26912687
advanced_indices = []
26922688
adv_group_axis = None
26932689
last_adv_group_axis = None
2694-
expanded_x_shape = tuple(
2695-
np.insert(np.array(x.type.shape, dtype=object), 1, new_axes)
2696-
)
26972690
for i, (idx, dim_length) in enumerate(
2698-
zip_longest(explicit_indices, expanded_x_shape, fillvalue=slice(None))
2691+
zip_longest(explicit_indices, x.type.shape, fillvalue=slice(None))
26992692
):
2700-
if idx is np.newaxis:
2701-
basic_group_shape.append(1) # New-axis
2702-
elif isinstance(idx, slice):
2693+
if isinstance(idx, slice):
27032694
basic_group_shape.append(slice_static_length(idx, dim_length))
27042695
else: # TensorType (advanced index)
27052696
# Keep track of advanced group axis
@@ -2752,16 +2743,14 @@ def is_bool_index(idx):
27522743
or getattr(idx, "dtype", None) == "bool"
27532744
)
27542745

2755-
# Reconstruct the full indices from idx_list and inputs (like perform method)
2746+
# Reconstruct the full indices from idx_list and inputs (newaxis handled by __getitem__)
27562747
inputs = node.inputs[1:]
27572748

27582749
full_indices = []
27592750
input_idx = 0
27602751

27612752
for entry in self.idx_list:
2762-
if entry is np.newaxis:
2763-
full_indices.append(np.newaxis)
2764-
elif isinstance(entry, slice):
2753+
if isinstance(entry, slice):
27652754
# Reconstruct slice from idx_list and inputs
27662755
if entry.start is not None and isinstance(entry.start, Type):
27672756
start_val = inputs[input_idx]
@@ -2794,8 +2783,6 @@ def is_bool_index(idx):
27942783
for idx in full_indices:
27952784
if isinstance(idx, slice):
27962785
index_shapes.append(idx)
2797-
elif idx is np.newaxis:
2798-
index_shapes.append(idx)
27992786
elif hasattr(idx, 'type'):
28002787
# Mixed bool indexes are converted to nonzero entries
28012788
shape0_op = Shape_i(0)
@@ -2837,17 +2824,15 @@ def is_bool_index(idx):
28372824
def perform(self, node, inputs, out_):
28382825
(out,) = out_
28392826

2840-
# Reconstruct the full tuple of indices from idx_list and inputs
2827+
# Reconstruct the full tuple of indices from idx_list and inputs (newaxis handled by __getitem__)
28412828
x = inputs[0]
28422829
tensor_inputs = inputs[1:]
28432830

28442831
full_indices = []
28452832
input_idx = 0
28462833

28472834
for entry in self.idx_list:
2848-
if entry is np.newaxis:
2849-
full_indices.append(np.newaxis)
2850-
elif isinstance(entry, slice):
2835+
if isinstance(entry, slice):
28512836
# Reconstruct slice from idx_list and inputs
28522837
if entry.start is not None and isinstance(entry.start, Type):
28532838
start_val = tensor_inputs[input_idx]
@@ -2938,7 +2923,7 @@ def non_consecutive_adv_indexing(node: Apply) -> bool:
29382923
bool
29392924
True if the advanced indexing is non-consecutive, False otherwise.
29402925
"""
2941-
# Reconstruct the full indices from idx_list and inputs to check consecutivity
2926+
# Reconstruct the full indices from idx_list and inputs to check consecutivity (newaxis handled by __getitem__)
29422927
op = node.op
29432928
tensor_inputs = node.inputs[1:]
29442929

@@ -2948,8 +2933,6 @@ def non_consecutive_adv_indexing(node: Apply) -> bool:
29482933
for entry in op.idx_list:
29492934
if isinstance(entry, slice):
29502935
full_indices.append(slice(None)) # Represent as basic slice
2951-
elif entry is np.newaxis:
2952-
full_indices.append(np.newaxis)
29532936
elif isinstance(entry, Type):
29542937
# This is a numerical index - get from inputs
29552938
if input_idx < len(tensor_inputs):
@@ -3035,14 +3018,12 @@ def make_node(self, x, y, *inputs):
30353018
def perform(self, node, inputs, out_):
30363019
x, y, *tensor_inputs = inputs
30373020

3038-
# Reconstruct the full tuple of indices from idx_list and inputs
3021+
# Reconstruct the full tuple of indices from idx_list and inputs (newaxis handled by __getitem__)
30393022
full_indices = []
30403023
input_idx = 0
30413024

30423025
for entry in self.idx_list:
3043-
if entry is np.newaxis:
3044-
full_indices.append(np.newaxis)
3045-
elif isinstance(entry, slice):
3026+
if isinstance(entry, slice):
30463027
# Reconstruct slice from idx_list and inputs
30473028
if entry.start is not None and isinstance(entry.start, Type):
30483029
start_val = tensor_inputs[input_idx]
@@ -3154,7 +3135,7 @@ def non_consecutive_adv_indexing(node: Apply) -> bool:
31543135
bool
31553136
True if the advanced indexing is non-consecutive, False otherwise.
31563137
"""
3157-
# Reconstruct the full indices from idx_list and inputs to check consecutivity
3138+
# Reconstruct the full indices from idx_list and inputs to check consecutivity (newaxis handled by __getitem__)
31583139
op = node.op
31593140
tensor_inputs = node.inputs[2:] # Skip x and y
31603141

@@ -3164,8 +3145,6 @@ def non_consecutive_adv_indexing(node: Apply) -> bool:
31643145
for entry in op.idx_list:
31653146
if isinstance(entry, slice):
31663147
full_indices.append(slice(None)) # Represent as basic slice
3167-
elif entry is np.newaxis:
3168-
full_indices.append(np.newaxis)
31693148
elif isinstance(entry, Type):
31703149
# This is a numerical index - get from inputs
31713150
if input_idx < len(tensor_inputs):
@@ -3180,6 +3159,9 @@ def advanced_subtensor(x, *args):
31803159
31813160
This function converts the arguments to work with the new AdvancedSubtensor
31823161
interface that separates slice structure from variable inputs.
3162+
3163+
Note: newaxis (None) should be handled by __getitem__ using dimshuffle
3164+
before calling this function.
31833165
"""
31843166
# Convert args using as_index_variable (like original AdvancedSubtensor did)
31853167
processed_args = tuple(map(as_index_variable, args))
@@ -3189,9 +3171,7 @@ def advanced_subtensor(x, *args):
31893171
input_vars = []
31903172

31913173
for arg in processed_args:
3192-
if isinstance(arg.type, NoneTypeT):
3193-
idx_list.append(np.newaxis)
3194-
elif isinstance(arg.type, SliceType):
3174+
if isinstance(arg.type, SliceType):
31953175
# Handle SliceType - extract components and structure
31963176
if isinstance(arg, Constant):
31973177
# Constant slice
@@ -3218,15 +3198,19 @@ def advanced_subtensor(x, *args):
32183198
# Other slice case
32193199
idx_list.append(slice(None))
32203200
else:
3221-
# Tensor index
3201+
# Tensor index (should not be NoneType since newaxis handled in __getitem__)
32223202
idx_list.append(index_vars_to_types(arg))
32233203
input_vars.append(arg)
32243204

32253205
return AdvancedSubtensor(idx_list).make_node(x, *input_vars).outputs[0]
32263206

32273207

32283208
def advanced_inc_subtensor(x, y, *args, **kwargs):
3229-
"""Create an AdvancedIncSubtensor operation for incrementing."""
3209+
"""Create an AdvancedIncSubtensor operation for incrementing.
3210+
3211+
Note: newaxis (None) should be handled by __getitem__ using dimshuffle
3212+
before calling this function.
3213+
"""
32303214
# Convert args using as_index_variable (like original AdvancedIncSubtensor would)
32313215
processed_args = tuple(map(as_index_variable, args))
32323216

@@ -3235,9 +3219,7 @@ def advanced_inc_subtensor(x, y, *args, **kwargs):
32353219
input_vars = []
32363220

32373221
for arg in processed_args:
3238-
if isinstance(arg.type, NoneTypeT):
3239-
idx_list.append(np.newaxis)
3240-
elif isinstance(arg.type, SliceType):
3222+
if isinstance(arg.type, SliceType):
32413223
# Handle SliceType - extract components and structure
32423224
if isinstance(arg, Constant):
32433225
# Constant slice
@@ -3264,7 +3246,7 @@ def advanced_inc_subtensor(x, y, *args, **kwargs):
32643246
# Other slice case
32653247
idx_list.append(slice(None))
32663248
else:
3267-
# Tensor index
3249+
# Tensor index (should not be NoneType since newaxis handled in __getitem__)
32683250
idx_list.append(index_vars_to_types(arg))
32693251
input_vars.append(arg)
32703252

pytensor/tensor/variable.py

Lines changed: 47 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -539,55 +539,55 @@ def is_empty_array(val):
539539
else:
540540
advanced = True
541541

542-
if advanced:
543-
return pt.subtensor.advanced_subtensor(self, *args)
544-
else:
545-
if np.newaxis in args or NoneConst in args:
546-
# `np.newaxis` (i.e. `None`) in NumPy indexing mean "add a new
547-
# broadcastable dimension at this location". Since PyTensor adds
548-
# new broadcastable dimensions via the `DimShuffle` `Op`, the
549-
# following code uses said `Op` to add one of the new axes and
550-
# then uses recursion to apply any other indices and add any
551-
# remaining new axes.
552-
553-
counter = 0
554-
pattern = []
555-
new_args = []
556-
for arg in args:
557-
if arg is np.newaxis or arg is NoneConst:
558-
pattern.append("x")
559-
new_args.append(slice(None, None, None))
560-
else:
561-
pattern.append(counter)
562-
counter += 1
563-
new_args.append(arg)
564-
565-
pattern.extend(list(range(counter, self.ndim)))
566-
567-
view = self.dimshuffle(pattern)
568-
full_slices = True
569-
for arg in new_args:
570-
# We can't do arg == slice(None, None, None) as in
571-
# Python 2.7, this call __lt__ if we have a slice
572-
# with some symbolic variable.
573-
if not (
574-
isinstance(arg, slice)
575-
and (arg.start is None or arg.start is NoneConst)
576-
and (arg.stop is None or arg.stop is NoneConst)
577-
and (arg.step is None or arg.step is NoneConst)
578-
):
579-
full_slices = False
580-
if full_slices:
581-
return view
542+
# Handle newaxis (None) for both basic and advanced indexing
543+
if np.newaxis in args or NoneConst in args:
544+
# `np.newaxis` (i.e. `None`) in NumPy indexing mean "add a new
545+
# broadcastable dimension at this location". Since PyTensor adds
546+
# new broadcastable dimensions via the `DimShuffle` `Op`, the
547+
# following code uses said `Op` to add one of the new axes and
548+
# then uses recursion to apply any other indices and add any
549+
# remaining new axes.
550+
551+
counter = 0
552+
pattern = []
553+
new_args = []
554+
for arg in args:
555+
if arg is np.newaxis or arg is NoneConst:
556+
pattern.append("x")
557+
new_args.append(slice(None, None, None))
582558
else:
583-
return view.__getitem__(tuple(new_args))
559+
pattern.append(counter)
560+
counter += 1
561+
new_args.append(arg)
562+
563+
pattern.extend(list(range(counter, self.ndim)))
564+
565+
view = self.dimshuffle(pattern)
566+
full_slices = True
567+
for arg in new_args:
568+
# We can't do arg == slice(None, None, None) as in
569+
# Python 2.7, this call __lt__ if we have a slice
570+
# with some symbolic variable.
571+
if not (
572+
isinstance(arg, slice)
573+
and (arg.start is None or arg.start is NoneConst)
574+
and (arg.stop is None or arg.stop is NoneConst)
575+
and (arg.step is None or arg.step is NoneConst)
576+
):
577+
full_slices = False
578+
if full_slices:
579+
return view
584580
else:
585-
return pt.subtensor.Subtensor(args)(
586-
self,
587-
*pt.subtensor.get_slice_elements(
588-
args, lambda entry: isinstance(entry, Variable)
589-
),
590-
)
581+
return view.__getitem__(tuple(new_args))
582+
elif advanced:
583+
return pt.subtensor.advanced_subtensor(self, *args)
584+
else:
585+
return pt.subtensor.Subtensor(args)(
586+
self,
587+
*pt.subtensor.get_slice_elements(
588+
args, lambda entry: isinstance(entry, Variable)
589+
),
590+
)
591591

592592
def __setitem__(self, key, value):
593593
raise TypeError(

0 commit comments

Comments
 (0)