@@ -2738,10 +2738,42 @@ def is_bool_index(idx):
27382738 assert node .outputs [0 ].ndim == len (res_shape )
27392739 return [res_shape ]
27402740
2741+ def _broadcast_indices (self , x , indices ):
2742+ new_indices = []
2743+ x_dim = 0
2744+ for idx in indices :
2745+ if idx is None :
2746+ new_indices .append (idx )
2747+ continue
2748+ if isinstance (idx , slice ):
2749+ x_dim += 1
2750+ new_indices .append (idx )
2751+ continue
2752+
2753+ # Check for boolean
2754+ if hasattr (idx , "dtype" ) and (idx .dtype == bool or idx .dtype == np .bool_ ):
2755+ x_dim += idx .ndim
2756+ new_indices .append (idx )
2757+ continue
2758+
2759+ # Integer array
2760+ if x_dim < x .ndim and x .shape [x_dim ] == 1 :
2761+ # Broadcast: replace with zeros
2762+ new_indices .append (np .zeros_like (idx ))
2763+ else :
2764+ new_indices .append (idx )
2765+ x_dim += 1
2766+ return tuple (new_indices )
2767+
27412768 def perform (self , node , inputs , out_ ):
27422769 (out ,) = out_
2743- check_advanced_indexing_dimensions (inputs [0 ], inputs [1 :])
2744- rval = inputs [0 ].__getitem__ (tuple (inputs [1 :]))
2770+ x = inputs [0 ]
2771+ indices = inputs [1 :]
2772+
2773+ indices = self ._broadcast_indices (x , indices )
2774+
2775+ check_advanced_indexing_dimensions (x , indices )
2776+ rval = x .__getitem__ (tuple (indices ))
27452777 # When there are no arrays, we are not actually doing advanced
27462778 # indexing, so __getitem__ will not return a copy.
27472779 # Since no view_map is set, we need to copy the returned value
@@ -2807,6 +2839,97 @@ def non_consecutive_adv_indexing(node: Apply) -> bool:
28072839advanced_subtensor = AdvancedSubtensor ()
28082840
28092841
2842+ class BatchedSliceType (Type ):
2843+ def filter (self , x , strict = False , allow_downcast = None ):
2844+ return x
2845+
2846+ def __str__ (self ):
2847+ return "BatchedSliceType"
2848+
2849+
2850+ batched_slice_type = BatchedSliceType ()
2851+
2852+
2853+ class BatchedSlice (Op ):
2854+ __props__ = ()
2855+
2856+ def make_node (self , start , stop , step ):
2857+ return Apply (self , [start , stop , step ], [batched_slice_type ()])
2858+
2859+ def perform (self , node , inp , out_ ):
2860+ raise NotImplementedError ("BatchedSlice is a placeholder" )
2861+
2862+
2863+ @_vectorize_node .register (MakeSlice )
2864+ def vectorize_make_slice (op , node , * batched_inputs ):
2865+ is_batched = False
2866+ for orig , batched in zip (node .inputs , batched_inputs ):
2867+ if hasattr (batched .type , "ndim" ) and hasattr (orig .type , "ndim" ):
2868+ if batched .type .ndim > orig .type .ndim :
2869+ is_batched = True
2870+ break
2871+
2872+ if is_batched :
2873+ return BatchedSlice ().make_node (* batched_inputs )
2874+ return op .make_node (* batched_inputs )
2875+
2876+
2877+ class AdvancedIncSubtensorExplicit (Op ):
2878+ __props__ = ("structure" , "set_instead_of_inc" , "inplace" , "ignore_duplicates" )
2879+
2880+ def __init__ (
2881+ self ,
2882+ structure ,
2883+ set_instead_of_inc = False ,
2884+ inplace = False ,
2885+ ignore_duplicates = False ,
2886+ ):
2887+ self .structure = structure
2888+ self .set_instead_of_inc = set_instead_of_inc
2889+ self .inplace = inplace
2890+ self .ignore_duplicates = ignore_duplicates
2891+
2892+ def make_node (self , x , y , * inputs ):
2893+ return Apply (self , [x , y , * inputs ], [x .type ()])
2894+
2895+ def perform (self , node , inputs , out_ ):
2896+ x , y , * flat_indices = inputs
2897+
2898+ indices = []
2899+ idx_ptr = 0
2900+ for kind in self .structure :
2901+ if kind == "slice" :
2902+ (
2903+ start_val ,
2904+ start_none ,
2905+ stop_val ,
2906+ stop_none ,
2907+ step_val ,
2908+ step_none ,
2909+ ) = flat_indices [idx_ptr : idx_ptr + 6 ]
2910+ start = None if start_none else start_val
2911+ stop = None if stop_none else stop_val
2912+ step = None if step_none else step_val
2913+ indices .append (slice (start , stop , step ))
2914+ idx_ptr += 6
2915+ else :
2916+ indices .append (flat_indices [idx_ptr ])
2917+ idx_ptr += 1
2918+
2919+ (out ,) = out_
2920+ if not self .inplace :
2921+ out [0 ] = x .copy ()
2922+ else :
2923+ out [0 ] = x
2924+
2925+ if self .set_instead_of_inc :
2926+ out [0 ][tuple (indices )] = y
2927+ elif self .ignore_duplicates :
2928+ out [0 ][tuple (indices )] += y
2929+ else :
2930+ np .add .at (out [0 ], tuple (indices ), y )
2931+
2932+
28102933@_vectorize_node .register (AdvancedSubtensor )
28112934def vectorize_advanced_subtensor (op : AdvancedSubtensor , node , * batch_inputs ):
28122935 x , * idxs = node .inputs
@@ -2967,9 +3090,130 @@ def non_consecutive_adv_indexing(node: Apply) -> bool:
29673090advanced_inc_subtensor = AdvancedIncSubtensor ()
29683091advanced_set_subtensor = AdvancedIncSubtensor (set_instead_of_inc = True )
29693092advanced_inc_subtensor_nodup = AdvancedIncSubtensor (ignore_duplicates = True )
2970- advanced_set_subtensor_nodup = AdvancedIncSubtensor (
2971- set_instead_of_inc = True , ignore_duplicates = True
2972- )
3093+
3094+
3095+ @_vectorize_node .register (AdvancedIncSubtensor )
3096+ def advanced_inc_subtensor_vectorize_node (op , node , * batched_inputs ):
3097+ x = batched_inputs [0 ]
3098+ y = batched_inputs [1 ]
3099+ indices = batched_inputs [2 :]
3100+
3101+ # Check if we have batched slices (BatchedSliceType)
3102+ has_batched_slices = any (isinstance (idx .type , BatchedSliceType ) for idx in indices )
3103+
3104+ if has_batched_slices :
3105+ from pytensor .tensor .blockwise import Blockwise , safe_signature
3106+
3107+ structure = []
3108+ flat_inputs = []
3109+
3110+ # We need to construct inputs for AdvancedIncSubtensorExplicit
3111+ # x and y are passed as is (Blockwise handles them)
3112+
3113+ for idx in indices :
3114+ if isinstance (idx .type , BatchedSliceType ):
3115+ structure .append ("slice" )
3116+ # Unwrap BatchedSlice
3117+ # idx is a Variable output of BatchedSlice node
3118+ bs_node = idx .owner
3119+ assert isinstance (bs_node .op , BatchedSlice )
3120+
3121+ for comp in bs_node .inputs :
3122+ if isinstance (comp .type , NoneTypeT ):
3123+ # Pass dummy and True flag
3124+ flat_inputs .append (tensor_constant (0 , dtype = "int8" ))
3125+ flat_inputs .append (tensor_constant (True , dtype = "bool" ))
3126+ else :
3127+ # Pass component and False flag
3128+ flat_inputs .append (comp )
3129+ flat_inputs .append (tensor_constant (False , dtype = "bool" ))
3130+ else :
3131+ structure .append ("tensor" )
3132+ flat_inputs .append (idx )
3133+
3134+ core_op = AdvancedIncSubtensorExplicit (
3135+ structure = tuple (structure ),
3136+ set_instead_of_inc = op .set_instead_of_inc ,
3137+ inplace = op .inplace ,
3138+ ignore_duplicates = op .ignore_duplicates ,
3139+ )
3140+
3141+ # Signature
3142+ # x: (n, m, ...), y: (n, m, ...), indices... -> (n, m, ...)
3143+
3144+ x_core_ndim = node .inputs [0 ].ndim
3145+ y_core_ndim = node .inputs [1 ].ndim
3146+
3147+ input_core_ndims = [x_core_ndim , y_core_ndim ]
3148+
3149+ # For indices
3150+ for i , idx in enumerate (indices ):
3151+ if isinstance (idx .type , BatchedSliceType ):
3152+ # 6 components, all scalar (0-d) for slice parameters
3153+ input_core_ndims .extend ([0 ] * 6 )
3154+ else :
3155+ # Tensor index
3156+ # Core ndim is the ndim of the original index
3157+ input_core_ndims .append (node .inputs [2 + i ].ndim )
3158+
3159+ output_core_ndims = [node .outputs [0 ].ndim ]
3160+
3161+ signature = safe_signature (input_core_ndims , output_core_ndims )
3162+
3163+ return Blockwise (core_op , signature = signature ).make_node (x , y , * flat_inputs )
3164+
3165+ x_orig = node .inputs [0 ]
3166+ x_batch_ndim = x .ndim - x_orig .ndim
3167+
3168+ y_orig = node .inputs [1 ]
3169+ y_batch_ndim = y .ndim - y_orig .ndim
3170+
3171+ batch_ndim = max (x_batch_ndim , y_batch_ndim )
3172+
3173+ if batch_ndim == 0 :
3174+ return op .make_node (x , y , * indices )
3175+
3176+ if x_batch_ndim < batch_ndim :
3177+ # Broadcast x to match batch dimensions
3178+ # We assume the batch dimensions are the first batch_ndim dimensions
3179+ # and that y has them (since y_batch_ndim >= batch_ndim)
3180+ from pytensor .tensor import concatenate
3181+
3182+ batch_shape = y .shape [:batch_ndim ]
3183+ full_shape = concatenate ([batch_shape , x .shape ])
3184+ x = alloc (x , * full_shape )
3185+
3186+ # Check if any index is batched
3187+ any_batched_index = False
3188+ for i , idx in enumerate (indices ):
3189+ orig_idx = node .inputs [2 + i ]
3190+ if (
3191+ hasattr (idx , "ndim" )
3192+ and hasattr (orig_idx , "ndim" )
3193+ and idx .ndim > orig_idx .ndim
3194+ ):
3195+ any_batched_index = True
3196+ break
3197+
3198+ if not any_batched_index :
3199+ # Simple case: prepend slice(None) for each batch dim
3200+ sl = make_slice (None , None , None )
3201+ new_indices = [sl ] * batch_ndim + list (indices )
3202+ return op .make_node (x , y , * new_indices )
3203+
3204+ from pytensor .tensor import arange
3205+
3206+ batch_indices = []
3207+ for d in range (batch_ndim ):
3208+ dim_len = x .shape [d ]
3209+ idx = arange (dim_len )
3210+ pattern = ["x" ] * batch_ndim
3211+ pattern [d ] = 0
3212+ idx = idx .dimshuffle (pattern )
3213+ batch_indices .append (idx )
3214+
3215+ new_indices = batch_indices + list (indices )
3216+ return op .make_node (x , y , * new_indices )
29733217
29743218
29753219def take (a , indices , axis = None , mode = "raise" ):
0 commit comments