Skip to content

Commit 048568f

Browse files
committed
Fix NVFP4 QAT convert path
**Summary:** The previous convert path through `QATConfig` did not swap `NVFP4FakeQuantizedLinear` back to `torch.nn.Linear`. The numerics tests still passed because this fake quantized linear happen to match the PTQ numerics exactly. **Test Plan:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 ```
1 parent aa21b80 commit 048568f

File tree

3 files changed

+54
-18
lines changed

3 files changed

+54
-18
lines changed

test/quantization/test_qat.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2106,23 +2106,28 @@ def test_qat_nvfp4(self, use_per_tensor_scale: bool):
21062106
Test QAT with `NVFP4FakeQuantizeConfig`.
21072107
"""
21082108
from torchao.prototype.mx_formats import NVFP4DynamicActivationNVFP4WeightConfig
2109-
from torchao.prototype.qat import NVFP4FakeQuantizeConfig
2109+
from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4Tensor
2110+
from torchao.prototype.qat import (
2111+
NVFP4FakeQuantizeConfig,
2112+
NVFP4FakeQuantizedLinear,
2113+
)
21102114

21112115
torch.manual_seed(self.SEED)
21122116
m = M().cuda()
21132117
baseline_model = copy.deepcopy(m)
2114-
quantize_(
2115-
baseline_model,
2116-
NVFP4DynamicActivationNVFP4WeightConfig(
2117-
use_dynamic_per_tensor_scale=use_per_tensor_scale
2118-
),
2118+
base_config = NVFP4DynamicActivationNVFP4WeightConfig(
2119+
use_dynamic_per_tensor_scale=use_per_tensor_scale
21192120
)
2121+
quantize_(baseline_model, base_config)
21202122
qat_config = QATConfig(
21212123
activation_config=NVFP4FakeQuantizeConfig(use_per_tensor_scale),
21222124
weight_config=NVFP4FakeQuantizeConfig(use_per_tensor_scale),
21232125
step="prepare",
21242126
)
21252127
quantize_(m, qat_config)
2128+
self.assertEqual(type(m.linear1), NVFP4FakeQuantizedLinear)
2129+
self.assertEqual(type(m.linear2), NVFP4FakeQuantizedLinear)
2130+
self.assertEqual(type(m.sub.linear), NVFP4FakeQuantizedLinear)
21262131

21272132
# Compare prepared values
21282133
torch.manual_seed(self.SEED)
@@ -2132,6 +2137,18 @@ def test_qat_nvfp4(self, use_per_tensor_scale: bool):
21322137
sqnr = compute_error(out, baseline_out).item()
21332138
self.assertGreaterEqual(sqnr, float("inf"))
21342139

2140+
# Compare converted values
2141+
quantize_(m, QATConfig(base_config, step="convert"))
2142+
self.assertEqual(type(m.linear1), torch.nn.Linear)
2143+
self.assertEqual(type(m.linear2), torch.nn.Linear)
2144+
self.assertEqual(type(m.sub.linear), torch.nn.Linear)
2145+
self.assertEqual(type(m.linear1.weight), NVFP4Tensor)
2146+
self.assertEqual(type(m.linear2.weight), NVFP4Tensor)
2147+
self.assertEqual(type(m.sub.linear.weight), NVFP4Tensor)
2148+
out = m(*x)
2149+
sqnr = compute_error(out, baseline_out).item()
2150+
self.assertGreaterEqual(sqnr, float("inf"))
2151+
21352152
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
21362153
@unittest.skipIf(
21372154
not _is_fbgemm_gpu_genai_available(), "Requires fbgemm-gpu-genai >= 1.2.0"

torchao/prototype/qat/nvfp4.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,22 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
162162
else:
163163
return fq
164164

165+
def to_linear(self) -> torch.nn.Linear:
166+
new_linear = torch.nn.Linear(
167+
self.in_features,
168+
self.out_features,
169+
self.bias is not None,
170+
device=self.weight.device,
171+
dtype=self.weight.dtype,
172+
)
173+
# In distributed training, the model may be instantiated
174+
# on the meta device, in which case there is no need to
175+
# copy the weights, and doing so will result in an error
176+
if self.weight.device != torch.device("meta"):
177+
new_linear.weight = self.weight
178+
new_linear.bias = self.bias
179+
return new_linear
180+
165181
@classmethod
166182
def from_linear(
167183
cls,

torchao/quantization/qat/api.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,13 @@ def _qat_config_transform(
198198
modules to the corresponding built-in `torch.nn.Module`s, then apply the
199199
base config directly to quantize the module.
200200
"""
201+
# TODO: rewrite this using a registration API so
202+
# specific quantization schemes do not leak here
203+
from torchao.prototype.qat import (
204+
NVFP4FakeQuantizeConfig,
205+
NVFP4FakeQuantizedLinear,
206+
)
207+
201208
# Prepare step
202209
# Swap nn.Linear -> FakeQuantizedLinear
203210
# Swap nn.Embedding -> FakeQuantizedEmbedding
@@ -210,13 +217,6 @@ def _qat_config_transform(
210217
act_config = config.activation_config
211218
weight_config = config.weight_config
212219
if isinstance(module, torch.nn.Linear):
213-
# TODO: rewrite this using a registration API so
214-
# specific quantization schemes do not leak here
215-
from torchao.prototype.qat import (
216-
NVFP4FakeQuantizeConfig,
217-
NVFP4FakeQuantizedLinear,
218-
)
219-
220220
if isinstance(weight_config, NVFP4FakeQuantizeConfig):
221221
assert act_config is None or isinstance(
222222
act_config, NVFP4FakeQuantizeConfig
@@ -245,17 +245,20 @@ def _qat_config_transform(
245245
assert config.weight_config is None, "unexpected `weight_config`"
246246

247247
# Ignore unrelated modules
248-
if not isinstance(module, (FakeQuantizedLinear, FakeQuantizedEmbedding)):
248+
if not isinstance(
249+
module,
250+
(FakeQuantizedLinear, FakeQuantizedEmbedding, NVFP4FakeQuantizedLinear),
251+
):
249252
return module
250253

251254
# Optionally pass custom scales and zero points to base config handler
252255
# This is only for range learning and only applies to weights
253256
kwargs = {}
254257
has_custom_scale_and_zero_point = False
255-
weight_config = module.weight_fake_quantizer.config
256258
if (
257-
isinstance(weight_config, IntxFakeQuantizeConfig)
258-
and weight_config.range_learning
259+
hasattr(module, "weight_fake_quantizer")
260+
and isinstance(module.weight_fake_quantizer.config, IntxFakeQuantizeConfig)
261+
and module.weight_fake_quantizer.config.range_learning
259262
):
260263
kwargs["custom_scale"] = module.weight_fake_quantizer.scale
261264
kwargs["custom_zero_point"] = module.weight_fake_quantizer.zero_point
@@ -265,7 +268,7 @@ def _qat_config_transform(
265268
# Swap FakeQuantizedEmbedding -> nn.Embedding
266269
# Then apply the base config's transform function to quantize the model
267270
# If there is no base config, then simply perform the module swap
268-
if isinstance(module, FakeQuantizedLinear):
271+
if isinstance(module, (FakeQuantizedLinear, NVFP4FakeQuantizedLinear)):
269272
module = module.to_linear()
270273
elif isinstance(module, FakeQuantizedEmbedding):
271274
module = module.to_embedding()

0 commit comments

Comments
 (0)