-
Notifications
You must be signed in to change notification settings - Fork 74
Benchmark for nvfp4 scaled mm #5737
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Review updated until commit 5a40972 Description
|
| Relevant files | |||||
|---|---|---|---|---|---|
| Bug fix |
| ||||
| Enhancement |
|
PR Reviewer Guide
Here are some key observations to aid the review process:
| 🧪 No relevant tests |
| ⚡ Recommended focus areas for review |
Missing Tests
|
Greptile Summaryadds NVFP4 quantized scaled matmul operations for non-grouped linear layers in the Python benchmark Changes
Issues Found
Confidence Score: 3/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant Benchmark as benchmark_inference.py
participant Quantize as _quantize_llama4()
participant Linear as NVFP4InferenceLinear
participant Op as nvfuser_f16a_nvfp4weight_scaled_mm
participant Thunder as Thunder/nvFuser
participant Alloc as allocations.cpp
Note over Benchmark: Register custom ops
Benchmark->>Thunder: _register_nvfp4_ops()
Thunder->>Thunder: register nvfp4_scaled_mm_symbol
Thunder->>Thunder: register nvfp4_scaled_mm_translator
Note over Benchmark: Model preparation
Benchmark->>Quantize: _quantize_llama4(model)
Quantize->>Quantize: replace GroupedSwiGLU with NVFP4InferenceGroupedSwiGLU
Quantize->>Quantize: replace SwiGLU with NVFP4InferenceSwiGLU
Quantize->>Quantize: find Llama4MoE modules
Quantize->>Linear: gate = NVFP4InferenceLinear.from_linear()
Linear->>Linear: quantize_linear_weight_to_nvfp4()
Linear-->>Quantize: NVFP4InferenceLinear instance
Note over Benchmark: Forward pass
Benchmark->>Linear: forward(hidden_states)
Linear->>Linear: flatten: view(-1, in_features)
Linear->>Op: f16a_nvfp4weight_scaled_mm(activation, fp4_weight, ...)
Op->>Op: dequantize_to_dtype()
Op->>Op: torch.nn.functional.linear()
Op-->>Linear: output (2D, bfloat16)
Linear-->>Benchmark: output (shape not restored)
Note over Thunder: nvFuser translation
Thunder->>Thunder: nvfp4_scaled_mm_translator()
Thunder->>Thunder: nv_block_quantize(activation)
Thunder->>Thunder: scaled_mm(quantized_activation, fp4_weight, ...)
Note over Alloc: Bug fix for padded dimensions
Alloc->>Alloc: handle(Split)
Alloc->>Alloc: check: is_divisible = (in_extent % factor == 0)
alt is_divisible && contiguous
Alloc->>Alloc: tensor_.view(new_shape)
else non-divisible or non-contiguous
Alloc->>Alloc: tensor_.as_strided(tensor_new_shape, strides)
end
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, 1 comment
…to pbasu/nvfp4_linear_bench
|
!test |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, 1 comment
|
!test |
|
!test |
|
!test |
|
!test |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, 2 comments
| raise ValueError("Expected all inputs to be on the same device.") | ||
|
|
||
|
|
||
| a = torch.empty((activation.shape[0], fp4_weight.t().shape[0]), device=activation.device, dtype=torch.bfloat16) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: output shape assumes 2D activation but doesn't preserve batch structure. If activation was originally 3D (batch, seq, hidden) and flattened to 2D (batch*seq, hidden) before this call, the output remains 2D (batch*seq, out_features) instead of restoring to 3D (batch, seq, out_features). This is inconsistent with the docstring in the forward method at line 611 which documents input as [batch, seq_len, in_features].
| a = torch.empty((activation.shape[0], fp4_weight.t().shape[0]), device=activation.device, dtype=torch.bfloat16) | |
| output_shape = activation.shape[:-1] + (fp4_weight.size(1),) | |
| return torch.empty(output_shape, device=activation.device, dtype=torch.bfloat16) |
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | ||
| """Forward pass using nvfp4_scaled_mm. | ||
| Args: | ||
| hidden_states: Input tensor of shape [batch, seq_len, in_features] | ||
| """ | ||
| hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) | ||
|
|
||
| # Use nvfp4_scaled_mm which handles the full computation | ||
| output = torch.ops.nvf_cutlass.f16a_nvfp4weight_scaled_mm( | ||
| hidden_states, | ||
| self.fp4_weight, | ||
| self.weight_scaling_factor, | ||
| self.weight_global_scale, | ||
| ) | ||
|
|
||
| return output |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: flattens input to 2D but doesn't restore original shape before returning. Docstring says input is [batch, seq_len, in_features] but output remains flattened as (batch*seq, out_features). Compare with grouped version's fake (lines 332-336) which preserves input shape.
need to store original shape and restore:
original_shape = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
output = torch.ops.nvf_cutlass.f16a_nvfp4weight_scaled_mm(...)
return output.view(*original_shape[:-1], -1)
This add quantized scaled MM ops to our Python benchmark.
This will create/quantize the module to:
There was a small bug fixed.
When inferring the output allocation we don't call
tensor_.viewwhen one of the split was not a divisible split.This problem shows up when we pad the inner dimension by 4, and the "padded" outer split dimension was one.