Skip to content

LigerRMSNorm kernel produces non-exact outputs (max diff 0.125 at bfloat16) #105

@AmitMY

Description

@AmitMY

Summary

When using kernelize on a Llama-based model, only the LigerRMSNorm kernel causes numerical differences. The SiLU kernel produces identical outputs.

Results

Config Time Speedup Max Logit Diff Argmax Match
Baseline (no kernelize) 107.83ms 1.00x 0
Only RMSNorm 97.69ms 1.10x 0.125
Only SiLU 104.64ms 1.03x 0
Both 94.40ms 1.14x 0.125

Key Findings

  1. LigerRMSNorm (from kernels-community/liger_kernels):

    • 10% speedup
    • ⚠️ Max logit diff: 0.125
    • ✓ Argmax still matches (generation unaffected)
  2. Silu (from kernels-community/activation):

    • 3% speedup
    • Exact match (0 diff)

Environment

  • Model: sign/utf8-lm-tiny (Llama-based, ~70M params)
  • dtype: torch.bfloat16
  • GPU: NVIDIA GB10 (CUDA 12.1)

Minimal Reproduction

from transformers import AutoModelForCausalLM
import torch
from utf8_tokenizer import UTF8Tokenizer
from kernels import Mode
from kernels.layer.layer import kernelize_layer
from kernels.layer.device import Device
from transformers.models.llama.modeling_llama import LlamaRMSNorm
from transformers.activations import SiLUActivation

model_id = "sign/utf8-lm-tiny"
device = "cuda"
dtype = torch.bfloat16

tokenizer = UTF8Tokenizer()
prompt = "Hello world! " * 9  # ~118 tokens

inputs = tokenizer([prompt], return_tensors="pt", padding=True, add_special_tokens=True)
inputs["input_ids"] = inputs["input_ids"].to(torch.long)[:, :-1].to(device)
inputs["attention_mask"] = inputs["attention_mask"][:, :-1].to(device)

# Baseline
model_base = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype).to(device).eval()
with torch.no_grad():
    logits_base = model_base(**inputs).logits

# Only RMSNorm kernelized
model_rms = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype).to(device)
for name, module in model_rms.named_modules():
    if isinstance(module, LlamaRMSNorm):
        kernelize_layer(module, mode=Mode.INFERENCE, device_type=Device(type="cuda"), use_fallback=True)
model_rms.eval()
with torch.no_grad():
    logits_rms = model_rms(**inputs).logits

# Only SiLU kernelized
model_silu = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype).to(device)
for name, module in model_silu.named_modules():
    if isinstance(module, SiLUActivation):
        kernelize_layer(module, mode=Mode.INFERENCE, device_type=Device(type="cuda"), use_fallback=True)
model_silu.eval()
with torch.no_grad():
    logits_silu = model_silu(**inputs).logits

print(f"RMSNorm diff: {(logits_base - logits_rms).abs().max().item()}")  # 0.125
print(f"SiLU diff: {(logits_base - logits_silu).abs().max().item()}")    # 0.0

Analysis

The LigerRMSNorm kernel uses a different numerical implementation than the original LlamaRMSNorm:

  • Likely uses a fused kernel with different reduction order
  • At bfloat16 precision, this causes small differences that accumulate through 9 norm layers

The differences are small enough that argmax predictions are unaffected, so generation results remain identical.

Recommendation

  • SiLU kernel is bit-exact - safe to use, 3% speedup
  • RMSNorm kernel causes differences - avoid if exact reproducibility is required

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions