Skip to content

[Common] Persistent Grouped MXFP8 quantization kernel#2738

Open
Oleg-Goncharov wants to merge 27 commits intoNVIDIA:mainfrom
Oleg-Goncharov:pr_persistent_grouped_mxfp8_kernel
Open

[Common] Persistent Grouped MXFP8 quantization kernel#2738
Oleg-Goncharov wants to merge 27 commits intoNVIDIA:mainfrom
Oleg-Goncharov:pr_persistent_grouped_mxfp8_kernel

Conversation

@Oleg-Goncharov
Copy link
Collaborator

@Oleg-Goncharov Oleg-Goncharov commented Mar 5, 2026

Description

This PR adds a persistent grouped MXFP8 quantization kernel with static scheduling.
It is built on top of the PR#2674 [Common] MOE Split dBias

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Added persistent kernel
  • Added TunableConfig structure to tune performance

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Oleg-Goncharov and others added 22 commits February 27, 2026 15:53
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
for more information, see https://pre-commit.ci

Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
for more information, see https://pre-commit.ci

Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
@Oleg-Goncharov Oleg-Goncharov added enhancement New feature or request MoE labels Mar 5, 2026
@Oleg-Goncharov Oleg-Goncharov requested a review from ptrendx March 5, 2026 16:18
@Oleg-Goncharov
Copy link
Collaborator Author

/te-ci

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 5, 2026

Greptile Summary

This PR adds a persistent grouped MXFP8 quantization kernel with static scheduling and upgrades the dbias output from a single flat tensor to a per-group NVTEGroupedTensor. The persistent kernel introduces a TunableConfig struct, a grid-stride work scheduler that maps a compact physical CTA grid over a virtual work grid, and a ping-pong double-buffer pipeline for TMA loads and stores.

Key changes:

  • TunableConfig centralises tunable constants (CHUNK_DIM_Y/X, THREADS_PER_CHUNK, PREFETCH_STAGES, PERSISTENT, STATIC_PERSISTENT_BLOCKS_PER_SM); the existing CHUNK_DIM_X/Y constants are now aliases.
  • The main kernel gains work_blocks_X/Y parameters and a while (!job_finished) outer loop; job decoding is split into decode_job / decode_block / is_job_valid device helpers.
  • Colwise and rowwise processing are extracted into process_colwise_stage / process_rowwise_stage device functions, reducing the kernel body significantly.
  • grouped_reduce_dbias (new host function + group_reduce_dbias_kernel) replaces reduce_dbias, writing one output row per tensor into a [num_tensors, cols] layout.
  • ShapeRepresentation enum is moved to common.cuh so both the main kernel and the new reduction kernel share the same type.
  • All nvte_group_quantize_dbias* public API signatures change NVTETensor dbias → NVTEGroupedTensor dbias.
  • Tests are updated to exercise the new grouped dbias layout, run the reference over all tensors uniformly, and skip non-16-byte-aligned last dimensions.

Two issues identified:

  1. The is_job_valid function contains a mathematically-impossible condition (block_offset_X_in_tensor >= job.cols) due to modulo arithmetic that should be removed for clarity.
  2. The group_reduce_dbias_kernel assumes all tensors in a group share the same last dimension (last_logical_dim), but does not enforce this precondition. The kernel would produce incorrect results if called with VARYING_LAST_DIMS or VARYING_BOTH_DIMS shapes. Tests currently skip dbias validation for these cases, preventing the bug from manifesting, but adding a precondition check or documentation is recommended.

Confidence Score: 4/5

  • The PR is functionally correct for the common paths (single-tensor and varying-first-dim cases) but has one low-severity issue in the persistent kernel and one latent bug in the dbias reduction path.
  • The redundant condition in is_job_valid is a clarity issue that doesn't affect correctness. The uniform-cols assumption in group_reduce_dbias_kernel is a real bug that could corrupt dbias output for varying-last-dim groups, but tests currently skip these cases, so no failures manifest. Both issues are addressable with minimal code changes (remove dead condition, add precondition check). The overall refactoring is well-structured and the job decoding, barrier management, and grouped dbias reduction logic is correct.
  • transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh (remove dead condition) and transformer_engine/common/cast/core/common.cuh (add precondition check for varying-last-dim support).

Last reviewed commit: 5815335

Comment on lines 1091 to +1097
} else {
NVTE_CHECK(num_tensors < MAX_SUPPORTED_TENSOR_DESCRIPTORS,
NVTE_CHECK(num_tensors <= MAX_SUPPORTED_TENSOR_DESCRIPTORS,
"Number of tensors in a group is larger than "
"the MAX number of supported descriptors (64).");
// Only full tiles supported
NVTE_CHECK(last_logical_dim % CHUNK_DIM_X == 0,
"Last dimension of a grouped tensor should be divisible by 128.");
blocks_Y = 1;
blocks_X = DIVUP(elts_total, CHUNK_DIM_Y * CHUNK_DIM_X);
work_blocks_Y = 1;
work_blocks_X = DIVUP(elts_total, CHUNK_DIM_Y * CHUNK_DIM_X);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing column-alignment check for non-single-tensor grouped tensors

The original code included an NVTE_CHECK that enforced last_logical_dim % CHUNK_DIM_X == 0 for the non-single-tensor path (VARYING_LAST_DIM, VARYING_BOTH_DIMS). This check was removed in this PR, but the kernel still requires this alignment for correctness in the non-single-tensor path.

The unit tests themselves still skip when this condition is not met:

if (!is_single_tensor && (last_dims[t] % CHUNK_DIM_X != 0)) {
    GTEST_SKIP();
}

Without the runtime check, callers can pass non-128-aligned last dimensions for non-single-tensor groups and receive silently wrong results. For example, with cols = 160:

  • blocks_X_num_in_current_tensor = DIVUP(160, 128) = 2
  • decode_block maps block_id = 1 to block_offset_X = 128, issuing a TMA load at column offset 128 in a 160-wide tensor
  • Meanwhile is_job_valid computes the flat-space element offset as 16384, which maps to (row=102, col=64) — these coordinates conflict with what decode_block produces, leading to incorrect quantization

The check should be restored:

  } else {
    NVTE_CHECK(num_tensors <= MAX_SUPPORTED_TENSOR_DESCRIPTORS,
               "Number of tensors in a group is larger than "
               "the MAX number of supported descriptors (64).");
    NVTE_CHECK(last_logical_dim % CHUNK_DIM_X == 0,
               "Last dimension of a grouped tensor must be divisible by 128.");
    work_blocks_Y = 1;
    work_blocks_X = DIVUP(elts_total, CHUNK_DIM_Y * CHUNK_DIM_X);
  }

Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
@Oleg-Goncharov Oleg-Goncharov force-pushed the pr_persistent_grouped_mxfp8_kernel branch from 924ff91 to 325181b Compare March 6, 2026 10:39
Comment on lines +211 to +212
if (block_offset_Y_in_tensor >= job.rows || block_offset_X_in_tensor >= job.cols) {
return false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Redundant block_offset_X_in_tensor >= job.cols condition is always false

block_offset_X_in_tensor is computed as tensor_offset_from_start % job.cols (line 210). By the definition of the modulo operator, this result is always in [0, job.cols - 1], so the condition block_offset_X_in_tensor >= job.cols is mathematically impossible to be true.

The actual guard that matters is block_offset_Y_in_tensor >= job.rows. The dead half of the condition silently provides no protection against out-of-bounds blocks.

Suggested change
if (block_offset_Y_in_tensor >= job.rows || block_offset_X_in_tensor >= job.cols) {
return false;
if (block_offset_Y_in_tensor >= job.rows) {
return false;
}

}

const float *const thread_in_base = dbias_partial + dbias_in_offset_Y * cols + thread_id * nvec;
OType *const thread_out_base = dbias_output + tensor_id * cols + thread_id * nvec;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Output stride assumes uniform cols across all tensors

The output write offset is computed as:

OType *const thread_out_base = dbias_output + tensor_id * cols + thread_id * nvec;

where cols is last_logical_dim — a single value shared across all tensors in the group. This is correct for SAME_BOTH_DIMS and VARYING_FIRST_DIM (where all tensors share the same last dimension), but the kernel receives shape_rep as a parameter and does not enforce that restriction.

For VARYING_LAST_DIM or VARYING_BOTH_DIMS where per-tensor cols differ, the fixed tensor_id * cols stride would compute wrong output offsets. Currently, tests skip dbias validation for these cases, but the kernel would produce incorrect results if actually called with varying-last-dim tensors.

Consider adding a device-side assertion to enforce the precondition:

Suggested change
OType *const thread_out_base = dbias_output + tensor_id * cols + thread_id * nvec;
if (shape_rep != ShapeRepresentation::SAME_BOTH_DIMS && shape_rep != ShapeRepresentation::VARYING_FIRST_DIM) {
NVTE_DEVICE_ERROR("group_reduce_dbias_kernel requires uniform last dimensions across tensors");
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request MoE

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant