Skip to content

[Common] Use specialized unfused MXFP8 cast kernels by default#2958

Open
Oleg-Goncharov wants to merge 7 commits intoNVIDIA:mainfrom
Oleg-Goncharov:pr_fast_default_mxfp8_kernels
Open

[Common] Use specialized unfused MXFP8 cast kernels by default#2958
Oleg-Goncharov wants to merge 7 commits intoNVIDIA:mainfrom
Oleg-Goncharov:pr_fast_default_mxfp8_kernels

Conversation

@Oleg-Goncharov
Copy link
Copy Markdown
Collaborator

Description

This PR enables the fast unfused MXFP8 cast kernels by default.

Previously, these kernels were gated behind an environment variable and therefore were not used unless explicitly enabled. This change makes the specialized cast-only path the default behavior.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Removed environment variable

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 5, 2026

Greptile Summary

This PR makes the specialized unfused MXFP8 cast kernels the default code path by removing the ENABLE_CAST_ONLY environment-variable gate and replacing it with an always-true hasSpec(). A new runtime guard, scaling_type_has_specialized_support, restricts the fast path to ROWWISE (with a column-alignment check) and BIDIMENSIONAL scaling types, while also cleaning up the now-unreachable COLWISE fallback branch from the inner switch.

  • specialized/quantize_mxfp8.cuh: Removes is_cast_only_enabled() and makes all four hasSpec template specializations unconditionally return true.
  • quantize_mxfp8.cuh: Introduces scaling_type_has_specialized_support to guard the specialized dispatch; uses cols % 128 == 0 as the ROWWISE alignment check and removes the now-dead ScalingType::COLWISE case from the switch.

Confidence Score: 5/5

Safe to merge — the specialized kernels are now always enabled for supported type combinations, and correctness is preserved by the scaling_type_has_specialized_support guard and the col-alignment check.

The change is narrow and well-contained: env-var removal plus a runtime dispatch guard. The alignment check at cols % 128 == 0 is conservative rather than unsafe — the actual kernel minimum is cols % 32 == 0 — so no correctness regression is introduced.

The alignment constant in quantize_mxfp8.cuh (128 vs the kernel's actual 32-element chunk size) is worth revisiting to avoid excluding valid shapes from the fast path.

Important Files Changed

Filename Overview
transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh Adds scaling_type_has_specialized_support guard before the specialized-kernel dispatch, removes the dead COLWISE fallback branch from the inner switch, and routes ROWWISE shapes only when cols % 128 == 0 (which is over-conservative — cols % 32 == 0 is the true kernel requirement).
transformer_engine/common/cast/mxfp8/specialized/quantize_mxfp8.cuh Removes the is_cast_only_enabled() env-var helper and replaces all four hasSpec specializations with unconditional return true, enabling the specialized kernels by default without any other logic changes.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[quantize called] --> B{hasSpec AND\nnot GEMM-swizzled?}
    B -- No --> G[Generic kernel path]
    B -- Yes --> C{scaling_type_has_\nspecialized_support?}
    C -- No\n(COLWISE or partial row) --> G
    C -- Yes --> D{ScalingType?}
    D -- ROWWISE\ncols%128==0 --> E[specialized rowwise\ncast-only kernel]
    D -- BIDIMENSIONAL --> F[specialized bidimensional\ncast-only kernel with TMA]
    D -- default --> H[NVTE_ERROR]
    E --> I[return]
    F --> I
Loading

Reviews (4): Last reviewed commit: "Merge branch 'main' into pr_fast_default..." | Re-trigger Greptile

ksivaman
ksivaman previously approved these changes May 5, 2026
Copy link
Copy Markdown
Member

@ksivaman ksivaman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
@Oleg-Goncharov
Copy link
Copy Markdown
Collaborator Author

/te-ci

ksivaman
ksivaman previously approved these changes May 5, 2026
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
@Oleg-Goncharov
Copy link
Copy Markdown
Collaborator Author

/te-ci

Oleg-Goncharov and others added 2 commits May 7, 2026 14:25
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
@Oleg-Goncharov
Copy link
Copy Markdown
Collaborator Author

/te-ci

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants