Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 37 additions & 3 deletions fms_mo/aiu_addons/fp8/fp8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@
"""Implement FP8 linear module to be loaded via FMS."""

# Standard
from importlib.metadata import version
from typing import Any, Mapping

# Third Party
from packaging.version import Version
import torch

# Local
Expand All @@ -27,6 +29,9 @@
# torch.nn.functional.linear not recognized as callable
# open issue in PyLint: https://github.com/pytorch/pytorch/issues/119482

TORCH_VERSION = Version(torch.__version__.split("+")[0])
SUPPORTS_CPU_PER_CHANNEL_FP8 = Version("2.10") > TORCH_VERSION

# Gated torchao imports for FP8 implementation
if available_packages["fms"] and available_packages["torchao"]:
# Third Party
Expand Down Expand Up @@ -213,7 +218,11 @@ def _construct_qweight_structure(self) -> "AffineQuantizedTensor":

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""If input quantization is active, compute FP8xFP8 addmm leveraging torchao
functionalities. Otherwise compute non-quantized addmm."""
functionalities. Otherwise compute non-quantized addmm.

In Pytorch 2.10, torch._scale_mm only supports FP8 on CPU when
quantization is per-tensor. In this case, we perform a mock FP8xFP8 matmul.
"""
# fp8 weight tensor for torchao
qweight: AffineQuantizedTensor = self._construct_qweight_structure()

Expand All @@ -234,6 +243,30 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
)
qx = self._input_activation_quant_func_fp8(x, **input_quant_kwargs)

# Check if we need CPU fallback for per-channel quantization
is_cpu = qx.device.type == "cpu"
is_per_tensor = (
self.linear_config["weights"]["strategy"] == "tensor"
and self.linear_config["input_activations"]["strategy"] == "tensor"
)

# Perform mock FP8xFP8 matmul
if is_cpu and not is_per_tensor and not SUPPORTS_CPU_PER_CHANNEL_FP8:
# Check torchao version without loading the full package
if Version("0.11") < Version(version("torchao")):
raise NotImplementedError(
"Fallback path for FP8 matmul on CPU is not supported "
"on torchao > 0.11."
)
x_dequant = qx.dequantize()
w_dequant = qweight.dequantize()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we expect this to affect the quality significantly ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If anything it'll improve it on cpu

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sense. Since we use these numbers to compare against accelerator results, this can cause wider deviation between those results? Unless the diff is quite small.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we're downcasting back to fp8 anyways, so it shouldn't be too different.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would also expect a very minimal discrepancy in terms of generation compared to the earlier operation.

There may be some runtime overhead, as this new fallback is likely less performant than calling torch._scaled_mm. To clarify: potential overheads on CPU validation only, no impact at all on AIU runtime.

out = torch.nn.functional.linear(
x_dequant.to(w_dequant.dtype),
w_dequant,
self.bias if self.has_bias else None,
)
return out.to(x.dtype)

# Copied from torchao _linear_fp8_act_fp8_weight_impl
# (with changes to support fp8 out)
scaled_mm_config = Float8MMConfig(use_fast_accum=True)
Expand Down Expand Up @@ -276,10 +309,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
).reshape(out_shape)

# activations not quantized, dequant fp8 weight and do regular matmul
w_dequant = qweight.dequantize()
out = torch.nn.functional.linear(
x, qweight.dequantize(), self.bias if self.has_bias else None
x.to(w_dequant.dtype), w_dequant, self.bias if self.has_bias else None
)
return out
return out.to(x.dtype)

