Skip to content

Commit 4881d1b

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent aff53ff commit 4881d1b

4 files changed

Lines changed: 34 additions & 39 deletions

File tree

tests/cpp/operator/test_cast_mxfp8_grouped.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ void performTest(const ProcessingMethod processing_method,
288288
rowwise_sfs_num += rowwise_sfs;
289289
colwise_sfs_num += colwise_sfs;
290290
sum_of_last_dims += K;
291-
291+
292292
rowwise_scales_offset[t+1] = rowwise_sfs_num;
293293
colwise_scales_offset[t+1] = colwise_sfs_num;
294294
dbias_offsets[t+1] = sum_of_last_dims;
@@ -370,7 +370,7 @@ void performTest(const ProcessingMethod processing_method,
370370
cudaMemcpy(offsets_d, offsets_h.data(), offsets_size, cudaMemcpyHostToDevice);
371371

372372
NVTEShape logical_shape_ = nvte_make_shape(logical_shape_vec.data(), logical_shape_vec.size());
373-
373+
374374
std::vector<size_t> dbias_logical_shape_vec= {num_tensors, cols};
375375
NVTEShape dbias_logical_shape_ = nvte_make_shape(dbias_logical_shape_vec.data(),
376376
dbias_logical_shape_vec.size());

transformer_engine/common/cast/core/common.cuh

Lines changed: 18 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -89,30 +89,25 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK)
8989

9090
template <int nvec, typename OType>
9191
__global__ void __launch_bounds__(THREADS_PER_BLOCK)
92-
group_reduce_dbias_kernel(const ShapeRepresentation shape_rep,
93-
const size_t num_tensors,
94-
const size_t first_logical_dim,
95-
const size_t last_logical_dim,
96-
const int64_t *const offsets_ptr,
97-
const int64_t *const first_dims_ptr,
98-
const int64_t *const last_dims_ptr,
99-
OType *const dbias_output,
100-
const float *dbias_partial,
101-
const size_t chunk_dim_Y) {
92+
group_reduce_dbias_kernel(const ShapeRepresentation shape_rep, const size_t num_tensors,
93+
const size_t first_logical_dim, const size_t last_logical_dim,
94+
const int64_t *const offsets_ptr, const int64_t *const first_dims_ptr,
95+
const int64_t *const last_dims_ptr, OType *const dbias_output,
96+
const float *dbias_partial, const size_t chunk_dim_Y) {
10297
using ComputeVec = Vec<float, nvec>;
10398
using OutputVec = Vec<OType, nvec>;
10499

105100
const size_t tensor_id = blockIdx.y;
106101
const size_t tensor_rows = (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS)
107-
? (first_logical_dim / num_tensors)
108-
: first_dims_ptr[tensor_id];
109-
102+
? (first_logical_dim / num_tensors)
103+
: first_dims_ptr[tensor_id];
104+
110105
const size_t rows = tensor_rows / chunk_dim_Y;
111106
const size_t cols = last_logical_dim;
112107

113108
const size_t dbias_in_offset_Y = (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS)
114-
? (tensor_id * (tensor_rows / chunk_dim_Y))
115-
: (offsets_ptr[tensor_id] / cols / chunk_dim_Y);
109+
? (tensor_id * (tensor_rows / chunk_dim_Y))
110+
: (offsets_ptr[tensor_id] / cols / chunk_dim_Y);
116111

117112
const size_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;
118113

@@ -160,16 +155,12 @@ void reduce_dbias(const float *workspace_ptr, Tensor *dbias, const size_t rows,
160155
}
161156

162157
template <typename IType>
163-
void grouped_reduce_dbias(const ShapeRepresentation shape_rep,
164-
const size_t num_tensors,
165-
const size_t first_logical_dim,
166-
const size_t last_logical_dim,
158+
void grouped_reduce_dbias(const ShapeRepresentation shape_rep, const size_t num_tensors,
159+
const size_t first_logical_dim, const size_t last_logical_dim,
167160
const int64_t *const data_tensor_offsets_ptr,
168161
const int64_t *const data_tensor_first_dims_ptr,
169-
const int64_t *const data_tensor_last_dims_ptr,
170-
GroupedTensor *dbias,
171-
const float *workspace_ptr,
172-
const size_t chunk_dim_Y,
162+
const int64_t *const data_tensor_last_dims_ptr, GroupedTensor *dbias,
163+
const float *workspace_ptr, const size_t chunk_dim_Y,
173164
cudaStream_t stream) {
174165
using namespace kernel;
175166
constexpr size_t reduce_dbias_store_bytes = 8; // stg.64
@@ -181,11 +172,10 @@ void grouped_reduce_dbias(const ShapeRepresentation shape_rep,
181172
const size_t blocks_Y = num_tensors;
182173
const dim3 grid(blocks_X, blocks_Y);
183174

184-
group_reduce_dbias_kernel<reduce_dbias_nvec, IType>
185-
<<<grid, THREADS_PER_BLOCK, 0, stream>>>(
186-
shape_rep, num_tensors, first_logical_dim, last_logical_dim,
187-
data_tensor_offsets_ptr, data_tensor_first_dims_ptr, data_tensor_last_dims_ptr,
188-
reinterpret_cast<IType *>(dbias->data.dptr), workspace_ptr, chunk_dim_Y);
175+
group_reduce_dbias_kernel<reduce_dbias_nvec, IType><<<grid, THREADS_PER_BLOCK, 0, stream>>>(
176+
shape_rep, num_tensors, first_logical_dim, last_logical_dim, data_tensor_offsets_ptr,
177+
data_tensor_first_dims_ptr, data_tensor_last_dims_ptr,
178+
reinterpret_cast<IType *>(dbias->data.dptr), workspace_ptr, chunk_dim_Y);
189179

190180
NVTE_CHECK_CUDA(cudaGetLastError());
191181
}

transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,8 @@ __device__ __forceinline__ void modify_base_tensor_map(const CUtensorMap base_te
144144
NVTE_DEVICE_ERROR("Tensor data pointer must be 16B aligned");
145145
}
146146
if (global_dim_X % CHUNK_DIM_X != 0) {
147-
NVTE_DEVICE_ERROR("The grouped tensor must be divisible by 128x128 tiles without a tail tile.");
147+
NVTE_DEVICE_ERROR(
148+
"The grouped tensor must be divisible by 128x128 tiles without a tail tile.");
148149
}
149150

150151
asm volatile(
@@ -941,9 +942,8 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations
941942

942943
if constexpr (IS_DBIAS) {
943944
common::grouped_reduce_dbias<IType>(
944-
shape_rep, num_tensors, first_logical_dim, last_logical_dim,
945-
offsets_ptr, first_dims_ptr, last_dims_ptr,
946-
dbias, workspace_ptr, CHUNK_DIM_Y, stream);
945+
shape_rep, num_tensors, first_logical_dim, last_logical_dim, offsets_ptr,
946+
first_dims_ptr, last_dims_ptr, dbias, workspace_ptr, CHUNK_DIM_Y, stream);
947947
}
948948

949949
NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*)

transformer_engine/common/include/transformer_engine/cast.h

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,8 @@ void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor act_inpu
207207
*/
208208
void nvte_group_quantize_dbias_dgelu(const NVTEGroupedTensor input,
209209
const NVTEGroupedTensor act_input, NVTEGroupedTensor output,
210-
NVTEGroupedTensor dbias, NVTETensor workspace, cudaStream_t stream);
210+
NVTEGroupedTensor dbias, NVTETensor workspace,
211+
cudaStream_t stream);
211212

212213
/*! \brief Computes backward of SiLU operation on the input, then casts to FP8/MXFP8.
213214
* Additionally, reduces the result of the SiLU backward along columns.
@@ -253,7 +254,8 @@ void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor act_inpu
253254
*/
254255
void nvte_group_quantize_dbias_dsilu(const NVTEGroupedTensor input,
255256
const NVTEGroupedTensor act_input, NVTEGroupedTensor output,
256-
NVTEGroupedTensor dbias, NVTETensor workspace, cudaStream_t stream);
257+
NVTEGroupedTensor dbias, NVTETensor workspace,
258+
cudaStream_t stream);
257259

258260
/*! \brief Computes backward of ReLU operation on the input, then casts to FP8/MXFP8.
259261
* Additionally, reduces the result of the ReLU backward along columns.
@@ -299,7 +301,8 @@ void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor act_inpu
299301
*/
300302
void nvte_group_quantize_dbias_drelu(const NVTEGroupedTensor input,
301303
const NVTEGroupedTensor act_input, NVTEGroupedTensor output,
302-
NVTEGroupedTensor dbias, NVTETensor workspace, cudaStream_t stream);
304+
NVTEGroupedTensor dbias, NVTETensor workspace,
305+
cudaStream_t stream);
303306

304307
/*! \brief Computes backward of Quick GeLU operation on the input, then casts to FP8/MXFP8.
305308
* Additionally, reduces the result of the Quick GeLU backward along columns.
@@ -345,7 +348,8 @@ void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor act_inp
345348
*/
346349
void nvte_group_quantize_dbias_dqgelu(const NVTEGroupedTensor input,
347350
const NVTEGroupedTensor act_input, NVTEGroupedTensor output,
348-
NVTEGroupedTensor dbias, NVTETensor workspace, cudaStream_t stream);
351+
NVTEGroupedTensor dbias, NVTETensor workspace,
352+
cudaStream_t stream);
349353

350354
/*! \brief Computes backward of Squared ReLU operation on the input, then casts to FP8/MXFP8.
351355
* Additionally, reduces the result of the Squared ReLU backward along columns.
@@ -391,7 +395,8 @@ void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor act_inp
391395
*/
392396
void nvte_group_quantize_dbias_dsrelu(const NVTEGroupedTensor input,
393397
const NVTEGroupedTensor act_input, NVTEGroupedTensor output,
394-
NVTEGroupedTensor dbias, NVTETensor workspace, cudaStream_t stream);
398+
NVTEGroupedTensor dbias, NVTETensor workspace,
399+
cudaStream_t stream);
395400

396401
/*! \brief Casts input tensor from reduced to higher precision.
397402
* If the scaling mode of the input tensor is set to NVTE_MXFP8_1D_SCALING,

0 commit comments

Comments
 (0)