From b4d378f751d70c2bcb451938a12ce24f17844030 Mon Sep 17 00:00:00 2001 From: S1ro1 Date: Tue, 23 Dec 2025 15:51:25 +0100 Subject: [PATCH] Fix: init range --- problems/nvidia/nvfp4_dual_gemm/reference.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/problems/nvidia/nvfp4_dual_gemm/reference.py b/problems/nvidia/nvfp4_dual_gemm/reference.py index b378f78..6183237 100644 --- a/problems/nvidia/nvfp4_dual_gemm/reference.py +++ b/problems/nvidia/nvfp4_dual_gemm/reference.py @@ -112,13 +112,13 @@ def generate_input( # Generate uint8 tensor, then convert to float4e2m1fn_x2 data type a_ref = torch.randint( - 0, 2, (l, m, k // 2), dtype=torch.int8, device="cuda" + 0, 4, (l, m, k // 2), dtype=torch.int8, device="cuda" ).permute(1, 2, 0) b1_ref = torch.randint( - 0, 2, (l, n, k // 2), dtype=torch.int8, device="cuda" + 0, 4, (l, n, k // 2), dtype=torch.int8, device="cuda" ).permute(1, 2, 0) b2_ref = torch.randint( - 0, 2, (l, n, k // 2), dtype=torch.int8, device="cuda" + 0, 4, (l, n, k // 2), dtype=torch.int8, device="cuda" ).permute(1, 2, 0) a_ref = a_ref.view(torch.float4_e2m1fn_x2) b1_ref = b1_ref.view(torch.float4_e2m1fn_x2) @@ -137,7 +137,7 @@ def create_scale_factor_tensors(l, mn, sf_k): ref_shape = (l, mn, sf_k) ref_permute_order = (1, 2, 0) # Init with uint8 tensor, then convert to float8_e4m3fn - ref_f8_random_int = torch.randint(-1, 2, ref_shape, dtype=torch.int8, device='cuda') + ref_f8_random_int = torch.randint(0, 3, ref_shape, dtype=torch.int8, device='cuda') ref_f8_torch_tensor = ref_f8_random_int.to(dtype=torch.float8_e4m3fn) # permute to match ref_permute_order ref_f8_torch_tensor_permuted = ref_f8_torch_tensor.permute(*ref_permute_order)