Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
c7c1a76
Implemented the kernel with split dbias
Oleg-Goncharov Feb 11, 2026
7abbc7b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 11, 2026
f820b21
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 12, 2026
0c05632
Relaxed constraints on the last dimension
Oleg-Goncharov Feb 13, 2026
4a85dea
Added notes on group tensor restrictions into documentation
Oleg-Goncharov Feb 13, 2026
aedd53d
Fixes per the review
Oleg-Goncharov Feb 27, 2026
38288b1
Fixed pointer
Oleg-Goncharov Feb 27, 2026
ce3a137
More fixes
Oleg-Goncharov Feb 27, 2026
bddd804
Fixed kernel grid size
Oleg-Goncharov Mar 2, 2026
a894d1a
Merge branch 'main' into pr_split_dbias
Oleg-Goncharov Mar 2, 2026
87352bd
Enabled persistency with WorkID Query feature
Oleg-Goncharov Mar 4, 2026
e23f553
Added a struct with tunable parameters
Oleg-Goncharov Mar 4, 2026
d185299
Added persistency with static scheduling
Oleg-Goncharov Mar 4, 2026
5e15f57
Fixed test cases
Oleg-Goncharov Mar 4, 2026
98e9558
Ready for benchmarking
Oleg-Goncharov Mar 4, 2026
ab816cb
Fixed out-of-boundary error
Oleg-Goncharov Mar 4, 2026
8a429ad
Tuned kernel parameters
Oleg-Goncharov Mar 4, 2026
ab3f911
Refactoring
Oleg-Goncharov Mar 4, 2026
92720ac
Refactoring 2
Oleg-Goncharov Mar 4, 2026
46d9811
Refactoring 3
Oleg-Goncharov Mar 4, 2026
7172400
Removed the dynamic (WorkID Query) persistency
Oleg-Goncharov Mar 5, 2026
4344627
Ready for PR
Oleg-Goncharov Mar 5, 2026
ede33b4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 5, 2026
219e925
Merge branch 'main' into pr_persistent_grouped_mxfp8_kernel
Oleg-Goncharov Mar 5, 2026
325181b
Fixes per the review
Oleg-Goncharov Mar 6, 2026
04609b1
Merge branch 'main' into pr_persistent_grouped_mxfp8_kernel
Oleg-Goncharov Mar 6, 2026
5815335
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 6, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 88 additions & 74 deletions tests/cpp/operator/test_cast_mxfp8_grouped.cu

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions transformer_engine/common/activation/gelu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ void nvte_group_dgelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor inpu
NVTEGroupedTensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_dgelu);
using namespace transformer_engine;
NVTETensor dbias = nullptr;
NVTEGroupedTensor dbias = nullptr;
NVTETensor workspace = nullptr;

constexpr bool IS_DBIAS = false;
Expand All @@ -57,7 +57,7 @@ void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activati

void nvte_group_quantize_dbias_dgelu(const NVTEGroupedTensor input,
const NVTEGroupedTensor activation_input,
NVTEGroupedTensor output, NVTETensor dbias,
NVTEGroupedTensor output, NVTEGroupedTensor dbias,
NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_quantize_dbias_dgelu);
using namespace transformer_engine;
Expand Down Expand Up @@ -110,7 +110,7 @@ void nvte_group_dqgelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor inp
NVTEGroupedTensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_dqgelu);
using namespace transformer_engine;
NVTETensor dbias = nullptr;
NVTEGroupedTensor dbias = nullptr;
NVTETensor workspace = nullptr;

constexpr bool IS_DBIAS = false;
Expand All @@ -135,7 +135,7 @@ void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activat

void nvte_group_quantize_dbias_dqgelu(const NVTEGroupedTensor input,
const NVTEGroupedTensor activation_input,
NVTEGroupedTensor output, NVTETensor dbias,
NVTEGroupedTensor output, NVTEGroupedTensor dbias,
NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_quantize_dbias_dqgelu);
using namespace transformer_engine;
Expand Down
8 changes: 4 additions & 4 deletions transformer_engine/common/activation/relu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ void nvte_group_drelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor inpu
NVTEGroupedTensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_drelu);
using namespace transformer_engine;
NVTETensor dbias = nullptr;
NVTEGroupedTensor dbias = nullptr;
NVTETensor workspace = nullptr;

constexpr bool IS_DBIAS = false;
Expand All @@ -57,7 +57,7 @@ void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activati

void nvte_group_quantize_dbias_drelu(const NVTEGroupedTensor input,
const NVTEGroupedTensor activation_input,
NVTEGroupedTensor output, NVTETensor dbias,
NVTEGroupedTensor output, NVTEGroupedTensor dbias,
NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_quantize_dbias_drelu);
using namespace transformer_engine;
Expand Down Expand Up @@ -110,7 +110,7 @@ void nvte_group_dsrelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor inp
NVTEGroupedTensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_dsrelu);
using namespace transformer_engine;
NVTETensor dbias = nullptr;
NVTEGroupedTensor dbias = nullptr;
NVTETensor workspace = nullptr;

constexpr bool IS_DBIAS = false;
Expand All @@ -135,7 +135,7 @@ void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activat

void nvte_group_quantize_dbias_dsrelu(const NVTEGroupedTensor input,
const NVTEGroupedTensor activation_input,
NVTEGroupedTensor output, NVTETensor dbias,
NVTEGroupedTensor output, NVTEGroupedTensor dbias,
NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_quantize_dbias_dsrelu);
using namespace transformer_engine;
Expand Down
4 changes: 2 additions & 2 deletions transformer_engine/common/activation/swiglu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ void nvte_group_dsilu(const NVTEGroupedTensor grad, const NVTEGroupedTensor inpu
NVTEGroupedTensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_dsilu);
using namespace transformer_engine;
NVTETensor dbias = nullptr;
NVTEGroupedTensor dbias = nullptr;
NVTETensor workspace = nullptr;

constexpr bool IS_DBIAS = false;
Expand All @@ -57,7 +57,7 @@ void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activati

void nvte_group_quantize_dbias_dsilu(const NVTEGroupedTensor input,
const NVTEGroupedTensor activation_input,
NVTEGroupedTensor output, NVTETensor dbias,
NVTEGroupedTensor output, NVTEGroupedTensor dbias,
NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_quantize_dbias_dsilu);
using namespace transformer_engine;
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/common/cast/cast.cu
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor d
}

void nvte_group_quantize_dbias(const NVTEGroupedTensor input, NVTEGroupedTensor output,
NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) {
NVTEGroupedTensor dbias, NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_quantize_dbias);
using namespace transformer_engine;

Expand Down
85 changes: 85 additions & 0 deletions transformer_engine/common/cast/core/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,14 @@
namespace transformer_engine {
namespace dispatch {
namespace common {

enum ShapeRepresentation {
SAME_BOTH_DIMS = 0,
VARYING_FIRST_DIM = 1,
VARYING_LAST_DIM = 2,
VARYING_BOTH_DIMS = 3
};

inline bool full_tile_1D_tensor(const Tensor *const t, const size_t elems_per_block) {
const size_t N = product(t->data.shape);
const bool isFullTile = (N % elems_per_block == 0);
Expand Down Expand Up @@ -78,6 +86,57 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK)
}
stg_vec.store_to(thread_out_base);
}

template <int nvec, typename OType>
__global__ void __launch_bounds__(THREADS_PER_BLOCK)
group_reduce_dbias_kernel(const ShapeRepresentation shape_rep, const size_t num_tensors,
const size_t first_logical_dim, const size_t last_logical_dim,
const int64_t *const offsets_ptr, const int64_t *const first_dims_ptr,
const int64_t *const last_dims_ptr, OType *const dbias_output,
const float *dbias_partial, const size_t chunk_dim_Y) {
using ComputeVec = Vec<float, nvec>;
using OutputVec = Vec<OType, nvec>;

const size_t tensor_id = blockIdx.y;
const size_t tensor_rows = (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS)
? (first_logical_dim / num_tensors)
: static_cast<size_t>(first_dims_ptr[tensor_id]);

const size_t rows = tensor_rows / chunk_dim_Y;
const size_t cols = last_logical_dim;

const size_t dbias_in_offset_Y =
(shape_rep == ShapeRepresentation::SAME_BOTH_DIMS)
? (tensor_id * (tensor_rows / chunk_dim_Y))
: (static_cast<size_t>(offsets_ptr[tensor_id]) / cols / chunk_dim_Y);

const size_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;

if (thread_id * nvec >= cols) {
return;
}

const float *const thread_in_base = dbias_partial + dbias_in_offset_Y * cols + thread_id * nvec;
OType *const thread_out_base = dbias_output + tensor_id * cols + thread_id * nvec;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Output stride assumes uniform cols across all tensors

The output write offset is computed as:

OType *const thread_out_base = dbias_output + tensor_id * cols + thread_id * nvec;

where cols is last_logical_dim — a single value shared across all tensors in the group. This is correct for SAME_BOTH_DIMS and VARYING_FIRST_DIM (where all tensors share the same last dimension), but the kernel receives shape_rep as a parameter and does not enforce that restriction.

For VARYING_LAST_DIM or VARYING_BOTH_DIMS where per-tensor cols differ, the fixed tensor_id * cols stride would compute wrong output offsets. Currently, tests skip dbias validation for these cases, but the kernel would produce incorrect results if actually called with varying-last-dim tensors.

Consider adding a device-side assertion to enforce the precondition:

Suggested change
OType *const thread_out_base = dbias_output + tensor_id * cols + thread_id * nvec;
if (shape_rep != ShapeRepresentation::SAME_BOTH_DIMS && shape_rep != ShapeRepresentation::VARYING_FIRST_DIM) {
NVTE_DEVICE_ERROR("group_reduce_dbias_kernel requires uniform last dimensions across tensors");
}


ComputeVec ldg_vec;
ComputeVec acc_vec;
acc_vec.clear();
for (int i = 0; i < rows; ++i) {
ldg_vec.load_from(thread_in_base + i * cols);
#pragma unroll
for (int e = 0; e < nvec; ++e) {
acc_vec.data.elt[e] += ldg_vec.data.elt[e];
}
}

OutputVec stg_vec;
#pragma unroll
for (int e = 0; e < nvec; ++e) {
stg_vec.data.elt[e] = static_cast<OType>(acc_vec.data.elt[e]);
}
stg_vec.store_to(thread_out_base);
}
} // namespace kernel

template <typename IType>
Expand All @@ -96,6 +155,32 @@ void reduce_dbias(const float *workspace_ptr, Tensor *dbias, const size_t rows,
NVTE_CHECK_CUDA(cudaGetLastError());
}

template <typename IType>
void grouped_reduce_dbias(const ShapeRepresentation shape_rep, const size_t num_tensors,
const size_t first_logical_dim, const size_t last_logical_dim,
const int64_t *const data_tensor_offsets_ptr,
const int64_t *const data_tensor_first_dims_ptr,
const int64_t *const data_tensor_last_dims_ptr, GroupedTensor *dbias,
const float *workspace_ptr, const size_t chunk_dim_Y,
cudaStream_t stream) {
using namespace kernel;
constexpr size_t reduce_dbias_store_bytes = 8; // stg.64
constexpr size_t reduce_dbias_nvec = reduce_dbias_store_bytes / sizeof(IType);

NVTE_CHECK(last_logical_dim % reduce_dbias_nvec == 0, "Unsupported shape.");

const size_t blocks_X = DIVUP(last_logical_dim, THREADS_PER_BLOCK * reduce_dbias_nvec);
const size_t blocks_Y = num_tensors;
const dim3 grid(blocks_X, blocks_Y);

group_reduce_dbias_kernel<reduce_dbias_nvec, IType><<<grid, THREADS_PER_BLOCK, 0, stream>>>(
shape_rep, num_tensors, first_logical_dim, last_logical_dim, data_tensor_offsets_ptr,
data_tensor_first_dims_ptr, data_tensor_last_dims_ptr,
reinterpret_cast<IType *>(dbias->data.dptr), workspace_ptr, chunk_dim_Y);

NVTE_CHECK_CUDA(cudaGetLastError());
}

} // namespace common
} // namespace dispatch
} // namespace transformer_engine
Expand Down
11 changes: 6 additions & 5 deletions transformer_engine/common/cast/dispatch/quantize.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -382,13 +382,13 @@ void group_quantize_fwd_helper(const NVTEGroupedTensor input, NVTEGroupedTensor
NVTEScalingMode scaling_mode = nvte_grouped_tensor_scaling_mode(output);

const NVTEGroupedTensor activation = nullptr;
NVTETensor dbias = nullptr;
NVTEGroupedTensor dbias = nullptr;
NVTETensor workspace = nullptr;

const GroupedTensor *input_tensor = convertNVTEGroupedTensorCheck(input);
GroupedTensor *output_tensor = convertNVTEGroupedTensorCheck(output);
const GroupedTensor *activations_tensor = convertNVTEGroupedTensor(activation);
Tensor *dbias_tensor = convertNVTETensor(dbias);
GroupedTensor *dbias_tensor = convertNVTEGroupedTensor(dbias);
Tensor *workspace_tensor = convertNVTETensor(workspace);

// Quantization config
Expand Down Expand Up @@ -419,16 +419,17 @@ void group_quantize_fwd_helper(const NVTEGroupedTensor input, NVTEGroupedTensor

template <bool IS_DBIAS, bool IS_DACT, typename ParamOP, float (*OP)(float, const ParamOP &)>
void group_quantize_bwd_helper(const NVTEGroupedTensor grad, const NVTEGroupedTensor input,
NVTEGroupedTensor output, NVTETensor dbias, NVTETensor workspace,
const NVTEQuantizationConfig quant_config, cudaStream_t stream) {
NVTEGroupedTensor output, NVTEGroupedTensor dbias,
NVTETensor workspace, const NVTEQuantizationConfig quant_config,
cudaStream_t stream) {
using namespace detail;

NVTEScalingMode scaling_mode = nvte_grouped_tensor_scaling_mode(output);

const GroupedTensor *grad_tensor = convertNVTEGroupedTensorCheck(grad);
const GroupedTensor *input_tensor = convertNVTEGroupedTensor(input);
GroupedTensor *output_tensor = convertNVTEGroupedTensorCheck(output);
Tensor *dbias_tensor = convertNVTETensor(dbias);
GroupedTensor *dbias_tensor = convertNVTEGroupedTensor(dbias);
Tensor *workspace_tensor = convertNVTETensor(workspace);

// Quantization config
Expand Down
Loading
Loading