From 96e328464214a10cb229fb1fca52be6d35fd5f1f Mon Sep 17 00:00:00 2001 From: Andrea Fasoli Date: Thu, 19 Mar 2026 13:50:41 -0400 Subject: [PATCH 1/9] add fallback to mock fp8 matmul on cpu Signed-off-by: Andrea Fasoli --- fms_mo/aiu_addons/fp8/fp8_linear.py | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/fms_mo/aiu_addons/fp8/fp8_linear.py b/fms_mo/aiu_addons/fp8/fp8_linear.py index 866d6aa..33ca41a 100644 --- a/fms_mo/aiu_addons/fp8/fp8_linear.py +++ b/fms_mo/aiu_addons/fp8/fp8_linear.py @@ -17,6 +17,7 @@ from typing import Any, Mapping # Third Party +from packaging.version import Version import torch # Local @@ -27,6 +28,9 @@ # torch.nn.functional.linear not recognized as callable # open issue in PyLint: https://github.com/pytorch/pytorch/issues/119482 +TORCH_VERSION = Version(torch.__version__.split("+")[0]) +SUPPORTS_CPU_PER_CHANNEL_FP8 = TORCH_VERSION < Version("2.10") + # Gated torchao imports for FP8 implementation if available_packages["fms"] and available_packages["torchao"]: # Third Party @@ -213,7 +217,11 @@ def _construct_qweight_structure(self) -> "AffineQuantizedTensor": def forward(self, x: torch.Tensor) -> torch.Tensor: """If input quantization is active, compute FP8xFP8 addmm leveraging torchao - functionalities. Otherwise compute non-quantized addmm.""" + functionalities. Otherwise compute non-quantized addmm. + + In Pytorch 2.10, torch._scale_mm only supports FP8 on CPU when + quantization is per-tensor. In this case, we perform a mock FP8xFP8 matmul. + """ # fp8 weight tensor for torchao qweight: AffineQuantizedTensor = self._construct_qweight_structure() @@ -234,6 +242,22 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) qx = self._input_activation_quant_func_fp8(x, **input_quant_kwargs) + # Check if we need CPU fallback for per-channel quantization + is_cpu = qx.device.type == "cpu" + is_per_tensor = ( + self.linear_config["weights"]["strategy"] == "tensor" and + self.linear_config["input_activations"]["strategy"] == "tensor" + ) + + # Perform mock FP8xFP8 matmul + if is_cpu and not is_per_tensor and not SUPPORTS_CPU_PER_CHANNEL_FP8: + x_dequant = qx.dequantize() + w_dequant = qweight.dequantize() + out = torch.nn.functional.linear( + x_dequant.to(w_dequant.dtype), w_dequant, self.bias if self.has_bias else None + ) + return out.to(x.dtype) + # Copied from torchao _linear_fp8_act_fp8_weight_impl # (with changes to support fp8 out) scaled_mm_config = Float8MMConfig(use_fast_accum=True) From 84d0ac04514cdb14eed691193e98c36260fe9cc3 Mon Sep 17 00:00:00 2001 From: Andrea Fasoli Date: Thu, 19 Mar 2026 14:04:49 -0400 Subject: [PATCH 2/9] formatting Signed-off-by: Andrea Fasoli --- fms_mo/aiu_addons/fp8/fp8_linear.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/fms_mo/aiu_addons/fp8/fp8_linear.py b/fms_mo/aiu_addons/fp8/fp8_linear.py index 33ca41a..e7f7f46 100644 --- a/fms_mo/aiu_addons/fp8/fp8_linear.py +++ b/fms_mo/aiu_addons/fp8/fp8_linear.py @@ -29,7 +29,7 @@ # open issue in PyLint: https://github.com/pytorch/pytorch/issues/119482 TORCH_VERSION = Version(torch.__version__.split("+")[0]) -SUPPORTS_CPU_PER_CHANNEL_FP8 = TORCH_VERSION < Version("2.10") +SUPPORTS_CPU_PER_CHANNEL_FP8 = Version("2.10") > TORCH_VERSION # Gated torchao imports for FP8 implementation if available_packages["fms"] and available_packages["torchao"]: @@ -245,8 +245,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # Check if we need CPU fallback for per-channel quantization is_cpu = qx.device.type == "cpu" is_per_tensor = ( - self.linear_config["weights"]["strategy"] == "tensor" and - self.linear_config["input_activations"]["strategy"] == "tensor" + self.linear_config["weights"]["strategy"] == "tensor" + and self.linear_config["input_activations"]["strategy"] == "tensor" ) # Perform mock FP8xFP8 matmul @@ -254,7 +254,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x_dequant = qx.dequantize() w_dequant = qweight.dequantize() out = torch.nn.functional.linear( - x_dequant.to(w_dequant.dtype), w_dequant, self.bias if self.has_bias else None + x_dequant.to(w_dequant.dtype), + w_dequant, + self.bias if self.has_bias else None, ) return out.to(x.dtype) From 23b10a3deea9f4bfb043aa9bf1cd8fa1d7b0848f Mon Sep 17 00:00:00 2001 From: Andrea Fasoli Date: Thu, 19 Mar 2026 19:51:37 -0400 Subject: [PATCH 3/9] fix dtype of fp8 matmul with non-quantized activations Signed-off-by: Andrea Fasoli --- fms_mo/aiu_addons/fp8/fp8_linear.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/fms_mo/aiu_addons/fp8/fp8_linear.py b/fms_mo/aiu_addons/fp8/fp8_linear.py index e7f7f46..c89b17e 100644 --- a/fms_mo/aiu_addons/fp8/fp8_linear.py +++ b/fms_mo/aiu_addons/fp8/fp8_linear.py @@ -302,10 +302,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ).reshape(out_shape) # activations not quantized, dequant fp8 weight and do regular matmul + w_dequant = qweight.dequantize() out = torch.nn.functional.linear( - x, qweight.dequantize(), self.bias if self.has_bias else None + x.to(w_dequant.dtype), w_dequant, self.bias if self.has_bias else None ) - return out + return out.to(x.dtype) def __repr__(self) -> str: return ( From 8a046a92fca806a4640fa68559559e1cf8aa9d58 Mon Sep 17 00:00:00 2001 From: Andrea Fasoli Date: Thu, 19 Mar 2026 20:17:39 -0400 Subject: [PATCH 4/9] add torchao version check in fallback Signed-off-by: Andrea Fasoli --- fms_mo/aiu_addons/fp8/fp8_linear.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/fms_mo/aiu_addons/fp8/fp8_linear.py b/fms_mo/aiu_addons/fp8/fp8_linear.py index c89b17e..f4f6ac4 100644 --- a/fms_mo/aiu_addons/fp8/fp8_linear.py +++ b/fms_mo/aiu_addons/fp8/fp8_linear.py @@ -14,6 +14,7 @@ """Implement FP8 linear module to be loaded via FMS.""" # Standard +from importlib.metadata import version from typing import Any, Mapping # Third Party @@ -33,6 +34,8 @@ # Gated torchao imports for FP8 implementation if available_packages["fms"] and available_packages["torchao"]: + TORCHAO_VERSION = Version(version("torchao")) + # Third Party from fms.modules.linear import ( LinearModuleShardingInfo, @@ -251,6 +254,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # Perform mock FP8xFP8 matmul if is_cpu and not is_per_tensor and not SUPPORTS_CPU_PER_CHANNEL_FP8: + if Version("0.11") < TORCHAO_VERSION: + raise NotImplementedError( + "Fallback path for FP8 matmul on CPU is not supported " + "on torchao > 0.11." + ) x_dequant = qx.dequantize() w_dequant = qweight.dequantize() out = torch.nn.functional.linear( From c0a617f4c1344ec2a8d44c1defd49942384aa38b Mon Sep 17 00:00:00 2001 From: Andrea Fasoli Date: Thu, 19 Mar 2026 20:57:02 -0400 Subject: [PATCH 5/9] add unit tests for FP8 matmul on CPU Signed-off-by: Andrea Fasoli --- tests/aiu_addons/test_fp8_addon.py | 263 ++++++++++++++++++++++++++++- 1 file changed, 260 insertions(+), 3 deletions(-) diff --git a/tests/aiu_addons/test_fp8_addon.py b/tests/aiu_addons/test_fp8_addon.py index a382c63..875b6f1 100644 --- a/tests/aiu_addons/test_fp8_addon.py +++ b/tests/aiu_addons/test_fp8_addon.py @@ -21,6 +21,117 @@ from fms_mo.prep import available_packages import fms_mo.aiu_addons.fp8.fp8_spyre_op # pylint: disable=unused-import +# ============================================================================ +# Helper Functions +# ============================================================================ + + +def initialize_fp8_weights( + fp8_linear, + weight_strategy: str, + in_features: int, + out_features: int, +) -> None: + """Initialize FP8Linear weights with proper absmax scaling. + + Args: + fp8_linear: FP8Linear module to initialize + weight_strategy: "tensor" or "channel" for weight quantization + in_features: Input feature dimension + out_features: Output feature dimension + """ + with torch.no_grad(): + # Create random float weights + float_weights = torch.randn(out_features, in_features) + + # Calculate FP8 E4M3 max value (448.0) + fp8_max = torch.finfo(torch.float8_e4m3fn).max + + # Set appropriate scales based on strategy using absmax + if weight_strategy == "tensor": + # Per-tensor: single scale for entire weight matrix + absmax = float_weights.abs().max() + scale = absmax / fp8_max + # Ensure scale is not zero + scale = torch.clamp(scale, min=1e-12) + fp8_linear.weight_scale.fill_(scale.item()) + else: # channel (per-row for weight matrix) + # Per-channel: one scale per output channel (row) + absmax = float_weights.abs().amax(dim=1) + scale = absmax / fp8_max + # Ensure scales are not zero + scale = torch.clamp(scale, min=1e-12) + # Reshape to match weight_scale parameter shape (out_features, 1) + fp8_linear.weight_scale.copy_(scale.reshape(-1, 1)) + + # Quantize weights to FP8 + quantized_weights = (float_weights / fp8_linear.weight_scale).clamp( + -fp8_max, fp8_max + ) + fp8_linear.weight.copy_(quantized_weights.to(torch.float8_e4m3fn)) + + # Initialize bias if present + if fp8_linear.has_bias: + fp8_linear.bias.copy_(torch.randn(out_features)) + + +def initialize_fp8_input_scale( + fp8_linear, + activation_strategy: str, + batch_size: int, + seq_len: int, + in_features: int, +) -> None: + """Initialize static input scale for FP8Linear. + + Args: + fp8_linear: FP8Linear module to initialize + activation_strategy: "tensor" or "token" for activation quantization + batch_size: Batch size for sample input + seq_len: Sequence length for sample input + in_features: Input feature dimension + """ + with torch.no_grad(): + # For static quantization, use a representative input to calculate scales + sample_input = torch.randn(batch_size, seq_len, in_features) + fp8_max = torch.finfo(torch.float8_e4m3fn).max + + if activation_strategy == "tensor": + # Per-tensor: single scale for entire activation + absmax = sample_input.abs().max() + scale = absmax / fp8_max + scale = torch.clamp(scale, min=1e-12) + fp8_linear.input_scale.fill_(scale.item()) + else: # token + # For per-token static quantization, use a calibrated scale + # based on representative input statistics + absmax = sample_input.abs().max() + scale = absmax / fp8_max + scale = torch.clamp(scale, min=1e-12) + # Fill all scales with the same representative value + fp8_linear.input_scale.fill_(scale.item()) + + +# ============================================================================ +# Pytest Fixtures +# ============================================================================ + + +@pytest.fixture +def fp8_test_dimensions(): + """Common test dimensions for FP8Linear tests.""" + return { + "batch_size": 2, + "seq_len": 4, + "in_features": 8, + "out_features": 16, + } + + +# ============================================================================ +# Tests +# ============================================================================ + def test_fp8_registration() -> None: """ @@ -44,9 +155,10 @@ def test_fp8_registration() -> None: reason="FP8 is only available on GPUs with device level 8.9 or higher", ) def test_fp8_op() -> None: - """Validate output shapes of GPTQ W4A16 tensors. - Note: this AIU-compatible operation only returns a zero tensor of the - expected shape, it does not perform a real W4A16 matmul operation. + """Validate output shapes of FP8 attention operation. + + Tests the FP8 attention compute operation to ensure it produces + outputs with the expected shape. """ # Local from fms_mo.aiu_addons.fp8.fp8_attn import _math_fp8_compute_op @@ -57,3 +169,148 @@ def test_fp8_op() -> None: out = _math_fp8_compute_op(query, key, value, 32, 32, 0.0, None) assert out.size() == query.size() + + +@pytest.mark.skipif( + not available_packages["torchao"] or not available_packages["fms"], + reason="FMS and torchao required to run this test", +) +@pytest.mark.parametrize( + "weight_strategy,activation_strategy,dynamic_activation", + [ + ("tensor", "tensor", True), # Per-tensor weights + per-tensor activations + ("tensor", "token", True), # Per-tensor weights + per-token activations + ("channel", "tensor", True), # Per-channel weights + per-tensor activations + ("channel", "token", True), # Per-channel weights + per-token activations + ], +) +def test_fp8_linear_cpu_support( + weight_strategy: str, + activation_strategy: str, + dynamic_activation: bool, + fp8_test_dimensions: dict, +) -> None: + """Test FP8Linear on CPU with different quantization strategies. + + This test ensures that FP8Linear works correctly on CPU, including: + - Per-tensor quantization (native support in PyTorch 2.10+) + - Per-channel/per-token quantization (uses fallback path in PyTorch 2.10+) + + Note: PyTorch 2.10+ only supports per-tensor FP8 matmul on CPU. Per-channel + and per-token quantization require a fallback to dequantize + regular matmul. + + Args: + weight_strategy: "tensor" or "channel" for weight quantization + activation_strategy: "tensor" or "token" for activation quantization + dynamic_activation: Whether to use dynamic activation quantization + fp8_test_dimensions: Test dimensions fixture + """ + # Local + from fms_mo.aiu_addons.fp8.fp8_linear import FP8Linear + + # Get test dimensions + batch_size = fp8_test_dimensions["batch_size"] + seq_len = fp8_test_dimensions["seq_len"] + in_features = fp8_test_dimensions["in_features"] + out_features = fp8_test_dimensions["out_features"] + + # Create FP8Linear configuration + linear_config = { + "weights": { + "strategy": weight_strategy, + "symmetric": True, + "dynamic": False, + }, + "input_activations": { + "strategy": activation_strategy, + "symmetric": True, + "dynamic": dynamic_activation, + }, + } + + # Create FP8Linear module + fp8_linear = FP8Linear( + in_features=in_features, + out_features=out_features, + bias=True, + linear_config=linear_config, + ) + + # Initialize weights using helper function + initialize_fp8_weights(fp8_linear, weight_strategy, in_features, out_features) + + # Initialize input scale if static quantization + if not dynamic_activation: + initialize_fp8_input_scale( + fp8_linear, activation_strategy, batch_size, seq_len, in_features + ) + + # Create input tensor on CPU + x = torch.randn(batch_size, seq_len, in_features, dtype=torch.bfloat16) + + # Run forward pass - should not raise an error + output = fp8_linear(x) + + # Validate output shape + assert output.shape == (batch_size, seq_len, out_features) + + # Validate output is not NaN or Inf + assert not torch.isnan(output).any() + assert not torch.isinf(output).any() + + # Validate output dtype matches input dtype + assert output.dtype == x.dtype + + +@pytest.mark.skipif( + not available_packages["torchao"] or not available_packages["fms"], + reason="FMS and torchao required to run this test", +) +def test_fp8_linear_cpu_no_activation_quantization(fp8_test_dimensions: dict) -> None: + """Test FP8Linear on CPU with only weight quantization (no activation quantization). + + This tests the code path where activations are not quantized but weights are FP8. + + Args: + fp8_test_dimensions: Test dimensions fixture + """ + # Local + from fms_mo.aiu_addons.fp8.fp8_linear import FP8Linear + + # Get test dimensions + batch_size = fp8_test_dimensions["batch_size"] + seq_len = fp8_test_dimensions["seq_len"] + in_features = fp8_test_dimensions["in_features"] + out_features = fp8_test_dimensions["out_features"] + + # Create FP8Linear configuration with no activation quantization + linear_config = { + "weights": { + "strategy": "channel", + "symmetric": True, + "dynamic": False, + }, + "input_activations": None, # No activation quantization + } + + # Create FP8Linear module + fp8_linear = FP8Linear( + in_features=in_features, + out_features=out_features, + bias=True, + linear_config=linear_config, + ) + + # Initialize weights using helper function + initialize_fp8_weights(fp8_linear, "channel", in_features, out_features) + + # Create input tensor on CPU + x = torch.randn(batch_size, seq_len, in_features, dtype=torch.bfloat16) + + # Run forward pass + output = fp8_linear(x) + + # Validate output + assert output.shape == (batch_size, seq_len, out_features) + assert not torch.isnan(output).any() + assert not torch.isinf(output).any() From c28329182374a1e789448217417efc1631e028a5 Mon Sep 17 00:00:00 2001 From: Andrea Fasoli Date: Thu, 19 Mar 2026 21:05:18 -0400 Subject: [PATCH 6/9] minor updates Signed-off-by: Andrea Fasoli --- fms_mo/aiu_addons/fp8/fp8_linear.py | 4 ++-- tests/aiu_addons/test_fp8_addon.py | 21 ++++++++++++--------- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/fms_mo/aiu_addons/fp8/fp8_linear.py b/fms_mo/aiu_addons/fp8/fp8_linear.py index f4f6ac4..a9b3543 100644 --- a/fms_mo/aiu_addons/fp8/fp8_linear.py +++ b/fms_mo/aiu_addons/fp8/fp8_linear.py @@ -34,7 +34,6 @@ # Gated torchao imports for FP8 implementation if available_packages["fms"] and available_packages["torchao"]: - TORCHAO_VERSION = Version(version("torchao")) # Third Party from fms.modules.linear import ( @@ -254,7 +253,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # Perform mock FP8xFP8 matmul if is_cpu and not is_per_tensor and not SUPPORTS_CPU_PER_CHANNEL_FP8: - if Version("0.11") < TORCHAO_VERSION: + # Check torchao version without loading the full package + if Version("0.11") < Version(version("torchao")): raise NotImplementedError( "Fallback path for FP8 matmul on CPU is not supported " "on torchao > 0.11." diff --git a/tests/aiu_addons/test_fp8_addon.py b/tests/aiu_addons/test_fp8_addon.py index 875b6f1..f989e1d 100644 --- a/tests/aiu_addons/test_fp8_addon.py +++ b/tests/aiu_addons/test_fp8_addon.py @@ -21,6 +21,13 @@ from fms_mo.prep import available_packages import fms_mo.aiu_addons.fp8.fp8_spyre_op # pylint: disable=unused-import +# ============================================================================ +# Constants +# ============================================================================ + +# FP8 E4M3 maximum value +FP8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max + # ============================================================================ # Helper Functions # ============================================================================ @@ -44,21 +51,18 @@ def initialize_fp8_weights( # Create random float weights float_weights = torch.randn(out_features, in_features) - # Calculate FP8 E4M3 max value (448.0) - fp8_max = torch.finfo(torch.float8_e4m3fn).max - # Set appropriate scales based on strategy using absmax if weight_strategy == "tensor": # Per-tensor: single scale for entire weight matrix absmax = float_weights.abs().max() - scale = absmax / fp8_max + scale = absmax / FP8_E4M3_MAX # Ensure scale is not zero scale = torch.clamp(scale, min=1e-12) fp8_linear.weight_scale.fill_(scale.item()) else: # channel (per-row for weight matrix) # Per-channel: one scale per output channel (row) absmax = float_weights.abs().amax(dim=1) - scale = absmax / fp8_max + scale = absmax / FP8_E4M3_MAX # Ensure scales are not zero scale = torch.clamp(scale, min=1e-12) # Reshape to match weight_scale parameter shape (out_features, 1) @@ -66,7 +70,7 @@ def initialize_fp8_weights( # Quantize weights to FP8 quantized_weights = (float_weights / fp8_linear.weight_scale).clamp( - -fp8_max, fp8_max + -FP8_E4M3_MAX, FP8_E4M3_MAX ) fp8_linear.weight.copy_(quantized_weights.to(torch.float8_e4m3fn)) @@ -94,19 +98,18 @@ def initialize_fp8_input_scale( with torch.no_grad(): # For static quantization, use a representative input to calculate scales sample_input = torch.randn(batch_size, seq_len, in_features) - fp8_max = torch.finfo(torch.float8_e4m3fn).max if activation_strategy == "tensor": # Per-tensor: single scale for entire activation absmax = sample_input.abs().max() - scale = absmax / fp8_max + scale = absmax / FP8_E4M3_MAX scale = torch.clamp(scale, min=1e-12) fp8_linear.input_scale.fill_(scale.item()) else: # token # For per-token static quantization, use a calibrated scale # based on representative input statistics absmax = sample_input.abs().max() - scale = absmax / fp8_max + scale = absmax / FP8_E4M3_MAX scale = torch.clamp(scale, min=1e-12) # Fill all scales with the same representative value fp8_linear.input_scale.fill_(scale.item()) From e158bd26042942332fd72401165763e260ac6861 Mon Sep 17 00:00:00 2001 From: Andrea Fasoli Date: Thu, 19 Mar 2026 21:05:46 -0400 Subject: [PATCH 7/9] minor updates Signed-off-by: Andrea Fasoli --- fms_mo/aiu_addons/fp8/fp8_linear.py | 1 - 1 file changed, 1 deletion(-) diff --git a/fms_mo/aiu_addons/fp8/fp8_linear.py b/fms_mo/aiu_addons/fp8/fp8_linear.py index a9b3543..f4c10aa 100644 --- a/fms_mo/aiu_addons/fp8/fp8_linear.py +++ b/fms_mo/aiu_addons/fp8/fp8_linear.py @@ -34,7 +34,6 @@ # Gated torchao imports for FP8 implementation if available_packages["fms"] and available_packages["torchao"]: - # Third Party from fms.modules.linear import ( LinearModuleShardingInfo, From ef7357615304ccdec39434e66e6cc62c0d1f3a89 Mon Sep 17 00:00:00 2001 From: Andrea Fasoli Date: Thu, 19 Mar 2026 21:12:58 -0400 Subject: [PATCH 8/9] remove static activation test Signed-off-by: Andrea Fasoli --- tests/aiu_addons/test_fp8_addon.py | 62 +++++------------------------- 1 file changed, 9 insertions(+), 53 deletions(-) diff --git a/tests/aiu_addons/test_fp8_addon.py b/tests/aiu_addons/test_fp8_addon.py index f989e1d..714bd0a 100644 --- a/tests/aiu_addons/test_fp8_addon.py +++ b/tests/aiu_addons/test_fp8_addon.py @@ -79,42 +79,6 @@ def initialize_fp8_weights( fp8_linear.bias.copy_(torch.randn(out_features)) -def initialize_fp8_input_scale( - fp8_linear, - activation_strategy: str, - batch_size: int, - seq_len: int, - in_features: int, -) -> None: - """Initialize static input scale for FP8Linear. - - Args: - fp8_linear: FP8Linear module to initialize - activation_strategy: "tensor" or "token" for activation quantization - batch_size: Batch size for sample input - seq_len: Sequence length for sample input - in_features: Input feature dimension - """ - with torch.no_grad(): - # For static quantization, use a representative input to calculate scales - sample_input = torch.randn(batch_size, seq_len, in_features) - - if activation_strategy == "tensor": - # Per-tensor: single scale for entire activation - absmax = sample_input.abs().max() - scale = absmax / FP8_E4M3_MAX - scale = torch.clamp(scale, min=1e-12) - fp8_linear.input_scale.fill_(scale.item()) - else: # token - # For per-token static quantization, use a calibrated scale - # based on representative input statistics - absmax = sample_input.abs().max() - scale = absmax / FP8_E4M3_MAX - scale = torch.clamp(scale, min=1e-12) - # Fill all scales with the same representative value - fp8_linear.input_scale.fill_(scale.item()) - - # ============================================================================ # Pytest Fixtures # ============================================================================ @@ -179,23 +143,22 @@ def test_fp8_op() -> None: reason="FMS and torchao required to run this test", ) @pytest.mark.parametrize( - "weight_strategy,activation_strategy,dynamic_activation", + "weight_strategy,activation_strategy", [ - ("tensor", "tensor", True), # Per-tensor weights + per-tensor activations - ("tensor", "token", True), # Per-tensor weights + per-token activations - ("channel", "tensor", True), # Per-channel weights + per-tensor activations - ("channel", "token", True), # Per-channel weights + per-token activations + ("tensor", "tensor"), # Per-tensor W + per-tensor dynamic A + ("tensor", "token"), # Per-tensor W + per-token dynamic A + ("channel", "tensor"), # Per-channel W + per-tensor dynamic A + ("channel", "token"), # Per-channel W + per-token dynamic A ], ) def test_fp8_linear_cpu_support( weight_strategy: str, activation_strategy: str, - dynamic_activation: bool, fp8_test_dimensions: dict, ) -> None: """Test FP8Linear on CPU with different quantization strategies. - This test ensures that FP8Linear works correctly on CPU, including: + This test ensures that FP8Linear works correctly on CPU with: - Per-tensor quantization (native support in PyTorch 2.10+) - Per-channel/per-token quantization (uses fallback path in PyTorch 2.10+) @@ -203,9 +166,8 @@ def test_fp8_linear_cpu_support( and per-token quantization require a fallback to dequantize + regular matmul. Args: - weight_strategy: "tensor" or "channel" for weight quantization - activation_strategy: "tensor" or "token" for activation quantization - dynamic_activation: Whether to use dynamic activation quantization + weight_strategy: "tensor" or "channel" weight quantization + activation_strategy: "tensor" or "token" dynamic activation quantization fp8_test_dimensions: Test dimensions fixture """ # Local @@ -227,7 +189,7 @@ def test_fp8_linear_cpu_support( "input_activations": { "strategy": activation_strategy, "symmetric": True, - "dynamic": dynamic_activation, + "dynamic": True, }, } @@ -242,12 +204,6 @@ def test_fp8_linear_cpu_support( # Initialize weights using helper function initialize_fp8_weights(fp8_linear, weight_strategy, in_features, out_features) - # Initialize input scale if static quantization - if not dynamic_activation: - initialize_fp8_input_scale( - fp8_linear, activation_strategy, batch_size, seq_len, in_features - ) - # Create input tensor on CPU x = torch.randn(batch_size, seq_len, in_features, dtype=torch.bfloat16) From bc5de3ee238f8d54c9a83a1b3212ae5f8ed5410f Mon Sep 17 00:00:00 2001 From: Andrea Fasoli Date: Thu, 19 Mar 2026 21:21:58 -0400 Subject: [PATCH 9/9] fix pylint false positive Signed-off-by: Andrea Fasoli --- tests/aiu_addons/test_fp8_addon.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/aiu_addons/test_fp8_addon.py b/tests/aiu_addons/test_fp8_addon.py index 714bd0a..13ee1a9 100644 --- a/tests/aiu_addons/test_fp8_addon.py +++ b/tests/aiu_addons/test_fp8_addon.py @@ -151,7 +151,7 @@ def test_fp8_op() -> None: ("channel", "token"), # Per-channel W + per-token dynamic A ], ) -def test_fp8_linear_cpu_support( +def test_fp8_linear_cpu_support( # pylint: disable=redefined-outer-name weight_strategy: str, activation_strategy: str, fp8_test_dimensions: dict, @@ -225,7 +225,7 @@ def test_fp8_linear_cpu_support( not available_packages["torchao"] or not available_packages["fms"], reason="FMS and torchao required to run this test", ) -def test_fp8_linear_cpu_no_activation_quantization(fp8_test_dimensions: dict) -> None: +def test_fp8_linear_cpu_no_activation_quantization(fp8_test_dimensions: dict) -> None: # pylint: disable=redefined-outer-name """Test FP8Linear on CPU with only weight quantization (no activation quantization). This tests the code path where activations are not quantized but weights are FP8.