Conversation
|
Review updated until commit 5a40972 Description
|
| Relevant files | |||||
|---|---|---|---|---|---|
| Bug fix |
| ||||
| Enhancement |
|
PR Reviewer Guide
Here are some key observations to aid the review process:
| 🧪 No relevant tests |
| ⚡ Recommended focus areas for review |
Missing Tests
|
Greptile Summaryadds NVFP4 quantized scaled matmul operations for non-grouped linear layers in the Python benchmark Changes
Issues Found
Confidence Score: 3/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant Benchmark as benchmark_inference.py
participant Quantize as _quantize_llama4()
participant Linear as NVFP4InferenceLinear
participant Op as nvfuser_f16a_nvfp4weight_scaled_mm
participant Thunder as Thunder/nvFuser
participant Alloc as allocations.cpp
Note over Benchmark: Register custom ops
Benchmark->>Thunder: _register_nvfp4_ops()
Thunder->>Thunder: register nvfp4_scaled_mm_symbol
Thunder->>Thunder: register nvfp4_scaled_mm_translator
Note over Benchmark: Model preparation
Benchmark->>Quantize: _quantize_llama4(model)
Quantize->>Quantize: replace GroupedSwiGLU with NVFP4InferenceGroupedSwiGLU
Quantize->>Quantize: replace SwiGLU with NVFP4InferenceSwiGLU
Quantize->>Quantize: find Llama4MoE modules
Quantize->>Linear: gate = NVFP4InferenceLinear.from_linear()
Linear->>Linear: quantize_linear_weight_to_nvfp4()
Linear-->>Quantize: NVFP4InferenceLinear instance
Note over Benchmark: Forward pass
Benchmark->>Linear: forward(hidden_states)
Linear->>Linear: flatten: view(-1, in_features)
Linear->>Op: f16a_nvfp4weight_scaled_mm(activation, fp4_weight, ...)
Op->>Op: dequantize_to_dtype()
Op->>Op: torch.nn.functional.linear()
Op-->>Linear: output (2D, bfloat16)
Linear-->>Benchmark: output (shape not restored)
Note over Thunder: nvFuser translation
Thunder->>Thunder: nvfp4_scaled_mm_translator()
Thunder->>Thunder: nv_block_quantize(activation)
Thunder->>Thunder: scaled_mm(quantized_activation, fp4_weight, ...)
Note over Alloc: Bug fix for padded dimensions
Alloc->>Alloc: handle(Split)
Alloc->>Alloc: check: is_divisible = (in_extent % factor == 0)
alt is_divisible && contiguous
Alloc->>Alloc: tensor_.view(new_shape)
else non-divisible or non-contiguous
Alloc->>Alloc: tensor_.as_strided(tensor_new_shape, strides)
end
|
…to pbasu/nvfp4_linear_bench
|
!test |
|
!test |
|
!test |
|
!test |
|
!test |
| raise ValueError("Expected all inputs to be on the same device.") | ||
|
|
||
|
|
||
| a = torch.empty((activation.shape[0], fp4_weight.t().shape[0]), device=activation.device, dtype=torch.bfloat16) |
There was a problem hiding this comment.
logic: output shape assumes 2D activation but doesn't preserve batch structure. If activation was originally 3D (batch, seq, hidden) and flattened to 2D (batch*seq, hidden) before this call, the output remains 2D (batch*seq, out_features) instead of restoring to 3D (batch, seq, out_features). This is inconsistent with the docstring in the forward method at line 611 which documents input as [batch, seq_len, in_features].
| a = torch.empty((activation.shape[0], fp4_weight.t().shape[0]), device=activation.device, dtype=torch.bfloat16) | |
| output_shape = activation.shape[:-1] + (fp4_weight.size(1),) | |
| return torch.empty(output_shape, device=activation.device, dtype=torch.bfloat16) |
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | ||
| """Forward pass using nvfp4_scaled_mm. | ||
|
|
||
| Args: | ||
| hidden_states: Input tensor of shape [batch, seq_len, in_features] | ||
| """ | ||
| hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) | ||
|
|
||
| # Use nvfp4_scaled_mm which handles the full computation | ||
| output = torch.ops.nvf_cutlass.f16a_nvfp4weight_scaled_mm( | ||
| hidden_states, | ||
| self.fp4_weight, | ||
| self.weight_scaling_factor, | ||
| self.weight_global_scale, | ||
| ) | ||
|
|
||
| return output |
There was a problem hiding this comment.
logic: flattens input to 2D but doesn't restore original shape before returning. Docstring says input is [batch, seq_len, in_features] but output remains flattened as (batch*seq, out_features). Compare with grouped version's fake (lines 332-336) which preserves input shape.
need to store original shape and restore:
original_shape = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
output = torch.ops.nvf_cutlass.f16a_nvfp4weight_scaled_mm(...)
return output.view(*original_shape[:-1], -1)
This add quantized scaled MM ops to our Python benchmark.
This will create/quantize the module to:
There was a small bug fixed.
When inferring the output allocation we don't call
tensor_.viewwhen one of the split was not a divisible split.This problem shows up when we pad the inner dimension by 4, and the "padded" outer split dimension was one.