Skip to content

Commit 325181b

Browse files
Fixes per the review
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
1 parent 219e925 commit 325181b

3 files changed

Lines changed: 9 additions & 4 deletions

File tree

tests/cpp/operator/test_cast_mxfp8_grouped.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,7 @@ void performTest(const ProcessingMethod processing_method,
371371

372372
NVTEShape logical_shape_ = nvte_make_shape(logical_shape_vec.data(), logical_shape_vec.size());
373373

374-
std::vector<size_t> dbias_logical_shape_vec= {num_tensors, cols};
374+
std::vector<size_t> dbias_logical_shape_vec = {num_tensors, cols};
375375
NVTEShape dbias_logical_shape_ = nvte_make_shape(dbias_logical_shape_vec.data(),
376376
dbias_logical_shape_vec.size());
377377

transformer_engine/common/cast/core/common.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,14 +100,14 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK)
100100
const size_t tensor_id = blockIdx.y;
101101
const size_t tensor_rows = (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS)
102102
? (first_logical_dim / num_tensors)
103-
: first_dims_ptr[tensor_id];
103+
: static_cast<size_t>(first_dims_ptr[tensor_id]);
104104

105105
const size_t rows = tensor_rows / chunk_dim_Y;
106106
const size_t cols = last_logical_dim;
107107

108108
const size_t dbias_in_offset_Y = (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS)
109109
? (tensor_id * (tensor_rows / chunk_dim_Y))
110-
: (offsets_ptr[tensor_id] / cols / chunk_dim_Y);
110+
: (static_cast<size_t>(offsets_ptr[tensor_id]) / cols / chunk_dim_Y);
111111

112112
const size_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;
113113

transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,10 @@ __device__ __forceinline__ size_t get_tensor_cols_num(
142142
case ShapeRepresentation::VARYING_LAST_DIM:
143143
case ShapeRepresentation::VARYING_BOTH_DIMS:
144144
cols_num = static_cast<size_t>(last_dims_ptr[tensor_id]);
145+
if (cols_num % 128 != 0) {
146+
NVTE_DEVICE_ERROR("For non-single tensors, the last dimension of each tensor in a group "
147+
"must be divisible by 128.");
148+
}
145149
break;
146150
}
147151
return cols_num;
@@ -215,7 +219,8 @@ decode_block(const JobDescriptor &job, const bool is_single_tensor,
215219
const int64_t *const __restrict__ offsets_ptr) {
216220
BlockDescriptor block{};
217221
block.tensor_base = is_single_tensor ? 0 : static_cast<size_t>(offsets_ptr[job.tensor_id]);
218-
const size_t blocks_X_num_in_current_tensor = DIVUP(job.cols, static_cast<size_t>(128));
222+
const size_t CHUNK_DIM_X_ = CHUNK_DIM_X;
223+
const size_t blocks_X_num_in_current_tensor = DIVUP(job.cols, CHUNK_DIM_X_);
219224
block.block_id_in_current_tensor =
220225
is_single_tensor ? job.block_id : (job.block_id - block.tensor_base / ELTS_PER_CHUNK);
221226
block.block_id_Y = block.block_id_in_current_tensor / blocks_X_num_in_current_tensor;

0 commit comments

Comments
 (0)