Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions tests/pytorch/nvfp4/test_nvfp4_post_rht_amax_estimation_sanity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

import pytest
import torch

import transformer_engine.pytorch as te
from transformer_engine.pytorch import NVFP4Quantizer


recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True)


@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required for NVFP4 quantization")
def test_nvfp4_post_rht_amax_estimation_sanity() -> None:
"""Sanity: when using post-RHT amax estimation, columnwise amax is scaled pre-RHT amax."""

torch.manual_seed(0)
torch.cuda.manual_seed(0)

# Shape must satisfy NVFP4 constraints and RHT kernel constraints.
# rows % 64 == 0 and cols % 128 == 0 triggers the fast RHT-cast fusion path.
M, N = 128, 128
x = torch.randn((M, N), device="cuda", dtype=torch.bfloat16)

scale = 2.0
q = NVFP4Quantizer(
rowwise=True,
columnwise=True,
with_rht=True,
# Estimation path requires post-RHT amax kernel disabled.
with_post_rht_amax=False,
amax_estimation_scale=scale,
stochastic_rounding=False,
)

y = q(x)
assert y._amax_rowwise is not None
assert y._amax_columnwise is not None

amax_pre = torch.max(torch.abs(x)).to(torch.float32).view(1)
torch.testing.assert_close(y._amax_rowwise, amax_pre, atol=0.0, rtol=0.0)
torch.testing.assert_close(y._amax_columnwise, amax_pre * scale, atol=0.0, rtol=0.0)
17 changes: 17 additions & 0 deletions transformer_engine/common/include/transformer_engine/recipe.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,23 @@ void nvte_compute_amax(const NVTETensor input, NVTETensor output, cudaStream_t s
void nvte_compute_amax_with_config(const NVTETensor input, NVTETensor output,
const NVTEQuantizationConfig config, cudaStream_t stream);

/*! \brief Scale a tensor's amax by a scalar.
*
* This is a lightweight utility intended for cases where the amax is
* derived/estimated from another amax value (e.g., post-transform amax
* estimated from pre-transform amax via a linear scale factor).
*
* If `columnwise` is true, scales `tensor.columnwise_amax` if present.
* Otherwise, scales `tensor.amax` if present. If the selected amax pointer
* is null, this function is a no-op.
*
* \param[in,out] tensor Tensor that owns the amax buffer(s).
* \param[in] columnwise Whether to scale columnwise amax (true) or rowwise amax (false).
* \param[in] scale Scalar multiplier applied to the amax value.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_scale_amax(NVTETensor tensor, bool columnwise, float scale, cudaStream_t stream);

/*! \brief Update an FP8 tensor's scale based on its amax.
*
* This is only supported for FP8 tensors with per-tensor scaling.
Expand Down
43 changes: 42 additions & 1 deletion transformer_engine/common/recipe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,21 +65,26 @@ class QParams:
amax_epsilon: optional minimum value of abs max
random_hadamard_transform: whether to use random hadamard transform
stochastic_rounding: whether to use stocastic rounding
amax_estimation_scale: scale factor for estimating post-RHT amax from pre-RHT amax.
When None, true post-RHT amax is computed (default behavior).
When set to a float, post-RHT amax is estimated as: pre_rht_amax * amax_estimation_scale
"""

power_2_scale: bool = False
amax_epsilon: float = 0.0
random_hadamard_transform: bool = False
stochastic_rounding: bool = False
fp4_2d_quantization: bool = False
amax_estimation_scale: Optional[float] = None

def __repr__(self) -> str:
return (
f"Qparams(\npower_2_scale={self.power_2_scale},\n"
f"amax_epsilon={self.amax_epsilon},\n"
f"random_hadamard_transform={self.random_hadamard_transform},\n"
f"stochastic_rounding={self.stochastic_rounding},\n"
f"fp4_2d_quantization={self.fp4_2d_quantization}\n)"
f"fp4_2d_quantization={self.fp4_2d_quantization},\n"
f"amax_estimation_scale={self.amax_estimation_scale}\n)"
)


Expand Down Expand Up @@ -428,6 +433,16 @@ class NVFP4BlockScaling(Recipe):
If set to `True`, stochastic rounding is disabled during quantization for all tensors.
disable_2d_quantization : bool, default = False
If set to `True`, 1D block scaling with block size 16 is used for all tensors.
use_post_rht_amax_estimation : bool, default = False
**EXPERIMENTAL**: If set to `True`, post-RHT amax is estimated from pre-RHT amax
instead of being computed by a separate RHT+amax kernel. This can reduce the
number of kernel launches but may affect numerical accuracy.
post_rht_amax_estimation_scale_fwd_inp : float, default = 2.0
Scale factor for estimating post-RHT amax for forward input activations.
Only used when `use_post_rht_amax_estimation=True`.
post_rht_amax_estimation_scale_bwd_grad : float, default = 1.0
Scale factor for estimating post-RHT amax for backward gradients.
Only used when `use_post_rht_amax_estimation=True`.
"""

# Configuration envvars
Expand All @@ -444,17 +459,41 @@ class NVFP4BlockScaling(Recipe):
fp8_dpa: bool = False
fp8_mha: bool = False

# Experimental: Post-RHT amax estimation
use_post_rht_amax_estimation: bool = (
os.getenv("NVTE_NVFP4_POST_RHT_AMAX_ESTIMATION", "0") == "1"
)
post_rht_amax_estimation_scale_fwd_inp = float(
os.getenv("NVTE_NVFP4_POST_RHT_AMAX_ESTIMATION_X_SCALE", "2.0")
)
post_rht_amax_estimation_scale_bwd_grad = float(
os.getenv("NVTE_NVFP4_POST_RHT_AMAX_ESTIMATION_G_SCALE", "1.0")
)

def __post_init__(self) -> None:
assert self.fp4_format == Format.E2M1, "Only E2M1 is supported for NVFP4 scaling"
assert self.fp8_format == Format.E4M3, "Only E4M3 is supported for NVFP4 scaling"

# Determine amax estimation scales (None = use true post-RHT amax)
amax_scale_fwd_inp = (
self.post_rht_amax_estimation_scale_fwd_inp
if self.use_post_rht_amax_estimation
else None
)
amax_scale_bwd_grad = (
self.post_rht_amax_estimation_scale_bwd_grad
if self.use_post_rht_amax_estimation
else None
)

# Quantization params
# Note: RHT is currently only applied to column-wise usage so that
# it can be used for wgrad GEMM.
self.fp4_quant_fwd_inp = QParams(
random_hadamard_transform=not self.disable_rht,
stochastic_rounding=False,
fp4_2d_quantization=False,
amax_estimation_scale=amax_scale_fwd_inp,
)
self.fp4_quant_fwd_weight = QParams(
random_hadamard_transform=False,
Expand All @@ -465,6 +504,7 @@ def __post_init__(self) -> None:
random_hadamard_transform=not self.disable_rht,
stochastic_rounding=not self.disable_stochastic_rounding,
fp4_2d_quantization=False,
amax_estimation_scale=amax_scale_bwd_grad,
)

def __repr__(self) -> str:
Expand All @@ -477,6 +517,7 @@ def __repr__(self) -> str:
f"fp4_quant_fwd_inp={self.fp4_quant_fwd_inp}, "
f"fp4_quant_fwd_weight={self.fp4_quant_fwd_weight}, "
f"fp4_quant_bwd_grad={self.fp4_quant_bwd_grad}, "
f"use_post_rht_amax_estimation={self.use_post_rht_amax_estimation}, "
)


Expand Down
40 changes: 40 additions & 0 deletions transformer_engine/common/recipe/current_scaling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,46 @@ void nvte_compute_amax_with_config(const NVTETensor input_, const NVTETensor out
compute_amax_impl(input_, output_, stream, config_);
}

namespace {

__global__ void scale_amax_kernel(float *amax_ptr, float scale) { amax_ptr[0] *= scale; }

} // namespace

void nvte_scale_amax(NVTETensor tensor_, bool columnwise, float scale, cudaStream_t stream) {
NVTE_API_CALL(nvte_scale_amax);
NVTE_CHECK(tensor_ != nullptr, "Invalid tensor (got NULL)");
auto &tensor = *transformer_engine::convertNVTETensorCheck(tensor_);

// Pick amax pointer
void *amax_dptr = nullptr;
if (columnwise) {
amax_dptr = tensor.columnwise_amax.dptr;
} else {
amax_dptr = tensor.amax.dptr;
}
if (amax_dptr == nullptr) {
return;
}
NVTE_CHECK((!columnwise && tensor.amax.numel() == 1) ||
(columnwise && tensor.columnwise_amax.numel() == 1),
"Invalid amax buffer (expected 1 element)");
NVTE_CHECK(
(!columnwise && tensor.amax.dtype == transformer_engine::DType::kFloat32) ||
(columnwise && tensor.columnwise_amax.dtype == transformer_engine::DType::kFloat32),
"Invalid amax dtype (expected FP32)");

// No-op for scale==1 to save a launch
if (scale == 1.0f) {
return;
}
// Scale should be positive for amax estimation use-cases
NVTE_CHECK(scale > 0.0f, "nvte_scale_amax requires scale > 0 (got ", scale, ")");

scale_amax_kernel<<<1, 1, 0, stream>>>(reinterpret_cast<float *>(amax_dptr), scale);
NVTE_CHECK_CUDA(cudaGetLastError());
}

namespace transformer_engine {
namespace {

Expand Down
3 changes: 3 additions & 0 deletions transformer_engine/pytorch/csrc/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,9 @@ class NVFP4Quantizer : public Quantizer {
// random hadamard transform
bool with_rht;
bool with_post_rht_amax;
// Optional: estimate post-RHT amax from pre-RHT amax using a linear scale
bool with_amax_estimation;
float amax_estimation_scale;
// 2D block scaling
bool with_2d_quantization;
bool stochastic_rounding;
Expand Down
10 changes: 6 additions & 4 deletions transformer_engine/pytorch/csrc/extensions/activation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,9 @@ py::object activation_helper(const at::Tensor& input, py::handle quantizer, int
} else if (detail::IsNVFP4Quantizers(quantizer.ptr())) {
auto nvfp4_quantizer_cpp = dynamic_cast<NVFP4Quantizer*>(quantizer_cpp.get());
NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer");
if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) {
// Post-RHT amax is handled within NVFP4 quantizer
if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax &&
!nvfp4_quantizer_cpp->with_amax_estimation) {
// True post-RHT amax is handled within NVFP4 quantizer
impl = Impl::UNFUSED;
} else {
impl = Impl::FUSED_ACTIVATION_AMAX_NVFP4;
Expand Down Expand Up @@ -154,8 +155,9 @@ py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& i
} else if (detail::IsNVFP4Quantizers(quantizer.ptr())) {
auto nvfp4_quantizer_cpp = dynamic_cast<NVFP4Quantizer*>(quantizer_cpp.get());
NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer");
if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) {
// Post-RHT amax is handled within NVFP4 quantizer
if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax &&
!nvfp4_quantizer_cpp->with_amax_estimation) {
// True post-RHT amax is handled within NVFP4 quantizer
impl = Impl::UNFUSED;
} else {
impl = Impl::FUSED_ACTIVATION_AMAX_NVFP4;
Expand Down
5 changes: 3 additions & 2 deletions transformer_engine/pytorch/csrc/extensions/bias.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,9 @@ std::vector<py::object> dact_dbias(
} else if (detail::IsNVFP4Quantizers(quantizer_py.ptr())) {
auto nvfp4_quantizer_cpp = dynamic_cast<NVFP4Quantizer *>(quantizer_cpp.get());
NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer");
if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) {
// Post-RHT amax is handled within NVFP4 quantizer
if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax &&
!nvfp4_quantizer_cpp->with_amax_estimation) {
// True post-RHT amax is handled within NVFP4 quantizer
impl = Impl::UNFUSED;
} else {
impl = Impl::FUSED_DACT_AMAX_NVFP4;
Expand Down
34 changes: 33 additions & 1 deletion transformer_engine/pytorch/csrc/extensions/cast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "../extensions.h"
#include "common.h"
#include "pybind.h"
#include "transformer_engine/recipe.h"
#include "transformer_engine/transformer_engine.h"

namespace transformer_engine {
Expand Down Expand Up @@ -827,8 +828,39 @@ void split_quantize_nvfp4_impl(const TensorWrapper &input,
input.data(), reinterpret_cast<NVTETensor *>(nvte_tensor_output_list.data()),
split_sections.data(), num_tensors, 0, quantizer.rht_matrix_random_sign_mask_t, stream);
});
} else if (quantizer.with_amax_estimation) {
// Consume/compute pre-RHT amax, and later estimate post-RHT amax from it
NVTE_SCOPED_GIL_RELEASE({
for (size_t i = 0; i < num_tensors; ++i) {
if (input_list[i].numel() == 0) {
continue;
}
nvte_compute_amax_with_config(input_list[i].data(), output_list[i].data(),
quant_config_list[i], stream);

auto rowwise_amax_ptr = output_list[i].get_amax().data_ptr;
auto columnwise_amax_ptr = output_list[i].get_columnwise_amax().data_ptr;
void *amax_ptr = rowwise_amax_ptr != nullptr ? rowwise_amax_ptr : columnwise_amax_ptr;
if (amax_ptr != nullptr) {
if (rowwise_amax_ptr != amax_ptr && rowwise_amax_ptr != nullptr) {
NVTE_CHECK_CUDA(cudaMemcpyAsync(rowwise_amax_ptr, amax_ptr, sizeof(float),
cudaMemcpyDeviceToDevice, stream));
}
if (columnwise_amax_ptr != amax_ptr && columnwise_amax_ptr != nullptr) {
NVTE_CHECK_CUDA(cudaMemcpyAsync(columnwise_amax_ptr, amax_ptr, sizeof(float),
cudaMemcpyDeviceToDevice, stream));
}
}

// Estimate post-RHT amax for columnwise path via scaling.
if (quantizer.with_amax_estimation && quantizer.columnwise_usage) {
nvte_scale_amax(output_list[i].data(), /*columnwise=*/true,
quantizer.amax_estimation_scale, stream);
}
}
});
} else {
// RHT is enabled, but amax is pre-RHT amax
// with_rht but not with_post_rht_amax and not using estimation
NVTE_ERROR("NVFP4 split-quantize does not yet support pre-RHT amax");
}

