@@ -2612,16 +2612,12 @@ def make_node(self, x, *inputs):
26122612 if len (inputs ) != len (expected_inputs ):
26132613 raise ValueError (f"Expected { len (expected_inputs )} inputs but got { len (inputs )} " )
26142614
2615- # Build explicit_indices for shape inference
2615+ # Build explicit_indices for shape inference (newaxis handled by __getitem__)
26162616 explicit_indices = []
2617- new_axes = []
26182617 input_idx = 0
26192618
26202619 for i , entry in enumerate (idx_list ):
2621- if entry is np .newaxis :
2622- new_axes .append (len (explicit_indices ))
2623- explicit_indices .append (np .newaxis )
2624- elif isinstance (entry , slice ):
2620+ if isinstance (entry , slice ):
26252621 # Reconstruct slice with actual values from inputs
26262622 if entry .start is not None and isinstance (entry .start , Type ):
26272623 start_val = inputs [input_idx ]
@@ -2655,7 +2651,7 @@ def make_node(self, x, *inputs):
26552651 )
26562652
26572653 # Check static shape aligned
2658- axis = len (explicit_indices ) - len ( new_axes )
2654+ axis = len (explicit_indices )
26592655 indexed_shape = x .type .shape [axis : axis + inp .type .ndim ]
26602656 for j , (indexed_length , indexer_length ) in enumerate (
26612657 zip (indexed_shape , inp .type .shape )
@@ -2681,25 +2677,20 @@ def make_node(self, x, *inputs):
26812677 else :
26822678 raise ValueError (f"Invalid entry in idx_list: { entry } " )
26832679
2684- if ( len (explicit_indices ) - len ( new_axes ) ) > x .type .ndim :
2680+ if len (explicit_indices ) > x .type .ndim :
26852681 raise IndexError (
2686- f"too many indices for array: tensor is { x .type .ndim } -dimensional, but { len (explicit_indices ) - len ( new_axes ) } were indexed"
2682+ f"too many indices for array: tensor is { x .type .ndim } -dimensional, but { len (explicit_indices )} were indexed"
26872683 )
26882684
2689- # Perform basic and advanced indexing shape inference separately
2685+ # Perform basic and advanced indexing shape inference separately (no newaxis)
26902686 basic_group_shape = []
26912687 advanced_indices = []
26922688 adv_group_axis = None
26932689 last_adv_group_axis = None
2694- expanded_x_shape = tuple (
2695- np .insert (np .array (x .type .shape , dtype = object ), 1 , new_axes )
2696- )
26972690 for i , (idx , dim_length ) in enumerate (
2698- zip_longest (explicit_indices , expanded_x_shape , fillvalue = slice (None ))
2691+ zip_longest (explicit_indices , x . type . shape , fillvalue = slice (None ))
26992692 ):
2700- if idx is np .newaxis :
2701- basic_group_shape .append (1 ) # New-axis
2702- elif isinstance (idx , slice ):
2693+ if isinstance (idx , slice ):
27032694 basic_group_shape .append (slice_static_length (idx , dim_length ))
27042695 else : # TensorType (advanced index)
27052696 # Keep track of advanced group axis
@@ -2752,16 +2743,14 @@ def is_bool_index(idx):
27522743 or getattr (idx , "dtype" , None ) == "bool"
27532744 )
27542745
2755- # Reconstruct the full indices from idx_list and inputs (like perform method )
2746+ # Reconstruct the full indices from idx_list and inputs (newaxis handled by __getitem__ )
27562747 inputs = node .inputs [1 :]
27572748
27582749 full_indices = []
27592750 input_idx = 0
27602751
27612752 for entry in self .idx_list :
2762- if entry is np .newaxis :
2763- full_indices .append (np .newaxis )
2764- elif isinstance (entry , slice ):
2753+ if isinstance (entry , slice ):
27652754 # Reconstruct slice from idx_list and inputs
27662755 if entry .start is not None and isinstance (entry .start , Type ):
27672756 start_val = inputs [input_idx ]
@@ -2794,8 +2783,6 @@ def is_bool_index(idx):
27942783 for idx in full_indices :
27952784 if isinstance (idx , slice ):
27962785 index_shapes .append (idx )
2797- elif idx is np .newaxis :
2798- index_shapes .append (idx )
27992786 elif hasattr (idx , 'type' ):
28002787 # Mixed bool indexes are converted to nonzero entries
28012788 shape0_op = Shape_i (0 )
@@ -2837,17 +2824,15 @@ def is_bool_index(idx):
28372824 def perform (self , node , inputs , out_ ):
28382825 (out ,) = out_
28392826
2840- # Reconstruct the full tuple of indices from idx_list and inputs
2827+ # Reconstruct the full tuple of indices from idx_list and inputs (newaxis handled by __getitem__)
28412828 x = inputs [0 ]
28422829 tensor_inputs = inputs [1 :]
28432830
28442831 full_indices = []
28452832 input_idx = 0
28462833
28472834 for entry in self .idx_list :
2848- if entry is np .newaxis :
2849- full_indices .append (np .newaxis )
2850- elif isinstance (entry , slice ):
2835+ if isinstance (entry , slice ):
28512836 # Reconstruct slice from idx_list and inputs
28522837 if entry .start is not None and isinstance (entry .start , Type ):
28532838 start_val = tensor_inputs [input_idx ]
@@ -2938,7 +2923,7 @@ def non_consecutive_adv_indexing(node: Apply) -> bool:
29382923 bool
29392924 True if the advanced indexing is non-consecutive, False otherwise.
29402925 """
2941- # Reconstruct the full indices from idx_list and inputs to check consecutivity
2926+ # Reconstruct the full indices from idx_list and inputs to check consecutivity (newaxis handled by __getitem__)
29422927 op = node .op
29432928 tensor_inputs = node .inputs [1 :]
29442929
@@ -2948,8 +2933,6 @@ def non_consecutive_adv_indexing(node: Apply) -> bool:
29482933 for entry in op .idx_list :
29492934 if isinstance (entry , slice ):
29502935 full_indices .append (slice (None )) # Represent as basic slice
2951- elif entry is np .newaxis :
2952- full_indices .append (np .newaxis )
29532936 elif isinstance (entry , Type ):
29542937 # This is a numerical index - get from inputs
29552938 if input_idx < len (tensor_inputs ):
@@ -3035,14 +3018,12 @@ def make_node(self, x, y, *inputs):
30353018 def perform (self , node , inputs , out_ ):
30363019 x , y , * tensor_inputs = inputs
30373020
3038- # Reconstruct the full tuple of indices from idx_list and inputs
3021+ # Reconstruct the full tuple of indices from idx_list and inputs (newaxis handled by __getitem__)
30393022 full_indices = []
30403023 input_idx = 0
30413024
30423025 for entry in self .idx_list :
3043- if entry is np .newaxis :
3044- full_indices .append (np .newaxis )
3045- elif isinstance (entry , slice ):
3026+ if isinstance (entry , slice ):
30463027 # Reconstruct slice from idx_list and inputs
30473028 if entry .start is not None and isinstance (entry .start , Type ):
30483029 start_val = tensor_inputs [input_idx ]
@@ -3154,7 +3135,7 @@ def non_consecutive_adv_indexing(node: Apply) -> bool:
31543135 bool
31553136 True if the advanced indexing is non-consecutive, False otherwise.
31563137 """
3157- # Reconstruct the full indices from idx_list and inputs to check consecutivity
3138+ # Reconstruct the full indices from idx_list and inputs to check consecutivity (newaxis handled by __getitem__)
31583139 op = node .op
31593140 tensor_inputs = node .inputs [2 :] # Skip x and y
31603141
@@ -3164,8 +3145,6 @@ def non_consecutive_adv_indexing(node: Apply) -> bool:
31643145 for entry in op .idx_list :
31653146 if isinstance (entry , slice ):
31663147 full_indices .append (slice (None )) # Represent as basic slice
3167- elif entry is np .newaxis :
3168- full_indices .append (np .newaxis )
31693148 elif isinstance (entry , Type ):
31703149 # This is a numerical index - get from inputs
31713150 if input_idx < len (tensor_inputs ):
@@ -3180,6 +3159,9 @@ def advanced_subtensor(x, *args):
31803159
31813160 This function converts the arguments to work with the new AdvancedSubtensor
31823161 interface that separates slice structure from variable inputs.
3162+
3163+ Note: newaxis (None) should be handled by __getitem__ using dimshuffle
3164+ before calling this function.
31833165 """
31843166 # Convert args using as_index_variable (like original AdvancedSubtensor did)
31853167 processed_args = tuple (map (as_index_variable , args ))
@@ -3189,9 +3171,7 @@ def advanced_subtensor(x, *args):
31893171 input_vars = []
31903172
31913173 for arg in processed_args :
3192- if isinstance (arg .type , NoneTypeT ):
3193- idx_list .append (np .newaxis )
3194- elif isinstance (arg .type , SliceType ):
3174+ if isinstance (arg .type , SliceType ):
31953175 # Handle SliceType - extract components and structure
31963176 if isinstance (arg , Constant ):
31973177 # Constant slice
@@ -3218,15 +3198,19 @@ def advanced_subtensor(x, *args):
32183198 # Other slice case
32193199 idx_list .append (slice (None ))
32203200 else :
3221- # Tensor index
3201+ # Tensor index (should not be NoneType since newaxis handled in __getitem__)
32223202 idx_list .append (index_vars_to_types (arg ))
32233203 input_vars .append (arg )
32243204
32253205 return AdvancedSubtensor (idx_list ).make_node (x , * input_vars ).outputs [0 ]
32263206
32273207
32283208def advanced_inc_subtensor (x , y , * args , ** kwargs ):
3229- """Create an AdvancedIncSubtensor operation for incrementing."""
3209+ """Create an AdvancedIncSubtensor operation for incrementing.
3210+
3211+ Note: newaxis (None) should be handled by __getitem__ using dimshuffle
3212+ before calling this function.
3213+ """
32303214 # Convert args using as_index_variable (like original AdvancedIncSubtensor would)
32313215 processed_args = tuple (map (as_index_variable , args ))
32323216
@@ -3235,9 +3219,7 @@ def advanced_inc_subtensor(x, y, *args, **kwargs):
32353219 input_vars = []
32363220
32373221 for arg in processed_args :
3238- if isinstance (arg .type , NoneTypeT ):
3239- idx_list .append (np .newaxis )
3240- elif isinstance (arg .type , SliceType ):
3222+ if isinstance (arg .type , SliceType ):
32413223 # Handle SliceType - extract components and structure
32423224 if isinstance (arg , Constant ):
32433225 # Constant slice
@@ -3264,7 +3246,7 @@ def advanced_inc_subtensor(x, y, *args, **kwargs):
32643246 # Other slice case
32653247 idx_list .append (slice (None ))
32663248 else :
3267- # Tensor index
3249+ # Tensor index (should not be NoneType since newaxis handled in __getitem__)
32683250 idx_list .append (index_vars_to_types (arg ))
32693251 input_vars .append (arg )
32703252
0 commit comments