@@ -89,30 +89,25 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK)
8989
9090template <int nvec, typename OType>
9191__global__ void __launch_bounds__ (THREADS_PER_BLOCK)
92- group_reduce_dbias_kernel(const ShapeRepresentation shape_rep,
93- const size_t num_tensors,
94- const size_t first_logical_dim,
95- const size_t last_logical_dim,
96- const int64_t *const offsets_ptr,
97- const int64_t *const first_dims_ptr,
98- const int64_t *const last_dims_ptr,
99- OType *const dbias_output,
100- const float *dbias_partial,
101- const size_t chunk_dim_Y) {
92+ group_reduce_dbias_kernel(const ShapeRepresentation shape_rep, const size_t num_tensors,
93+ const size_t first_logical_dim, const size_t last_logical_dim,
94+ const int64_t *const offsets_ptr, const int64_t *const first_dims_ptr,
95+ const int64_t *const last_dims_ptr, OType *const dbias_output,
96+ const float *dbias_partial, const size_t chunk_dim_Y) {
10297 using ComputeVec = Vec<float , nvec>;
10398 using OutputVec = Vec<OType, nvec>;
10499
105100 const size_t tensor_id = blockIdx .y ;
106101 const size_t tensor_rows = (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS)
107- ? (first_logical_dim / num_tensors)
108- : first_dims_ptr[tensor_id];
109-
102+ ? (first_logical_dim / num_tensors)
103+ : first_dims_ptr[tensor_id];
104+
110105 const size_t rows = tensor_rows / chunk_dim_Y;
111106 const size_t cols = last_logical_dim;
112107
113108 const size_t dbias_in_offset_Y = (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS)
114- ? (tensor_id * (tensor_rows / chunk_dim_Y))
115- : (offsets_ptr[tensor_id] / cols / chunk_dim_Y);
109+ ? (tensor_id * (tensor_rows / chunk_dim_Y))
110+ : (offsets_ptr[tensor_id] / cols / chunk_dim_Y);
116111
117112 const size_t thread_id = blockIdx .x * blockDim .x + threadIdx .x ;
118113
@@ -160,16 +155,12 @@ void reduce_dbias(const float *workspace_ptr, Tensor *dbias, const size_t rows,
160155}
161156
162157template <typename IType>
163- void grouped_reduce_dbias (const ShapeRepresentation shape_rep,
164- const size_t num_tensors,
165- const size_t first_logical_dim,
166- const size_t last_logical_dim,
158+ void grouped_reduce_dbias (const ShapeRepresentation shape_rep, const size_t num_tensors,
159+ const size_t first_logical_dim, const size_t last_logical_dim,
167160 const int64_t *const data_tensor_offsets_ptr,
168161 const int64_t *const data_tensor_first_dims_ptr,
169- const int64_t *const data_tensor_last_dims_ptr,
170- GroupedTensor *dbias,
171- const float *workspace_ptr,
172- const size_t chunk_dim_Y,
162+ const int64_t *const data_tensor_last_dims_ptr, GroupedTensor *dbias,
163+ const float *workspace_ptr, const size_t chunk_dim_Y,
173164 cudaStream_t stream) {
174165 using namespace kernel ;
175166 constexpr size_t reduce_dbias_store_bytes = 8 ; // stg.64
@@ -181,11 +172,10 @@ void grouped_reduce_dbias(const ShapeRepresentation shape_rep,
181172 const size_t blocks_Y = num_tensors;
182173 const dim3 grid (blocks_X, blocks_Y);
183174
184- group_reduce_dbias_kernel<reduce_dbias_nvec, IType>
185- <<<grid, THREADS_PER_BLOCK, 0 , stream>>> (
186- shape_rep, num_tensors, first_logical_dim, last_logical_dim,
187- data_tensor_offsets_ptr, data_tensor_first_dims_ptr, data_tensor_last_dims_ptr,
188- reinterpret_cast <IType *>(dbias->data .dptr ), workspace_ptr, chunk_dim_Y);
175+ group_reduce_dbias_kernel<reduce_dbias_nvec, IType><<<grid, THREADS_PER_BLOCK, 0 , stream>>> (
176+ shape_rep, num_tensors, first_logical_dim, last_logical_dim, data_tensor_offsets_ptr,
177+ data_tensor_first_dims_ptr, data_tensor_last_dims_ptr,
178+ reinterpret_cast <IType *>(dbias->data .dptr ), workspace_ptr, chunk_dim_Y);
189179
190180 NVTE_CHECK_CUDA (cudaGetLastError ());
191181}
0 commit comments