Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions problems/nvidia/nvfp4_dual_gemm/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,13 +112,13 @@ def generate_input(

# Generate uint8 tensor, then convert to float4e2m1fn_x2 data type
a_ref = torch.randint(
0, 2, (l, m, k // 2), dtype=torch.int8, device="cuda"
0, 4, (l, m, k // 2), dtype=torch.int8, device="cuda"
).permute(1, 2, 0)
b1_ref = torch.randint(
0, 2, (l, n, k // 2), dtype=torch.int8, device="cuda"
0, 4, (l, n, k // 2), dtype=torch.int8, device="cuda"
).permute(1, 2, 0)
b2_ref = torch.randint(
0, 2, (l, n, k // 2), dtype=torch.int8, device="cuda"
0, 4, (l, n, k // 2), dtype=torch.int8, device="cuda"
).permute(1, 2, 0)
a_ref = a_ref.view(torch.float4_e2m1fn_x2)
b1_ref = b1_ref.view(torch.float4_e2m1fn_x2)
Expand All @@ -137,7 +137,7 @@ def create_scale_factor_tensors(l, mn, sf_k):
ref_shape = (l, mn, sf_k)
ref_permute_order = (1, 2, 0)
# Init with uint8 tensor, then convert to float8_e4m3fn
ref_f8_random_int = torch.randint(-1, 2, ref_shape, dtype=torch.int8, device='cuda')
ref_f8_random_int = torch.randint(0, 3, ref_shape, dtype=torch.int8, device='cuda')
ref_f8_torch_tensor = ref_f8_random_int.to(dtype=torch.float8_e4m3fn)
# permute to match ref_permute_order
ref_f8_torch_tensor_permuted = ref_f8_torch_tensor.permute(*ref_permute_order)
Expand Down