Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions qa/L0_pytorch_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_router.xml
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_partial_cast.xml $TE_PATH/tests/pytorch/test_partial_cast.py || test_fail "test_partial_cast.py"
# Disable autotuning to make unittests faster. In addition, disable TF32 path to fully align with the pytorch reference implementation's precision
NVTE_DISABLE_TRITON_AUTOTUNING=1 NVIDIA_TF32_OVERRIDE=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_mhc.xml $TE_PATH/tests/pytorch/test_mhc.py || test_fail "test_mhc.py"
NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_DISABLE_TRITON_AUTOTUNING=1 NVIDIA_TF32_OVERRIDE=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_mhc_deterministic.xml $TE_PATH/tests/pytorch/test_mhc.py || test_fail "test_mhc.py"

if [ "$RET" -ne 0 ]; then
echo "Error in the following test cases:$FAILED_CASES"
Expand Down
195 changes: 152 additions & 43 deletions tests/pytorch/test_mhc.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@
mhc_fused_aggregate,
mhc_fused_expand_combine,
mhc_fused_projection,
mhc_generate_mix_and_aggregate,
is_deterministic_enforced,
)

# Disable TF32 for matmul to ensure consistency between the fused and reference implementations
torch.backends.cuda.matmul.allow_tf32 = False


def mhc_projection_ref(x, phi):
def mhc_projection_ref(x, phi, norm_weight):
"""
Reference operator for mHC's projection building operation.

Expand All @@ -29,19 +31,20 @@ def mhc_projection_ref(x, phi):
- phi_pre: (n, nC)
- phi_post: (n, nC)
- phi_res: (n^2, nC)
norm_weight: (nC,) or None, if not None, apply element-wise multiplication to phi before projection
n: number of Hyper Connection streams
C: hidden dimension per stream
"""
x_dtype = x.dtype
x = x.to(torch.float32)
phi = phi.to(torch.float32)

Hs = x @ phi.T # (M, 2n + n^2)

x_fp32 = x.to(torch.float32) # Use fp32 for better numerical stability in variance calculation
x_fp32 = x.to(torch.float32)
ms = (x_fp32 * x_fp32).mean(dim=1)

return Hs.to(x_dtype), ms
phi_fp32 = phi.to(torch.float32)
if norm_weight is not None:
phi_fp32 = phi_fp32 * norm_weight.to(torch.float32)[None, :]
Hs = x_fp32 @ phi_fp32.T # (M, 2n + n^2)

return Hs, ms


def mhc_scale_ref(H, alpha, beta, ms, n):
Expand Down Expand Up @@ -139,9 +142,9 @@ def mhc_aggregate_ref(x, H_pre, n):
s, b, C, n = x.shape
H_pre = H_pre.view(s, b, n, 1)

out = (x @ H_pre).view(s, b, C)
out = (x.to(H_pre.dtype) @ H_pre).view(s, b, C)

return out
return out.to(x.dtype)


def mhc_expand_combine_ref(f, bias, H_post, x, H_res, n):
Expand Down Expand Up @@ -267,25 +270,42 @@ def get_tols(dtype):


@pytest.mark.parametrize("cfg", mhc_configs, ids=MHCConfig.desc)
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16], ids=["fp32", "bf16"])
def test_mhc_projection(cfg: MHCConfig, dtype):
@pytest.mark.parametrize(
"dtypes",
[
(torch.float32, torch.float32),
(torch.bfloat16, torch.bfloat16),
(torch.bfloat16, torch.float32),
],
ids=["x_fp32_phi_fp32", "x_bf16_phi_bf16", "x_bf16_phi_fp32"],
)
@pytest.mark.parametrize("has_norm_weight", [False, True], ids=["no_norm_weight", "norm_weight"])
def test_mhc_projection(cfg: MHCConfig, dtypes, has_norm_weight):
reset_rng_states()

s, b, C, n = cfg.s, cfg.b, cfg.C, cfg.n
nC = n * C
N = 2 * n + n * n

tols = get_tols(dtype)
x_dtype = dtypes[0]
phi_dtype = dtypes[1]
tols = get_tols(x_dtype)
use_tf32 = False

