Skip to content

Silent incorrectness in fused_moe with cloned FP4 weights/scales #2194

@msaroufim

Description

@msaroufim

Bug description

fused_moe produces numerically incorrect, non-deterministic results for identical FP4 weights when the weights/scales are clone, I'm guessing this is a bug in how the kernel handles tensor storage/layout rather than tensor values.

This came up in gpumode.com

Dependencies

https://github.com/gpu-mode/kernelbot/blob/main/docker/amd-docker.Dockerfile

In particular using aiter at this recent commit

RUN git clone --recursive https://github.com/ROCm/aiter.git \
    && cd aiter \
    && git checkout f3be04a12a0cfd6b5e2c7a94edc774f1bc24460d \

Repro

  import torch, math
  import aiter
  from aiter import ActivationType, QuantType, dtypes
  from aiter.fused_moe import fused_moe
  from aiter.utility import fp4_utils
  from aiter.ops.shuffle import shuffle_weight

  E, d_hidden, d_expert, top_k, bs = 257, 4096, 1024, 9, 8
  d_hidden_pad = d_expert_pad = 1024  # already 256-aligned

  gen = torch.Generator(device="cuda"); gen.manual_seed(42)
  hidden = torch.randn((bs, d_hidden), device="cuda", dtype=torch.bfloat16, generator=gen)
  topk_ids = torch.randint(0, E, (bs, top_k), device="cuda", dtype=torch.int32)
  topk_weights = torch.randn((bs, top_k), device="cuda", dtype=torch.float32).softmax(dim=-1)

  gu_bf16 = torch.randn((E, 2*d_expert_pad, d_hidden_pad), device="cuda", dtype=torch.bfloat16, generator=gen)
  dn_bf16 = torch.randn((E, d_hidden_pad, d_expert_pad), device="cuda", dtype=torch.bfloat16, generator=gen)

  torch_quant = aiter.get_torch_quant(QuantType.per_1x32)
  gu_w, gu_s = torch_quant(gu_bf16, quant_dtype=dtypes.fp4x2)
  dn_w, dn_s = torch_quant(dn_bf16, quant_dtype=dtypes.fp4x2)
  gu_w = gu_w.view(E, 2*d_expert_pad, d_hidden_pad//2)
  dn_w = dn_w.view(E, d_hidden_pad, d_expert_pad//2)

  gu_shuf = shuffle_weight(gu_w, layout=(16,16))
  dn_shuf = shuffle_weight(dn_w, layout=(16,16))
  gu_s_shuf = fp4_utils.e8m0_shuffle(gu_s)
  dn_s_shuf = fp4_utils.e8m0_shuffle(dn_s)

  kwargs = dict(expert_mask=None, activation=ActivationType.Silu,
      quant_type=QuantType.per_1x32, doweight_stage1=False,
      w1_scale=gu_s_shuf, w2_scale=dn_s_shuf,
      a1_scale=None, a2_scale=None, hidden_pad=0, intermediate_pad=0)

  # Run with original tensors
  torch.cuda.synchronize()
  out_orig = fused_moe(hidden, gu_shuf, dn_shuf, topk_weights, topk_ids, **kwargs)
  torch.cuda.synchronize()

  # Run with cloned weights (same values, different memory)
  torch.cuda.synchronize()
  out_clone = fused_moe(hidden, gu_shuf.clone(), dn_shuf.clone(), topk_weights, topk_ids,
      **{**kwargs, "w1_scale": gu_s_shuf.clone(), "w2_scale": dn_s_shuf.clone()})
  torch.cuda.synchronize()

  diff = (out_orig - out_clone).abs()
  print(f"Max diff: {diff.max().item()}")  # Expected: ~0.01. Actual: ~4.3
  print(f"Mismatched elements: {(diff > 0.05).sum().item()}/{diff.numel()}")
  # BUG: cloned weights produce completely different results

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions