From b144632ebd7e8a2080fca1f5aaa7b8a3d90d9729 Mon Sep 17 00:00:00 2001 From: avi singhal Date: Thu, 4 Dec 2025 07:04:06 +0000 Subject: [PATCH 01/18] add MXFP8 all gather support --- torchao/prototype/mx_formats/mx_tensor.py | 77 +++++++++++++ .../prototype/tests/test_mxfp8_allgather.py | 108 ++++++++++++++++++ 2 files changed, 185 insertions(+) create mode 100644 torchao/prototype/tests/test_mxfp8_allgather.py diff --git a/torchao/prototype/mx_formats/mx_tensor.py b/torchao/prototype/mx_formats/mx_tensor.py index 74f37bc2df..b22c58cc61 100644 --- a/torchao/prototype/mx_formats/mx_tensor.py +++ b/torchao/prototype/mx_formats/mx_tensor.py @@ -840,3 +840,80 @@ def mx_select(func, types, args, kwargs): old_mx_tensor._is_swizzled_scales, ) return return_and_correct_aliasing(func, args, kwargs, new_mx_tensor) + +@implements([torch.ops._c10d_functional.all_gather_into_tensor.default]) +def mx_all_gather(func, types, args, kwargs): + """ + All-gather for MXTensor + + Args: + func: The operation (all_gather_into_tensor) + types: Tensor types involved + args: (mx_tensor, group_tag, ...) + kwargs: Additional arguments + """ + mx_tensor = args[0] + group_tag = args[1] if len(args) > 1 else "default" + + # Gather both data and scale + gathered_qdata = torch.ops._c10d_functional.all_gather_into_tensor.default( + mx_tensor.qdata, # The quantized data + group_tag, + *args[2:], + **kwargs + ) + + gathered_scale = torch.ops._c10d_functional.all_gather_into_tensor.default( + mx_tensor._scale_e8m0.view(torch.uint8), # The scale factors, Need to cast to uint8 as float8_e8m0fnu is not support for all gather. + group_tag, + *args[2:], + **kwargs + ) + + gathered_scale=gathered_scale.view(torch.float8_e8m0fnu) + + # Return new MXTensor with gathered data + return MXTensor( + gathered_qdata, + gathered_scale, + mx_tensor._elem_dtype, + mx_tensor._block_size, + mx_tensor._orig_dtype, + mx_tensor._gemm_kernel_choice, + mx_tensor._pack_fp6, + mx_tensor.act_quant_kwargs + ) + +@implements([torch.ops._c10d_functional.wait_tensor.default]) +def mx_wait_tensor(func, types, args, kwargs): + """ + Wait for async collective to complete on MXTensor + + This is called after collectives like all_gather to ensure + the operation has completed before using the tensor. + """ + mx_tensor = args[0] + + # Wait on both components + waited_qdata = torch.ops._c10d_functional.wait_tensor.default( + mx_tensor.qdata, + *args[1:], + **kwargs + ) + + waited_scale = torch.ops._c10d_functional.wait_tensor.default( + mx_tensor._scale_e8m0, + *args[1:], + **kwargs + ) + + return MXTensor( + waited_qdata, + waited_scale, + mx_tensor._elem_dtype, + mx_tensor._block_size, + mx_tensor._orig_dtype, + mx_tensor._gemm_kernel_choice, + mx_tensor._pack_fp6, + mx_tensor.act_quant_kwargs + ) diff --git a/torchao/prototype/tests/test_mxfp8_allgather.py b/torchao/prototype/tests/test_mxfp8_allgather.py new file mode 100644 index 0000000000..bf919b2a37 --- /dev/null +++ b/torchao/prototype/tests/test_mxfp8_allgather.py @@ -0,0 +1,108 @@ +import pytest +import torch + +if not torch.cuda.is_available() or torch.cuda.get_device_capability() != (10, 0): + pytest.skip("Test requires CUDA build on SM100", allow_module_level=True) + +from torch.testing._internal.common_distributed import ( + MultiProcessTestCase, +) +import torch.distributed as dist +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + run_tests, +) +from torchao.prototype.mx_formats.mx_tensor import MXTensor + + +@instantiate_parametrized_tests +class MXFP8OnDeviceAllGatherTest(MultiProcessTestCase): + def setUp(self) -> None: + super().setUp() + self._spawn_processes() + + @property + def world_size(self) -> int: + return 4 + + @property + def device(self) -> torch.device: + return torch.device(f"cuda:{self.rank}") + + def _init_process(self): + torch.cuda.set_device(self.device) + store = dist.FileStore(self.file_name, self.world_size) + dist.init_process_group( + backend="nccl", + world_size=self.world_size, + rank=self.rank, + store=store, + ) + torch.manual_seed(42) + + def test_allgather(self): + self._init_process() + try: + torch.manual_seed(42) + golden_qdata = torch.randint(0, 256, (256, 512), dtype=torch.uint8).to(torch.float8_e5m2).to(self.device) + + # Random scale factors (typically float32 or uint8 for e8m0) + golden_scale = torch.randint(0, 256, (256, 16), dtype=torch.uint8).view(torch.float8_e8m0fnu).to(self.device) + + # Create golden MXTensor + golden_mx = MXTensor( + golden_qdata, + golden_scale, + elem_dtype=torch.float8_e5m2, + block_size=32, + orig_dtype=torch.float32, + gemm_kernel_choice=None, + pack_fp6=None, + act_quant_kwargs=None + ) + + world_size = self.world_size + # Each rank gets its shard (split along dim 0) + shard_size = golden_qdata.shape[0] // world_size # 2 rows per rank + start_idx = self.rank * shard_size + end_idx = (self.rank + 1) * shard_size + + # Create local MXTensor from shard + local_mx = MXTensor( + golden_qdata[start_idx:end_idx].clone().to(self.device), + golden_scale[start_idx:end_idx].clone().to(self.device), + elem_dtype=torch.float8_e5m2, + block_size=32, + orig_dtype=torch.float32, + gemm_kernel_choice=None, + pack_fp6=None, + act_quant_kwargs=None + ) + + # Perform all_gather + gathered_mx = torch.ops._c10d_functional.all_gather_into_tensor.default( + local_mx, + world_size, + "0", + ) + gathered_mx = torch.ops._c10d_functional.wait_tensor.default(gathered_mx) + + # ✅ Verify type + assert isinstance(gathered_mx, MXTensor), f"Expected MXTensor, got {type(gathered_mx)}" + + # ✅ Verify shape + assert gathered_mx.shape == golden_mx.shape, \ + f"Shape mismatch: {gathered_mx.shape} vs {golden_mx.shape}" + + # ✅ Verify qdata matches golden exactly + if not torch.equal(gathered_mx.qdata, golden_qdata): + assert False, "qdata mismatch" + + # ✅ Verify scale matches golden exactly + if not torch.equal(gathered_mx._scale_e8m0.view(torch.uint8), golden_scale.view(torch.uint8)): + assert False, "scale mismatch" + + assert gathered_mx._block_size == 32 + + finally: + dist.destroy_process_group() \ No newline at end of file From 3c0e6eddc62008a6712108844e6ba3d3131b2cab Mon Sep 17 00:00:00 2001 From: avi singhal Date: Thu, 4 Dec 2025 07:05:07 +0000 Subject: [PATCH 02/18] added TODO for future feature --- torchao/prototype/mx_formats/mx_tensor.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchao/prototype/mx_formats/mx_tensor.py b/torchao/prototype/mx_formats/mx_tensor.py index b22c58cc61..56b55982b5 100644 --- a/torchao/prototype/mx_formats/mx_tensor.py +++ b/torchao/prototype/mx_formats/mx_tensor.py @@ -855,6 +855,8 @@ def mx_all_gather(func, types, args, kwargs): mx_tensor = args[0] group_tag = args[1] if len(args) > 1 else "default" + #TODO: Add support for concat CC as a future optimization + # Gather both data and scale gathered_qdata = torch.ops._c10d_functional.all_gather_into_tensor.default( mx_tensor.qdata, # The quantized data From ef6ed8fbbe7b1faa6f0dc8625f6536731c4071c2 Mon Sep 17 00:00:00 2001 From: avi singhal Date: Thu, 4 Dec 2025 16:50:49 +0000 Subject: [PATCH 03/18] remove emoji from comment --- torchao/prototype/tests/test_mxfp8_allgather.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchao/prototype/tests/test_mxfp8_allgather.py b/torchao/prototype/tests/test_mxfp8_allgather.py index bf919b2a37..6f09d1382f 100644 --- a/torchao/prototype/tests/test_mxfp8_allgather.py +++ b/torchao/prototype/tests/test_mxfp8_allgather.py @@ -87,18 +87,18 @@ def test_allgather(self): ) gathered_mx = torch.ops._c10d_functional.wait_tensor.default(gathered_mx) - # ✅ Verify type + # Verify type assert isinstance(gathered_mx, MXTensor), f"Expected MXTensor, got {type(gathered_mx)}" - # ✅ Verify shape + # Verify shape assert gathered_mx.shape == golden_mx.shape, \ f"Shape mismatch: {gathered_mx.shape} vs {golden_mx.shape}" - # ✅ Verify qdata matches golden exactly + # Verify qdata matches golden exactly if not torch.equal(gathered_mx.qdata, golden_qdata): assert False, "qdata mismatch" - # ✅ Verify scale matches golden exactly + # Verify scale matches golden exactly if not torch.equal(gathered_mx._scale_e8m0.view(torch.uint8), golden_scale.view(torch.uint8)): assert False, "scale mismatch" From f96d168055e7bff1886a297cbd0fa7a8a7993631 Mon Sep 17 00:00:00 2001 From: avi singhal Date: Thu, 4 Dec 2025 20:36:13 +0000 Subject: [PATCH 04/18] fixed ruff formating --- torchao/prototype/tests/test_mxfp8_allgather.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchao/prototype/tests/test_mxfp8_allgather.py b/torchao/prototype/tests/test_mxfp8_allgather.py index 6f09d1382f..2a001cf0fb 100644 --- a/torchao/prototype/tests/test_mxfp8_allgather.py +++ b/torchao/prototype/tests/test_mxfp8_allgather.py @@ -4,14 +4,14 @@ if not torch.cuda.is_available() or torch.cuda.get_device_capability() != (10, 0): pytest.skip("Test requires CUDA build on SM100", allow_module_level=True) +import torch.distributed as dist from torch.testing._internal.common_distributed import ( MultiProcessTestCase, ) -import torch.distributed as dist from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, - run_tests, ) + from torchao.prototype.mx_formats.mx_tensor import MXTensor From 094b01cf0a00ac80bf8321bca55dd3f47258bdf3 Mon Sep 17 00:00:00 2001 From: avi singhal Date: Fri, 5 Dec 2025 06:08:04 +0000 Subject: [PATCH 05/18] fixed ruff formatting --- torchao/prototype/mx_formats/mx_tensor.py | 44 ++++++++-------- .../prototype/tests/test_mxfp8_allgather.py | 50 ++++++++++++------- 2 files changed, 54 insertions(+), 40 deletions(-) diff --git a/torchao/prototype/mx_formats/mx_tensor.py b/torchao/prototype/mx_formats/mx_tensor.py index 56b55982b5..7d5c6b8e83 100644 --- a/torchao/prototype/mx_formats/mx_tensor.py +++ b/torchao/prototype/mx_formats/mx_tensor.py @@ -841,11 +841,12 @@ def mx_select(func, types, args, kwargs): ) return return_and_correct_aliasing(func, args, kwargs, new_mx_tensor) + @implements([torch.ops._c10d_functional.all_gather_into_tensor.default]) def mx_all_gather(func, types, args, kwargs): """ All-gather for MXTensor - + Args: func: The operation (all_gather_into_tensor) types: Tensor types involved @@ -854,26 +855,28 @@ def mx_all_gather(func, types, args, kwargs): """ mx_tensor = args[0] group_tag = args[1] if len(args) > 1 else "default" - - #TODO: Add support for concat CC as a future optimization - + + # TODO: Add support for concat CC as a future optimization + # Gather both data and scale gathered_qdata = torch.ops._c10d_functional.all_gather_into_tensor.default( mx_tensor.qdata, # The quantized data group_tag, *args[2:], - **kwargs + **kwargs, ) - + gathered_scale = torch.ops._c10d_functional.all_gather_into_tensor.default( - mx_tensor._scale_e8m0.view(torch.uint8), # The scale factors, Need to cast to uint8 as float8_e8m0fnu is not support for all gather. + mx_tensor._scale_e8m0.view( + torch.uint8 + ), # The scale factors, Need to cast to uint8 as float8_e8m0fnu is not support for all gather. group_tag, *args[2:], - **kwargs + **kwargs, ) - gathered_scale=gathered_scale.view(torch.float8_e8m0fnu) - + gathered_scale = gathered_scale.view(torch.float8_e8m0fnu) + # Return new MXTensor with gathered data return MXTensor( gathered_qdata, @@ -883,32 +886,29 @@ def mx_all_gather(func, types, args, kwargs): mx_tensor._orig_dtype, mx_tensor._gemm_kernel_choice, mx_tensor._pack_fp6, - mx_tensor.act_quant_kwargs + mx_tensor.act_quant_kwargs, ) + @implements([torch.ops._c10d_functional.wait_tensor.default]) def mx_wait_tensor(func, types, args, kwargs): """ Wait for async collective to complete on MXTensor - + This is called after collectives like all_gather to ensure the operation has completed before using the tensor. """ mx_tensor = args[0] - + # Wait on both components waited_qdata = torch.ops._c10d_functional.wait_tensor.default( - mx_tensor.qdata, - *args[1:], - **kwargs + mx_tensor.qdata, *args[1:], **kwargs ) - + waited_scale = torch.ops._c10d_functional.wait_tensor.default( - mx_tensor._scale_e8m0, - *args[1:], - **kwargs + mx_tensor._scale_e8m0, *args[1:], **kwargs ) - + return MXTensor( waited_qdata, waited_scale, @@ -917,5 +917,5 @@ def mx_wait_tensor(func, types, args, kwargs): mx_tensor._orig_dtype, mx_tensor._gemm_kernel_choice, mx_tensor._pack_fp6, - mx_tensor.act_quant_kwargs + mx_tensor.act_quant_kwargs, ) diff --git a/torchao/prototype/tests/test_mxfp8_allgather.py b/torchao/prototype/tests/test_mxfp8_allgather.py index 2a001cf0fb..178195a823 100644 --- a/torchao/prototype/tests/test_mxfp8_allgather.py +++ b/torchao/prototype/tests/test_mxfp8_allgather.py @@ -44,11 +44,19 @@ def test_allgather(self): self._init_process() try: torch.manual_seed(42) - golden_qdata = torch.randint(0, 256, (256, 512), dtype=torch.uint8).to(torch.float8_e5m2).to(self.device) - + golden_qdata = ( + torch.randint(0, 256, (256, 512), dtype=torch.uint8) + .to(torch.float8_e5m2) + .to(self.device) + ) + # Random scale factors (typically float32 or uint8 for e8m0) - golden_scale = torch.randint(0, 256, (256, 16), dtype=torch.uint8).view(torch.float8_e8m0fnu).to(self.device) - + golden_scale = ( + torch.randint(0, 256, (256, 16), dtype=torch.uint8) + .view(torch.float8_e8m0fnu) + .to(self.device) + ) + # Create golden MXTensor golden_mx = MXTensor( golden_qdata, @@ -58,15 +66,15 @@ def test_allgather(self): orig_dtype=torch.float32, gemm_kernel_choice=None, pack_fp6=None, - act_quant_kwargs=None + act_quant_kwargs=None, ) - + world_size = self.world_size # Each rank gets its shard (split along dim 0) shard_size = golden_qdata.shape[0] // world_size # 2 rows per rank start_idx = self.rank * shard_size end_idx = (self.rank + 1) * shard_size - + # Create local MXTensor from shard local_mx = MXTensor( golden_qdata[start_idx:end_idx].clone().to(self.device), @@ -76,9 +84,9 @@ def test_allgather(self): orig_dtype=torch.float32, gemm_kernel_choice=None, pack_fp6=None, - act_quant_kwargs=None + act_quant_kwargs=None, ) - + # Perform all_gather gathered_mx = torch.ops._c10d_functional.all_gather_into_tensor.default( local_mx, @@ -86,23 +94,29 @@ def test_allgather(self): "0", ) gathered_mx = torch.ops._c10d_functional.wait_tensor.default(gathered_mx) - + # Verify type - assert isinstance(gathered_mx, MXTensor), f"Expected MXTensor, got {type(gathered_mx)}" - + assert isinstance(gathered_mx, MXTensor), ( + f"Expected MXTensor, got {type(gathered_mx)}" + ) + # Verify shape - assert gathered_mx.shape == golden_mx.shape, \ + assert gathered_mx.shape == golden_mx.shape, ( f"Shape mismatch: {gathered_mx.shape} vs {golden_mx.shape}" - + ) + # Verify qdata matches golden exactly if not torch.equal(gathered_mx.qdata, golden_qdata): assert False, "qdata mismatch" - + # Verify scale matches golden exactly - if not torch.equal(gathered_mx._scale_e8m0.view(torch.uint8), golden_scale.view(torch.uint8)): + if not torch.equal( + gathered_mx._scale_e8m0.view(torch.uint8), + golden_scale.view(torch.uint8), + ): assert False, "scale mismatch" - + assert gathered_mx._block_size == 32 finally: - dist.destroy_process_group() \ No newline at end of file + dist.destroy_process_group() From 243001ba22fdb481c6f96c977afadbcbe03ff8a7 Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Thu, 4 Dec 2025 06:21:36 -0500 Subject: [PATCH 06/18] add mxfp8 and nvfp4 to Llama eval scripts (#3394) Update [ghstack-poisoned] --- torchao/_models/llama/eval.py | 38 ++++++++++++++++++++++++-- torchao/prototype/mx_formats/README.md | 23 +++++++++++++++- 2 files changed, 57 insertions(+), 4 deletions(-) diff --git a/torchao/_models/llama/eval.py b/torchao/_models/llama/eval.py index df071fe9d2..002045215c 100644 --- a/torchao/_models/llama/eval.py +++ b/torchao/_models/llama/eval.py @@ -16,6 +16,10 @@ import torchao from torchao._models.llama.model import prepare_inputs_for_model +from torchao.prototype.mx_formats.inference_workflow import ( + MXDynamicActivationMXWeightConfig, + NVFP4DynamicActivationNVFP4WeightConfig, +) from torchao.quantization import ( Float8DynamicActivationFloat8WeightConfig, Float8WeightOnlyConfig, @@ -170,6 +174,8 @@ def run_evaluation( quantize_( model, Float8DynamicActivationFloat8WeightConfig(granularity=granularity), + filter_fn=lambda mod, fqn: isinstance(mod, torch.nn.Linear) + and fqn != "output", ) if quantization == "float8_a1x128_w128x128": config = Float8DynamicActivationFloat8WeightConfig( @@ -177,8 +183,34 @@ def run_evaluation( activation_value_lb=1e-12, ) # TODO(future): all workflows in this file should be skipping quantization - # of `lm_head` + # of `lm_head`/`output` quantize_(model, config) + if quantization == "mxfp8": + config = MXDynamicActivationMXWeightConfig( + activation_dtype=torch.float8_e4m3fn, + weight_dtype=torch.float8_e4m3fn, + ) + # TODO(future): all workflows in this file should be skipping quantization + # of `lm_head`/`output` + quantize_( + model, + config, + filter_fn=lambda mod, fqn: isinstance(mod, torch.nn.Linear) + and fqn != "output", + ) + if quantization == "nvfp4": + config = NVFP4DynamicActivationNVFP4WeightConfig( + use_dynamic_per_tensor_scale=True, + use_triton_kernel=True, + ) + # TODO(future): all workflows in this file should be skipping quantization + # of `lm_head`/`output` + quantize_( + model, + config, + filter_fn=lambda mod, fqn: isinstance(mod, torch.nn.Linear) + and fqn != "output", + ) if "autoround" in quantization: from transformers import AutoTokenizer @@ -284,8 +316,8 @@ def run_evaluation( if compile: # TODO(future PR): clean this up - if quantization == "float8_a1x128_w128x128": - # we don't need max-autotune for float8 blockwise quant + if quantization in ("float8_a1x128_w128x128", "mxfp8", "nvfp4"): + # we don't need max-autotune for float8 blockwise or mxfp8 quant model = torch.compile(model) else: model = torch.compile(model, mode="max-autotune", fullgraph=True) diff --git a/torchao/prototype/mx_formats/README.md b/torchao/prototype/mx_formats/README.md index e644fee8fe..6873c2d81f 100644 --- a/torchao/prototype/mx_formats/README.md +++ b/torchao/prototype/mx_formats/README.md @@ -223,7 +223,28 @@ To reproduce this on supported hardware, you can run the following command: ## inference -Coming soon! +Eval results on LLaMa 3.1 8B on common tasks. `mxfp8` and `nvfp4` recipes quantize all linears except `lm_head`. + +Note: the accuracy results below are WIP and are not optimized yet. + +| recipe | wikitext word_perplexity | winogrande | +| ------ | -------- | ---------- | +| bfloat16 (baseline) | 7.5472105433748435 | 0.7426992896606156 | +| mxfp8 | 7.609070006132819 | 0.7292817679558011 | +| nvfp4 | 8.44478255417328 | 0.7182320441988951 | + +To reproduce: + +```bash +# baseline +python torchao/_models/llama/eval.py --checkpoint_path checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --print_model --tasks wikitext winogrande + +# mxfp8 +python torchao/_models/llama/eval.py --checkpoint_path checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --print_model --tasks wikitext winogrande --quantization mxfp8 + +# nvfp4 +python torchao/_models/llama/eval.py --checkpoint_path checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --print_model --tasks wikitext winogrande --quantization nvfp4 +``` # testing From 88e2bb9e7e4d079faa7e0a57d844865148b7d9a6 Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Thu, 4 Dec 2025 06:22:52 -0500 Subject: [PATCH 07/18] flip mx inference scaling setting to RCEIL (#3428) * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] --- torchao/prototype/mx_formats/README.md | 2 +- torchao/prototype/mx_formats/inference_workflow.py | 2 ++ torchao/prototype/mx_formats/mx_tensor.py | 2 ++ 3 files changed, 5 insertions(+), 1 deletion(-) diff --git a/torchao/prototype/mx_formats/README.md b/torchao/prototype/mx_formats/README.md index 6873c2d81f..a367ae8336 100644 --- a/torchao/prototype/mx_formats/README.md +++ b/torchao/prototype/mx_formats/README.md @@ -230,7 +230,7 @@ Note: the accuracy results below are WIP and are not optimized yet. | recipe | wikitext word_perplexity | winogrande | | ------ | -------- | ---------- | | bfloat16 (baseline) | 7.5472105433748435 | 0.7426992896606156 | -| mxfp8 | 7.609070006132819 | 0.7292817679558011 | +| mxfp8 | 7.605192917647689 | 0.7355958958168903 | | nvfp4 | 8.44478255417328 | 0.7182320441988951 | To reproduce: diff --git a/torchao/prototype/mx_formats/inference_workflow.py b/torchao/prototype/mx_formats/inference_workflow.py index c771015ccd..5a7b5939c1 100644 --- a/torchao/prototype/mx_formats/inference_workflow.py +++ b/torchao/prototype/mx_formats/inference_workflow.py @@ -85,6 +85,7 @@ def _mx_inference_linear_transform( block_size=config.block_size, kernel_preference=config.kernel_preference, is_swizzled_scales=True, + scaling_mode=ScaleCalculationMode.RCEIL, ) # Convert weight to MX Tensor @@ -95,6 +96,7 @@ def _mx_inference_linear_transform( kernel_preference=config.kernel_preference, act_quant_kwargs=act_quant_kwargs, is_swizzled_scales=True, + scaling_mode=ScaleCalculationMode.RCEIL, ) module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False) diff --git a/torchao/prototype/mx_formats/mx_tensor.py b/torchao/prototype/mx_formats/mx_tensor.py index 7d5c6b8e83..b7bcb238d9 100644 --- a/torchao/prototype/mx_formats/mx_tensor.py +++ b/torchao/prototype/mx_formats/mx_tensor.py @@ -87,6 +87,7 @@ class QuantizeTensorToMXKwargs(QuantizeTensorKwargs): elem_dtype: Union[torch.dtype, str] = torch.float8_e4m3fn block_size: int = 32 + # TODO(future PR): flip the scaling_mode default to RCEIL scaling_mode: ScaleCalculationMode = ScaleCalculationMode.FLOOR kernel_preference: KernelPreference = KernelPreference.EMULATED is_swizzled_scales: bool = False @@ -533,6 +534,7 @@ def to_mx( data_hp: torch.Tensor, elem_dtype: Union[torch.dtype, str], block_size: int = BLOCK_SIZE_DEFAULT, + # TODO(future PR): flip the scaling_mode default to RCEIL scaling_mode: ScaleCalculationMode = ScaleCalculationMode.FLOOR, # TODO(future PR): switch default gemm to cublas kernel_preference: KernelPreference = KernelPreference.EMULATED, From ba74266b441858c67a0b11a49697a9daf28b5ecd Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Thu, 4 Dec 2025 15:05:00 -0500 Subject: [PATCH 08/18] add CLAUDE.local.md to gitignore (#3437) Summary: taking claude code for a more thorough spin, will start with local instructions and will see what makes sense to upstream Test Plan: Reviewers: Subscribers: Tasks: Tags: --- .gitignore | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.gitignore b/.gitignore index a68bce77ca..cccb8f7f72 100644 --- a/.gitignore +++ b/.gitignore @@ -378,3 +378,6 @@ checkpoints/ # Experimental torchao/experimental/cmake-out torchao/experimental/deps + +# local claude code files +CLAUDE.local.md From 11b2401bd1b437fb11013a4ec83b2aa7db2c7469 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Thu, 4 Dec 2025 13:36:57 -0800 Subject: [PATCH 09/18] bump python version in tutorial ci workflow (#3439) --- .github/workflows/run_tutorials.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/run_tutorials.yml b/.github/workflows/run_tutorials.yml index 3a8ee4df6b..be13a5a9d4 100644 --- a/.github/workflows/run_tutorials.yml +++ b/.github/workflows/run_tutorials.yml @@ -19,7 +19,7 @@ jobs: - name: Setup miniconda uses: pytorch/test-infra/.github/actions/setup-miniconda@main with: - python-version: "3.9" + python-version: "3.11" - name: Run tutorials shell: bash From 6081c0c2e32964a50e9edd17c37fdb3edefb3105 Mon Sep 17 00:00:00 2001 From: Xia Weiwen Date: Fri, 5 Dec 2025 07:58:55 +0800 Subject: [PATCH 10/18] [CPU] Reland qconv fp8 fusion passes (#3433) * [Reland][PT2E][X86] Add Inductor fusion passes of float8 qconv for X86Inductor backend * add torch version check for Qconv FP8 UTs * fix format issue * Skip tests for ROCm --------- Co-authored-by: Sun, Jiayi --- .../pt2e/test_x86inductor_fusion.py | 292 ++++++++++++++++-- .../quantization/pt2e/inductor_passes/x86.py | 279 ++++++++++------- 2 files changed, 425 insertions(+), 146 deletions(-) diff --git a/test/quantization/pt2e/test_x86inductor_fusion.py b/test/quantization/pt2e/test_x86inductor_fusion.py index 520b5fbdfb..2e0a4f7738 100644 --- a/test/quantization/pt2e/test_x86inductor_fusion.py +++ b/test/quantization/pt2e/test_x86inductor_fusion.py @@ -44,7 +44,6 @@ from torchao.quantization.pt2e.quantizer.x86_inductor_quantizer import ( X86InductorQuantizer, ) -from torchao.testing.utils import skip_if_rocm from torchao.utils import torch_version_at_least # The dict value is match_nodes(computation_op+unary_op) @@ -93,6 +92,9 @@ skipIfNoFloat8Support = unittest.skipIf( not torch_version_at_least("2.9.0"), "Float8 requires torch 2.9+" ) +skipIfNoQConvFp8Support = unittest.skipIf( + not torch_version_at_least("2.10.0.dev"), "QConv fp8 requires torch 2.10+" +) def get_default_quantizer(is_qat, is_dynamic): @@ -138,6 +140,61 @@ def forward(self, input): return out +class FP8QDQConv2d(torch.nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + ): + super().__init__() + self.qtype = torch.float8_e4m3fn + self.weight = torch.randn( + (out_channels, in_channels // groups, *kernel_size) + ).to(self.qtype) + self.weight_scale = 2.0 + self.scale = 2.0 + self.bias = None + if bias: + self.bias = torch.randn((out_channels,)) + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + + def forward(self, input): + weight = torch.ops.torchao.dequantize_affine_float8_non_decomposed.default( + tensor=self.weight.data, + scale=torch.tensor([self.weight_scale]), + output_dtype=torch.float, + ) + q_input = torch.ops.torchao.quantize_affine_float8_non_decomposed.default( + tensor=input, + scale=torch.tensor([self.scale]), + float8_dtype=self.qtype, + ) + dq_input = torch.ops.torchao.dequantize_affine_float8_non_decomposed.default( + tensor=q_input, + scale=torch.tensor([self.scale]), + output_dtype=torch.float, + ) + + return torch.nn.functional.conv2d( + dq_input, + weight, + self.bias, + self.stride, + self.padding, + self.dilation, + self.groups, + ) + + def qdq(input, scale): dtype = input.dtype q_input = torch.ops.torchao.quantize_affine_float8_non_decomposed.default( @@ -171,9 +228,7 @@ def create_mod_info_recursion(parent): parent_child_mod_dict = generate_model_info(model) for name, mod in model.named_modules(): mod_type_str = mod.__class__.__name__ - if mod_type_str not in [ - "Linear", - ]: + if mod_type_str not in ["Linear", "Conv2d"]: continue param = mod.weight xmax = torch.max(param) @@ -190,6 +245,20 @@ def create_mod_info_recursion(parent): patched_mod.bias = mod.bias patched_mod.weight_scale = weight_scale.item() patched_mod.weight.data = q_param + elif mod_type_str in ["Conv2d"]: + patched_mod = FP8QDQConv2d( + mod.in_channels, + mod.out_channels, + mod.kernel_size, + mod.stride, + mod.padding, + mod.dilation, + mod.groups, + False, + ) + patched_mod.bias = mod.bias + patched_mod.weight_scale = weight_scale.item() + patched_mod.weight.data = q_param parent = parent_child_mod_dict[mod].parent name = parent_child_mod_dict[mod].name @@ -381,8 +450,9 @@ def _test_code_common( @unittest.skipIf(not torch_version_at_least("2.8.0"), "Requires torch 2.8+") +@unittest.skipIf(torch.version.hip is not None, "Not applicable to ROCm") class TestPatternMatcher(TestPatternMatcherBase): - def _qconv2d_test_helper(self, device="cpu", int8_mixed_bf16=False): + def _qconv2d_test_helper(self, device="cpu", mixed_bf16=False, is_fp8=False): class M(torch.nn.Module): def __init__( self, @@ -408,14 +478,14 @@ def forward(self, x): def matcher_check_fn(): # 1. Dequant-Conv2D pattern matched in QConv2D weight prepack * 1 # int8_mixed_fp32: [dequant_node, dequantize_per_channel, clone, convolution] - # int8_mixed_bf16: [dequant_node, optional(convert_element_type_4), + # mixed_bf16: [dequant_node, optional(convert_element_type_4), # dequantize_per_channel, optional(convert_element_type_3), clone, convolution] self.assertEqual( counters["inductor"]["qconv_weight_prepack_matcher_count"], 3 ) self.assertEqual( counters["inductor"]["qconv_weight_prepack_matcher_nodes"], - 18 if int8_mixed_bf16 else 12, + 18 if mixed_bf16 else 12, ) self.assertEqual( counters["inductor"]["qconv_unary_lower_count"], 0 if TEST_ACL else 3 @@ -426,34 +496,53 @@ def matcher_check_fn(): (v,), matcher_check_fn, check_quantization=True, - check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float32, + check_autocast=torch.bfloat16 if mixed_bf16 else torch.float32, + is_fp8=is_fp8, ) @skipIfNoDynamoSupport @skipIfNoONEDNN - @skip_if_rocm("Not applicable to ROCm") def test_qconv2d_cpu(self): r""" This testcase will quantize a single Conv2d module. """ self._qconv2d_test_helper("cpu") + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skipIfNoQConvFp8Support + def test_qconv2d_fp8_cpu(self): + r""" + This testcase will quantize a single Conv2d module. + """ + self._qconv2d_test_helper("cpu", is_fp8=True) + @skipIfNoDynamoSupport @skipIfNoONEDNNBF16 @skipIfNoONEDNN - @skip_if_rocm("Not applicable to ROCm") def test_qconv2d_int8_mixed_bf16(self): r""" This testcase will quantize a single Conv2d module with int8_mixed_bf16 quantization. """ - self._qconv2d_test_helper(int8_mixed_bf16=True) + self._qconv2d_test_helper(mixed_bf16=True) + + @skipIfNoDynamoSupport + @skipIfNoONEDNNBF16 + @skipIfNoONEDNN + @skipIfNoQConvFp8Support + def test_qconv2d_fp8_mixed_bf16(self): + r""" + This testcase will quantize a single Conv2d module with int8_mixed_bf16 quantization. + """ + self._qconv2d_test_helper(mixed_bf16=True, is_fp8=True) def _qconv2d_unary_test_helper( self, device="cpu", - int8_mixed_bf16=False, + mixed_bf16=False, unary_op=torch.nn.ReLU(), qconv_unary_matcher_nodes=None, + is_fp8=False, ): class M(torch.nn.Module): def __init__( @@ -502,8 +591,9 @@ def matcher_check_fn(): mod, (v,), check_quantization=True, - check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float32, + check_autocast=torch.bfloat16 if mixed_bf16 else torch.float32, matcher_check_fn=matcher_check_fn, + is_fp8=is_fp8, ) @skipIfNoDynamoSupport @@ -514,6 +604,15 @@ def test_qconv2d_relu_cpu(self): """ self._qconv2d_unary_test_helper(device="cpu") + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skipIfNoQConvFp8Support + def test_qconv2d_relu_fp8_cpu(self): + r""" + This testcase will quantize Conv2d->ReLU pattern. + """ + self._qconv2d_unary_test_helper(device="cpu", is_fp8=True) + @skipIfNoDynamoSupport @skipIfNoONEDNNBF16 @skipIfNoONEDNN @@ -521,7 +620,7 @@ def test_qconv2d_relu_int8_mixed_bf16_xpu(self): r""" This testcase will quantize Conv2d->ReLU pattern with int8_mixed_bf16 quantization. """ - self._qconv2d_unary_test_helper(int8_mixed_bf16=True) + self._qconv2d_unary_test_helper(mixed_bf16=True) @skipIfNoDynamoSupport @skipIfNoONEDNN @@ -531,6 +630,17 @@ def test_qconv2d_relu6_cpu(self): """ self._qconv2d_unary_test_helper(device="cpu", unary_op=torch.nn.ReLU6()) + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skipIfNoQConvFp8Support + def test_qconv2d_relu6_fp8_cpu(self): + r""" + This testcase will quantize Conv2d->ReLU6 pattern. + """ + self._qconv2d_unary_test_helper( + device="cpu", unary_op=torch.nn.ReLU6(), is_fp8=True + ) + @skipIfNoDynamoSupport @skipIfNoONEDNN def test_qconv2d_hardtanh_cpu(self): @@ -539,6 +649,17 @@ def test_qconv2d_hardtanh_cpu(self): """ self._qconv2d_unary_test_helper(device="cpu", unary_op=torch.nn.Hardtanh()) + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skipIfNoQConvFp8Support + def test_qconv2d_hardtanh_fp8_cpu(self): + r""" + This testcase will quantize Conv2d->Hardtanh pattern. + """ + self._qconv2d_unary_test_helper( + device="cpu", unary_op=torch.nn.Hardtanh(), is_fp8=True + ) + @skipIfNoDynamoSupport @skipIfNoONEDNNBF16 @skipIfNoONEDNN @@ -551,8 +672,26 @@ def test_qconv2d_hardtanh_int8_mixed_bf16_cpu(self): """ self._qconv2d_unary_test_helper( unary_op=torch.nn.Hardtanh(), - int8_mixed_bf16=True, + mixed_bf16=True, + qconv_unary_matcher_nodes=11, + ) + + @skipIfNoDynamoSupport + @skipIfNoONEDNNBF16 + @skipIfNoONEDNN + @skipIfNoQConvFp8Support + def test_qconv2d_hardtanh_fp8_mixed_bf16_cpu(self): + r""" + This testcase will quantize Conv2d->Hardtanh pattern. + Match.nodes: + [qconv2d_pointwise_default, convert_element_type, clamp_min, clamp_max, convert_element_type, quantize_per_tensor] + [qconv2d_pointwise_default, convert_element_type, clamp_min, clamp_max, convert_element_type] + """ + self._qconv2d_unary_test_helper( + unary_op=torch.nn.Hardtanh(), + mixed_bf16=True, qconv_unary_matcher_nodes=11, + is_fp8=True, ) @skipIfNoDynamoSupport @@ -563,6 +702,17 @@ def test_qconv2d_hardswish_cpu(self): """ self._qconv2d_unary_test_helper(device="cpu", unary_op=torch.nn.Hardswish()) + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skipIfNoQConvFp8Support + def test_qconv2d_hardswish_fp8_cpu(self): + r""" + This testcase will quantize Conv2d->Hardswish pattern. + """ + self._qconv2d_unary_test_helper( + device="cpu", unary_op=torch.nn.Hardswish(), is_fp8=True + ) + @skipIfNoDynamoSupport @skipIfNoONEDNNBF16 @skipIfNoONEDNN @@ -576,8 +726,27 @@ def test_qconv2d_hardswish_int8_mixed_bf16_cpu(self): """ self._qconv2d_unary_test_helper( unary_op=torch.nn.Hardswish(), - int8_mixed_bf16=True, + mixed_bf16=True, + qconv_unary_matcher_nodes=17, + ) + + @skipIfNoDynamoSupport + @skipIfNoONEDNNBF16 + @skipIfNoONEDNN + @skipIfNoQConvFp8Support + def test_qconv2d_hardswish_fp8_mixed_bf16_cpu(self): + r""" + This testcase will quantize Conv2d->Hardswish pattern. + Match.nodes: + [qconv2d_pointwise_default, convert_element_type, add, clamp_min, + clamp_max, mul, div, convert_element_type, quantize_per_tensor] + [qconv2d_pointwise_default, convert_element_type, add, clamp_min, clamp_max, mul, div, convert_element_type] + """ + self._qconv2d_unary_test_helper( + unary_op=torch.nn.Hardswish(), + mixed_bf16=True, qconv_unary_matcher_nodes=17, + is_fp8=True, ) @skipIfNoDynamoSupport @@ -588,6 +757,17 @@ def test_qconv2d_silu_cpu(self): """ self._qconv2d_unary_test_helper(device="cpu", unary_op=torch.nn.SiLU()) + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skipIfNoQConvFp8Support + def test_qconv2d_silu_fp8_cpu(self): + r""" + This testcase will quantize Conv2d->SiLU pattern. + """ + self._qconv2d_unary_test_helper( + device="cpu", unary_op=torch.nn.SiLU(), is_fp8=True + ) + @skipIfNoDynamoSupport @skipIfNoONEDNNBF16 @skipIfNoONEDNN @@ -601,12 +781,31 @@ def test_qconv2d_silu_int8_mixed_bf16_cpu(self): """ self._qconv2d_unary_test_helper( unary_op=torch.nn.SiLU(), - int8_mixed_bf16=True, + mixed_bf16=True, + qconv_unary_matcher_nodes=11, + ) + + @skipIfNoDynamoSupport + @skipIfNoONEDNNBF16 + @skipIfNoONEDNN + @skipIfNoQConvFp8Support + def test_qconv2d_silu_fp8_mixed_bf16_cpu(self): + r""" + This testcase will quantize Conv2d->SiLU pattern. + Match.nodes: + [qconv2d_pointwise_default, convert_element_type, sigmoid, mul, + convert_element_type, quantize_per_tensor] + [qconv2d_pointwise_default, convert_element_type, sigmoid, mul, convert_element_type] + """ + self._qconv2d_unary_test_helper( + unary_op=torch.nn.SiLU(), + mixed_bf16=True, qconv_unary_matcher_nodes=11, + is_fp8=True, ) def _qconv2d_add_test_helper( - self, device="cpu", use_relu=False, int8_mixed_bf16=False + self, device="cpu", use_relu=False, mixed_bf16=False, is_fp8=False ): r""" This testcase will quantize a Conv2d->Add pattern as: @@ -680,11 +879,12 @@ def matcher_check_fn(): (v,), matcher_check_fn, check_quantization=True, - check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float32, + check_autocast=torch.bfloat16 if mixed_bf16 else torch.float32, + is_fp8=is_fp8, ) def _qconv2d_add_test_helper2( - self, device="cpu", use_relu=False, int8_mixed_bf16=False + self, device="cpu", use_relu=False, mixed_bf16=False, is_fp8=False ): r""" This testcase will quantize two Conv2d->Add patterns as: @@ -743,9 +943,10 @@ def forward(self, x, x2, x3): res = self.relu2(res) return res - for add_fn, swap_inputs in itertools.product( - quantization_add_fn_list + quantization_inplace_add_fn_list, [False, True] - ): + add_fn_list = quantization_add_fn_list + if not is_fp8: + add_fn_list = add_fn_list + quantization_inplace_add_fn_list + for add_fn, swap_inputs in itertools.product(add_fn_list, [False, True]): mod = M(add_fn, use_relu, swap_inputs).eval().to(device=device) x = torch.randn( (1, 3, 8, 8), dtype=torch.float32, requires_grad=False, device=device @@ -777,7 +978,8 @@ def matcher_check_fn(): (x, x2, x3), matcher_check_fn, check_quantization=True, - check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float32, + check_autocast=torch.bfloat16 if mixed_bf16 else torch.float32, + is_fp8=is_fp8, ) @skipIfNoDynamoSupport @@ -786,12 +988,27 @@ def test_qconv2d_add_cpu(self): self._qconv2d_add_test_helper() self._qconv2d_add_test_helper2() + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skipIfNoQConvFp8Support + def test_qconv2d_add_fp8_cpu(self): + self._qconv2d_add_test_helper(is_fp8=True) + self._qconv2d_add_test_helper2(is_fp8=True) + @skipIfNoDynamoSupport @skipIfNoONEDNNBF16 @skipIfNoONEDNN def test_qconv2d_add_int8_mixed_bf16(self): - self._qconv2d_add_test_helper(int8_mixed_bf16=True) - self._qconv2d_add_test_helper2(int8_mixed_bf16=True) + self._qconv2d_add_test_helper(mixed_bf16=True) + self._qconv2d_add_test_helper2(mixed_bf16=True) + + @skipIfNoDynamoSupport + @skipIfNoONEDNNBF16 + @skipIfNoONEDNN + @skipIfNoQConvFp8Support + def test_qconv2d_add_fp8_mixed_bf16(self): + self._qconv2d_add_test_helper(mixed_bf16=True, is_fp8=True) + self._qconv2d_add_test_helper2(mixed_bf16=True, is_fp8=True) @skipIfNoDynamoSupport @skipIfNoONEDNN @@ -799,12 +1016,27 @@ def test_qconv2d_add_relu_cpu(self): self._qconv2d_add_test_helper(use_relu=True) self._qconv2d_add_test_helper2(use_relu=True) + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skipIfNoQConvFp8Support + def test_qconv2d_add_relu_fp8_cpu(self): + self._qconv2d_add_test_helper(use_relu=True, is_fp8=True) + self._qconv2d_add_test_helper2(use_relu=True, is_fp8=True) + @skipIfNoDynamoSupport @skipIfNoONEDNNBF16 @skipIfNoONEDNN def test_qconv2d_add_relu_int8_mixed_bf16(self): - self._qconv2d_add_test_helper(use_relu=True, int8_mixed_bf16=True) - self._qconv2d_add_test_helper2(use_relu=True, int8_mixed_bf16=True) + self._qconv2d_add_test_helper(use_relu=True, mixed_bf16=True) + self._qconv2d_add_test_helper2(use_relu=True, mixed_bf16=True) + + @skipIfNoDynamoSupport + @skipIfNoONEDNNBF16 + @skipIfNoONEDNN + @skipIfNoQConvFp8Support + def test_qconv2d_add_relu_fp8_mixed_bf16(self): + self._qconv2d_add_test_helper(use_relu=True, mixed_bf16=True, is_fp8=True) + self._qconv2d_add_test_helper2(use_relu=True, mixed_bf16=True, is_fp8=True) @skipIfNoDynamoSupport @skipIfNoONEDNN @@ -1035,7 +1267,6 @@ def matcher_check_fn(): @skipIfNoDynamoSupport @skipIfNoONEDNN - @skip_if_rocm("Not applicable to ROCm") def test_qat_qconv2d(self): r""" This testcase will quantize a single Conv2d module with qat flow. @@ -1178,7 +1409,6 @@ def test_qat_qconv2d_hardswish(self): @skipIfNoDynamoSupport @skipIfNoONEDNN - @skip_if_rocm("Not applicable to ROCm") def test_qat_qconv2d_add(self): r""" This testcase will quantize a Conv2d->Add pattern as: @@ -1244,7 +1474,6 @@ def matcher_check_fn(): @skipIfNoDynamoSupport @skipIfNoONEDNN - @skip_if_rocm("Not applicable to ROCm") def test_qat_qconv2d_add_relu(self): r""" This testcase will quantize a Conv2d->Add->ReLU pattern as: @@ -1384,7 +1613,6 @@ def matcher_check_fn(): @skipIfNoDynamoSupport @skipIfNoONEDNN - @skip_if_rocm("Not applicable to ROCm") def test_qconv2d_dequant_promotion_cpu(self): self._test_qconv2d_dequant_promotion_helper() diff --git a/torchao/quantization/pt2e/inductor_passes/x86.py b/torchao/quantization/pt2e/inductor_passes/x86.py index a0aef11541..c5280b9db0 100644 --- a/torchao/quantization/pt2e/inductor_passes/x86.py +++ b/torchao/quantization/pt2e/inductor_passes/x86.py @@ -167,60 +167,49 @@ def get_dequantize_per_tensor_activation_pattern( KeywordArg("w_dtype"), ) -dequantize_per_channel_to_bf16_weight_pattern = ( - _may_generate_pattern_with_dtype_convert( - dequantize_per_channel_weight_pattern, - KeywordArg("autocast_wgt_dtype"), - ) +dequantize_fp8_weight_pattern = CallFunction( + torch.ops.torchao.dequantize_affine_float8_non_decomposed.default, + KeywordArg("q_weight"), + KeywordArg("w_scale"), + output_dtype=KeywordArg("w_dtype"), ) -dequantize_per_channel_clone_weight_pattern = CallFunction( - aten.clone.default, - dequantize_per_channel_weight_pattern, - memory_format=KeywordArg("memory_format"), -) -dequantize_per_channel_to_bf16_clone_weight_pattern = CallFunction( - aten.clone.default, - dequantize_per_channel_to_bf16_weight_pattern, - memory_format=KeywordArg("memory_format"), -) +def get_dequantize_to_bf16_weight_pattern(dequant_wgt_pattern): + return _may_generate_pattern_with_dtype_convert( + dequant_wgt_pattern, + KeywordArg("autocast_wgt_dtype"), + ) -def get_qconv_pt2e_pattern(users=1): +def get_dequantize_clone_weight_pattern(dequant_wgt_pattern): return CallFunction( - torch.ops.onednn.qconv_pointwise.default, - KeywordArg("x"), - KeywordArg("x_scale"), - KeywordArg("x_zp"), - KeywordArg("packed_weight"), - KeywordArg("w_scale"), - KeywordArg("w_zp"), - KeywordArg("b"), - KeywordArg("stride"), - KeywordArg("padding"), - KeywordArg("dilation"), - KeywordArg("groups"), - KeywordArg("output_scale"), - KeywordArg("output_zero_point"), - KeywordArg("output_dtype"), - KeywordArg("postop_name"), - KeywordArg("postop_args"), - KeywordArg("postop_algorithm"), - _users=users, + aten.clone.default, + dequant_wgt_pattern, + memory_format=KeywordArg("memory_format"), + ) + + +def get_dequantize_to_bf16_clone_weight_pattern(dequant_wgt_pattern): + return get_dequantize_clone_weight_pattern( + get_dequantize_to_bf16_weight_pattern(dequant_wgt_pattern) ) -def get_qconv2d_binary_pt2e_pattern(users=1): +def get_qconv_pt2e_pattern(x_scale_zp_are_tensors=False, users=1): + qconv_op = ( + torch.ops.onednn.qconv_pointwise.tensor + if x_scale_zp_are_tensors + else torch.ops.onednn.qconv_pointwise.default + ) return CallFunction( - torch.ops.onednn.qconv2d_pointwise.binary, + qconv_op, KeywordArg("x"), KeywordArg("x_scale"), KeywordArg("x_zp"), KeywordArg("packed_weight"), KeywordArg("w_scale"), KeywordArg("w_zp"), - KeywordArg("accum"), KeywordArg("b"), KeywordArg("stride"), KeywordArg("padding"), @@ -229,13 +218,9 @@ def get_qconv2d_binary_pt2e_pattern(users=1): KeywordArg("output_scale"), KeywordArg("output_zero_point"), KeywordArg("output_dtype"), - KeywordArg("accum_scale"), - KeywordArg("accum_zero_point"), - KeywordArg("binary_op_name"), - KeywordArg("alpha"), - KeywordArg("unary_op_name"), - KeywordArg("unary_op_args"), - KeywordArg("unary_op_algorithm"), + KeywordArg("postop_name"), + KeywordArg("postop_args"), + KeywordArg("postop_algorithm"), _users=users, ) @@ -461,6 +446,7 @@ def fn(match): return False binary_node_inputs = next(iter(compute_node.users)).args assert len(binary_node_inputs) == 2, "Expects binary node with 2 inputs" + is_fp8 = match.kwargs["x"].meta["val"].dtype is torch.float8_e4m3fn if output_dtype in [torch.float32, torch.bfloat16]: extra_input_of_binary_node = None for arg in binary_node_inputs: @@ -469,14 +455,18 @@ def fn(match): break assert extra_input_of_binary_node is not None # Extra input of binary node comes from dequant pattern - if extra_input_from_dequant and ( - (not isinstance(extra_input_of_binary_node, torch.fx.Node)) - or ( - extra_input_of_binary_node.target - not in [ - quantized_decomposed.dequantize_per_tensor.default, - torch.ops.torchao.dequantize_affine_float8_non_decomposed.default, - ] + if ( + not is_fp8 + and extra_input_from_dequant + and ( + (not isinstance(extra_input_of_binary_node, torch.fx.Node)) + or ( + extra_input_of_binary_node.target + not in [ + quantized_decomposed.dequantize_per_tensor.default, + torch.ops.torchao.dequantize_affine_float8_non_decomposed.default, + ] + ) ) ): return False @@ -711,7 +701,9 @@ def _inner(match): return _inner -def _register_qconv_weight_prepack_pass(pattern, pass_number, dtype=torch.float32): +def _register_qconv_weight_prepack_pass( + pattern, pass_number, dtype=torch.float32, is_fp8=False +): @register_freezing_graph_pattern( pattern, extra_check=_is_valid_dequant_conv_pattern(dtype), @@ -724,7 +716,7 @@ def qconv_weight_prepack(match: Match, *args, **kwargs): | dequant_per_tensor | - Conv2d <- optional(aten.clone.default) <- dequant_per_channel <- int8_weight + Conv2d <- optional(aten.clone.default) <- dequant <- int8_weight Insert weight prepack node and change the pattern to: int8 activation @@ -747,7 +739,7 @@ def qconv_weight_prepack(match: Match, *args, **kwargs): ) if dtype == torch.float32: - dequant_per_channel = ( + dequant = ( clone_node.args[0] # type: ignore[union-attr] if has_clone_to_channel_last_node_in_pattern else conv_node.args[1] @@ -758,9 +750,9 @@ def qconv_weight_prepack(match: Match, *args, **kwargs): if has_clone_to_channel_last_node_in_pattern else conv_node.args[1] ) - dequant_per_channel = weight_to_bf16_node.args[0] # type: ignore[union-attr] + dequant = weight_to_bf16_node.args[0] # type: ignore[union-attr] - assert dequant_per_channel.target in [ # type: ignore[union-attr] + assert dequant.target in [ # type: ignore[union-attr] quantized_decomposed.dequantize_per_channel.default, torch.ops.torchao.dequantize_affine_float8_non_decomposed.default, ] @@ -768,7 +760,7 @@ def qconv_weight_prepack(match: Match, *args, **kwargs): # Activation QParams qx, x_zp, x_scale = ( kwargs["x"], - kwargs["x_zp"], + kwargs["x_zp"] if "x_zp" in kwargs else None, kwargs["x_scale"], ) @@ -776,7 +768,7 @@ def qconv_weight_prepack(match: Match, *args, **kwargs): qw, w_scale, w_zp = ( kwargs["q_weight"], kwargs["w_scale"], - kwargs["w_zp"], + kwargs["w_zp"] if "w_zp" in kwargs else None, ) # Conv Params @@ -792,14 +784,25 @@ def qconv_weight_prepack(match: Match, *args, **kwargs): if has_free_symbols(x_shape): # For dynamic shape case, we can't get activation shape ahead of runtime. x_shape = None + if is_fp8: + # For float8, we assume the scales are from aten.full.default instead of + # a constant buffer to avoid constant folding of q/dq before fusion passes. + assert ( + w_scale.target is torch.ops.aten.full.default + and x_scale.target is torch.ops.aten.full.default + ) + with torch.utils._python_dispatch._disable_current_modes(): + w_scale_tensor = torch.tensor([w_scale.args[1]]) + match.graph.owning_module.register_buffer("w_scale", w_scale_tensor) + w_scale = match.graph.create_node("get_attr", "w_scale") graph = match.graph with graph.inserting_before(conv_node): # Insert weight prepack node and the QConv node packed_weight_inputs = ( qw, w_scale, - x_scale, - x_zp, + x_scale.args[1] if is_fp8 else x_scale, + 0, stride, padding, dilation, @@ -830,9 +833,16 @@ def qconv_weight_prepack(match: Match, *args, **kwargs): [], # scalars "", # algorithm ) - new_conv_node = graph.call_function( - torch.ops.onednn.qconv_pointwise.default, args=new_args - ) + Node = torch.fx.node.Node + # fp8 not need zp + if isinstance(x_scale, Node) and (isinstance(x_zp, Node) or is_fp8): + new_conv_node = graph.call_function( + torch.ops.onednn.qconv_pointwise.tensor, args=new_args + ) + else: + new_conv_node = graph.call_function( + torch.ops.onednn.qconv_pointwise.default, args=new_args + ) conv_node.replace_all_uses_with(new_conv_node) new_conv_node.meta.update(conv_node.meta) @@ -847,7 +857,7 @@ def qconv_weight_prepack(match: Match, *args, **kwargs): graph.erase_node(clone_node) # type: ignore[arg-type] if dtype == torch.bfloat16: graph.erase_node(weight_to_bf16_node) # type: ignore[possibly-undefined, arg-type] - graph.erase_node(dequant_per_channel) # type: ignore[arg-type] + graph.erase_node(dequant) # type: ignore[arg-type] counters["inductor"]["qconv_weight_prepack_matcher_count"] += 1 counters["inductor"]["qconv_weight_prepack_matcher_nodes"] += len( match.nodes @@ -855,17 +865,17 @@ def qconv_weight_prepack(match: Match, *args, **kwargs): def _generate_dequant_convolution_node_pattern( - _dequant_per_channel_pattern, dtype=torch.float32 + _dequant_pattern, dtype=torch.float32, is_fp8=False ): assert dtype in [torch.float32, torch.bfloat16] dequant_convolution_node_pattern = CallFunction( aten.convolution.default, _may_generate_pattern_with_dtype_convert( - get_dequantize_per_tensor_activation_pattern(), + get_dequantize_per_tensor_activation_pattern(is_fp8=is_fp8), KeywordArg("autocast_act_dtype"), dtype == torch.bfloat16, ), - _dequant_per_channel_pattern, + _dequant_pattern, KeywordArg("b"), KeywordArg("stride"), KeywordArg("padding"), @@ -877,24 +887,30 @@ def _generate_dequant_convolution_node_pattern( return dequant_convolution_node_pattern -def _generate_qconv_weight_prepack_patterns(dtype=torch.float32): +def _generate_qconv_weight_prepack_patterns(dtype=torch.float32, is_fp8=False): assert dtype in [torch.float32, torch.bfloat16] + if is_fp8: + dequant_wgt_pattern = dequantize_fp8_weight_pattern + else: + dequant_wgt_pattern = dequantize_per_channel_weight_pattern return ( _generate_dequant_convolution_node_pattern( - dequantize_per_channel_weight_pattern + dequant_wgt_pattern if dtype == torch.float32 - else dequantize_per_channel_to_bf16_weight_pattern, + else get_dequantize_to_bf16_weight_pattern(dequant_wgt_pattern), dtype, + is_fp8=is_fp8, ), # There is another pattern due to the pass of convert_conv_weights_to_channels_last # https://github.com/pytorch/pytorch/blob/07107919297db3f8ab37f11c12666b6d6d5f692e/torch/_inductor/freezing.py#L338-L362. # Depend on some heuristics, it may or may not insert to(channel_last) node - # between convolution and dequant_per_channel node + # between convolution and dequant node _generate_dequant_convolution_node_pattern( - dequantize_per_channel_clone_weight_pattern + get_dequantize_clone_weight_pattern(dequant_wgt_pattern) if dtype == torch.float32 - else dequantize_per_channel_to_bf16_clone_weight_pattern, + else get_dequantize_to_bf16_clone_weight_pattern(dequant_wgt_pattern), dtype, + is_fp8=is_fp8, ), ) @@ -1302,12 +1318,7 @@ def _generate_qlinear_weight_prepack_patterns( is_fp8=False, ): if is_fp8: - dequant_wgt_pattern = CallFunction( - torch.ops.torchao.dequantize_affine_float8_non_decomposed.default, - KeywordArg("q_weight"), - KeywordArg("w_scale"), - output_dtype=KeywordArg("w_dtype"), - ) + dequant_wgt_pattern = dequantize_fp8_weight_pattern else: dequant_wgt_pattern = dequantize_per_channel_weight_pattern if input_dim_exceeds_two and not input_contiguous: @@ -1449,12 +1460,16 @@ def _register_dequant_promotion(): def _register_qconv_weight_prepack(): - for dtype in [torch.float32, torch.bfloat16]: - weight_prepack_patterns = _generate_qconv_weight_prepack_patterns(dtype) + for dtype, is_fp8 in itertools.product( + [torch.float32, torch.bfloat16], [True, False] + ): + weight_prepack_patterns = _generate_qconv_weight_prepack_patterns( + dtype, is_fp8=is_fp8 + ) for weight_prepack_pattern in weight_prepack_patterns: # Register to pass_number 1, so we can do dequant promotion in pass_number 0. _register_qconv_weight_prepack_pass( - weight_prepack_pattern, pass_number=1, dtype=dtype + weight_prepack_pattern, pass_number=1, dtype=dtype, is_fp8=is_fp8 ) @@ -2053,13 +2068,25 @@ def qconv(match: Match, *args, **kwargs): kwargs["groups"], ) output_dtype = _get_pattern_output_dtype(match) - assert output_dtype in [torch.int8, torch.uint8, torch.float32, torch.bfloat16] + assert output_dtype in [ + torch.int8, + torch.uint8, + torch.float8_e4m3fn, + torch.float32, + torch.bfloat16, + ] # Output QParams - o_inv_scale = ( - kwargs["o_inv_scale"] - if (output_dtype == torch.uint8 or output_dtype == torch.int8) - else 1.0 - ) + if output_dtype == torch.float8_e4m3fn: + # For float8, we assume the scale is from aten.full.default instead of + # a constant buffer to avoid constant folding of q/dq before fusion passes. + assert kwargs["o_inv_scale"].target is torch.ops.aten.full.default + o_inv_scale = kwargs["o_inv_scale"].args[1] + else: + o_inv_scale = ( + kwargs["o_inv_scale"] + if (output_dtype == torch.uint8 or output_dtype == torch.int8) + else 1.0 + ) o_zero_point = ( kwargs["o_zp"] if (output_dtype == torch.uint8 or output_dtype == torch.int8) @@ -2165,56 +2192,69 @@ def _register_qconv_unary_fusion(): _silu_fusion, ) - for original_pattern_output_dtype in [torch.float32, torch.bfloat16]: + combinations = itertools.product( + [torch.float32, torch.bfloat16], [False, True], [False, True] + ) + for original_pattern_output_dtype, x_scale_zp_are_tensors, is_fp8 in combinations: # Priority 1 to match: QConv2d Unary pattern with int8 output # If a pattern1 is a sub-set of pattern2, we should try to match pattern2 firstly. # For example: pattern1 is qconv_fp32 -> relu, pattern2 is qconv_fp32 -> relu -> quant is_bf16 = original_pattern_output_dtype == torch.bfloat16 + computation_op = ( + torch.ops.onednn.qconv_pointwise.tensor + if x_scale_zp_are_tensors + else torch.ops.onednn.qconv_pointwise.default + ) conv_unary_replace_patterns = { PostOpAttr( "none", None, "none", [], "" ): generate_pattern_with_output_quant( - get_qconv_pt2e_pattern(1), + get_qconv_pt2e_pattern(x_scale_zp_are_tensors, 1), + is_fp8=is_fp8, ), PostOpAttr( "none", None, "relu", [], "" ): generate_pattern_with_output_quant( generate_pattern_with_unary( - get_qconv_pt2e_pattern(1), aten.relu.default + get_qconv_pt2e_pattern(x_scale_zp_are_tensors, 1), aten.relu.default ), + is_fp8=is_fp8, ), PostOpAttr( "none", None, "hardtanh", [], "" ): generate_pattern_with_output_quant( _unary_fusion_pattern( _hardtanh_fusion, - get_qconv_pt2e_pattern(1), + get_qconv_pt2e_pattern(x_scale_zp_are_tensors, 1), 1, is_bf16, ), with_dtype_convert=is_bf16, + is_fp8=is_fp8, ), PostOpAttr( "none", None, "hardswish", [], "" ): generate_pattern_with_output_quant( _unary_fusion_pattern( _hardswish_fusion, - get_qconv_pt2e_pattern(1 if is_bf16 else 2), + get_qconv_pt2e_pattern(x_scale_zp_are_tensors, 1 if is_bf16 else 2), 2, is_bf16, ), with_dtype_convert=is_bf16, + is_fp8=is_fp8, ), PostOpAttr( "none", None, "swish", [], "" ): generate_pattern_with_output_quant( _unary_fusion_pattern( _silu_fusion, - get_qconv_pt2e_pattern(1 if is_bf16 else 2), + get_qconv_pt2e_pattern(x_scale_zp_are_tensors, 1 if is_bf16 else 2), 2, is_bf16, ), with_dtype_convert=is_bf16, + is_fp8=is_fp8, ), } @@ -2223,21 +2263,21 @@ def _register_qconv_unary_fusion(): _register_qconv_post_op_fusion_pass( patterns, 3, # pass_number - torch.ops.onednn.qconv_pointwise.default, # computation_op + computation_op, # computation_op unary_attr, # unary_attr ) # Priority 2 to match: QConv2d Unary pattern with fp32/bfloat16 output conv_unary_replace_float_out_patterns = { PostOpAttr("none", None, "relu", [], ""): generate_pattern_with_unary( - get_qconv_pt2e_pattern(1), aten.relu.default + get_qconv_pt2e_pattern(x_scale_zp_are_tensors, 1), aten.relu.default ), PostOpAttr( "none", None, "hardtanh", [], "" ): _may_generate_pattern_with_dtype_convert( _unary_fusion_pattern( _hardtanh_fusion, - get_qconv_pt2e_pattern(1), + get_qconv_pt2e_pattern(x_scale_zp_are_tensors, 1), 1, is_bf16, ), @@ -2249,7 +2289,7 @@ def _register_qconv_unary_fusion(): ): _may_generate_pattern_with_dtype_convert( _unary_fusion_pattern( _hardswish_fusion, - get_qconv_pt2e_pattern(1 if is_bf16 else 2), + get_qconv_pt2e_pattern(x_scale_zp_are_tensors, 1 if is_bf16 else 2), 2, is_bf16, ), @@ -2261,7 +2301,7 @@ def _register_qconv_unary_fusion(): ): _may_generate_pattern_with_dtype_convert( _unary_fusion_pattern( _silu_fusion, - get_qconv_pt2e_pattern(1 if is_bf16 else 2), + get_qconv_pt2e_pattern(x_scale_zp_are_tensors, 1 if is_bf16 else 2), 2, is_bf16, ), @@ -2275,17 +2315,26 @@ def _register_qconv_unary_fusion(): _register_qconv_post_op_fusion_pass( patterns, 4, # pass_number - torch.ops.onednn.qconv_pointwise.default, # computation_op + computation_op, # computation_op unary_attr, # unary_attr ) def _register_qconv_binary_fusion(): - for int8_mixed_bf16_with_inplace_add in [False, True]: + for int8_mixed_bf16_with_inplace_add, x_scale_zp_are_tensors in itertools.product( + [False, True], [False, True] + ): + qconv_binary_op = ( + torch.ops.onednn.qconv2d_pointwise.binary_tensor + if x_scale_zp_are_tensors + else torch.ops.onednn.qconv2d_pointwise.binary + ) # Priority 1 to match: QConv2d Binary or Binary-Unary pattern with int8 output swap_binary_inputs_list = [False, True] binary_replace_patterns = {} - for swap_inputs in swap_binary_inputs_list: + for swap_inputs, is_fp8 in itertools.product( + swap_binary_inputs_list, [False, True] + ): binary_replace_patterns.update( { PostOpAttr( @@ -2293,11 +2342,12 @@ def _register_qconv_binary_fusion(): ): generate_pattern_with_output_quant( generate_pattern_with_binary( aten.add.Tensor, - get_qconv_pt2e_pattern(1), + get_qconv_pt2e_pattern(x_scale_zp_are_tensors, 1), dequantize_accum_pattern, int8_mixed_bf16_with_inplace_add, swap_inputs=swap_inputs, ), + is_fp8=is_fp8, ), PostOpAttr( "sum", 1.0, "relu", [], "" @@ -2305,13 +2355,14 @@ def _register_qconv_binary_fusion(): generate_pattern_with_unary( generate_pattern_with_binary( aten.add.Tensor, - get_qconv_pt2e_pattern(1), + get_qconv_pt2e_pattern(x_scale_zp_are_tensors, 1), dequantize_accum_pattern, int8_mixed_bf16_with_inplace_add, swap_inputs=swap_inputs, ), aten.relu.default, ), + is_fp8=is_fp8, ), } ) @@ -2320,7 +2371,7 @@ def _register_qconv_binary_fusion(): _register_qconv_post_op_fusion_pass( patterns, 3, # pass_number - torch.ops.onednn.qconv2d_pointwise.binary, # computation_op + qconv_binary_op, # computation_op binary_unary_attr, # binary_unary_attr ) @@ -2332,7 +2383,7 @@ def _register_qconv_binary_fusion(): PostOpAttr("sum", 1.0, "relu", [], ""): generate_pattern_with_unary( generate_pattern_with_binary( aten.add.Tensor, - get_qconv_pt2e_pattern(1), + get_qconv_pt2e_pattern(x_scale_zp_are_tensors, 1), KeywordArg("accum_after_dequant"), int8_mixed_bf16_with_inplace_add, swap_inputs=swap_inputs, @@ -2350,14 +2401,14 @@ def _register_qconv_binary_fusion(): _register_qconv_post_op_fusion_pass( patterns, 3, # pass_number - torch.ops.onednn.qconv2d_pointwise.binary, # computation_op + qconv_binary_op, # computation_op binary_unary_attr, # binary_unary_attr ) else: _register_qconv_post_op_fusion_pass( patterns, 4, # pass_number - torch.ops.onednn.qconv2d_pointwise.binary, # computation_op + qconv_binary_op, # computation_op binary_unary_attr, # binary_unary_attr ) @@ -2370,7 +2421,7 @@ def _register_qconv_binary_fusion(): "sum", 1.0, "none", [], "" ): generate_pattern_with_binary( aten.add.Tensor, - get_qconv_pt2e_pattern(1), + get_qconv_pt2e_pattern(x_scale_zp_are_tensors, 1), KeywordArg("accum_after_dequant"), int8_mixed_bf16_with_inplace_add, swap_inputs=swap_inputs, @@ -2385,7 +2436,7 @@ def _register_qconv_binary_fusion(): _register_qconv_post_op_fusion_pass( patterns, 4 if int8_mixed_bf16_with_inplace_add else 5, # pass_number - torch.ops.onednn.qconv2d_pointwise.binary, # computation_op + qconv_binary_op, # computation_op binary_unary_attr, # binary_unary_attr ) @@ -2427,8 +2478,8 @@ def qlinear_post_op_fusion(match: Match, *args, **kwargs): # Output QParams if output_dtype == torch.float8_e4m3fn: - # For float8, torchao.quantize_affine_float8 requires tensor as scale - # Support scale node is full firstly + # For float8, we assume the scale is from aten.full.default instead of + # a constant buffer to avoid constant folding of q/dq before fusion passes. assert kwargs["o_inv_scale"].target is torch.ops.aten.full.default o_inv_scale = kwargs["o_inv_scale"].args[1] else: From 74b84e208f8b7456308d782bf67fb2faeeab4120 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Thu, 4 Dec 2025 16:01:37 -0800 Subject: [PATCH 11/18] Int8Tensor migration cleanup (#3407) * Int8Tensor migration Summary: This PR creates a new Int8Tensor and updates the configs to use the new Int8Tensor flow Test Plan: To ensure BC: ``` pytest test/quantization/test_quant_api.py ``` To test new Int8Tensor: ``` pytest test/quantization/quantize_/workflows/int8/test_int8_tensor.py ``` Reviewers: Subscribers: Tasks: Tags: * ruff fixes * add init * fix ruff again * update * wip * undo update tests * fix ruff * fix varname * fix typing * add tests * fix dtype * fix ci * address granularity cr * update _choose_quant_func_and_quantize_tensor * make block size required attribute * made dtype required as well * address nits * skip per tensor weight only test for now --- .../workflows/int8/test_int8_tensor.py | 94 ++++---- torchao/quantization/__init__.py | 2 + torchao/quantization/quant_api.py | 82 +++---- torchao/quantization/quant_primitives.py | 7 +- .../common/quantize_tensor_kwargs.py | 9 + .../quantize_/workflows/__init__.py | 2 + .../quantize_/workflows/int8/int8_tensor.py | 212 +++++++----------- 7 files changed, 194 insertions(+), 214 deletions(-) diff --git a/test/quantization/quantize_/workflows/int8/test_int8_tensor.py b/test/quantization/quantize_/workflows/int8/test_int8_tensor.py index 2acdff2b84..2819903e69 100644 --- a/test/quantization/quantize_/workflows/int8/test_int8_tensor.py +++ b/test/quantization/quantize_/workflows/int8/test_int8_tensor.py @@ -18,11 +18,23 @@ quantize_, ) from torchao.quantization.granularity import PerRow, PerTensor +from torchao.quantization.quant_primitives import MappingType from torchao.quantization.utils import compute_error, get_block_size from torchao.testing.model_architectures import ToyTwoLinearModel from torchao.testing.utils import TorchAOIntegrationTestCase from torchao.utils import torch_version_at_least +INT8_TEST_CONFIGS = [ + Int8WeightOnlyConfig(version=2, granularity=PerTensor()), + Int8WeightOnlyConfig(version=2, granularity=PerRow()), + Int8DynamicActivationInt8WeightConfig( + version=2, granularity=PerTensor(), act_mapping_type=MappingType.SYMMETRIC + ), + Int8DynamicActivationInt8WeightConfig( + version=2, granularity=PerRow(), act_mapping_type=MappingType.SYMMETRIC + ), +] + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @common_utils.instantiate_parametrized_tests @@ -36,13 +48,7 @@ def setUp(self): torch.manual_seed(42) - @common_utils.parametrize( - "config", - [ - Int8DynamicActivationInt8WeightConfig(version=2), - Int8WeightOnlyConfig(version=2), - ], - ) + @common_utils.parametrize("config", INT8_TEST_CONFIGS) def test_creation_and_attributes(self, config): """Test tensor creation, dtypes, and ranges""" linear = torch.nn.Linear( @@ -60,15 +66,17 @@ def test_creation_and_attributes(self, config): self.assertEqual(w.qdata.dtype, torch.int8) self.assertTrue(torch.all(w.qdata >= -128) and torch.all(w.qdata <= 127)) + if isinstance(config.granularity, PerRow): + self.assertEqual(w.scale.shape, (w.shape[0], 1)) + elif isinstance(config.granularity, PerTensor): + self.assertEqual(w.scale.shape, (1, 1)) + + if hasattr(config, "act_mapping_type"): + self.assertEqual(w.act_quant_kwargs.mapping_type, config.act_mapping_type) + @common_utils.parametrize("dtype", [torch.bfloat16, torch.float32]) @common_utils.parametrize("compile", [True, False]) - @common_utils.parametrize( - "config", - [ - Int8DynamicActivationInt8WeightConfig(version=2), - Int8WeightOnlyConfig(version=2), - ], - ) + @common_utils.parametrize("config", INT8_TEST_CONFIGS) @common_utils.parametrize( "sizes", [ @@ -84,6 +92,8 @@ def test_int8_linear_variants( sizes: tuple, ): """Test linear operation supports including shape and compile""" + torch.compiler.reset() + M, N, K = sizes input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda") model = ToyTwoLinearModel(K, N, K, dtype=dtype, device="cuda").eval() @@ -91,10 +101,19 @@ def test_int8_linear_variants( quantize_(model_q, config) - self.assertEqual(model_q.linear2.weight.scale.shape, (K,)) - self.assertEqual(model_q.linear2.weight.scale.ndim, 1) + if isinstance(config.granularity, PerRow): + self.assertEqual(model_q.linear2.weight.scale.shape, (K, 1)) + elif isinstance(config.granularity, PerTensor): + self.assertEqual(model_q.linear2.weight.scale.shape, (1, 1)) + + self.assertEqual(model_q.linear2.weight.scale.ndim, 2) if compile: + if isinstance(config, Int8WeightOnlyConfig) and isinstance( + config.granularity, PerTensor + ): + # currently the inductor lowering for weight only quant in core does not support per-tensor gpu, so this errors. Skipping for now, but will address this in core + return model_q = torch.compile(model_q, fullgraph=True) output_fp = model(input_tensor) @@ -104,13 +123,7 @@ def test_int8_linear_variants( f"Quantization error is too high got a SQNR of {compute_error(output_fp, output_quantized)}" ) - @common_utils.parametrize( - "config", - [ - Int8DynamicActivationInt8WeightConfig(version=2), - Int8WeightOnlyConfig(version=2), - ], - ) + @common_utils.parametrize("config", INT8_TEST_CONFIGS) @common_utils.parametrize("device", ["cpu", "cuda"]) @common_utils.parametrize("dtype", [torch.bfloat16, torch.float16]) def test_slice(self, config, device, dtype): @@ -128,27 +141,24 @@ def test_slice(self, config, device, dtype): self.assertEqual(weight1.qdata, dummy.weight.qdata.narrow(0, 0, slice_sizes[0])) self.assertEqual(weight2.qdata, dummy.weight.qdata.narrow(1, 0, slice_sizes[1])) - self.assertEqual(weight1.scale, dummy.weight.scale.narrow(0, 0, slice_sizes[0])) + + if isinstance(config.granularity, PerRow): + self.assertEqual( + weight1.scale, dummy.weight.scale.narrow(0, 0, slice_sizes[0]) + ) + self.assertEqual(weight2.scale, dummy.weight.scale) with self.assertRaises(NotImplementedError): _ = dummy.weight[::2] - @common_utils.parametrize( - "config", - [ - Int8DynamicActivationInt8WeightConfig, - Int8WeightOnlyConfig, - ], - ) - @common_utils.parametrize("granularity", [PerTensor(), PerRow()]) - def test_index_select(self, config, granularity): + @common_utils.parametrize("config", INT8_TEST_CONFIGS) + def test_index_select(self, config): """test that `x_0 = x[0]` works when `x` is a 2D quantized tensor.""" N, K = 256, 512 x = torch.randn(N, K, device="cuda", dtype=torch.bfloat16) linear = torch.nn.Linear(K, N, bias=False, dtype=torch.bfloat16, device="cuda") linear.weight.data = x - config = config(version=2, granularity=granularity) quantize_(linear, config) x_int8 = linear.weight @@ -160,22 +170,16 @@ def test_index_select(self, config, granularity): ) # Test block_size granularity - if isinstance(granularity, PerRow): + if isinstance(config.granularity, PerRow): self.assertEqual( - list(get_block_size(x_int8.shape, x_int8.granularity)), [1, K] + list(get_block_size(x_int8.shape, config.granularity)), [1, K] ) - elif isinstance(granularity, PerTensor): + elif isinstance(config.granularity, PerTensor): self.assertEqual( - list(get_block_size(x_int8.shape, x_int8.granularity)), [N, K] + list(get_block_size(x_int8.shape, config.granularity)), [N, K] ) - @common_utils.parametrize( - "config", - [ - Int8DynamicActivationInt8WeightConfig(version=2), - Int8WeightOnlyConfig(version=2), - ], - ) + @common_utils.parametrize("config", INT8_TEST_CONFIGS) def test_dequantization_accuracy(self, config): """Test dequantization accuracy separately""" linear = torch.nn.Linear( diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index a1ca6b0b94..80e11dda5b 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -98,6 +98,7 @@ Int4PreshuffledTensor, Int4Tensor, Int4TilePackedTo4dTensor, + Int8Tensor, IntxOpaqueTensor, IntxUnpackedToInt8Tensor, ) @@ -164,6 +165,7 @@ "FqnToConfig", "ModuleFqnToConfig", # tensor subclasses + "Int8Tensor", "Int4Tensor", "Int4PlainInt32Tensor", "Int4PreshuffledTensor", diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 9a1bfeb0a5..24d6b6676c 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -1341,6 +1341,10 @@ class Int8WeightOnlyConfig(AOBaseConfig): def __post_init__(self): torch._C._log_api_usage_once("torchao.quantization.Int8WeightOnlyConfig") + if self.version == 2: + assert self.group_size is None, ( + f"Only support version 2 with group_size=None, got {self.group_size}" + ) # for BC @@ -1522,9 +1526,7 @@ class Int8DynamicActivationInt8WeightConfig(AOBaseConfig): layout: Optional[Layout] = PlainLayout() act_mapping_type: Optional[MappingType] = MappingType.SYMMETRIC weight_only_decode: bool = False - # TODO: Revisit for supported granularitys - # https://github.com/pytorch/ao/pull/3241#discussion_r2551497849 - granularity: Optional[Granularity] = PerRow() + granularity: Granularity = PerRow() set_inductor_config: bool = True version: int = 1 @@ -1541,37 +1543,30 @@ def __post_init__(self): def _int8_dynamic_activation_int8_weight_quantize_tensor(weight, config): - layout = config.layout - act_mapping_type = config.act_mapping_type - weight_only_decode = config.weight_only_decode - - in_features = weight.shape[-1] - # int8 dynamic quantization only has benefit when in_feature > 16 - if in_features <= 16: - logger.info( - f"Skipping applying Int8DynamicActivationInt8WeightConfig to weight of shape {weight.shape}" - f" because `in_feature` is <= 16: {in_features}" - ) - return weight + if config.version == 1: + layout = config.layout + act_mapping_type = config.act_mapping_type + weight_only_decode = config.weight_only_decode + + in_features = weight.shape[-1] + # int8 dynamic quantization only has benefit when in_feature > 16 + if in_features <= 16: + logger.info( + f"Skipping applying Int8DynamicActivationInt8WeightConfig to weight of shape {weight.shape}" + f" because `in_feature` is <= 16: {in_features}" + ) + return weight - # weight settings - mapping_type = MappingType.SYMMETRIC - weight_zero_point_domain = ZeroPointDomain.NONE + # weight settings + mapping_type = MappingType.SYMMETRIC + weight_zero_point_domain = ZeroPointDomain.NONE - target_dtype = torch.int8 - eps = torch.finfo(torch.float32).eps - zero_point_dtype = torch.int64 + def get_weight_block_size(x): + return tuple([1 for _ in range(x.dim() - 1)] + [x.shape[-1]]) - if config.version == 1: - warnings.warn( - "Config Deprecation: version 1 of Int8DynamicActivationInt8WeightConfig is deprecated and will no longer be supported in a future release, please use version 2, see https://github.com/pytorch/ao/issues/2752 for more details" - ) - if isinstance(config.granularity, PerTensor): - block_size = weight.shape - else: - block_size = tuple( - [1 for _ in range(weight.dim() - 1)] + [weight.shape[-1]] - ) + target_dtype = torch.int8 + eps = torch.finfo(torch.float32).eps + zero_point_dtype = torch.int64 if weight_only_decode: input_quant_func = _int8_symm_per_token_reduced_range_quant_noop_decode @@ -1582,7 +1577,8 @@ def _int8_dynamic_activation_int8_weight_quantize_tensor(weight, config): else: input_quant_func = _int8_asymm_per_token_quant - quantized_weight = to_affine_quantized_intx( + block_size = get_weight_block_size(weight) + new_weight = to_affine_quantized_intx( weight, mapping_type, block_size, @@ -1592,24 +1588,32 @@ def _int8_dynamic_activation_int8_weight_quantize_tensor(weight, config): _layout=layout, zero_point_domain=weight_zero_point_domain, ) - quantized_weight = to_linear_activation_quantized( - quantized_weight, input_quant_func - ) + quantized_weight = to_linear_activation_quantized(new_weight, input_quant_func) else: from torchao.quantization.quantize_.workflows.int8.int8_tensor import ( QuantizeTensorToInt8Kwargs, ) + assert config.granularity in {PerRow(), PerTensor()}, ( + "Only PerRow and PerTensor are supported" + ) + weight_granularity = config.granularity + act_granularity = config.granularity + + assert config.act_mapping_type == MappingType.SYMMETRIC, ( + "asymmetric dynamic quant not supported currently" + ) assert config.version == 2, f"Unexpected version: {config.version}" # TODO: Symmentric/Asymmetric choice for weight quantization # https://github.com/pytorch/ao/pull/3241#discussion_r2551515539 - # TODO: Add block_size args to return in from_hp - # https://github.com/pytorch/ao/pull/3241#discussion_r2552016429 quantized_weight = Int8Tensor.from_hp( weight, - granularity=config.granularity, - act_quant_kwargs=QuantizeTensorToInt8Kwargs(granularity=config.granularity), + granularity=weight_granularity, + act_quant_kwargs=QuantizeTensorToInt8Kwargs( + granularity=act_granularity, + mapping_type=config.act_mapping_type, + ), ) return quantized_weight diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index ee1da11c50..9bdb3871a2 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -1217,6 +1217,7 @@ def choose_qparams_affine( eps: Optional[float] = None, scale_dtype: Optional[torch.dtype] = None, zero_point_dtype: Optional[torch.dtype] = torch.int32, + keepdim: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: @@ -1247,6 +1248,7 @@ def choose_qparams_affine( eps, scale_dtype, zero_point_dtype, + keepdim, ) @@ -1521,6 +1523,7 @@ def _choose_qparams_affine( eps: Optional[float] = None, scale_dtype: Optional[torch.dtype] = None, zero_point_dtype: Optional[torch.dtype] = None, + keepdim: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: """op definition that has compatible signatures with custom op library @@ -1550,8 +1553,8 @@ def _choose_qparams_affine( ) input = input.view(shape_for_reduction) - min_val = torch.amin(input, dim=reduction_dims, keepdim=False) - max_val = torch.amax(input, dim=reduction_dims, keepdim=False) + min_val = torch.amin(input, dim=reduction_dims, keepdim=keepdim) + max_val = torch.amax(input, dim=reduction_dims, keepdim=keepdim) min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) diff --git a/torchao/quantization/quantize_/common/quantize_tensor_kwargs.py b/torchao/quantization/quantize_/common/quantize_tensor_kwargs.py index 0adc8c786d..e4544a2f0c 100644 --- a/torchao/quantization/quantize_/common/quantize_tensor_kwargs.py +++ b/torchao/quantization/quantize_/common/quantize_tensor_kwargs.py @@ -39,7 +39,9 @@ def _choose_quant_func_and_quantize_tensor( """ from torchao.quantization.quantize_.workflows import ( Float8Tensor, + Int8Tensor, QuantizeTensorToFloat8Kwargs, + QuantizeTensorToInt8Kwargs, ) if isinstance(quant_kwargs, QuantizeTensorToFloat8Kwargs): @@ -53,4 +55,11 @@ def _choose_quant_func_and_quantize_tensor( quant_kwargs.kernel_preference, ) + if isinstance(quant_kwargs, QuantizeTensorToInt8Kwargs): + return Int8Tensor.from_hp( + tensor, + quant_kwargs.granularity, + mapping_type=quant_kwargs.mapping_type, + ) + raise NotImplementedError(f"Quant kwargs not supported: {quant_kwargs}") diff --git a/torchao/quantization/quantize_/workflows/__init__.py b/torchao/quantization/quantize_/workflows/__init__.py index 962f95157f..17cb15d4f7 100644 --- a/torchao/quantization/quantize_/workflows/__init__.py +++ b/torchao/quantization/quantize_/workflows/__init__.py @@ -42,6 +42,8 @@ "QuantizeTensorToInt8Kwargs", "Float8Tensor", "QuantizeTensorToFloat8Kwargs", + "Int8Tensor", + "QuantizeTensorToInt8Kwargs", "Int4ChooseQParamsAlgorithm", "Int4PackingFormat", "IntxChooseQParamsAlgorithm", diff --git a/torchao/quantization/quantize_/workflows/int8/int8_tensor.py b/torchao/quantization/quantize_/workflows/int8/int8_tensor.py index 6ca31326cd..dd422b90f6 100644 --- a/torchao/quantization/quantize_/workflows/int8/int8_tensor.py +++ b/torchao/quantization/quantize_/workflows/int8/int8_tensor.py @@ -5,20 +5,24 @@ # LICENSE file in the root directory of this source tree. from dataclasses import dataclass -from typing import Optional +from typing import List, Optional import torch from torch.utils._python_dispatch import return_and_correct_aliasing +from torchao.float8.inference import _slice_scale_for_dimension from torchao.kernel import int_scaled_matmul -from torchao.quantization.granularity import Granularity, PerRow +from torchao.quantization.granularity import Granularity from torchao.quantization.quant_primitives import ( MappingType, choose_qparams_affine, dequantize_affine, quantize_affine, ) -from torchao.quantization.quantize_.common import QuantizeTensorKwargs +from torchao.quantization.quantize_.common import ( + QuantizeTensorKwargs, + _choose_quant_func_and_quantize_tensor, +) from torchao.quantization.utils import get_block_size from torchao.utils import TorchAOBaseTensor, fill_defaults @@ -29,18 +33,22 @@ @dataclass class QuantizeTensorToInt8Kwargs(QuantizeTensorKwargs): - """Tensor kwargs for creating int8 tensor (either activation or weight) + """Tensor kwargs for creating int8 tensor for activation. Args: granularity: the granularity for the Tensor, currently either PerRow() or PerTensor() + mapping_type: whether to use symmetric or asymmetric quant, only symmetric is supported currently """ - granularity: Granularity = PerRow() + granularity: Granularity + mapping_type: MappingType = MappingType.SYMMETRIC class Int8Tensor(TorchAOBaseTensor): """ - int8 quantized tensor with plain layout + int8 quantized tensor with plain layout. + + Currently only Symmetric quantization is supported. Tensor Attributes: qdata: (N, K) or (B, N, K) int8 quantized weight data (2D or 3D) @@ -54,21 +62,22 @@ class Int8Tensor(TorchAOBaseTensor): # TODO: Static quantization support using `static_scale` tensor_data_names = ["qdata", "scale"] - tensor_attribute_names = ["granularity"] - optional_tensor_attribute_names = ["act_quant_kwargs", "block_size", "dtype"] + tensor_attribute_names = ["block_size", "dtype"] + optional_tensor_attribute_names = [ + "act_quant_kwargs", + ] def __new__( cls: type, qdata: torch.Tensor, scale: torch.Tensor, - granularity: Optional[Granularity] = None, - block_size: Optional[torch.Size] = None, + block_size: List[int], + dtype: torch.dtype, act_quant_kwargs: Optional[QuantizeTensorToInt8Kwargs] = None, - dtype: Optional[torch.dtype] = None, ): kwargs = { "device": qdata.device, - "dtype": dtype or scale.dtype, + "dtype": dtype, "requires_grad": False, } return torch.Tensor._make_wrapper_subclass(cls, qdata.shape, **kwargs) @@ -77,16 +86,15 @@ def __init__( self, qdata: torch.Tensor, scale: torch.Tensor, - granularity: Granularity, - block_size: Optional[torch.Size] = None, + block_size: List[int], + dtype: torch.dtype, act_quant_kwargs: Optional[QuantizeTensorToInt8Kwargs] = None, - dtype: Optional[torch.dtype] = None, ): super().__init__() self.qdata = qdata self.scale = scale - self.granularity = granularity - self.block_size = block_size or get_block_size(qdata.shape, granularity) + self.block_size = block_size + # don't set dtype because this gets done in __new__ self.act_quant_kwargs = act_quant_kwargs def __repr__(self): @@ -95,7 +103,6 @@ def __repr__(self): f"act_quant_kwargs={self.act_quant_kwargs}, " f"qdata={self.qdata}, " f"scale={self.scale}, " - f"granularity={self.granularity}, " f"block_size={self.block_size}, " f"shape={self.shape}, " f"device={self.device}, " @@ -105,32 +112,29 @@ def __repr__(self): @classmethod def from_hp( cls, - w_hp: torch.Tensor, - granularity: Granularity = PerRow(), + hp_tensor: torch.Tensor, + granularity: Granularity, act_quant_kwargs: Optional[QuantizeTensorToInt8Kwargs] = None, + mapping_type=MappingType.SYMMETRIC, ): """Create Int8Tensor from high-precision tensor""" - block_size = get_block_size(w_hp.shape, granularity) - - if w_hp.dim() not in [2, 3] or len(block_size) != w_hp.dim(): - raise ValueError( - f"Expected 2D or 3D tensor with matching block_size dimensions, " - f"got tensor dim={w_hp.dim()}, block_size length={len(block_size)}" - ) + block_size = get_block_size(hp_tensor.shape, granularity) + block_size = list(block_size) scale, zero_point = choose_qparams_affine( - input=w_hp, - mapping_type=MappingType.SYMMETRIC, + input=hp_tensor, + mapping_type=mapping_type, block_size=block_size, target_dtype=torch.int8, quant_min=-128, quant_max=127, - scale_dtype=w_hp.dtype, + scale_dtype=hp_tensor.dtype, zero_point_dtype=torch.int8, + keepdim=True, ) int_data = quantize_affine( - w_hp, + hp_tensor, block_size=block_size, scale=scale, zero_point=zero_point, @@ -140,28 +144,22 @@ def from_hp( return cls( int_data, scale, - granularity, - block_size=block_size, + block_size, + hp_tensor.dtype, act_quant_kwargs=act_quant_kwargs, - dtype=w_hp.dtype, ) def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor: """Dequantize int8 tensor to floating point""" - if output_dtype is None: - output_dtype = self.dtype - - block_size = get_block_size(self.qdata.shape, self.granularity) - return dequantize_affine( input=self.qdata, - block_size=block_size, + block_size=self.block_size, scale=self.scale, zero_point=None, input_dtype=torch.int8, quant_min=-128, quant_max=127, - output_dtype=output_dtype, + output_dtype=output_dtype if output_dtype is not None else self.dtype, ) @@ -169,64 +167,6 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor implements_torch_function = Int8Tensor.implements_torch_function -def _slice_scale( - scale: torch.Tensor, - data_shape: list[int], - dim: int, - start: int, - end: int, - step: int, -) -> torch.Tensor: - """ - Slice the scale tensor appropriately based on the data tensor slicing. - This function calculates how the scale should be sliced when the data tensor - is sliced along a given dimension, taking into account the block structure. - - Example: - If data_shape is [256, 128] and scale shape is [1] (indicating per-tensor scaling), - slicing along any dimension should return the same scale tensor. - - If data_shape is [256, 128] and scale shape is [256] (indicating per-row scaling), - and we slice data along dim=0 from 64 to 192, the corresponding scale - """ - aten = torch.ops.aten - - # Case 1: Per-tensor quantization (scalar scale) - if scale.numel() <= 1: - return scale - - # Case 2: Per-row quantization (1D scale) - # Scale is per-element along this dimension - if scale.ndim == 1: - if dim == 0: - return aten.slice.Tensor(scale, 0, start, end, step) - else: - return scale - - # Case 3: Per-block quantization (2D scale) - block_sizes = tuple( - data_shape[i] // scale.shape[i] for i in range(len(scale.shape)) - ) - - block_size_for_dim = block_sizes[dim] - - if step > 1: - raise NotImplementedError( - "Slicing with step > 1 is not implemented for scale tensors." - ) - - # There is blocking in this dimension - # Calculate which scale elements correspond to the sliced data - scale_start = start // block_size_for_dim if start is not None else None - scale_end = ( - (end + block_size_for_dim - 1) // block_size_for_dim - if end is not None - else None - ) - - return aten.slice.Tensor(scale, dim, scale_start, scale_end, 1) - - @implements(aten.linear.default) @implements_torch_function(torch.nn.functional.linear) def _(func, types, args, kwargs): @@ -237,14 +177,15 @@ def _(func, types, args, kwargs): args[2] if len(args) > 2 else None, ) - if not isinstance(weight_tensor, Int8Tensor): - raise TypeError(f"Expected weight to be Int8Tensor, got {type(weight_tensor)}") + assert isinstance(weight_tensor, Int8Tensor), ( + f"Expected weight to be Int8Tensor, got {type(weight_tensor)}" + ) output_dtype = activation_tensor.dtype if weight_tensor.act_quant_kwargs is not None: - activation_tensor = Int8Tensor.from_hp( - activation_tensor, weight_tensor.act_quant_kwargs.granularity + activation_tensor = _choose_quant_func_and_quantize_tensor( + activation_tensor, weight_tensor.act_quant_kwargs ) # Dynamic activation quantization path @@ -270,7 +211,7 @@ def _(func, types, args, kwargs): y_dot_scaled = int_scaled_matmul( tmp, w_vals_int8_t, x_scales.reshape(-1, 1).to(intermediate_dtype) ).to(output_dtype) - y = (y_dot_scaled * w_scales).reshape( + y = (y_dot_scaled * w_scales.flatten()).reshape( *x_vals_int8.shape[:-1], y_dot_scaled.shape[-1] ) @@ -281,7 +222,7 @@ def _(func, types, args, kwargs): activation_tensor.reshape(-1, activation_tensor.shape[-1]), w_vals_int8_t.to(output_dtype), ) - y = m * weight_tensor.scale.to(m.dtype) + y = m * weight_tensor.scale.to(m.dtype).flatten() y = y.reshape(*activation_tensor.shape[:-1], weight_tensor.qdata.shape[0]) if bias is not None: @@ -310,7 +251,19 @@ def _(func, types, args, kwargs): end = self.shape[dim] sliced_qdata = aten.slice.Tensor(self.qdata, dim, start, end, step) - sliced_scale = _slice_scale(self.scale, self.qdata.shape, dim, start, end, step) + if self.scale.numel() == 1: + # Per-tensor quantization - scale doesn't change + sliced_scale = self.scale + else: + # Block-wise quantization - need to slice the scale appropriately + sliced_scale = _slice_scale_for_dimension( + self.scale, self.qdata.shape, dim, start, end, step + ) + + # adjust block_size since the shape has changed, block_size[i] should not be greater than shape[i] + block_size = self.block_size.copy() + for i in range(len(self.block_size)): + block_size[i] = min(block_size[i], sliced_qdata.shape[i]) return return_and_correct_aliasing( func, @@ -319,39 +272,42 @@ def _(func, types, args, kwargs): Int8Tensor( sliced_qdata, sliced_scale, - self.granularity, - block_size=get_block_size(sliced_qdata.shape, self.granularity), + block_size, + self.dtype, act_quant_kwargs=self.act_quant_kwargs, - dtype=self.dtype, ), ) -@implements(aten.select.int) +@implements(aten.index.Tensor) def _(func, types, args, kwargs): - """Select operation for Int8Tensor""" - self, dim, index = args - if dim != 0: - raise NotImplementedError(f"Only dim=0 supported, got dim={dim}") - - selected_qdata = self.qdata[index] - selected_scale = _slice_scale( - self.scale, self.qdata.shape, dim, index, index + 1, step=1 - ).squeeze(0) - return return_and_correct_aliasing( func, args, kwargs, - Int8Tensor( - selected_qdata, - selected_scale, - self.granularity, - block_size=get_block_size(selected_qdata.shape, self.granularity), - act_quant_kwargs=self.act_quant_kwargs, - dtype=self.dtype, - ), + args[0]._apply_fn_to_data(lambda x: func(x, *args[1:], **kwargs)), + ) + + +@implements(aten.select.int) +def _(func, types, args, kwargs): + """Select operation for Int8Tensor""" + old_int8_tensor, dim, index = args + assert dim == 0, f"Int8Tensor aten.select.int with {dim=} is not yet supported" + assert len(old_int8_tensor.qdata.shape) == len(old_int8_tensor.scale.shape), ( + "unsupported" + ) + assert len(old_int8_tensor.qdata.shape) == len(old_int8_tensor.block_size), ( + "unsupported" + ) + new_int8_tensor = Int8Tensor( + old_int8_tensor.qdata[index], + old_int8_tensor.scale[index], + old_int8_tensor.block_size[1:], + old_int8_tensor.dtype, + old_int8_tensor.act_quant_kwargs, ) + return return_and_correct_aliasing(func, args, kwargs, new_int8_tensor) Int8Tensor.__module__ = "torchao.quantization" From 3adc2860793587e2964b71a3d020870690d8aec9 Mon Sep 17 00:00:00 2001 From: xiangdong <40376367+zxd1997066@users.noreply.github.com> Date: Fri, 5 Dec 2025 09:05:25 +0800 Subject: [PATCH 12/18] [xpu][test] Port 2 test/dtypes_{floatx, bitpacking} UT files to intel XPU (#3368) * enable test/dtypes/test_bitpacking.py on intel xpu * enable test/dtypes/test_floatx.py * enable test/dtypes/test_floatx.py * fix format issue * fix format issue * update _DEVICES --- test/dtypes/test_bitpacking.py | 24 ++++++++++++++++-------- test/dtypes/test_floatx.py | 13 +++++++------ 2 files changed, 23 insertions(+), 14 deletions(-) diff --git a/test/dtypes/test_bitpacking.py b/test/dtypes/test_bitpacking.py index 0ed4462d5d..9c54631b45 100644 --- a/test/dtypes/test_bitpacking.py +++ b/test/dtypes/test_bitpacking.py @@ -8,9 +8,11 @@ from torch.utils._triton import has_triton from torchao.dtypes.uintx.bitpacking import pack, pack_cpu, unpack, unpack_cpu +from torchao.utils import get_current_accelerator_device bit_widths = (1, 2, 3, 4, 5, 6, 7) dimensions = (0, -1, 1) +_DEVICE = get_current_accelerator_device() @pytest.fixture(autouse=True) @@ -30,17 +32,19 @@ def test_CPU(bit_width, dim): assert unpacked.allclose(test_tensor) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available") @pytest.mark.parametrize("bit_width", bit_widths) @pytest.mark.parametrize("dim", dimensions) def test_GPU(bit_width, dim): - test_tensor = torch.randint(0, 2**bit_width, (32, 32, 32), dtype=torch.uint8).cuda() + test_tensor = torch.randint(0, 2**bit_width, (32, 32, 32), dtype=torch.uint8).to( + _DEVICE + ) packed = pack(test_tensor, bit_width, dim=dim) unpacked = unpack(packed, bit_width, dim=dim) assert unpacked.allclose(test_tensor) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available") @pytest.mark.skipif(not has_triton(), reason="unsupported without triton") @pytest.mark.parametrize("bit_width", bit_widths) @pytest.mark.parametrize("dim", dimensions) @@ -48,22 +52,26 @@ def test_compile(bit_width, dim): torch._dynamo.config.specialize_int = True torch.compile(pack, fullgraph=True) torch.compile(unpack, fullgraph=True) - test_tensor = torch.randint(0, 2**bit_width, (32, 32, 32), dtype=torch.uint8).cuda() + test_tensor = torch.randint(0, 2**bit_width, (32, 32, 32), dtype=torch.uint8).to( + _DEVICE + ) packed = pack(test_tensor, bit_width, dim=dim) unpacked = unpack(packed, bit_width, dim=dim) assert unpacked.allclose(test_tensor) # these test cases are for the example pack walk through in the bitpacking.py file -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available") def test_pack_example(): test_tensor = torch.tensor( [0x30, 0x29, 0x17, 0x5, 0x20, 0x16, 0x9, 0x22], dtype=torch.uint8 - ).cuda() + ).to(_DEVICE) shard_4, shard_2 = pack(test_tensor, 6) print(shard_4, shard_2) - assert torch.tensor([0, 105, 151, 37], dtype=torch.uint8).cuda().allclose(shard_4) - assert torch.tensor([39, 146], dtype=torch.uint8).cuda().allclose(shard_2) + assert ( + torch.tensor([0, 105, 151, 37], dtype=torch.uint8).to(_DEVICE).allclose(shard_4) + ) + assert torch.tensor([39, 146], dtype=torch.uint8).to(_DEVICE).allclose(shard_2) unpacked = unpack([shard_4, shard_2], 6) assert unpacked.allclose(test_tensor) diff --git a/test/dtypes/test_floatx.py b/test/dtypes/test_floatx.py index a3dd4d19e3..19a7ca4c56 100644 --- a/test/dtypes/test_floatx.py +++ b/test/dtypes/test_floatx.py @@ -33,10 +33,11 @@ quantize_, ) from torchao.testing.utils import skip_if_rocm -from torchao.utils import is_fbcode +from torchao.utils import get_current_accelerator_device, is_fbcode -_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) _Floatx_DTYPES = [(3, 2), (2, 2)] +_DEVICE = get_current_accelerator_device() +_DEVICES = ["cpu"] + ([_DEVICE] if torch.accelerator.is_available() else []) class TestFloatxTensorCoreAQTTensorImpl(TestCase): @@ -87,7 +88,7 @@ def test_from_scaled_tc_floatx_compile(self, ebits, mbits, device): ) torch.testing.assert_close(actual, expected) - @unittest.skipIf(not torch.cuda.is_available(), reason="CUDA not available") + @unittest.skipIf(not torch.accelerator.is_available(), reason="GPU not available") @parametrize("ebits,mbits", _Floatx_DTYPES) def test_to_copy_device(self, ebits, mbits): from torchao.quantization.quant_primitives import ( @@ -101,8 +102,8 @@ def test_to_copy_device(self, ebits, mbits): _layout = FloatxTensorCoreLayout(ebits, mbits) floatx_tensor_impl = FloatxTensorCoreAQTTensorImpl.from_plain( x, scale, None, _layout - ).cuda() - assert floatx_tensor_impl.device.type == "cuda" + ).to(_DEVICE) + assert floatx_tensor_impl.device.type == _DEVICE.type floatx_tensor_impl = floatx_tensor_impl.cpu() assert floatx_tensor_impl.device.type == "cpu" @@ -114,7 +115,7 @@ def test_to_copy_device(self, ebits, mbits): @skip_if_rocm("ROCm enablement in progress") def test_fpx_weight_only(self, ebits, mbits, bias, dtype): N, OC, IC = 4, 256, 64 - device = "cuda" + device = _DEVICE linear = torch.nn.Linear(IC, OC, bias=bias, device=device, dtype=dtype) fpx_linear = copy.deepcopy(linear) From 24059b07642a6a83d93e5d4afee6a55fd826086f Mon Sep 17 00:00:00 2001 From: xiangdong <40376367+zxd1997066@users.noreply.github.com> Date: Fri, 5 Dec 2025 13:13:29 +0800 Subject: [PATCH 13/18] [xpu][test] Port 2 test/quantization/pt2e/test_{quantize_pt2e, quantize_pt2e_qat} UT files to intel XPU (#3405) * add test/quantization/pt2e/test_quantize_pt2e.py * add test/quantization/pt2e/test_quantize_pt2e.py * test/quantization/pt2e/test_quantize_pt2e_qat.py * test/quantization/pt2e/test_quantize_pt2e_qat.py * fix format issue * update format * increase timeout for xpu --- .github/workflows/xpu_test.yml | 4 ++-- test/quantization/pt2e/test_quantize_pt2e.py | 16 +++++++++------- .../pt2e/test_quantize_pt2e_qat.py | 18 ++++++++++-------- 3 files changed, 21 insertions(+), 17 deletions(-) diff --git a/.github/workflows/xpu_test.yml b/.github/workflows/xpu_test.yml index 3f7d1c7171..32420951e4 100644 --- a/.github/workflows/xpu_test.yml +++ b/.github/workflows/xpu_test.yml @@ -21,7 +21,7 @@ jobs: test: # Don't run on forked repos or empty test matrix # if: github.repository_owner == 'pytorch' && toJSON(fromJSON(inputs.test-matrix).include) != '[]' - timeout-minutes: 60 + timeout-minutes: 120 runs-on: linux.idc.xpu env: DOCKER_IMAGE: ci-image:pytorch-linux-noble-xpu-n-py3 @@ -166,7 +166,7 @@ jobs: GITHUB_RUN_NUMBER: ${{ github.run_number }} GITHUB_RUN_ATTEMPT: ${{ github.run_attempt }} SHA1: ${{ github.event.pull_request.head.sha || github.sha }} - timeout-minutes: 60 + timeout-minutes: 120 run: | set -x diff --git a/test/quantization/pt2e/test_quantize_pt2e.py b/test/quantization/pt2e/test_quantize_pt2e.py index 5cd5e9d459..072de11328 100644 --- a/test/quantization/pt2e/test_quantize_pt2e.py +++ b/test/quantization/pt2e/test_quantize_pt2e.py @@ -32,6 +32,7 @@ ) from torch.testing._internal.common_utils import ( TEST_CUDA, + TEST_XPU, TemporaryFileName, instantiate_parametrized_tests, parametrize, @@ -70,9 +71,10 @@ QuantizationConfig, ) from torchao.testing.pt2e.utils import PT2EQuantizationTestCase -from torchao.utils import torch_version_at_least +from torchao.utils import get_current_accelerator_device, torch_version_at_least -DEVICE_LIST = ["cpu"] + (["cuda"] if TEST_CUDA else []) +DEVICE_LIST = ["cpu"] + (["cuda"] if TEST_CUDA else []) + (["xpu"] if TEST_XPU else []) +_DEVICE = get_current_accelerator_device() if torch_version_at_least("2.7.0"): from torch.testing._internal.common_utils import ( @@ -2154,7 +2156,7 @@ def __init__(self) -> None: def forward(self, x): return self.bn(x) - if TEST_CUDA or TEST_HPU: + if TEST_CUDA or TEST_HPU or TEST_XPU: m = M().train().to(device) example_inputs = (torch.randn((1, 3, 3, 3), device=device),) @@ -2229,9 +2231,9 @@ def forward(self, x): x = self.dropout(x) return x - if TEST_CUDA: - m = M().train().cuda() - example_inputs = (torch.randn(1, 3, 3, 3).cuda(),) + if TEST_CUDA or TEST_XPU: + m = M().train().to(_DEVICE) + example_inputs = (torch.randn(1, 3, 3, 3).to(_DEVICE),) else: m = M().train() example_inputs = (torch.randn(1, 3, 3, 3),) @@ -2243,7 +2245,7 @@ def _assert_ops_are_correct(m: torch.fx.GraphModule, train: bool): bn_op = bn_train_op if train else bn_eval_op bn_node = self._get_node(m, bn_op) self.assertTrue(bn_node is not None) - if TEST_CUDA: + if TEST_CUDA or TEST_XPU: self.assertEqual(bn_node.args[5], train) dropout_node = self._get_node(m, torch.ops.aten.dropout.default) self.assertEqual(dropout_node.args[2], train) diff --git a/test/quantization/pt2e/test_quantize_pt2e_qat.py b/test/quantization/pt2e/test_quantize_pt2e_qat.py index 293d243f6a..2004c9e04a 100644 --- a/test/quantization/pt2e/test_quantize_pt2e_qat.py +++ b/test/quantization/pt2e/test_quantize_pt2e_qat.py @@ -28,7 +28,7 @@ skipIfNoQNNPACK, ) from torch.testing._internal.common_quantized import override_quantized_engine -from torch.testing._internal.common_utils import run_tests +from torch.testing._internal.common_utils import TEST_XPU, run_tests from torchao.quantization.pt2e import ( FusedMovingAvgObsFakeQuantize, @@ -52,7 +52,9 @@ XNNPACKQuantizer, get_symmetric_quantization_config, ) -from torchao.utils import torch_version_at_least +from torchao.utils import get_current_accelerator_device, torch_version_at_least + +_DEVICE = get_current_accelerator_device() class PT2EQATTestCase(QuantizationTestCase): @@ -453,10 +455,10 @@ def test_qat_conv_bn_fusion(self): self._verify_symmetric_xnnpack_qat_graph(m, self.example_inputs, has_relu=False) self._verify_symmetric_xnnpack_qat_numerics(m, self.example_inputs) - @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") + @unittest.skipIf(not TEST_CUDA and not TEST_XPU, "GPU unavailable") def test_qat_conv_bn_fusion_cuda(self): - m = self._get_conv_bn_model().cuda() - example_inputs = (self.example_inputs[0].cuda(),) + m = self._get_conv_bn_model().to(_DEVICE) + example_inputs = (self.example_inputs[0].to(_DEVICE),) self._verify_symmetric_xnnpack_qat_graph( m, example_inputs, @@ -540,10 +542,10 @@ def test_qat_conv_bn_relu_fusion(self): self._verify_symmetric_xnnpack_qat_graph(m, self.example_inputs, has_relu=True) self._verify_symmetric_xnnpack_qat_numerics(m, self.example_inputs) - @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") + @unittest.skipIf(not TEST_CUDA and not TEST_XPU, "GPU unavailable") def test_qat_conv_bn_relu_fusion_cuda(self): - m = self._get_conv_bn_model(has_relu=True).cuda() - example_inputs = (self.example_inputs[0].cuda(),) + m = self._get_conv_bn_model(has_relu=True).to(_DEVICE) + example_inputs = (self.example_inputs[0].to(_DEVICE),) self._verify_symmetric_xnnpack_qat_graph( m, example_inputs, From 565e813ce03e705d6ceba535464f7a2996b679fa Mon Sep 17 00:00:00 2001 From: Artur Lesniak Date: Fri, 5 Dec 2025 09:23:11 +0100 Subject: [PATCH 14/18] [Intel GPU] Enable optim SR test (#3055) --- test/test_low_bit_optim.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_low_bit_optim.py b/test/test_low_bit_optim.py index b0edfc7fc5..0a356bf73d 100644 --- a/test/test_low_bit_optim.py +++ b/test/test_low_bit_optim.py @@ -419,8 +419,8 @@ def test_optim_cpu_offload_save_load(self): for p1, p2 in zip(model1.parameters(), model2.parameters()): torch.testing.assert_close(p2, p1) - def test_optim_bf16_stochastic_round_correctness(self): - device = "cuda" if torch.cuda.is_available() else "cpu" + @parametrize("device", _DEVICES) + def test_optim_bf16_stochastic_round_correctness(self, device): torch.manual_seed(2024) model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)) model1.to(device) From a99255a40198873f3cc47e8be18a18ae30d545f2 Mon Sep 17 00:00:00 2001 From: avi singhal Date: Fri, 5 Dec 2025 17:10:31 +0000 Subject: [PATCH 15/18] updated test with rebase changes --- .../mx_formats}/test_mxfp8_allgather.py | 19 +++++++------------ torchao/prototype/mx_formats/mx_tensor.py | 16 ++++++++-------- 2 files changed, 15 insertions(+), 20 deletions(-) rename {torchao/prototype/tests => test/prototype/mx_formats}/test_mxfp8_allgather.py (88%) diff --git a/torchao/prototype/tests/test_mxfp8_allgather.py b/test/prototype/mx_formats/test_mxfp8_allgather.py similarity index 88% rename from torchao/prototype/tests/test_mxfp8_allgather.py rename to test/prototype/mx_formats/test_mxfp8_allgather.py index 178195a823..43c31389c6 100644 --- a/torchao/prototype/tests/test_mxfp8_allgather.py +++ b/test/prototype/mx_formats/test_mxfp8_allgather.py @@ -1,9 +1,4 @@ -import pytest import torch - -if not torch.cuda.is_available() or torch.cuda.get_device_capability() != (10, 0): - pytest.skip("Test requires CUDA build on SM100", allow_module_level=True) - import torch.distributed as dist from torch.testing._internal.common_distributed import ( MultiProcessTestCase, @@ -23,7 +18,7 @@ def setUp(self) -> None: @property def world_size(self) -> int: - return 4 + return 2 @property def device(self) -> torch.device: @@ -64,9 +59,9 @@ def test_allgather(self): elem_dtype=torch.float8_e5m2, block_size=32, orig_dtype=torch.float32, - gemm_kernel_choice=None, - pack_fp6=None, + kernel_preference=None, act_quant_kwargs=None, + is_swizzled_scales=None, ) world_size = self.world_size @@ -82,9 +77,9 @@ def test_allgather(self): elem_dtype=torch.float8_e5m2, block_size=32, orig_dtype=torch.float32, - gemm_kernel_choice=None, - pack_fp6=None, + kernel_preference=None, act_quant_kwargs=None, + is_swizzled_scales=None, ) # Perform all_gather @@ -111,12 +106,12 @@ def test_allgather(self): # Verify scale matches golden exactly if not torch.equal( - gathered_mx._scale_e8m0.view(torch.uint8), + gathered_mx.scale.view(torch.uint8), golden_scale.view(torch.uint8), ): assert False, "scale mismatch" - assert gathered_mx._block_size == 32 + assert gathered_mx.block_size == 32 finally: dist.destroy_process_group() diff --git a/torchao/prototype/mx_formats/mx_tensor.py b/torchao/prototype/mx_formats/mx_tensor.py index b7bcb238d9..9bd1897074 100644 --- a/torchao/prototype/mx_formats/mx_tensor.py +++ b/torchao/prototype/mx_formats/mx_tensor.py @@ -869,7 +869,7 @@ def mx_all_gather(func, types, args, kwargs): ) gathered_scale = torch.ops._c10d_functional.all_gather_into_tensor.default( - mx_tensor._scale_e8m0.view( + mx_tensor.scale.view( torch.uint8 ), # The scale factors, Need to cast to uint8 as float8_e8m0fnu is not support for all gather. group_tag, @@ -884,11 +884,11 @@ def mx_all_gather(func, types, args, kwargs): gathered_qdata, gathered_scale, mx_tensor._elem_dtype, - mx_tensor._block_size, + mx_tensor.block_size, mx_tensor._orig_dtype, - mx_tensor._gemm_kernel_choice, - mx_tensor._pack_fp6, + mx_tensor.kernel_preference, mx_tensor.act_quant_kwargs, + mx_tensor._is_swizzled_scales, ) @@ -908,16 +908,16 @@ def mx_wait_tensor(func, types, args, kwargs): ) waited_scale = torch.ops._c10d_functional.wait_tensor.default( - mx_tensor._scale_e8m0, *args[1:], **kwargs + mx_tensor.scale, *args[1:], **kwargs ) return MXTensor( waited_qdata, waited_scale, mx_tensor._elem_dtype, - mx_tensor._block_size, + mx_tensor.block_size, mx_tensor._orig_dtype, - mx_tensor._gemm_kernel_choice, - mx_tensor._pack_fp6, + mx_tensor.kernel_preference, mx_tensor.act_quant_kwargs, + mx_tensor._is_swizzled_scales, ) From 248a403401a2110f37dff26fcbf3b67f6adc22ab Mon Sep 17 00:00:00 2001 From: avi singhal Date: Sat, 6 Dec 2025 02:14:42 +0000 Subject: [PATCH 16/18] added checks to run only on CUDA with compatibility >=9 --- test/prototype/mx_formats/test_mxfp8_allgather.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/test/prototype/mx_formats/test_mxfp8_allgather.py b/test/prototype/mx_formats/test_mxfp8_allgather.py index 43c31389c6..850e790820 100644 --- a/test/prototype/mx_formats/test_mxfp8_allgather.py +++ b/test/prototype/mx_formats/test_mxfp8_allgather.py @@ -1,3 +1,4 @@ +import pytest import torch import torch.distributed as dist from torch.testing._internal.common_distributed import ( @@ -9,6 +10,11 @@ from torchao.prototype.mx_formats.mx_tensor import MXTensor +if not torch.cuda.is_available() or torch.cuda.get_device_capability() < (9, 0): + pytest.skip( + "Test Requires CUDA and compute capability >= 9.0", allow_module_level=True + ) + @instantiate_parametrized_tests class MXFP8OnDeviceAllGatherTest(MultiProcessTestCase): From 08a03bac2c3b0281e6928ff216e6c7ddfe6be2bf Mon Sep 17 00:00:00 2001 From: avi singhal Date: Tue, 9 Dec 2025 04:09:18 +0000 Subject: [PATCH 17/18] updated test for H100 --- .../mx_formats/test_mxfp8_allgather.py | 214 ++++++++---------- .../mx_formats/test_mxfp8_allgather.sh | 12 + 2 files changed, 112 insertions(+), 114 deletions(-) create mode 100644 test/prototype/mx_formats/test_mxfp8_allgather.sh diff --git a/test/prototype/mx_formats/test_mxfp8_allgather.py b/test/prototype/mx_formats/test_mxfp8_allgather.py index 850e790820..d68d2e7f43 100644 --- a/test/prototype/mx_formats/test_mxfp8_allgather.py +++ b/test/prototype/mx_formats/test_mxfp8_allgather.py @@ -1,123 +1,109 @@ import pytest import torch import torch.distributed as dist -from torch.testing._internal.common_distributed import ( - MultiProcessTestCase, -) -from torch.testing._internal.common_utils import ( - instantiate_parametrized_tests, -) from torchao.prototype.mx_formats.mx_tensor import MXTensor +from torchao.utils import is_sm_at_least_90, torch_version_at_least -if not torch.cuda.is_available() or torch.cuda.get_device_capability() < (9, 0): - pytest.skip( - "Test Requires CUDA and compute capability >= 9.0", allow_module_level=True +if not torch_version_at_least("2.7.0"): + pytest.skip("Unsupported PyTorch version", allow_module_level=True) + + +def setup_distributed(): + dist.init_process_group("nccl") + # seed must be the same in all processes + torch.manual_seed(42) + local_rank = torch.distributed.get_rank() + torch.cuda.set_device(local_rank) + return local_rank + + +def _test_allgather(local_rank): + golden_qdata = ( + torch.randint(0, 256, (256, 512), dtype=torch.uint8) + .to(torch.float8_e5m2) + .to(local_rank) + ) + + # Random scale factors (typically float32 or uint8 for e8m0) + golden_scale = ( + torch.randint(0, 256, (256, 16), dtype=torch.uint8) + .view(torch.float8_e8m0fnu) + .to(local_rank) + ) + + # Create golden MXTensor + golden_mx = MXTensor( + golden_qdata, + golden_scale, + elem_dtype=torch.float8_e5m2, + block_size=32, + orig_dtype=torch.float32, + kernel_preference=None, + act_quant_kwargs=None, + is_swizzled_scales=None, + ) + + local_rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + + # Each rank gets its shard (split along dim 0) + shard_size = golden_qdata.shape[0] // world_size # 2 rows per rank + start_idx = local_rank * shard_size + end_idx = (local_rank + 1) * shard_size + + # Create local MXTensor from shard + local_mx = MXTensor( + golden_qdata[start_idx:end_idx].clone().to(local_rank), + golden_scale[start_idx:end_idx].clone().to(local_rank), + elem_dtype=torch.float8_e5m2, + block_size=32, + orig_dtype=torch.float32, + kernel_preference=None, + act_quant_kwargs=None, + is_swizzled_scales=None, + ) + + # Perform all_gather + gathered_mx = torch.ops._c10d_functional.all_gather_into_tensor.default( + local_mx, + world_size, + "0", + ) + gathered_mx = torch.ops._c10d_functional.wait_tensor.default(gathered_mx) + + # Verify type + assert isinstance(gathered_mx, MXTensor), ( + f"Expected MXTensor, got {type(gathered_mx)}" ) + # Verify shape + assert gathered_mx.shape == golden_mx.shape, ( + f"Shape mismatch: {gathered_mx.shape} vs {golden_mx.shape}" + ) + + # Verify qdata matches golden exactly + if not torch.equal(gathered_mx.qdata, golden_qdata): + assert False, "qdata mismatch" + + # Verify scale matches golden exactly + if not torch.equal( + gathered_mx.scale.view(torch.uint8), + golden_scale.view(torch.uint8), + ): + assert False, "scale mismatch" + + assert gathered_mx.block_size == 32 + + +if __name__ == "__main__": + local_rank = setup_distributed() + + assert is_sm_at_least_90() == True, "SM must be > 9.0" + + try: + _test_allgather(local_rank) + except Exception as e: + raise e -@instantiate_parametrized_tests -class MXFP8OnDeviceAllGatherTest(MultiProcessTestCase): - def setUp(self) -> None: - super().setUp() - self._spawn_processes() - - @property - def world_size(self) -> int: - return 2 - - @property - def device(self) -> torch.device: - return torch.device(f"cuda:{self.rank}") - - def _init_process(self): - torch.cuda.set_device(self.device) - store = dist.FileStore(self.file_name, self.world_size) - dist.init_process_group( - backend="nccl", - world_size=self.world_size, - rank=self.rank, - store=store, - ) - torch.manual_seed(42) - - def test_allgather(self): - self._init_process() - try: - torch.manual_seed(42) - golden_qdata = ( - torch.randint(0, 256, (256, 512), dtype=torch.uint8) - .to(torch.float8_e5m2) - .to(self.device) - ) - - # Random scale factors (typically float32 or uint8 for e8m0) - golden_scale = ( - torch.randint(0, 256, (256, 16), dtype=torch.uint8) - .view(torch.float8_e8m0fnu) - .to(self.device) - ) - - # Create golden MXTensor - golden_mx = MXTensor( - golden_qdata, - golden_scale, - elem_dtype=torch.float8_e5m2, - block_size=32, - orig_dtype=torch.float32, - kernel_preference=None, - act_quant_kwargs=None, - is_swizzled_scales=None, - ) - - world_size = self.world_size - # Each rank gets its shard (split along dim 0) - shard_size = golden_qdata.shape[0] // world_size # 2 rows per rank - start_idx = self.rank * shard_size - end_idx = (self.rank + 1) * shard_size - - # Create local MXTensor from shard - local_mx = MXTensor( - golden_qdata[start_idx:end_idx].clone().to(self.device), - golden_scale[start_idx:end_idx].clone().to(self.device), - elem_dtype=torch.float8_e5m2, - block_size=32, - orig_dtype=torch.float32, - kernel_preference=None, - act_quant_kwargs=None, - is_swizzled_scales=None, - ) - - # Perform all_gather - gathered_mx = torch.ops._c10d_functional.all_gather_into_tensor.default( - local_mx, - world_size, - "0", - ) - gathered_mx = torch.ops._c10d_functional.wait_tensor.default(gathered_mx) - - # Verify type - assert isinstance(gathered_mx, MXTensor), ( - f"Expected MXTensor, got {type(gathered_mx)}" - ) - - # Verify shape - assert gathered_mx.shape == golden_mx.shape, ( - f"Shape mismatch: {gathered_mx.shape} vs {golden_mx.shape}" - ) - - # Verify qdata matches golden exactly - if not torch.equal(gathered_mx.qdata, golden_qdata): - assert False, "qdata mismatch" - - # Verify scale matches golden exactly - if not torch.equal( - gathered_mx.scale.view(torch.uint8), - golden_scale.view(torch.uint8), - ): - assert False, "scale mismatch" - - assert gathered_mx.block_size == 32 - - finally: - dist.destroy_process_group() + torch.distributed.destroy_process_group() diff --git a/test/prototype/mx_formats/test_mxfp8_allgather.sh b/test/prototype/mx_formats/test_mxfp8_allgather.sh new file mode 100644 index 0000000000..180375af40 --- /dev/null +++ b/test/prototype/mx_formats/test_mxfp8_allgather.sh @@ -0,0 +1,12 @@ +#!/bin/bash + +# terminate script on first error +set -e + +if python -c 'import torch;print(torch.cuda.is_available())' | grep -q "False"; then + echo "Skipping test_dtensor.sh because no CUDA devices are available." + exit +fi + +# integration tests for TP/SP +NCCL_DEBUG=WARN torchrun --nproc_per_node 2 test/prototype/mx_formats/test_mxfp8_allgather.py \ No newline at end of file From eab336bf8c6bea23ef6ed31c3cdc8553b816ce2f Mon Sep 17 00:00:00 2001 From: avi singhal Date: Tue, 9 Dec 2025 04:13:20 +0000 Subject: [PATCH 18/18] added test to workflow --- .github/workflows/4xH100_tests.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/4xH100_tests.yml b/.github/workflows/4xH100_tests.yml index 72faeebebb..4ab2b98744 100644 --- a/.github/workflows/4xH100_tests.yml +++ b/.github/workflows/4xH100_tests.yml @@ -47,3 +47,4 @@ jobs: pip install . --no-build-isolation ./test/float8/test_everything_multi_gpu.sh ./test/prototype/mx_formats/test_mx_dtensor.sh + ./test/prototype/mx_formats/test_mxfp8_allgather.sh