Skip to content

Conversation

@timmoon10
Copy link
Collaborator

@timmoon10 timmoon10 commented Dec 6, 2025

Description

All of the supported block-scaled tensor formats (MXFP8, NVFP4, DSv3 FP8) have two ways of ordering their scaling factors:

  • "Compact" ordering for quantization, dequantization, and communication
  • "Swizzled" ordering for GEMM

The core infrastructure handles this in an ad hoc way, blindly assuming that the "right" scale ordering is used for the different operations. The PyTorch infrastructure only supports MXFP8 and NVFP4 scales are in compact order, although DSv3 FP8 does have awareness of "compact" and "GEMM-ready" formats. This situation makes it hard to implement fused kernels that can bypass the swizzle kernel.

This PR adds a with_gemm_swizzled_scales field in the C++ tensor class so that the core infrastructure can distinguish between the different scale orderings. It also adds this field in the PyTorch quantized tensor classes, and exposes a optimize_for_gemm option in the quantizer so that we can create tensors that do not need communication or checkpointing. Finally, it rips out all the DSv3 FP8 infrastructure for the compact format, which is no longer necessary.

Progress

  • MXFP8
  • DSv3 FP8
  • NVFP4
  • Add option to pre-swizzle weights
  • Pre-swizzle activations
  • Fused MXFP8 quantize + swizzle

Closes #2446.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Support GEMM swizzled scales in C++ tensor class
  • Support GEMM swizzled scales in PyTorch quantized tensor classes
  • Support optimize_for_gemm option in PyTorch quantizer
  • Expose PyTorch function to swizzle scales
  • Support MXFP8 quantization with pre-swizzled scales
  • Enable fused quantize+swizzle kernels in linear module and related
  • Remove DSv3 FP8 compact data format. It was used to avoid all-gather interleaving, which we can now fix with the swap-first-dims kernel.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10 timmoon10 force-pushed the tmoon/pre-swizzled-scales branch from d274220 to 52ce3a4 Compare December 6, 2025 02:53
@timmoon10 timmoon10 added enhancement New feature or request refactor labels Dec 6, 2025
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10 timmoon10 force-pushed the tmoon/pre-swizzled-scales branch from 4925b63 to 1de4b5e Compare December 10, 2025 07:19
@timmoon10

This comment was marked as outdated.

@timmoon10

This comment was marked as outdated.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
greptile-apps[bot]

This comment was marked as resolved.

timmoon10 and others added 2 commits December 12, 2025 08:51
@timmoon10

This comment was marked as outdated.

greptile-apps[bot]

This comment was marked as resolved.

timmoon10 and others added 7 commits December 12, 2025 23:48
greptile-apps[bot]

This comment was marked as resolved.

greptile-apps[bot]

This comment was marked as resolved.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10

This comment was marked as outdated.

greptile-apps[bot]

This comment was marked as outdated.

@timmoon10 timmoon10 added performance Performance issues MoE labels Dec 15, 2025
Signed-off-by: Tim Moon <tmoon@nvidia.com>
greptile-apps[bot]

This comment was marked as outdated.

Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
greptile-apps[bot]

This comment was marked as resolved.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10
Copy link
Collaborator Author

/te-ci L1

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (5)

  1. transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py, line 273-277 (link)

    syntax: return statement should return a tuple, not a dictionary

  2. transformer_engine/pytorch/csrc/common.h, line 389-390 (link)

    style: Float8E8M0 maps to kByte - this might cause confusion since kByte is used for both DType::kByte and DType::kFloat8E8M0. Is there a specific reason Float8E8M0 maps to kByte instead of having its own PyTorch scalar type?

  3. transformer_engine/common/include/transformer_engine/swizzle.h, line 7-8 (link)

    syntax: Header comment incorrectly states this is 'cast.h' and describes casting functions, but this is 'swizzle.h' for swizzle operations

  4. transformer_engine/common/cast/dispatch/quantize.cuh, line 150-157 (link)

    logic: Forward quantization always uses GEMM_READY format regardless of tensor's with_gemm_swizzled_scales field, while backward quantization respects it (lines 294-303). This inconsistency could lead to scale format mismatches. Should forward quantization also check output_tensor->with_gemm_swizzled_scales like the backward path does?

  5. transformer_engine/pytorch/distributed.py, line 1082-1084 (link)

    logic: Bug: quantizer(out) is called when quantizer is None. This will cause a TypeError: 'NoneType' object is not callable at runtime.

65 files reviewed, 5 comments

Edit Code Review Agent Settings | Greptile

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request MoE performance Performance issues refactor

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Support MXFP8/NVFP4 tensors with pre-swizzled scales

1 participant