Expand Down
10 changes: 6 additions & 4 deletions transformer_engine/pytorch/csrc/extensions/normalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,9 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
} else if (detail::IsNVFP4Quantizers(quantizer.ptr())) {
auto nvfp4_quantizer_cpp = dynamic_cast<NVFP4Quantizer *>(quantizer_cpp.get());
NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer");
if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) {
// Post-RHT amax is handled within NVFP4 quantizer
if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax &&
!nvfp4_quantizer_cpp->with_amax_estimation) {
// True post-RHT amax is handled within NVFP4 quantizer
impl = Impl::UNFUSED;
} else if (!transformer_engine::getenv<bool>("NVTE_NORM_FWD_USE_CUDNN")) {
// TE kernel supports amax output
Expand Down Expand Up @@ -355,8 +356,9 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
} else if (detail::IsNVFP4Quantizers(quantizer.ptr())) {
auto nvfp4_quantizer_cpp = dynamic_cast<NVFP4Quantizer *>(quantizer_cpp.get());
NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer");
if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) {
// Post-RHT amax is handled within NVFP4 quantizer
if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax &&
!nvfp4_quantizer_cpp->with_amax_estimation) {
// True post-RHT amax is handled within NVFP4 quantizer
impl = Impl::UNFUSED;
} else if (!transformer_engine::getenv<bool>("NVTE_NORM_FWD_USE_CUDNN")) {
// TE kernel supports amax output
Expand Down
Loading