Describe the bug
Two mathematically equivalent rank-2 dot_general forms have very different Neuron profiles.
- The compiler-friendly form contracts RHS dimension 0:
lhs[t, k] x rhs[k, n] -> out[t, n]
- The checkpoint-shaped form contracts RHS dimension 1:
lhs[t, k] x rhs[n, k] -> out[t, n] (eg: a Llama model)
Both forms compute the same values when rhs[k, n] is the transpose ofrhs[n, k].
On Inf2, the RHS-dim-1 form profiles with large TensorE transpose FLOPs, while the RHS-dim-0 form stays close to the useful model FLOPs.
The Neuron compiler should either choose a layout that avoids the extra TensorE transpose work for the RHS-dim-1 dot_general form, or expose a supported public way to mark a static RHS operand as transposable without
inserting an explicit graph transpose.
With shape lhs[1,4096] and rhs[4096,4096] / rhs[4096,4096], the canonical RHS-dim-0 form previously profiled around 34.60M adjusted hardware FLOPs with about 1.05M transpose FLOPs. The RHS-dim-1 form profiled around 4.33G adjusted hardware FLOPs with about 4.30G transpose FLOPs, even though model FLOPs are only about 33.55M.
Model Name
N/A. These is self-contained JAX reproducer.
Describe the workload type
LLM inference.
Instance Type
inf2.8xlarge
Release version
python=3.12.12
numpy=2.4.6
jax=0.7.0
jaxlib=0.7.0
jax-neuronx=0.7.0.1.0.8181+1e892be0
libneuronxla=3.0.2891.0
neuronx-cc=2.25.3371.0
aws-neuronx-runtime-lib=2.32.31.0
aws-neuronx-collectives=2.32.28.0
Reproduction Steps
NEURON_RT_VISIBLE_CORES=0 python3 main.py --jax-platforms=neuron,cpu
main.py:
import argparse
import importlib.metadata
import os
from pathlib import Path
import sys
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--jax-platforms", default=None)
parser.add_argument("--case", choices=("both", "canonical", "checkpoint"), default="both")
parser.add_argument("--tokens", type=int, default=1)
parser.add_argument("--hidden-size", type=int, default=4096)
parser.add_argument("--output-size", type=int, default=4096)
parser.add_argument("--dtype", choices=("bf16", "f32"), default="bf16")
args = parser.parse_args()
if args.jax_platforms:
os.environ["JAX_PLATFORMS"] = args.jax_platforms
import jax
import jax.numpy as jnp
from jax import lax
import numpy as np
dtype = jnp.bfloat16 if args.dtype == "bf16" else jnp.float32
reference_dtype = np.dtype("bfloat16" if args.dtype == "bf16" else "float32")
print(f"python={sys.version.split()[0]}")
print(f"numpy={np.__version__}")
print(f"jax={jax.__version__}")
print(f"jaxlib={importlib.metadata.version('jaxlib')}")
print(f"jax-neuronx={importlib.metadata.version('jax-neuronx')}")
print(f"libneuronxla={importlib.metadata.version('libneuronxla')}")
print(f"neuronx-cc={importlib.metadata.version('neuronx-cc')}")
print(f"jax_default_backend={jax.default_backend()}")
print(f"jax_devices={jax.devices()}")
tokens = args.tokens
hidden = args.hidden_size
output = args.output_size
lhs_np = np.linspace(-1.0, 1.0, tokens * hidden, dtype=np.float32).reshape(tokens, hidden)
rhs_checkpoint_np = np.linspace(-0.75, 0.75, output * hidden, dtype=np.float32).reshape(output, hidden)
rhs_canonical_np = rhs_checkpoint_np.T.copy()
lhs = jnp.asarray(lhs_np, dtype=dtype)
rhs_canonical = jnp.asarray(rhs_canonical_np, dtype=dtype)
rhs_checkpoint = jnp.asarray(rhs_checkpoint_np, dtype=dtype)
lhs_ref = lhs_np.astype(reference_dtype).astype(np.float32)
rhs_canonical_ref = rhs_canonical_np.astype(reference_dtype).astype(np.float32)
def canonical_dot(x, w):
return lax.dot_general(x, w, (((1,), (0,)), ((), ())))
def checkpoint_dot(x, w):
return lax.dot_general(x, w, (((1,), (1,)), ((), ())))
def run_case(name, fn, x, w, expected):
lowered = jax.jit(fn).lower(x, w)
print(f"\n=== {name} StableHLO ===")
print(lowered.compiler_ir(dialect="stablehlo"))
compiled = lowered.compile()
out = np.asarray(jax.device_get(compiled(x, w)), dtype=np.float32)
abs_diff = np.abs(out - expected)
max_abs = float(abs_diff.max()) if abs_diff.size else 0.0
close = np.isclose(out, expected, rtol=0.05, atol=0.2)
close_fraction = float(close.mean()) if close.size else 1.0
print(f"{name}_shape={out.shape}")
print(f"{name}_sum={float(out.sum()):.6f}")
print(f"{name}_max_abs={max_abs:.6f}")
print(f"{name}_close_fraction={close_fraction:.6f}")
if close_fraction < 1.0:
raise AssertionError(f"{name} output mismatch: close_fraction={close_fraction:.6f}")
expected = lhs_ref @ rhs_canonical_ref
if args.case in ("both", "canonical"):
run_case("canonical_rhs_contracting_axis_0", canonical_dot, lhs, rhs_canonical, expected)
if args.case in ("both", "checkpoint"):
run_case("checkpoint_rhs_contracting_axis_1", checkpoint_dot, lhs, rhs_checkpoint, expected)
Regression Issue
Possible Solution
No response
Logs/Context/Additional Information
No response
Describe the bug
Two mathematically equivalent rank-2 dot_general forms have very different Neuron profiles.
lhs[t, k] x rhs[k, n] -> out[t, n]lhs[t, k] x rhs[n, k] -> out[t, n](eg: a Llama model)Both forms compute the same values when
rhs[k, n]is the transpose ofrhs[n, k].On Inf2, the RHS-dim-1 form profiles with large TensorE transpose FLOPs, while the RHS-dim-0 form stays close to the useful model FLOPs.
The Neuron compiler should either choose a layout that avoids the extra TensorE transpose work for the RHS-dim-1
dot_generalform, or expose a supported public way to mark a static RHS operand as transposable withoutinserting an explicit graph transpose.
With shape
lhs[1,4096]andrhs[4096,4096]/rhs[4096,4096], the canonical RHS-dim-0 form previously profiled around 34.60M adjusted hardware FLOPs with about 1.05M transpose FLOPs. The RHS-dim-1 form profiled around 4.33G adjusted hardware FLOPs with about 4.30G transpose FLOPs, even though model FLOPs are only about 33.55M.Model Name
N/A. These is self-contained JAX reproducer.
Describe the workload type
LLM inference.
Instance Type
inf2.8xlarge
Release version
python=3.12.12
numpy=2.4.6
jax=0.7.0
jaxlib=0.7.0
jax-neuronx=0.7.0.1.0.8181+1e892be0
libneuronxla=3.0.2891.0
neuronx-cc=2.25.3371.0
aws-neuronx-runtime-lib=2.32.31.0
aws-neuronx-collectives=2.32.28.0
Reproduction Steps
NEURON_RT_VISIBLE_CORES=0 python3 main.py --jax-platforms=neuron,cpumain.py:Regression Issue
Possible Solution
No response
Logs/Context/Additional Information
No response