x = torch.randn(s * b, nC, device="cuda", requires_grad=True, dtype=dtype)
phi = torch.randn(N, nC, dtype=dtype, requires_grad=True, device="cuda")

x = torch.randn(s * b, nC, device="cuda", requires_grad=True, dtype=x_dtype)
phi = torch.randn(N, nC, dtype=phi_dtype, requires_grad=True, device="cuda")
x_ref = x.detach().clone().requires_grad_(True)
phi_ref = phi.detach().clone().requires_grad_(True)

ref_out_Hs, ref_out_ms = mhc_projection_ref(x_ref, phi_ref)
fused_out_Hs_padded, fused_out_ms = mhc_fused_projection(x, phi, use_tf32)
if has_norm_weight:
norm_weight = torch.randn(nC, device="cuda", requires_grad=True, dtype=x_dtype)
norm_weight_ref = norm_weight.detach().clone().requires_grad_(True)
else:
norm_weight = None
norm_weight_ref = None

ref_out_Hs, ref_out_ms = mhc_projection_ref(x_ref, phi_ref, norm_weight_ref)
fused_out_Hs_padded, fused_out_ms = mhc_fused_projection(x, phi, norm_weight, use_tf32)
fused_out_Hs = fused_out_Hs_padded[:, :N]

torch.testing.assert_close(fused_out_Hs, ref_out_Hs, **tols)
Expand All @@ -295,10 +315,12 @@ def test_mhc_projection(cfg: MHCConfig, dtype):

torch.testing.assert_close(x.grad, x_ref.grad, **tols)
torch.testing.assert_close(phi.grad, phi_ref.grad, **tols)
if has_norm_weight:
torch.testing.assert_close(norm_weight.grad, norm_weight_ref.grad, **tols)


@pytest.mark.parametrize("cfg", mhc_configs, ids=MHCConfig.desc)
@pytest.mark.parametrize("dtype", [torch.float32], ids=["fp32"])
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16], ids=["fp32", "bf16"])
def test_mhc_scale(cfg: MHCConfig, dtype):
reset_rng_states()

Expand Down Expand Up @@ -329,37 +351,55 @@ def test_mhc_scale(cfg: MHCConfig, dtype):
torch.cat([fused_out[i] for i in range(3)], dim=-1).sum().backward()

torch.testing.assert_close(H_padded.grad[:, :N], H_ref.grad, **tols)
torch.testing.assert_close(ms.grad, ms_ref.grad, **tols)
torch.testing.assert_close(alpha.grad, alpha_ref.grad, **tols)
torch.testing.assert_close(beta.grad, beta_ref.grad, **tols)
torch.testing.assert_close(ms.grad, ms_ref.grad, **tols)


@pytest.mark.parametrize("cfg", mhc_configs, ids=MHCConfig.desc)
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16], ids=["fp32", "bf16"])
def test_mhc_combined(cfg: MHCConfig, dtype):
@pytest.mark.parametrize(
"dtypes",
[
(torch.float32, torch.float32),
(torch.bfloat16, torch.bfloat16),
(torch.bfloat16, torch.float32),
],
ids=["x_fp32_phi_fp32", "x_bf16_phi_bf16", "x_bf16_phi_fp32"],
)
@pytest.mark.parametrize("has_norm_weight", [False, True], ids=["no_norm_weight", "norm_weight"])
def test_mhc_rmsnorm(cfg: MHCConfig, dtypes, has_norm_weight):
# Verify if the fused kernel is equivalent to applying RMSNorm in the normal order
reset_rng_states()

s, b, C, n = cfg.s, cfg.b, cfg.C, cfg.n
N = 2 * n + n * n
nC = n * C

tols = get_tols(dtype)
x_dtype = dtypes[0]
phi_dtype = dtypes[1]
tols = get_tols(x_dtype)
use_tf32 = False

x = torch.randn(s * b, nC, device="cuda", requires_grad=True, dtype=dtype)
phi = torch.randn(N, nC, dtype=dtype, requires_grad=True, device="cuda")

alpha = torch.randn(3, device="cuda", requires_grad=True, dtype=dtype)
beta = torch.randn(1, 2 * n + n * n, device="cuda", requires_grad=True, dtype=dtype)
x = torch.randn(s * b, nC, device="cuda", requires_grad=True, dtype=x_dtype)
phi = torch.randn(N, nC, dtype=phi_dtype, requires_grad=True, device="cuda")
alpha = torch.randn(3, device="cuda", requires_grad=True, dtype=phi_dtype)
beta = torch.randn(1, 2 * n + n * n, device="cuda", requires_grad=True, dtype=phi_dtype)

x_ref = x.detach().clone().requires_grad_(True)
phi_ref = phi.detach().clone().requires_grad_(True)

alpha_ref = alpha.detach().clone().requires_grad_(True)
beta_ref = beta.detach().clone().requires_grad_(True)

ref_out_H, ref_out_r = mhc_projection_ref(x_ref, phi_ref)
fused_out_H_padded, fused_out_r = mhc_fused_projection(x, phi, use_tf32)
if has_norm_weight:
norm_weight = torch.randn(nC, device="cuda", requires_grad=True, dtype=x_dtype)
norm_weight_ref = norm_weight.detach().clone().requires_grad_(True)
else:
norm_weight = None
norm_weight_ref = None

ref_out_H, ref_out_r = mhc_projection_ref(x_ref, phi_ref, norm_weight_ref)
fused_out_H_padded, fused_out_r = mhc_fused_projection(x, phi, norm_weight, use_tf32)

ref_H_pre, ref_H_post, ref_H_res = mhc_scale_ref(
ref_out_H[:, :N], alpha_ref, beta_ref, ref_out_r, n
Expand All @@ -368,17 +408,19 @@ def test_mhc_combined(cfg: MHCConfig, dtype):
fused_out_H_padded, alpha, beta, fused_out_r, n
)

def mhc_combined(x_ref, phi_ref, alpha_ref, beta_ref):
dtype = x_ref.dtype
x_ref = x_ref.to(torch.float32)
phi_ref = phi_ref.to(torch.float32)
alpha_ref = alpha_ref.to(torch.float32)
beta_ref = beta_ref.to(torch.float32)

def mhc_combined(x_ref, phi_ref, alpha_ref, beta_ref, norm_weight_ref):
# Check if after spliting RMSNorm to two steps in projection and scaling,
# theresult is close to applying RMSNorm in the correct order
x_rmsnorm = F.rms_norm(x_ref, normalized_shape=(nC,))
H = x_rmsnorm @ phi_ref.T
# the result is close to applying RMSNorm in the correct order.
# Run RMSNorm in fp32 so the bf16 case has the same precision pattern as the
# kernel/ref (F.rms_norm on bf16 input would round x_rmsnorm back to bf16).
eps = torch.finfo(torch.float32).eps
norm_weight_fp32 = (
norm_weight_ref.to(torch.float32) if norm_weight_ref is not None else None
)
x_rmsnorm = F.rms_norm(
x_ref.to(torch.float32), normalized_shape=(nC,), weight=norm_weight_fp32, eps=eps
)
H = x_rmsnorm @ phi_ref.T.to(torch.float32)
H_pre = H[:, :n]
H_post = H[:, n : 2 * n]
H_res = H[:, 2 * n :]
Expand All @@ -391,21 +433,88 @@ def mhc_combined(x_ref, phi_ref, alpha_ref, beta_ref):
out_post = 2 * out_post.sigmoid()
out_res = out_res

return out_pre.to(dtype), out_post.to(dtype), out_res.to(dtype)
return out_pre, out_post, out_res # Return in FP32 to match the kernel's behavior

combined_H_pre, combined_H_post, combined_H_res = mhc_combined(
x_ref, phi_ref, alpha_ref, beta_ref
x_ref, phi_ref, alpha_ref, beta_ref, norm_weight_ref
)

torch.testing.assert_close(combined_H_pre, ref_H_pre, **tols)
torch.testing.assert_close(combined_H_post, ref_H_post, **tols)
torch.testing.assert_close(combined_H_res, ref_H_res, **tols)

torch.testing.assert_close(ref_H_pre, fused_H_pre, **tols)
torch.testing.assert_close(ref_H_post, fused_H_post, **tols)
torch.testing.assert_close(ref_H_res, fused_H_res, **tols)

torch.testing.assert_close(combined_H_pre, fused_H_pre, **tols)
torch.testing.assert_close(combined_H_post, fused_H_post, **tols)
torch.testing.assert_close(combined_H_res, fused_H_res, **tols)


@pytest.mark.parametrize("cfg", mhc_configs, ids=MHCConfig.desc)
@pytest.mark.parametrize("dtype", [torch.float32], ids=["fp32"])
def test_mhc_fuse_grad_acc(cfg: MHCConfig, dtype):
# Skip bf16 tests since in the unfused path the we accumulate 3 bf16 gradients, whereas in the fused path
# we accumulate 3 fp32 gradients and then cast to bf16 in the end, which causes two paths to have different precision patterns

if not is_deterministic_enforced():
pytest.skip(
"This test needs to be tested under deterministic mode to avoid discrepancies from"
" running two non-deterministic implementations twice"
)

reset_rng_states()

s, b, C, n = cfg.s, cfg.b, cfg.C, cfg.n
N = 2 * n + n * n
nC = n * C

# Since we tested the exactly same path twice with only gradient accumulation logic different, we can use tighter tolerances here
tols = {"atol": 1e-6, "rtol": 1e-6}
use_tf32 = False

x = torch.randn(s, b, C, n, device="cuda", requires_grad=True, dtype=dtype)
phi = torch.randn(N, nC, dtype=dtype, requires_grad=True, device="cuda")

alpha = torch.randn(3, device="cuda", requires_grad=True, dtype=dtype)
beta = torch.randn(1, 2 * n + n * n, device="cuda", requires_grad=True, dtype=dtype)
x_ref = x.detach().clone().requires_grad_(True)
phi_ref = phi.detach().clone().requires_grad_(True)

alpha_ref = alpha.detach().clone().requires_grad_(True)
beta_ref = beta.detach().clone().requires_grad_(True)

def end_to_end(x, phi, alpha, beta, fused_grad_x_acc):
aggregated, H_post, H_res = mhc_generate_mix_and_aggregate(
x, phi, alpha, beta, None, use_tf32, fused_grad_x_acc
)
H_res = mhc_fused_sinkhorn(H_res.view(s, b, n, n), n).view(s * b, n * n)
expanded_combined = mhc_fused_expand_combine(
aggregated,
None,
H_post,
x,
H_res,
False,
fused_grad_x_acc,
)

return expanded_combined

expanded_combined_fuse_grad = end_to_end(x_ref, phi_ref, alpha_ref, beta_ref, False)
expanded_combined_no_fuse_grad = end_to_end(x, phi, alpha, beta, True)

grad_output = torch.randn_like(expanded_combined_fuse_grad)
expanded_combined_fuse_grad.backward(grad_output)
expanded_combined_no_fuse_grad.backward(grad_output)

torch.testing.assert_close(x.grad, x_ref.grad, **tols)
torch.testing.assert_close(phi.grad, phi_ref.grad, **tols)
torch.testing.assert_close(alpha.grad, alpha_ref.grad, **tols)
torch.testing.assert_close(beta.grad, beta_ref.grad, **tols)


@pytest.mark.parametrize("cfg", mhc_configs, ids=MHCConfig.desc)
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16], ids=["fp32", "bf16"])
@pytest.mark.parametrize("recompute", [False, True], ids=["no_recompute", "recompute"])
Expand Down Expand Up @@ -482,7 +591,7 @@ def test_mhc_expand_combine(cfg: MHCConfig, dtype, with_bias):
H_res_ref = H_res.detach().clone().requires_grad_(True)

ref_out = mhc_expand_combine_ref(f_ref, bias_ref, H_post_ref, x_ref, H_res_ref, n)
fused_out = mhc_fused_expand_combine(f, bias, H_post, x, H_res, n, False)
fused_out = mhc_fused_expand_combine(f, bias, H_post, x, H_res, False, False)

torch.testing.assert_close(fused_out, ref_out, **tols)

Expand Down
Loading
Loading