Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 143 additions & 0 deletions test/inductor/test_cat_linear_fused.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
# Owner(s): ["module: inductor"]
import operator

import torch
import torch.fx as fx
import torch.nn as nn
import torch.nn.functional as F
from torch._dynamo.utils import counters
from torch._inductor.fx_passes.cat_linear_fused import (
cat_linear_fused_pre_grad_pass,
fuse_cat_linear_in_graph,
MAX_PARTS,
MAX_TOTAL_CAT_WIDTH,
MIN_PIECE_WIDTH,
)
from torch.testing._internal.common_utils import run_tests, TestCase


def _shape_meta(node, shape):
"""Attach a FakeTensor-like meta['val'] so the matcher can read shapes."""
node.meta["val"] = torch.empty(shape)
return node


def _build_linear_cat_graph(num_parts, K_per_part, M=4, N=8, cat_dim=-1):
g = fx.Graph()
parts = []
for i in range(num_parts):
p = g.placeholder(f"t{i}")
_shape_meta(p, (M, K_per_part))
parts.append(p)
w = g.placeholder("w")
_shape_meta(w, (N, num_parts * K_per_part))
b = g.placeholder("b")
_shape_meta(b, (N,))
cat = g.call_function(torch.cat, args=(parts, cat_dim))
_shape_meta(cat, (M, num_parts * K_per_part))
lin = g.call_function(F.linear, args=(cat, w, b))
_shape_meta(lin, (M, N))
g.output(lin)
return g


class TestCatLinearFusedMatcher(TestCase):
def test_canonical_pattern_fires(self):
g = _build_linear_cat_graph(num_parts=2, K_per_part=16)
n = fuse_cat_linear_in_graph(g)
self.assertEqual(n, 1)

def test_three_parts_fires(self):
g = _build_linear_cat_graph(num_parts=3, K_per_part=16)
n = fuse_cat_linear_in_graph(g)
self.assertEqual(n, 1)

def test_rejects_too_many_parts(self):
g = _build_linear_cat_graph(num_parts=MAX_PARTS + 1, K_per_part=16)
n = fuse_cat_linear_in_graph(g)
self.assertEqual(n, 0)

def test_rejects_piece_below_min_width(self):
g = _build_linear_cat_graph(num_parts=2, K_per_part=MIN_PIECE_WIDTH - 1)
n = fuse_cat_linear_in_graph(g)
self.assertEqual(n, 0)

def test_rejects_total_above_max_width(self):
# Two parts, each just over MAX_TOTAL_CAT_WIDTH/2, so total > cap.
K = MAX_TOTAL_CAT_WIDTH // 2 + 8
g = _build_linear_cat_graph(num_parts=2, K_per_part=K)
n = fuse_cat_linear_in_graph(g)
self.assertEqual(n, 0)

def test_rejects_mul_parented_part(self):
g = fx.Graph()
a = g.placeholder("a")
_shape_meta(a, (4, 16))
b = g.placeholder("b")
_shape_meta(b, (4, 16))
# One of the parts is the output of a `mul` - should be skipped by
# the matcher.
m = g.call_function(operator.mul, args=(a, b))
_shape_meta(m, (4, 16))
c = g.placeholder("c")
_shape_meta(c, (4, 16))
cat = g.call_function(torch.cat, args=([m, c], -1))
_shape_meta(cat, (4, 32))
w = g.placeholder("w")
_shape_meta(w, (8, 32))
lin = g.call_function(F.linear, args=(cat, w, None))
_shape_meta(lin, (4, 8))
g.output(lin)
n = fuse_cat_linear_in_graph(g)
self.assertEqual(n, 0)

def test_rejects_non_lastdim_cat(self):
g = _build_linear_cat_graph(num_parts=2, K_per_part=16, cat_dim=0)
n = fuse_cat_linear_in_graph(g)
self.assertEqual(n, 0)


class _CatLinearMod(nn.Module):
"""F.linear(torch.cat([proj_a(a), proj_b(b)], dim=-1), W, b) head."""

def __init__(self, dim_a=64, dim_b=64, out=32):
super().__init__()
self.proj_a = nn.Linear(dim_a, dim_a)
self.proj_b = nn.Linear(dim_b, dim_b)
self.head = nn.Linear(dim_a + dim_b, out)

def forward(self, a, b):
ha = F.relu(self.proj_a(a))
hb = F.relu(self.proj_b(b))
return self.head(torch.cat([ha, hb], dim=-1))


class TestCatLinearFusedIntegration(TestCase):
def test_compile_fires_and_matches_reference(self):
torch._inductor.config.pre_grad_custom_pass = cat_linear_fused_pre_grad_pass
try:
counters.clear()
mod = _CatLinearMod().eval()
a = torch.randn(8, 64)
b = torch.randn(8, 64)
ref = mod(a, b)
traced = torch.compile(mod, mode="default", dynamic=False)
out = traced(a, b)
self.assertTrue(counters["inductor"]["cat_linear_fused"] >= 1)
self.assertEqual(ref.shape, out.shape)
torch.testing.assert_close(out, ref, rtol=1e-4, atol=1e-4)
finally:
torch._inductor.config.pre_grad_custom_pass = None

def test_disabled_by_default(self):
counters.clear()
mod = _CatLinearMod().eval()
a = torch.randn(8, 64)
b = torch.randn(8, 64)
traced = torch.compile(mod, mode="default", dynamic=False)
traced(a, b)
self.assertEqual(counters["inductor"]["cat_linear_fused"], 0)


if __name__ == "__main__":
run_tests()
236 changes: 236 additions & 0 deletions torch/_inductor/fx_passes/cat_linear_fused.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
"""Pre-grad FX pass that rewrites ``F.linear(cat([...], dim=-1), W, b)``
into a reduce-sum of per-piece ``F.linear`` calls on contiguous weight
slices, eliminating the ``cat`` materialisation in both forward and
backward.

Off by default; opt in by setting
``torch._inductor.config.pre_grad_custom_pass = cat_linear_fused_pre_grad_pass``
(or composing it with other custom passes).
"""
from __future__ import annotations

import logging
import operator

import torch
import torch.fx as fx
from torch._dynamo.utils import counters


log = logging.getLogger(__name__)


MAX_TOTAL_CAT_WIDTH = 384
MIN_PIECE_WIDTH = 8
MAX_PARTS = 3

# `mul`-parented parts are gated patterns that another fusion (block_cat_fused)
# is responsible for; skip them here to avoid overlap.
_REJECT_PARENT_TARGETS = {torch.mul, operator.mul}

# Dynamo sometimes inlines F.linear into torch._C._nn.linear on hot paths.
_LINEAR_TARGETS = [
t for t in (
torch.nn.functional.linear,
getattr(torch._C._nn, "linear", None),
) if t is not None
]


def _val_of(node):
if not hasattr(node, "meta"):
return None
# Use explicit `is None` instead of `or` here: when meta["val"] is a
# Tensor (which can happen if the matcher is invoked on a post-grad
# graph or a test graph that populates meta["val"] directly), the
# `or` short-circuit calls `Tensor.__bool__`, which raises for any
# multi-element tensor.
v = node.meta.get("val")
if v is None:
v = node.meta.get("example_value")
return v


def _shape_of(node):
val = _val_of(node)
if val is None:
return None
try:
return tuple(int(s) for s in val.shape)
except Exception:
return None


def _is_cat(node):
if node.op != "call_function":
return False
return node.target in (
torch.cat,
getattr(torch, "concatenate", None),
getattr(torch, "concat", None),
)


def _is_linear(node):
return node.op == "call_function" and node.target in _LINEAR_TARGETS


def _try_match_cat_linear(ln):
"""Match `linear(cat([parts...], dim=-1), W, bias)`. Returns
(parts, weight, bias_or_None, K_offsets) or None.
"""
if len(ln.args) < 2:
return None
cat_in = ln.args[0]
weight = ln.args[1]
bias = ln.args[2] if len(ln.args) >= 3 else ln.kwargs.get("bias")

if not isinstance(cat_in, fx.Node) or not _is_cat(cat_in):
return None
if not cat_in.args:
return None
parts_arg = cat_in.args[0]
if not isinstance(parts_arg, (list, tuple)):
return None
parts = list(parts_arg)
if not (2 <= len(parts) <= MAX_PARTS):
return None
if not all(isinstance(p, fx.Node) for p in parts):
return None

cat_dim = cat_in.args[1] if len(cat_in.args) >= 2 else cat_in.kwargs.get("dim", 0)
if isinstance(cat_dim, str):
try:
cat_dim = int(cat_dim)
except ValueError:
return None
cat_shape = _shape_of(cat_in)
if cat_shape is None or len(cat_shape) < 2:
return None
rank = len(cat_shape)
if not isinstance(cat_dim, int):
return None
cat_dim_norm = cat_dim + rank if cat_dim < 0 else cat_dim
if cat_dim_norm != rank - 1:
return None

for p in parts:
if p.op == "call_function" and p.target in _REJECT_PARENT_TARGETS:
return None

K_total = int(cat_shape[-1])
if K_total > MAX_TOTAL_CAT_WIDTH:
return None

K_sum = 0
K_offsets = [0]
for p in parts:
sh = _shape_of(p)
if sh is None or len(sh) != rank:
return None
K_i = int(sh[-1])
if K_i < MIN_PIECE_WIDTH:
return None
K_sum += K_i
K_offsets.append(K_sum)
if K_sum != K_total:
return None

w_shape = _shape_of(weight)
if w_shape is None or len(w_shape) != 2 or int(w_shape[-1]) != K_total:
return None

return parts, weight, bias, K_offsets


def _slice_then_contiguous(graph, weight, start, stop, weight_val):
"""Emit aten.slice + aten.clone(memory_format=contiguous) on the weight
so the per-piece F.linear sees a dense buffer (hipBLASLt/cuBLASLt fast
path prefers contiguous inputs).
"""
slice_node = graph.call_function(
torch.ops.aten.slice.Tensor, args=(weight, 1, start, stop)
)
if weight_val is not None:
try:
slice_node.meta["val"] = torch.ops.aten.slice.Tensor(
weight_val, 1, start, stop
)
except Exception:
pass
clone = graph.call_function(
torch.ops.aten.clone.default,
args=(slice_node,),
kwargs={"memory_format": torch.contiguous_format},
)
if slice_node.meta.get("val") is not None:
try:
clone.meta["val"] = slice_node.meta["val"].contiguous()
except Exception:
pass
return clone


def fuse_cat_linear_in_graph(graph: fx.Graph) -> int:
"""Run the matcher across `graph`, returning the number of rewrites."""
n = 0
linear_target = torch.nn.functional.linear
add_target = operator.add

candidates = [
node for node in list(graph.nodes)
if node.op == "call_function" and _is_linear(node)
]
for ln in candidates:
match = _try_match_cat_linear(ln)
if match is None:
continue
parts, weight, bias, K_off = match
weight_val = _val_of(weight)
ln_val = _val_of(ln)

with graph.inserting_before(ln):
partial_outs = []
for i in range(len(parts)):
w_slc = _slice_then_contiguous(
graph, weight, K_off[i], K_off[i + 1], weight_val
)
# Bias only on the first F.linear; mathematically equivalent
# and saves a standalone tensor-add at the end.
out_i = graph.call_function(
linear_target,
args=(parts[i], w_slc, bias if i == 0 else None),
)
if ln_val is not None:
out_i.meta["val"] = ln_val
partial_outs.append(out_i)
acc = partial_outs[0]
for o in partial_outs[1:]:
acc = graph.call_function(add_target, args=(acc, o))
if ln_val is not None:
acc.meta["val"] = ln_val

ln.replace_all_uses_with(acc)
graph.erase_node(ln)
# The cat is now dead; DCE will clean it up.

counters["inductor"]["cat_linear_fused"] += 1
log.debug(
"cat_linear_fused: rewrote linear(cat n=%d, K_total=%d, has_bias=%s)",
len(parts), K_off[-1], bias is not None,
)
n += 1
return n


def cat_linear_fused_pre_grad_pass(graph: fx.Graph):
n = fuse_cat_linear_in_graph(graph)
if n > 0:
log.info("cat_linear_fused: %d replacement(s)", n)
return graph


__all__ = [
"fuse_cat_linear_in_graph",
"cat_linear_fused_pre_grad_pass",
]