diff --git a/fms_mo/aiu_addons/fp8/fp8_linear.py b/fms_mo/aiu_addons/fp8/fp8_linear.py index 866d6aa..f4c10aa 100644 --- a/fms_mo/aiu_addons/fp8/fp8_linear.py +++ b/fms_mo/aiu_addons/fp8/fp8_linear.py @@ -14,9 +14,11 @@ """Implement FP8 linear module to be loaded via FMS.""" # Standard +from importlib.metadata import version from typing import Any, Mapping # Third Party +from packaging.version import Version import torch # Local @@ -27,6 +29,9 @@ # torch.nn.functional.linear not recognized as callable # open issue in PyLint: https://github.com/pytorch/pytorch/issues/119482 +TORCH_VERSION = Version(torch.__version__.split("+")[0]) +SUPPORTS_CPU_PER_CHANNEL_FP8 = Version("2.10") > TORCH_VERSION + # Gated torchao imports for FP8 implementation if available_packages["fms"] and available_packages["torchao"]: # Third Party @@ -213,7 +218,11 @@ def _construct_qweight_structure(self) -> "AffineQuantizedTensor": def forward(self, x: torch.Tensor) -> torch.Tensor: """If input quantization is active, compute FP8xFP8 addmm leveraging torchao - functionalities. Otherwise compute non-quantized addmm.""" + functionalities. Otherwise compute non-quantized addmm. + + In Pytorch 2.10, torch._scale_mm only supports FP8 on CPU when + quantization is per-tensor. In this case, we perform a mock FP8xFP8 matmul. + """ # fp8 weight tensor for torchao qweight: AffineQuantizedTensor = self._construct_qweight_structure() @@ -234,6 +243,30 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) qx = self._input_activation_quant_func_fp8(x, **input_quant_kwargs) + # Check if we need CPU fallback for per-channel quantization + is_cpu = qx.device.type == "cpu" + is_per_tensor = ( + self.linear_config["weights"]["strategy"] == "tensor" + and self.linear_config["input_activations"]["strategy"] == "tensor" + ) + + # Perform mock FP8xFP8 matmul + if is_cpu and not is_per_tensor and not SUPPORTS_CPU_PER_CHANNEL_FP8: + # Check torchao version without loading the full package + if Version("0.11") < Version(version("torchao")): + raise NotImplementedError( + "Fallback path for FP8 matmul on CPU is not supported " + "on torchao > 0.11." + ) + x_dequant = qx.dequantize() + w_dequant = qweight.dequantize() + out = torch.nn.functional.linear( + x_dequant.to(w_dequant.dtype), + w_dequant, + self.bias if self.has_bias else None, + ) + return out.to(x.dtype) + # Copied from torchao _linear_fp8_act_fp8_weight_impl # (with changes to support fp8 out) scaled_mm_config = Float8MMConfig(use_fast_accum=True) @@ -276,10 +309,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ).reshape(out_shape) # activations not quantized, dequant fp8 weight and do regular matmul + w_dequant = qweight.dequantize() out = torch.nn.functional.linear( - x, qweight.dequantize(), self.bias if self.has_bias else None + x.to(w_dequant.dtype), w_dequant, self.bias if self.has_bias else None ) - return out + return out.to(x.dtype) def __repr__(self) -> str: return ( diff --git a/tests/aiu_addons/test_fp8_addon.py b/tests/aiu_addons/test_fp8_addon.py index a382c63..13ee1a9 100644 --- a/tests/aiu_addons/test_fp8_addon.py +++ b/tests/aiu_addons/test_fp8_addon.py @@ -21,6 +21,84 @@ from fms_mo.prep import available_packages import fms_mo.aiu_addons.fp8.fp8_spyre_op # pylint: disable=unused-import +# ============================================================================ +# Constants +# ============================================================================ + +# FP8 E4M3 maximum value +FP8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max + +# ============================================================================ +# Helper Functions +# ============================================================================ + + +def initialize_fp8_weights( + fp8_linear, + weight_strategy: str, + in_features: int, + out_features: int, +) -> None: + """Initialize FP8Linear weights with proper absmax scaling. + + Args: + fp8_linear: FP8Linear module to initialize + weight_strategy: "tensor" or "channel" for weight quantization + in_features: Input feature dimension + out_features: Output feature dimension + """ + with torch.no_grad(): + # Create random float weights + float_weights = torch.randn(out_features, in_features) + + # Set appropriate scales based on strategy using absmax + if weight_strategy == "tensor": + # Per-tensor: single scale for entire weight matrix + absmax = float_weights.abs().max() + scale = absmax / FP8_E4M3_MAX + # Ensure scale is not zero + scale = torch.clamp(scale, min=1e-12) + fp8_linear.weight_scale.fill_(scale.item()) + else: # channel (per-row for weight matrix) + # Per-channel: one scale per output channel (row) + absmax = float_weights.abs().amax(dim=1) + scale = absmax / FP8_E4M3_MAX + # Ensure scales are not zero + scale = torch.clamp(scale, min=1e-12) + # Reshape to match weight_scale parameter shape (out_features, 1) + fp8_linear.weight_scale.copy_(scale.reshape(-1, 1)) + + # Quantize weights to FP8 + quantized_weights = (float_weights / fp8_linear.weight_scale).clamp( + -FP8_E4M3_MAX, FP8_E4M3_MAX + ) + fp8_linear.weight.copy_(quantized_weights.to(torch.float8_e4m3fn)) + + # Initialize bias if present + if fp8_linear.has_bias: + fp8_linear.bias.copy_(torch.randn(out_features)) + + +# ============================================================================ +# Pytest Fixtures +# ============================================================================ + + +@pytest.fixture +def fp8_test_dimensions(): + """Common test dimensions for FP8Linear tests.""" + return { + "batch_size": 2, + "seq_len": 4, + "in_features": 8, + "out_features": 16, + } + + +# ============================================================================ +# Tests +# ============================================================================ + def test_fp8_registration() -> None: """ @@ -44,9 +122,10 @@ def test_fp8_registration() -> None: reason="FP8 is only available on GPUs with device level 8.9 or higher", ) def test_fp8_op() -> None: - """Validate output shapes of GPTQ W4A16 tensors. - Note: this AIU-compatible operation only returns a zero tensor of the - expected shape, it does not perform a real W4A16 matmul operation. + """Validate output shapes of FP8 attention operation. + + Tests the FP8 attention compute operation to ensure it produces + outputs with the expected shape. """ # Local from fms_mo.aiu_addons.fp8.fp8_attn import _math_fp8_compute_op @@ -57,3 +136,140 @@ def test_fp8_op() -> None: out = _math_fp8_compute_op(query, key, value, 32, 32, 0.0, None) assert out.size() == query.size() + + +@pytest.mark.skipif( + not available_packages["torchao"] or not available_packages["fms"], + reason="FMS and torchao required to run this test", +) +@pytest.mark.parametrize( + "weight_strategy,activation_strategy", + [ + ("tensor", "tensor"), # Per-tensor W + per-tensor dynamic A + ("tensor", "token"), # Per-tensor W + per-token dynamic A + ("channel", "tensor"), # Per-channel W + per-tensor dynamic A + ("channel", "token"), # Per-channel W + per-token dynamic A + ], +) +def test_fp8_linear_cpu_support( # pylint: disable=redefined-outer-name + weight_strategy: str, + activation_strategy: str, + fp8_test_dimensions: dict, +) -> None: + """Test FP8Linear on CPU with different quantization strategies. + + This test ensures that FP8Linear works correctly on CPU with: + - Per-tensor quantization (native support in PyTorch 2.10+) + - Per-channel/per-token quantization (uses fallback path in PyTorch 2.10+) + + Note: PyTorch 2.10+ only supports per-tensor FP8 matmul on CPU. Per-channel + and per-token quantization require a fallback to dequantize + regular matmul. + + Args: + weight_strategy: "tensor" or "channel" weight quantization + activation_strategy: "tensor" or "token" dynamic activation quantization + fp8_test_dimensions: Test dimensions fixture + """ + # Local + from fms_mo.aiu_addons.fp8.fp8_linear import FP8Linear + + # Get test dimensions + batch_size = fp8_test_dimensions["batch_size"] + seq_len = fp8_test_dimensions["seq_len"] + in_features = fp8_test_dimensions["in_features"] + out_features = fp8_test_dimensions["out_features"] + + # Create FP8Linear configuration + linear_config = { + "weights": { + "strategy": weight_strategy, + "symmetric": True, + "dynamic": False, + }, + "input_activations": { + "strategy": activation_strategy, + "symmetric": True, + "dynamic": True, + }, + } + + # Create FP8Linear module + fp8_linear = FP8Linear( + in_features=in_features, + out_features=out_features, + bias=True, + linear_config=linear_config, + ) + + # Initialize weights using helper function + initialize_fp8_weights(fp8_linear, weight_strategy, in_features, out_features) + + # Create input tensor on CPU + x = torch.randn(batch_size, seq_len, in_features, dtype=torch.bfloat16) + + # Run forward pass - should not raise an error + output = fp8_linear(x) + + # Validate output shape + assert output.shape == (batch_size, seq_len, out_features) + + # Validate output is not NaN or Inf + assert not torch.isnan(output).any() + assert not torch.isinf(output).any() + + # Validate output dtype matches input dtype + assert output.dtype == x.dtype + + +@pytest.mark.skipif( + not available_packages["torchao"] or not available_packages["fms"], + reason="FMS and torchao required to run this test", +) +def test_fp8_linear_cpu_no_activation_quantization(fp8_test_dimensions: dict) -> None: # pylint: disable=redefined-outer-name + """Test FP8Linear on CPU with only weight quantization (no activation quantization). + + This tests the code path where activations are not quantized but weights are FP8. + + Args: + fp8_test_dimensions: Test dimensions fixture + """ + # Local + from fms_mo.aiu_addons.fp8.fp8_linear import FP8Linear + + # Get test dimensions + batch_size = fp8_test_dimensions["batch_size"] + seq_len = fp8_test_dimensions["seq_len"] + in_features = fp8_test_dimensions["in_features"] + out_features = fp8_test_dimensions["out_features"] + + # Create FP8Linear configuration with no activation quantization + linear_config = { + "weights": { + "strategy": "channel", + "symmetric": True, + "dynamic": False, + }, + "input_activations": None, # No activation quantization + } + + # Create FP8Linear module + fp8_linear = FP8Linear( + in_features=in_features, + out_features=out_features, + bias=True, + linear_config=linear_config, + ) + + # Initialize weights using helper function + initialize_fp8_weights(fp8_linear, "channel", in_features, out_features) + + # Create input tensor on CPU + x = torch.randn(batch_size, seq_len, in_features, dtype=torch.bfloat16) + + # Run forward pass + output = fp8_linear(x) + + # Validate output + assert output.shape == (batch_size, seq_len, out_features) + assert not torch.isnan(output).any() + assert not torch.isinf(output).any()