Skip to content

Rank-2 dot with RHS contracting dimension 1 lowers to large transpose work #1336

@hugomano

Description

@hugomano

Describe the bug

Two mathematically equivalent rank-2 dot_general forms have very different Neuron profiles.

  • The compiler-friendly form contracts RHS dimension 0: lhs[t, k] x rhs[k, n] -> out[t, n]
  • The checkpoint-shaped form contracts RHS dimension 1: lhs[t, k] x rhs[n, k] -> out[t, n] (eg: a Llama model)

Both forms compute the same values when rhs[k, n] is the transpose ofrhs[n, k].
On Inf2, the RHS-dim-1 form profiles with large TensorE transpose FLOPs, while the RHS-dim-0 form stays close to the useful model FLOPs.

The Neuron compiler should either choose a layout that avoids the extra TensorE transpose work for the RHS-dim-1 dot_general form, or expose a supported public way to mark a static RHS operand as transposable without
inserting an explicit graph transpose.

With shape lhs[1,4096] and rhs[4096,4096] / rhs[4096,4096], the canonical RHS-dim-0 form previously profiled around 34.60M adjusted hardware FLOPs with about 1.05M transpose FLOPs. The RHS-dim-1 form profiled around 4.33G adjusted hardware FLOPs with about 4.30G transpose FLOPs, even though model FLOPs are only about 33.55M.

Model Name

N/A. These is self-contained JAX reproducer.

Describe the workload type

LLM inference.

Instance Type

inf2.8xlarge

Release version

python=3.12.12
numpy=2.4.6
jax=0.7.0
jaxlib=0.7.0
jax-neuronx=0.7.0.1.0.8181+1e892be0
libneuronxla=3.0.2891.0
neuronx-cc=2.25.3371.0
aws-neuronx-runtime-lib=2.32.31.0
aws-neuronx-collectives=2.32.28.0

Reproduction Steps

NEURON_RT_VISIBLE_CORES=0 python3 main.py --jax-platforms=neuron,cpu

main.py:

import argparse
import importlib.metadata
import os
from pathlib import Path
import sys

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--jax-platforms", default=None)
    parser.add_argument("--case", choices=("both", "canonical", "checkpoint"), default="both")
    parser.add_argument("--tokens", type=int, default=1)
    parser.add_argument("--hidden-size", type=int, default=4096)
    parser.add_argument("--output-size", type=int, default=4096)
    parser.add_argument("--dtype", choices=("bf16", "f32"), default="bf16")
    args = parser.parse_args()
    if args.jax_platforms:
        os.environ["JAX_PLATFORMS"] = args.jax_platforms

    import jax
    import jax.numpy as jnp
    from jax import lax
    import numpy as np

    dtype = jnp.bfloat16 if args.dtype == "bf16" else jnp.float32
    reference_dtype = np.dtype("bfloat16" if args.dtype == "bf16" else "float32")

    print(f"python={sys.version.split()[0]}")
    print(f"numpy={np.__version__}")
    print(f"jax={jax.__version__}")
    print(f"jaxlib={importlib.metadata.version('jaxlib')}")
    print(f"jax-neuronx={importlib.metadata.version('jax-neuronx')}")
    print(f"libneuronxla={importlib.metadata.version('libneuronxla')}")
    print(f"neuronx-cc={importlib.metadata.version('neuronx-cc')}")
    print(f"jax_default_backend={jax.default_backend()}")
    print(f"jax_devices={jax.devices()}")

    tokens = args.tokens
    hidden = args.hidden_size
    output = args.output_size

    lhs_np = np.linspace(-1.0, 1.0, tokens * hidden, dtype=np.float32).reshape(tokens, hidden)
    rhs_checkpoint_np = np.linspace(-0.75, 0.75, output * hidden, dtype=np.float32).reshape(output, hidden)
    rhs_canonical_np = rhs_checkpoint_np.T.copy()

    lhs = jnp.asarray(lhs_np, dtype=dtype)
    rhs_canonical = jnp.asarray(rhs_canonical_np, dtype=dtype)
    rhs_checkpoint = jnp.asarray(rhs_checkpoint_np, dtype=dtype)

    lhs_ref = lhs_np.astype(reference_dtype).astype(np.float32)
    rhs_canonical_ref = rhs_canonical_np.astype(reference_dtype).astype(np.float32)

    def canonical_dot(x, w):
        return lax.dot_general(x, w, (((1,), (0,)), ((), ())))

    def checkpoint_dot(x, w):
        return lax.dot_general(x, w, (((1,), (1,)), ((), ())))

    def run_case(name, fn, x, w, expected):
        lowered = jax.jit(fn).lower(x, w)
        print(f"\n=== {name} StableHLO ===")
        print(lowered.compiler_ir(dialect="stablehlo"))

        compiled = lowered.compile()
        out = np.asarray(jax.device_get(compiled(x, w)), dtype=np.float32)

        abs_diff = np.abs(out - expected)
        max_abs = float(abs_diff.max()) if abs_diff.size else 0.0
        close = np.isclose(out, expected, rtol=0.05, atol=0.2)
        close_fraction = float(close.mean()) if close.size else 1.0

        print(f"{name}_shape={out.shape}")
        print(f"{name}_sum={float(out.sum()):.6f}")
        print(f"{name}_max_abs={max_abs:.6f}")
        print(f"{name}_close_fraction={close_fraction:.6f}")

        if close_fraction < 1.0:
            raise AssertionError(f"{name} output mismatch: close_fraction={close_fraction:.6f}")

    expected = lhs_ref @ rhs_canonical_ref

    if args.case in ("both", "canonical"):
        run_case("canonical_rhs_contracting_axis_0", canonical_dot, lhs, rhs_canonical, expected)
    if args.case in ("both", "checkpoint"):
        run_case("checkpoint_rhs_contracting_axis_1", checkpoint_dot, lhs, rhs_checkpoint, expected)

Regression Issue

  • Select this option if this issue appears to be a regression.

Possible Solution

No response

Logs/Context/Additional Information

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    Inf2bugSomething isn't working

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions