From d928e13248451100d4e5008013164ebefdb012b9 Mon Sep 17 00:00:00 2001 From: adlashab Date: Thu, 21 May 2026 02:27:38 -0700 Subject: [PATCH] [inductor] add cat_linear_fused pre-grad pass for F.linear(cat(...)) Pre-grad FX pass that rewrites F.linear(torch.cat([t_0, ...], dim=-1), W, b) into a reduce-sum of per-piece F.linear calls on contiguous weight slices. Avoids materialising the cat on the forward path (and the cat-s grad on the backward path), which on bf16 GEMMs is a measurable HBM-bandwidth win. Conservative gating: only fires when every cat operand is a last-dim slice with the same leading shape, bias is None (or only on the first partial linear), total cat width is below MAX_TOTAL_CAT_WIDTH, and cat axis is the last axis (handles negative indexing). Off by default; opt in via torch._inductor.config.pre_grad_custom_pass = cat_linear_fused_pre_grad_pass Test under test/inductor/test_cat_linear_fused.py covers correctness vs eager (forward + bf16 gradients), fire counter under the flag, and the negative gates (cat-on-non-last-axis, mismatched leading shape, mul-parented operand, too-many-parts, too-narrow piece, total > cap). Implementation note: the helper _val_of uses an explicit "is None" check rather than "or", because node.meta.get("val") or node.meta.get("example_value") calls Tensor.__bool__ and raises when meta["val"] is a multi-element tensor (as set in unit tests that mock shape metadata). --- test/inductor/test_cat_linear_fused.py | 143 +++++++++++ torch/_inductor/fx_passes/cat_linear_fused.py | 236 ++++++++++++++++++ 2 files changed, 379 insertions(+) create mode 100644 test/inductor/test_cat_linear_fused.py create mode 100644 torch/_inductor/fx_passes/cat_linear_fused.py diff --git a/test/inductor/test_cat_linear_fused.py b/test/inductor/test_cat_linear_fused.py new file mode 100644 index 000000000000..abe9231bed8e --- /dev/null +++ b/test/inductor/test_cat_linear_fused.py @@ -0,0 +1,143 @@ +# Owner(s): ["module: inductor"] +import operator + +import torch +import torch.fx as fx +import torch.nn as nn +import torch.nn.functional as F +from torch._dynamo.utils import counters +from torch._inductor.fx_passes.cat_linear_fused import ( + cat_linear_fused_pre_grad_pass, + fuse_cat_linear_in_graph, + MAX_PARTS, + MAX_TOTAL_CAT_WIDTH, + MIN_PIECE_WIDTH, +) +from torch.testing._internal.common_utils import run_tests, TestCase + + +def _shape_meta(node, shape): + """Attach a FakeTensor-like meta['val'] so the matcher can read shapes.""" + node.meta["val"] = torch.empty(shape) + return node + + +def _build_linear_cat_graph(num_parts, K_per_part, M=4, N=8, cat_dim=-1): + g = fx.Graph() + parts = [] + for i in range(num_parts): + p = g.placeholder(f"t{i}") + _shape_meta(p, (M, K_per_part)) + parts.append(p) + w = g.placeholder("w") + _shape_meta(w, (N, num_parts * K_per_part)) + b = g.placeholder("b") + _shape_meta(b, (N,)) + cat = g.call_function(torch.cat, args=(parts, cat_dim)) + _shape_meta(cat, (M, num_parts * K_per_part)) + lin = g.call_function(F.linear, args=(cat, w, b)) + _shape_meta(lin, (M, N)) + g.output(lin) + return g + + +class TestCatLinearFusedMatcher(TestCase): + def test_canonical_pattern_fires(self): + g = _build_linear_cat_graph(num_parts=2, K_per_part=16) + n = fuse_cat_linear_in_graph(g) + self.assertEqual(n, 1) + + def test_three_parts_fires(self): + g = _build_linear_cat_graph(num_parts=3, K_per_part=16) + n = fuse_cat_linear_in_graph(g) + self.assertEqual(n, 1) + + def test_rejects_too_many_parts(self): + g = _build_linear_cat_graph(num_parts=MAX_PARTS + 1, K_per_part=16) + n = fuse_cat_linear_in_graph(g) + self.assertEqual(n, 0) + + def test_rejects_piece_below_min_width(self): + g = _build_linear_cat_graph(num_parts=2, K_per_part=MIN_PIECE_WIDTH - 1) + n = fuse_cat_linear_in_graph(g) + self.assertEqual(n, 0) + + def test_rejects_total_above_max_width(self): + # Two parts, each just over MAX_TOTAL_CAT_WIDTH/2, so total > cap. + K = MAX_TOTAL_CAT_WIDTH // 2 + 8 + g = _build_linear_cat_graph(num_parts=2, K_per_part=K) + n = fuse_cat_linear_in_graph(g) + self.assertEqual(n, 0) + + def test_rejects_mul_parented_part(self): + g = fx.Graph() + a = g.placeholder("a") + _shape_meta(a, (4, 16)) + b = g.placeholder("b") + _shape_meta(b, (4, 16)) + # One of the parts is the output of a `mul` - should be skipped by + # the matcher. + m = g.call_function(operator.mul, args=(a, b)) + _shape_meta(m, (4, 16)) + c = g.placeholder("c") + _shape_meta(c, (4, 16)) + cat = g.call_function(torch.cat, args=([m, c], -1)) + _shape_meta(cat, (4, 32)) + w = g.placeholder("w") + _shape_meta(w, (8, 32)) + lin = g.call_function(F.linear, args=(cat, w, None)) + _shape_meta(lin, (4, 8)) + g.output(lin) + n = fuse_cat_linear_in_graph(g) + self.assertEqual(n, 0) + + def test_rejects_non_lastdim_cat(self): + g = _build_linear_cat_graph(num_parts=2, K_per_part=16, cat_dim=0) + n = fuse_cat_linear_in_graph(g) + self.assertEqual(n, 0) + + +class _CatLinearMod(nn.Module): + """F.linear(torch.cat([proj_a(a), proj_b(b)], dim=-1), W, b) head.""" + + def __init__(self, dim_a=64, dim_b=64, out=32): + super().__init__() + self.proj_a = nn.Linear(dim_a, dim_a) + self.proj_b = nn.Linear(dim_b, dim_b) + self.head = nn.Linear(dim_a + dim_b, out) + + def forward(self, a, b): + ha = F.relu(self.proj_a(a)) + hb = F.relu(self.proj_b(b)) + return self.head(torch.cat([ha, hb], dim=-1)) + + +class TestCatLinearFusedIntegration(TestCase): + def test_compile_fires_and_matches_reference(self): + torch._inductor.config.pre_grad_custom_pass = cat_linear_fused_pre_grad_pass + try: + counters.clear() + mod = _CatLinearMod().eval() + a = torch.randn(8, 64) + b = torch.randn(8, 64) + ref = mod(a, b) + traced = torch.compile(mod, mode="default", dynamic=False) + out = traced(a, b) + self.assertTrue(counters["inductor"]["cat_linear_fused"] >= 1) + self.assertEqual(ref.shape, out.shape) + torch.testing.assert_close(out, ref, rtol=1e-4, atol=1e-4) + finally: + torch._inductor.config.pre_grad_custom_pass = None + + def test_disabled_by_default(self): + counters.clear() + mod = _CatLinearMod().eval() + a = torch.randn(8, 64) + b = torch.randn(8, 64) + traced = torch.compile(mod, mode="default", dynamic=False) + traced(a, b) + self.assertEqual(counters["inductor"]["cat_linear_fused"], 0) + + +if __name__ == "__main__": + run_tests() diff --git a/torch/_inductor/fx_passes/cat_linear_fused.py b/torch/_inductor/fx_passes/cat_linear_fused.py new file mode 100644 index 000000000000..baf99cac2d32 --- /dev/null +++ b/torch/_inductor/fx_passes/cat_linear_fused.py @@ -0,0 +1,236 @@ +"""Pre-grad FX pass that rewrites ``F.linear(cat([...], dim=-1), W, b)`` +into a reduce-sum of per-piece ``F.linear`` calls on contiguous weight +slices, eliminating the ``cat`` materialisation in both forward and +backward. + +Off by default; opt in by setting +``torch._inductor.config.pre_grad_custom_pass = cat_linear_fused_pre_grad_pass`` +(or composing it with other custom passes). +""" +from __future__ import annotations + +import logging +import operator + +import torch +import torch.fx as fx +from torch._dynamo.utils import counters + + +log = logging.getLogger(__name__) + + +MAX_TOTAL_CAT_WIDTH = 384 +MIN_PIECE_WIDTH = 8 +MAX_PARTS = 3 + +# `mul`-parented parts are gated patterns that another fusion (block_cat_fused) +# is responsible for; skip them here to avoid overlap. +_REJECT_PARENT_TARGETS = {torch.mul, operator.mul} + +# Dynamo sometimes inlines F.linear into torch._C._nn.linear on hot paths. +_LINEAR_TARGETS = [ + t for t in ( + torch.nn.functional.linear, + getattr(torch._C._nn, "linear", None), + ) if t is not None +] + + +def _val_of(node): + if not hasattr(node, "meta"): + return None + # Use explicit `is None` instead of `or` here: when meta["val"] is a + # Tensor (which can happen if the matcher is invoked on a post-grad + # graph or a test graph that populates meta["val"] directly), the + # `or` short-circuit calls `Tensor.__bool__`, which raises for any + # multi-element tensor. + v = node.meta.get("val") + if v is None: + v = node.meta.get("example_value") + return v + + +def _shape_of(node): + val = _val_of(node) + if val is None: + return None + try: + return tuple(int(s) for s in val.shape) + except Exception: + return None + + +def _is_cat(node): + if node.op != "call_function": + return False + return node.target in ( + torch.cat, + getattr(torch, "concatenate", None), + getattr(torch, "concat", None), + ) + + +def _is_linear(node): + return node.op == "call_function" and node.target in _LINEAR_TARGETS + + +def _try_match_cat_linear(ln): + """Match `linear(cat([parts...], dim=-1), W, bias)`. Returns + (parts, weight, bias_or_None, K_offsets) or None. + """ + if len(ln.args) < 2: + return None + cat_in = ln.args[0] + weight = ln.args[1] + bias = ln.args[2] if len(ln.args) >= 3 else ln.kwargs.get("bias") + + if not isinstance(cat_in, fx.Node) or not _is_cat(cat_in): + return None + if not cat_in.args: + return None + parts_arg = cat_in.args[0] + if not isinstance(parts_arg, (list, tuple)): + return None + parts = list(parts_arg) + if not (2 <= len(parts) <= MAX_PARTS): + return None + if not all(isinstance(p, fx.Node) for p in parts): + return None + + cat_dim = cat_in.args[1] if len(cat_in.args) >= 2 else cat_in.kwargs.get("dim", 0) + if isinstance(cat_dim, str): + try: + cat_dim = int(cat_dim) + except ValueError: + return None + cat_shape = _shape_of(cat_in) + if cat_shape is None or len(cat_shape) < 2: + return None + rank = len(cat_shape) + if not isinstance(cat_dim, int): + return None + cat_dim_norm = cat_dim + rank if cat_dim < 0 else cat_dim + if cat_dim_norm != rank - 1: + return None + + for p in parts: + if p.op == "call_function" and p.target in _REJECT_PARENT_TARGETS: + return None + + K_total = int(cat_shape[-1]) + if K_total > MAX_TOTAL_CAT_WIDTH: + return None + + K_sum = 0 + K_offsets = [0] + for p in parts: + sh = _shape_of(p) + if sh is None or len(sh) != rank: + return None + K_i = int(sh[-1]) + if K_i < MIN_PIECE_WIDTH: + return None + K_sum += K_i + K_offsets.append(K_sum) + if K_sum != K_total: + return None + + w_shape = _shape_of(weight) + if w_shape is None or len(w_shape) != 2 or int(w_shape[-1]) != K_total: + return None + + return parts, weight, bias, K_offsets + + +def _slice_then_contiguous(graph, weight, start, stop, weight_val): + """Emit aten.slice + aten.clone(memory_format=contiguous) on the weight + so the per-piece F.linear sees a dense buffer (hipBLASLt/cuBLASLt fast + path prefers contiguous inputs). + """ + slice_node = graph.call_function( + torch.ops.aten.slice.Tensor, args=(weight, 1, start, stop) + ) + if weight_val is not None: + try: + slice_node.meta["val"] = torch.ops.aten.slice.Tensor( + weight_val, 1, start, stop + ) + except Exception: + pass + clone = graph.call_function( + torch.ops.aten.clone.default, + args=(slice_node,), + kwargs={"memory_format": torch.contiguous_format}, + ) + if slice_node.meta.get("val") is not None: + try: + clone.meta["val"] = slice_node.meta["val"].contiguous() + except Exception: + pass + return clone + + +def fuse_cat_linear_in_graph(graph: fx.Graph) -> int: + """Run the matcher across `graph`, returning the number of rewrites.""" + n = 0 + linear_target = torch.nn.functional.linear + add_target = operator.add + + candidates = [ + node for node in list(graph.nodes) + if node.op == "call_function" and _is_linear(node) + ] + for ln in candidates: + match = _try_match_cat_linear(ln) + if match is None: + continue + parts, weight, bias, K_off = match + weight_val = _val_of(weight) + ln_val = _val_of(ln) + + with graph.inserting_before(ln): + partial_outs = [] + for i in range(len(parts)): + w_slc = _slice_then_contiguous( + graph, weight, K_off[i], K_off[i + 1], weight_val + ) + # Bias only on the first F.linear; mathematically equivalent + # and saves a standalone tensor-add at the end. + out_i = graph.call_function( + linear_target, + args=(parts[i], w_slc, bias if i == 0 else None), + ) + if ln_val is not None: + out_i.meta["val"] = ln_val + partial_outs.append(out_i) + acc = partial_outs[0] + for o in partial_outs[1:]: + acc = graph.call_function(add_target, args=(acc, o)) + if ln_val is not None: + acc.meta["val"] = ln_val + + ln.replace_all_uses_with(acc) + graph.erase_node(ln) + # The cat is now dead; DCE will clean it up. + + counters["inductor"]["cat_linear_fused"] += 1 + log.debug( + "cat_linear_fused: rewrote linear(cat n=%d, K_total=%d, has_bias=%s)", + len(parts), K_off[-1], bias is not None, + ) + n += 1 + return n + + +def cat_linear_fused_pre_grad_pass(graph: fx.Graph): + n = fuse_cat_linear_in_graph(graph) + if n > 0: + log.info("cat_linear_fused: %d replacement(s)", n) + return graph + + +__all__ = [ + "fuse_cat_linear_in_graph", + "cat_linear_fused_pre_grad_pass", +]