def __repr__(self) -> str:
return (
Expand Down
222 changes: 219 additions & 3 deletions tests/aiu_addons/test_fp8_addon.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,84 @@
from fms_mo.prep import available_packages
import fms_mo.aiu_addons.fp8.fp8_spyre_op # pylint: disable=unused-import

# ============================================================================
# Constants
# ============================================================================

# FP8 E4M3 maximum value
FP8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max

# ============================================================================
# Helper Functions
# ============================================================================


def initialize_fp8_weights(
fp8_linear,
weight_strategy: str,
in_features: int,
out_features: int,
) -> None:
"""Initialize FP8Linear weights with proper absmax scaling.

Args:
fp8_linear: FP8Linear module to initialize
weight_strategy: "tensor" or "channel" for weight quantization
in_features: Input feature dimension
out_features: Output feature dimension
"""
with torch.no_grad():
# Create random float weights
float_weights = torch.randn(out_features, in_features)

# Set appropriate scales based on strategy using absmax
if weight_strategy == "tensor":
# Per-tensor: single scale for entire weight matrix
absmax = float_weights.abs().max()
scale = absmax / FP8_E4M3_MAX
# Ensure scale is not zero
scale = torch.clamp(scale, min=1e-12)
fp8_linear.weight_scale.fill_(scale.item())
else: # channel (per-row for weight matrix)
# Per-channel: one scale per output channel (row)
absmax = float_weights.abs().amax(dim=1)
scale = absmax / FP8_E4M3_MAX
# Ensure scales are not zero
scale = torch.clamp(scale, min=1e-12)
# Reshape to match weight_scale parameter shape (out_features, 1)
fp8_linear.weight_scale.copy_(scale.reshape(-1, 1))

# Quantize weights to FP8
quantized_weights = (float_weights / fp8_linear.weight_scale).clamp(
-FP8_E4M3_MAX, FP8_E4M3_MAX
)
fp8_linear.weight.copy_(quantized_weights.to(torch.float8_e4m3fn))

# Initialize bias if present
if fp8_linear.has_bias:
fp8_linear.bias.copy_(torch.randn(out_features))


# ============================================================================
# Pytest Fixtures
# ============================================================================


@pytest.fixture
def fp8_test_dimensions():
"""Common test dimensions for FP8Linear tests."""
return {
"batch_size": 2,
"seq_len": 4,
"in_features": 8,
"out_features": 16,
}


# ============================================================================
# Tests
# ============================================================================


def test_fp8_registration() -> None:
"""
Expand All @@ -44,9 +122,10 @@ def test_fp8_registration() -> None:
reason="FP8 is only available on GPUs with device level 8.9 or higher",
)
def test_fp8_op() -> None:
"""Validate output shapes of GPTQ W4A16 tensors.
Note: this AIU-compatible operation only returns a zero tensor of the
expected shape, it does not perform a real W4A16 matmul operation.
"""Validate output shapes of FP8 attention operation.

Tests the FP8 attention compute operation to ensure it produces
outputs with the expected shape.
"""
# Local
from fms_mo.aiu_addons.fp8.fp8_attn import _math_fp8_compute_op
Expand All @@ -57,3 +136,140 @@ def test_fp8_op() -> None:

out = _math_fp8_compute_op(query, key, value, 32, 32, 0.0, None)
assert out.size() == query.size()


@pytest.mark.skipif(
not available_packages["torchao"] or not available_packages["fms"],
reason="FMS and torchao required to run this test",
)
@pytest.mark.parametrize(
"weight_strategy,activation_strategy",
[
("tensor", "tensor"), # Per-tensor W + per-tensor dynamic A
("tensor", "token"), # Per-tensor W + per-token dynamic A
("channel", "tensor"), # Per-channel W + per-tensor dynamic A
("channel", "token"), # Per-channel W + per-token dynamic A
],
)
def test_fp8_linear_cpu_support( # pylint: disable=redefined-outer-name
weight_strategy: str,
activation_strategy: str,
fp8_test_dimensions: dict,
) -> None:
"""Test FP8Linear on CPU with different quantization strategies.

This test ensures that FP8Linear works correctly on CPU with:
- Per-tensor quantization (native support in PyTorch 2.10+)
- Per-channel/per-token quantization (uses fallback path in PyTorch 2.10+)

Note: PyTorch 2.10+ only supports per-tensor FP8 matmul on CPU. Per-channel
and per-token quantization require a fallback to dequantize + regular matmul.

Args:
weight_strategy: "tensor" or "channel" weight quantization
activation_strategy: "tensor" or "token" dynamic activation quantization
fp8_test_dimensions: Test dimensions fixture
"""
# Local
from fms_mo.aiu_addons.fp8.fp8_linear import FP8Linear

# Get test dimensions
batch_size = fp8_test_dimensions["batch_size"]
seq_len = fp8_test_dimensions["seq_len"]
in_features = fp8_test_dimensions["in_features"]
out_features = fp8_test_dimensions["out_features"]

# Create FP8Linear configuration
linear_config = {
"weights": {
"strategy": weight_strategy,
"symmetric": True,
"dynamic": False,
},
"input_activations": {
"strategy": activation_strategy,
"symmetric": True,
"dynamic": True,
},
}

# Create FP8Linear module
fp8_linear = FP8Linear(
in_features=in_features,
out_features=out_features,
bias=True,
linear_config=linear_config,
)

# Initialize weights using helper function
initialize_fp8_weights(fp8_linear, weight_strategy, in_features, out_features)

# Create input tensor on CPU
x = torch.randn(batch_size, seq_len, in_features, dtype=torch.bfloat16)

# Run forward pass - should not raise an error
output = fp8_linear(x)

# Validate output shape
assert output.shape == (batch_size, seq_len, out_features)

# Validate output is not NaN or Inf
assert not torch.isnan(output).any()
assert not torch.isinf(output).any()

# Validate output dtype matches input dtype
assert output.dtype == x.dtype


@pytest.mark.skipif(
not available_packages["torchao"] or not available_packages["fms"],
reason="FMS and torchao required to run this test",
)
def test_fp8_linear_cpu_no_activation_quantization(fp8_test_dimensions: dict) -> None: # pylint: disable=redefined-outer-name
"""Test FP8Linear on CPU with only weight quantization (no activation quantization).

This tests the code path where activations are not quantized but weights are FP8.

Args:
fp8_test_dimensions: Test dimensions fixture
"""
# Local
from fms_mo.aiu_addons.fp8.fp8_linear import FP8Linear

# Get test dimensions
batch_size = fp8_test_dimensions["batch_size"]
seq_len = fp8_test_dimensions["seq_len"]
in_features = fp8_test_dimensions["in_features"]
out_features = fp8_test_dimensions["out_features"]

# Create FP8Linear configuration with no activation quantization
linear_config = {
"weights": {
"strategy": "channel",
"symmetric": True,
"dynamic": False,
},
"input_activations": None, # No activation quantization
}

# Create FP8Linear module
fp8_linear = FP8Linear(
in_features=in_features,
out_features=out_features,
bias=True,
linear_config=linear_config,
)

# Initialize weights using helper function
initialize_fp8_weights(fp8_linear, "channel", in_features, out_features)

# Create input tensor on CPU
x = torch.randn(batch_size, seq_len, in_features, dtype=torch.bfloat16)

# Run forward pass
output = fp8_linear(x)

# Validate output
assert output.shape == (batch_size, seq_len, out_features)
assert not torch.isnan(output).any()
assert not torch.isinf(output).any()
Loading