Skip to content

Commit 957ef98

Browse files
Fixes per the review
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
1 parent dca9759 commit 957ef98

2 files changed

Lines changed: 10 additions & 12 deletions

File tree

tests/cpp/operator/test_cast_mxfp8_grouped.cu

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -346,16 +346,16 @@ void performTest(const ProcessingMethod processing_method,
346346
const size_t last_dims_size = num_tensors * sizeof(size_t);
347347
const size_t offsets_size = (num_tensors + 1) * sizeof(size_t);
348348

349-
InputType* grad_data_d;
350-
InputType* in_data_d;
351-
InputType* dbias_out_data_d;
352-
OutputType* out_data_rowwise_d;
353-
OutputType* out_data_colwise_d;
354-
fp8e8m0* out_scales_rowwise_d;
355-
fp8e8m0* out_scales_colwise_d;
356-
size_t* first_dims_d;
357-
size_t* last_dims_d;
358-
size_t* offsets_d;
349+
InputType* grad_data_d = nullptr;
350+
InputType* in_data_d = nullptr;
351+
InputType* dbias_out_data_d = nullptr;
352+
OutputType* out_data_rowwise_d = nullptr;
353+
OutputType* out_data_colwise_d = nullptr;
354+
fp8e8m0* out_scales_rowwise_d = nullptr;
355+
fp8e8m0* out_scales_colwise_d = nullptr;
356+
size_t* first_dims_d = nullptr;
357+
size_t* last_dims_d = nullptr;
358+
size_t* offsets_d = nullptr;
359359

360360
cudaMalloc((void**)&grad_data_d, in_data_size);
361361
cudaMalloc((void**)&in_data_d, in_data_size);

transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -808,8 +808,6 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations
808808
NVTE_CHECK(num_tensors <= MAX_SUPPORTED_TENSOR_DESCRIPTORS,
809809
"Number of tensors in a group is larger than "
810810
"the MAX number of supported descriptors (64).");
811-
// Only full tiles supported
812-
NVTE_CHECK(elts_total % ELTS_PER_CHUNK == 0, "Only full-tile grouped tensors supported.");
813811
blocks = DIVUP(elts_total, CHUNK_DIM_Y * CHUNK_DIM_X);
814812
}
815813
const size_t block_size = THREADS_PER_CHUNK;

0 commit comments

Comments
 (0)