Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 62 additions & 23 deletions torch/_inductor/runtime/triton_heuristics.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,7 +616,10 @@ def _make_launchers(self):
except (OutOfResources, PTXASError, torch.cuda.OutOfMemoryError) as e:
exc = e
if len(launchers) == 0:
raise RuntimeError(f"No valid triton configs. {type(exc).__name__}: {exc}")
raise NoTritonConfigsError(
f"No valid triton configs. {type(exc).__name__}: {exc}"
)

self.launchers = launchers

def prepare_for_pickle(self) -> tuple[Any, Any, Any, Any, Any, Any]:
Expand Down Expand Up @@ -2529,14 +2532,22 @@ def pointwise(
)
) # 20% improvement
configs += [
triton_config_with_settings(size_hints, 2048, num_warps=8, num_stages=2, waves_per_eu=1), # 20% improvement # .. in where?
triton_config_with_settings(size_hints, 4096), # wrt1: better than the max_block for some kernel
triton_config_with_settings(size_hints, 128, num_warps=2, num_stages=2, waves_per_eu=1),
# -> wrt1/t18: 2X improvement: triton_poi_fused_index_put_new_zeros_37,
# triton_poi_fused_index_put_new_zeros_45
triton_config_with_settings(
size_hints, 2048, num_warps=8, num_stages=2, waves_per_eu=1
), # 20% improvement # .. in where?
triton_config_with_settings(
size_hints, 4096
), # wrt1: better than the max_block for some kernel
triton_config_with_settings(
size_hints, 128, num_warps=2, num_stages=2, waves_per_eu=1
),
# -> wrt1/t18: 2X improvement: triton_poi_fused_index_put_new_zeros_37,
# triton_poi_fused_index_put_new_zeros_45
# triton_poi_fused_index_put_new_zeros_49
# triton_poi_fused_index_put_new_zeros_54
triton_config_with_settings(size_hints, 128, num_warps=1, num_stages=1), # wri0: 56 us: triton_poi_fused_cat_mul_sigmoid_view_51
triton_config_with_settings(
size_hints, 128, num_warps=1, num_stages=1
), # wri0: 56 us: triton_poi_fused_cat_mul_sigmoid_view_51
]
if len(size_hints) == 2:
# Only avoiding tuning on TileHint.SQUARE if not on ROCm builds
Expand All @@ -2556,7 +2567,7 @@ def pointwise(
size_hints, 64, 32
), # better for some kernels
triton_config_with_settings(size_hints, 64, 64), # ~8% better for fp16
triton_config_with_settings(size_hints, 256, 16),
triton_config_with_settings(size_hints, 256, 16),
triton_config_with_settings(size_hints, 16, 256),
triton_config_with_settings(
size_hints, 128, 16
Expand All @@ -2574,7 +2585,10 @@ def pointwise(
"""add 2D tiling configs, but don't use triton_config_with_settings function
as it is buggy and might change the tiling randomly
"""
def addConfig__(xblock:int, yblock:int, num_warps:int, num_stages:int):

def addConfig__(
xblock: int, yblock: int, num_warps: int, num_stages: int
):
# only add a tiling config if size is bigger than the tile
# check also for grid overflow
xgrid = (size_hints["x"] + xblock - 1) // xblock
Expand All @@ -2588,12 +2602,27 @@ def addConfig__(xblock:int, yblock:int, num_warps:int, num_stages:int):
if size_hints["y"] < yblock:
return
# all good, add the config
configs.append(Config({"XBLOCK": xblock, "YBLOCK": yblock}, num_warps=num_warps, num_stages=num_stages))
addConfig__(512, 8, 8,1 ) # wrt1/t21 # triton_poi_fused__unsafe_view_add_addmm_cat_clone_permute_split_with_sizes_view_19
addConfig__(32, 128, 4, 1) # wrt2: 570us : triton_poi_fused_add_transpose_view_52
addConfig__(64, 32, 8, 1) # wrt3: 150us: triton_poi_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_permute_view_103
addConfig__(64, 256, 4, 1) # wri0: 70us: triton_poi_fused_clone_tanh_transpose_19
addConfig__(512, 64, 8, 1) # wri0: 58us: triton_poi_fused_clone_53
configs.append(
Config(
{"XBLOCK": xblock, "YBLOCK": yblock},
num_warps=num_warps,
num_stages=num_stages,
)
)

addConfig__(
512, 8, 8, 1
) # wrt1/t21 # triton_poi_fused__unsafe_view_add_addmm_cat_clone_permute_split_with_sizes_view_19
addConfig__(
32, 128, 4, 1
) # wrt2: 570us : triton_poi_fused_add_transpose_view_52
addConfig__(
64, 32, 8, 1
) # wrt3: 150us: triton_poi_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_permute_view_103
addConfig__(
64, 256, 4, 1
) # wri0: 70us: triton_poi_fused_clone_tanh_transpose_19
addConfig__(512, 64, 8, 1) # wri0: 58us: triton_poi_fused_clone_53

if len(size_hints) == 3:
if disable_pointwise_autotuning(inductor_meta):
Expand Down Expand Up @@ -2801,13 +2830,23 @@ def outer_config_opt():
[
make_config(1024, 8, num_warps=4, num_stages=1, waves_per_eu=2),
make_config(512, 8, num_warps=4, num_stages=1, waves_per_eu=1),
make_config(128, 4, num_warps=2, num_stages=1, waves_per_eu=1), # wrt2: 3X # triton_red_fused_index_add_index_select_mul_native_layer_norm_native_layer_norm_backward_new_zeros_sigmoid_8
make_config(1, 512, num_warps=8, num_stages=1, waves_per_eu=1), # wrt2: 2X # triton_red_fused_index_add_index_select_mul_native_layer_norm_native_layer_norm_backward_new_zeros_sigmoid_8-v2 & v3 & v4
make_config(1, 4096, num_warps=8, num_stages=1, waves_per_eu=1), # wrt3: 380 us # triton_red_fused__to_copy__unsafe_view_clone_native_layer_norm_native_layer_norm_backward_slice_tanh_tanh_backward_153
make_config(64, 128, num_warps=4, num_stages=1, waves_per_eu=1), # wrt3: 170 us # triton_red_fused__to_copy__unsafe_view_add_addmm_cat_clone_native_layer_norm_permute_tanh_view_16
make_config(2, 2048, num_warps=8, num_stages=1, waves_per_eu=1) # wrt3: 170 us # triton_red_fused__to_copy__unsafe_view_clone_native_layer_norm_native_layer_norm_backward_permute_tanh_tanh_backward_29
]
)
make_config(
128, 4, num_warps=2, num_stages=1, waves_per_eu=1
), # wrt2: 3X # triton_red_fused_index_add_index_select_mul_native_layer_norm_native_layer_norm_backward_new_zeros_sigmoid_8
make_config(
1, 512, num_warps=8, num_stages=1, waves_per_eu=1
), # wrt2: 2X # triton_red_fused_index_add_index_select_mul_native_layer_norm_native_layer_norm_backward_new_zeros_sigmoid_8-v2 & v3 & v4
make_config(
1, 4096, num_warps=8, num_stages=1, waves_per_eu=1
), # wrt3: 380 us # triton_red_fused__to_copy__unsafe_view_clone_native_layer_norm_native_layer_norm_backward_slice_tanh_tanh_backward_153
make_config(
64, 128, num_warps=4, num_stages=1, waves_per_eu=1
), # wrt3: 170 us # triton_red_fused__to_copy__unsafe_view_add_addmm_cat_clone_native_layer_norm_permute_tanh_view_16
make_config(
2, 2048, num_warps=8, num_stages=1, waves_per_eu=1
), # wrt3: 170 us # triton_red_fused__to_copy__unsafe_view_clone_native_layer_norm_native_layer_norm_backward_permute_tanh_tanh_backward_29
]
)

return result_configs

Expand Down Expand Up @@ -3257,7 +3296,7 @@ def foreach(triton_meta, filename=None, inductor_meta=None):
Compile a triton foreach kernel
"""
configs = []

# Naive autotuning path for num_warps
if disable_pointwise_autotuning(inductor_meta) and not (
inductor_meta.get("max_autotune") or inductor_meta.get("max_autotune_pointwise")
Expand Down
30 changes: 29 additions & 1 deletion torch/_inductor/sizevars.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ def __init__(self, shape_env=None) -> None:
shape_env = ShapeEnv()
self.shape_env = shape_env
self.var_to_val = self.shape_env.var_to_val
# var_to_hint_override may not exist in older PyTorch versions
self.var_to_hint_override = getattr(self.shape_env, "var_to_hint_override", {})
self.replacements: dict[sympy.Symbol, Expr] = self.shape_env.replacements
self.unbacked_replacements: Optional[dict[Expr, Expr]] = None
# Maps of dynamic sizes that have to be precomputed on the host to the kernel args.
Expand Down Expand Up @@ -573,14 +575,40 @@ def size_hint(
fallback: Optional[int] = None,
hint_override: Optional[int] = None,
) -> int:
out = self.symbolic_hint(expr, hint_override=hint_override)
out = self.symbolic_hint(
expr,
hint_override=hint_override,
)
if not isinstance(out, (int, sympy.Integer)) and fallback is not None:
# Use the provided heuristic fallback hint
unbacked_sym_vrs = {
s: self.shape_env.var_to_range.get(s, None) for s in out.free_symbols
}
if all(vr is not None for vr in unbacked_sym_vrs.values()):
hint_vr = bound_sympy(out, unbacked_sym_vrs) # type: ignore[arg-type]
# For expressions like `768*u0`, we need to substitute unbacked symints
# with their hinted values (upper bound clamped by fallback) and evaluate,
# rather than just returning the clamped fallback directly.
# This ensures strides like `768*u0` evaluate to `768*128=98304` rather than `8192`.
unbacked_hints = {}
for s, vr in unbacked_sym_vrs.items():
if vr is not None:
sym_fallback = fallback
if isinstance(vr.lower, (int, sympy.Integer)):
sym_fallback = max(sym_fallback, int(vr.lower))
if isinstance(vr.upper, (int, sympy.Integer)):
sym_fallback = min(sym_fallback, int(vr.upper))
unbacked_hints[s] = sym_fallback

if unbacked_hints:
# Substitute unbacked symints with their hinted values and evaluate
substituted = sympy_subs(out, unbacked_hints)
try:
return int(substituted)
except (TypeError, ValueError):
pass # Fall through to old behavior if substitution didn't work

# Fallback to old behavior: clamp fallback to expression bounds
if isinstance(hint_vr.lower, (int, sympy.Integer)):
fallback = max(fallback, int(hint_vr.lower))
if isinstance(hint_vr.upper, (int, sympy.Integer)):
Expand Down
31 changes: 28 additions & 3 deletions torch/_inductor/template_heuristics/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,23 +565,48 @@ def _get_exceeding_shared_memory_checker(
Returns a function that checks whether a given configuration exceeds the available shared memory for the device.
If the device does not report available shared memory, returns None.
"""
sm_available = 0

try:
device = torch.cuda.current_device()
props = torch.cuda.get_device_properties(device)
if not hasattr(props, "shared_memory_per_block_optin"): # for NVidia GPUs
return None
sm_available = int(props.shared_memory_per_block_optin)

if hasattr(props, "shared_memory_per_block_optin"):
sm_available = int(props.shared_memory_per_block_optin)

# AMD/HIP path
elif torch.version.hip:
from ..utils import get_gpu_shared_memory

# Try Triton's driver API first for AMD
try:
sm_available = get_gpu_shared_memory()
except Exception:
pass

# If Triton API fails, fallback to architecture-based detection
# (same logic as torch/cuda/_utils.py)
# TODO : Can expose shared_memory through device props and use it instead of hardcoding
if sm_available <= 0:
# navi, CDNA1-CDNA3 allows a max of 64KB shared memory
# CDNA4 (gfx950) allows a max of 160KB shared memory
gcn_arch = getattr(props, "gcnArchName", "")
sm_available = 65536 if gcn_arch != "gfx950" else 160 * 1024

except Exception:
# If CUDA is not available or properties cannot be queried, return None
return None

if sm_available <= 0:
return None

# TODO make a BaseDeviceConfigHeuristics to handle different device configuration in its own implementation.
def exceeds(gemm_config: BaseConfig, dtype_size: int) -> bool:
shared_mem_accum = dtype_size * (
gemm_config.block_m * gemm_config.block_k
+ gemm_config.block_n * gemm_config.block_k
)
assert sm_available is not None
return shared_mem_accum * gemm_config.num_stages > sm_available

return exceeds
Expand Down