diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 2bd7c372f396f..08b7e2c7df305 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -616,7 +616,10 @@ def _make_launchers(self): except (OutOfResources, PTXASError, torch.cuda.OutOfMemoryError) as e: exc = e if len(launchers) == 0: - raise RuntimeError(f"No valid triton configs. {type(exc).__name__}: {exc}") + raise NoTritonConfigsError( + f"No valid triton configs. {type(exc).__name__}: {exc}" + ) + self.launchers = launchers def prepare_for_pickle(self) -> tuple[Any, Any, Any, Any, Any, Any]: @@ -2529,14 +2532,22 @@ def pointwise( ) ) # 20% improvement configs += [ - triton_config_with_settings(size_hints, 2048, num_warps=8, num_stages=2, waves_per_eu=1), # 20% improvement # .. in where? - triton_config_with_settings(size_hints, 4096), # wrt1: better than the max_block for some kernel - triton_config_with_settings(size_hints, 128, num_warps=2, num_stages=2, waves_per_eu=1), - # -> wrt1/t18: 2X improvement: triton_poi_fused_index_put_new_zeros_37, - # triton_poi_fused_index_put_new_zeros_45 + triton_config_with_settings( + size_hints, 2048, num_warps=8, num_stages=2, waves_per_eu=1 + ), # 20% improvement # .. in where? + triton_config_with_settings( + size_hints, 4096 + ), # wrt1: better than the max_block for some kernel + triton_config_with_settings( + size_hints, 128, num_warps=2, num_stages=2, waves_per_eu=1 + ), + # -> wrt1/t18: 2X improvement: triton_poi_fused_index_put_new_zeros_37, + # triton_poi_fused_index_put_new_zeros_45 # triton_poi_fused_index_put_new_zeros_49 # triton_poi_fused_index_put_new_zeros_54 - triton_config_with_settings(size_hints, 128, num_warps=1, num_stages=1), # wri0: 56 us: triton_poi_fused_cat_mul_sigmoid_view_51 + triton_config_with_settings( + size_hints, 128, num_warps=1, num_stages=1 + ), # wri0: 56 us: triton_poi_fused_cat_mul_sigmoid_view_51 ] if len(size_hints) == 2: # Only avoiding tuning on TileHint.SQUARE if not on ROCm builds @@ -2556,7 +2567,7 @@ def pointwise( size_hints, 64, 32 ), # better for some kernels triton_config_with_settings(size_hints, 64, 64), # ~8% better for fp16 - triton_config_with_settings(size_hints, 256, 16), + triton_config_with_settings(size_hints, 256, 16), triton_config_with_settings(size_hints, 16, 256), triton_config_with_settings( size_hints, 128, 16 @@ -2574,7 +2585,10 @@ def pointwise( """add 2D tiling configs, but don't use triton_config_with_settings function as it is buggy and might change the tiling randomly """ - def addConfig__(xblock:int, yblock:int, num_warps:int, num_stages:int): + + def addConfig__( + xblock: int, yblock: int, num_warps: int, num_stages: int + ): # only add a tiling config if size is bigger than the tile # check also for grid overflow xgrid = (size_hints["x"] + xblock - 1) // xblock @@ -2588,12 +2602,27 @@ def addConfig__(xblock:int, yblock:int, num_warps:int, num_stages:int): if size_hints["y"] < yblock: return # all good, add the config - configs.append(Config({"XBLOCK": xblock, "YBLOCK": yblock}, num_warps=num_warps, num_stages=num_stages)) - addConfig__(512, 8, 8,1 ) # wrt1/t21 # triton_poi_fused__unsafe_view_add_addmm_cat_clone_permute_split_with_sizes_view_19 - addConfig__(32, 128, 4, 1) # wrt2: 570us : triton_poi_fused_add_transpose_view_52 - addConfig__(64, 32, 8, 1) # wrt3: 150us: triton_poi_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_permute_view_103 - addConfig__(64, 256, 4, 1) # wri0: 70us: triton_poi_fused_clone_tanh_transpose_19 - addConfig__(512, 64, 8, 1) # wri0: 58us: triton_poi_fused_clone_53 + configs.append( + Config( + {"XBLOCK": xblock, "YBLOCK": yblock}, + num_warps=num_warps, + num_stages=num_stages, + ) + ) + + addConfig__( + 512, 8, 8, 1 + ) # wrt1/t21 # triton_poi_fused__unsafe_view_add_addmm_cat_clone_permute_split_with_sizes_view_19 + addConfig__( + 32, 128, 4, 1 + ) # wrt2: 570us : triton_poi_fused_add_transpose_view_52 + addConfig__( + 64, 32, 8, 1 + ) # wrt3: 150us: triton_poi_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_permute_view_103 + addConfig__( + 64, 256, 4, 1 + ) # wri0: 70us: triton_poi_fused_clone_tanh_transpose_19 + addConfig__(512, 64, 8, 1) # wri0: 58us: triton_poi_fused_clone_53 if len(size_hints) == 3: if disable_pointwise_autotuning(inductor_meta): @@ -2801,13 +2830,23 @@ def outer_config_opt(): [ make_config(1024, 8, num_warps=4, num_stages=1, waves_per_eu=2), make_config(512, 8, num_warps=4, num_stages=1, waves_per_eu=1), - make_config(128, 4, num_warps=2, num_stages=1, waves_per_eu=1), # wrt2: 3X # triton_red_fused_index_add_index_select_mul_native_layer_norm_native_layer_norm_backward_new_zeros_sigmoid_8 - make_config(1, 512, num_warps=8, num_stages=1, waves_per_eu=1), # wrt2: 2X # triton_red_fused_index_add_index_select_mul_native_layer_norm_native_layer_norm_backward_new_zeros_sigmoid_8-v2 & v3 & v4 - make_config(1, 4096, num_warps=8, num_stages=1, waves_per_eu=1), # wrt3: 380 us # triton_red_fused__to_copy__unsafe_view_clone_native_layer_norm_native_layer_norm_backward_slice_tanh_tanh_backward_153 - make_config(64, 128, num_warps=4, num_stages=1, waves_per_eu=1), # wrt3: 170 us # triton_red_fused__to_copy__unsafe_view_add_addmm_cat_clone_native_layer_norm_permute_tanh_view_16 - make_config(2, 2048, num_warps=8, num_stages=1, waves_per_eu=1) # wrt3: 170 us # triton_red_fused__to_copy__unsafe_view_clone_native_layer_norm_native_layer_norm_backward_permute_tanh_tanh_backward_29 - ] - ) + make_config( + 128, 4, num_warps=2, num_stages=1, waves_per_eu=1 + ), # wrt2: 3X # triton_red_fused_index_add_index_select_mul_native_layer_norm_native_layer_norm_backward_new_zeros_sigmoid_8 + make_config( + 1, 512, num_warps=8, num_stages=1, waves_per_eu=1 + ), # wrt2: 2X # triton_red_fused_index_add_index_select_mul_native_layer_norm_native_layer_norm_backward_new_zeros_sigmoid_8-v2 & v3 & v4 + make_config( + 1, 4096, num_warps=8, num_stages=1, waves_per_eu=1 + ), # wrt3: 380 us # triton_red_fused__to_copy__unsafe_view_clone_native_layer_norm_native_layer_norm_backward_slice_tanh_tanh_backward_153 + make_config( + 64, 128, num_warps=4, num_stages=1, waves_per_eu=1 + ), # wrt3: 170 us # triton_red_fused__to_copy__unsafe_view_add_addmm_cat_clone_native_layer_norm_permute_tanh_view_16 + make_config( + 2, 2048, num_warps=8, num_stages=1, waves_per_eu=1 + ), # wrt3: 170 us # triton_red_fused__to_copy__unsafe_view_clone_native_layer_norm_native_layer_norm_backward_permute_tanh_tanh_backward_29 + ] + ) return result_configs @@ -3257,7 +3296,7 @@ def foreach(triton_meta, filename=None, inductor_meta=None): Compile a triton foreach kernel """ configs = [] - + # Naive autotuning path for num_warps if disable_pointwise_autotuning(inductor_meta) and not ( inductor_meta.get("max_autotune") or inductor_meta.get("max_autotune_pointwise") diff --git a/torch/_inductor/sizevars.py b/torch/_inductor/sizevars.py index 8727777b562b2..6b7b32c77472e 100644 --- a/torch/_inductor/sizevars.py +++ b/torch/_inductor/sizevars.py @@ -74,6 +74,8 @@ def __init__(self, shape_env=None) -> None: shape_env = ShapeEnv() self.shape_env = shape_env self.var_to_val = self.shape_env.var_to_val + # var_to_hint_override may not exist in older PyTorch versions + self.var_to_hint_override = getattr(self.shape_env, "var_to_hint_override", {}) self.replacements: dict[sympy.Symbol, Expr] = self.shape_env.replacements self.unbacked_replacements: Optional[dict[Expr, Expr]] = None # Maps of dynamic sizes that have to be precomputed on the host to the kernel args. @@ -573,7 +575,10 @@ def size_hint( fallback: Optional[int] = None, hint_override: Optional[int] = None, ) -> int: - out = self.symbolic_hint(expr, hint_override=hint_override) + out = self.symbolic_hint( + expr, + hint_override=hint_override, + ) if not isinstance(out, (int, sympy.Integer)) and fallback is not None: # Use the provided heuristic fallback hint unbacked_sym_vrs = { @@ -581,6 +586,29 @@ def size_hint( } if all(vr is not None for vr in unbacked_sym_vrs.values()): hint_vr = bound_sympy(out, unbacked_sym_vrs) # type: ignore[arg-type] + # For expressions like `768*u0`, we need to substitute unbacked symints + # with their hinted values (upper bound clamped by fallback) and evaluate, + # rather than just returning the clamped fallback directly. + # This ensures strides like `768*u0` evaluate to `768*128=98304` rather than `8192`. + unbacked_hints = {} + for s, vr in unbacked_sym_vrs.items(): + if vr is not None: + sym_fallback = fallback + if isinstance(vr.lower, (int, sympy.Integer)): + sym_fallback = max(sym_fallback, int(vr.lower)) + if isinstance(vr.upper, (int, sympy.Integer)): + sym_fallback = min(sym_fallback, int(vr.upper)) + unbacked_hints[s] = sym_fallback + + if unbacked_hints: + # Substitute unbacked symints with their hinted values and evaluate + substituted = sympy_subs(out, unbacked_hints) + try: + return int(substituted) + except (TypeError, ValueError): + pass # Fall through to old behavior if substitution didn't work + + # Fallback to old behavior: clamp fallback to expression bounds if isinstance(hint_vr.lower, (int, sympy.Integer)): fallback = max(fallback, int(hint_vr.lower)) if isinstance(hint_vr.upper, (int, sympy.Integer)): diff --git a/torch/_inductor/template_heuristics/triton.py b/torch/_inductor/template_heuristics/triton.py index a7f4d9f5763ff..189739ce1532c 100644 --- a/torch/_inductor/template_heuristics/triton.py +++ b/torch/_inductor/template_heuristics/triton.py @@ -565,23 +565,48 @@ def _get_exceeding_shared_memory_checker( Returns a function that checks whether a given configuration exceeds the available shared memory for the device. If the device does not report available shared memory, returns None. """ + sm_available = 0 try: device = torch.cuda.current_device() props = torch.cuda.get_device_properties(device) - if not hasattr(props, "shared_memory_per_block_optin"): # for NVidia GPUs - return None - sm_available = int(props.shared_memory_per_block_optin) + + if hasattr(props, "shared_memory_per_block_optin"): + sm_available = int(props.shared_memory_per_block_optin) + + # AMD/HIP path + elif torch.version.hip: + from ..utils import get_gpu_shared_memory + + # Try Triton's driver API first for AMD + try: + sm_available = get_gpu_shared_memory() + except Exception: + pass + + # If Triton API fails, fallback to architecture-based detection + # (same logic as torch/cuda/_utils.py) + # TODO : Can expose shared_memory through device props and use it instead of hardcoding + if sm_available <= 0: + # navi, CDNA1-CDNA3 allows a max of 64KB shared memory + # CDNA4 (gfx950) allows a max of 160KB shared memory + gcn_arch = getattr(props, "gcnArchName", "") + sm_available = 65536 if gcn_arch != "gfx950" else 160 * 1024 + except Exception: # If CUDA is not available or properties cannot be queried, return None return None + if sm_available <= 0: + return None + # TODO make a BaseDeviceConfigHeuristics to handle different device configuration in its own implementation. def exceeds(gemm_config: BaseConfig, dtype_size: int) -> bool: shared_mem_accum = dtype_size * ( gemm_config.block_m * gemm_config.block_k + gemm_config.block_n * gemm_config.block_k ) + assert sm_available is not None return shared_mem_accum * gemm_config.num_stages > sm_available return exceeds