Skip to content

Commit 833da5b

Browse files
committed
rm redundant codegen
1 parent f7d39c7 commit 833da5b

2 files changed

Lines changed: 19 additions & 15 deletions

File tree

fbgemm_gpu/codegen/training/backward/embedding_backward_split_device_kernel_template.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ DEVICE_INLINE void store_grad_sum(
6363
*/
6464
#}
6565

66-
{%- if not nobag and not weighted and vbe %}
66+
{%- if not nobag and not weighted and vbe and not ssd %}
6767
template <
6868
typename grad_t,
6969
typename cache_t,

fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row(
153153
const float weight_decay_base = 1 - learning_rate * weight_decay;
154154
{%- endif %}
155155

156-
{%- if not nobag and vbe and not weighted and not ssd %}
156+
{%- if not is_gwd_kernel and not nobag and vbe and not weighted and not ssd and optimizer == "rowwise_adagrad" and mdesc == "split" %}
157157
const auto run_sum = sorted_linear_indices_run.size(0) < sorted_linear_indices_num_runs[0]
158158
? sorted_linear_indices_run.size(0)
159159
: sorted_linear_indices_num_runs[0];
@@ -178,7 +178,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row(
178178
? smem.getPointer() + threadIdx.y * grad_sum_stride
179179
: nullptr;
180180

181-
{%- if vbe and not weighted and not ssd and not nobag and optimizer == "rowwise_adagrad" and not is_gwd_kernel %}
181+
{%- if not is_gwd_kernel and not nobag and vbe and not weighted and not ssd and optimizer == "rowwise_adagrad" and mdesc == "split" %}
182182
int32_t segment_start = segment_start_pre;
183183
int32_t segment_end = segment_end_pre;
184184
int64_t linear_index = linear_index_pre;
@@ -196,6 +196,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row(
196196

197197
for (uint32_t run_id = start_run_id; run_id < run_sum; run_id += gridDim.x * blockDim.y) {
198198
{%- else %}
199+
199200
for (uint32_t run_id = start_run_id;
200201
run_id < sorted_linear_indices_run.size(0) && run_id < sorted_linear_indices_num_runs[0];
201202
run_id += gridDim.x * blockDim.y) {
@@ -208,17 +209,19 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row(
208209
const int32_t SL = segment_end - segment_start;
209210
{%- endif %}
210211

212+
211213
if (SL >= max_segment_length_per_warp) {
212214
continue;
213215
}
214216

215-
{%- if vbe and not weighted and not ssd and not nobag and optimizer == "rowwise_adagrad" and not is_gwd_kernel %}
217+
{%- if not is_gwd_kernel and not nobag and vbe and not weighted and not ssd and optimizer == "rowwise_adagrad" and mdesc == "split" %}
216218
if (run_id + gridDim.x * blockDim.y < run_sum) {
217219
linear_index_pre = sorted_linear_indices_run[run_id + gridDim.x * blockDim.y];
218220
segment_start_pre = sorted_linear_indices_cumulative_run_lengths[run_id + gridDim.x * blockDim.y];
219221
segment_end_pre = sorted_linear_indices_cumulative_run_lengths[run_id + gridDim.x * blockDim.y + 1];
220222
}
221223
{%- else %}
224+
222225
// now, each segment corresponds to exactly one table `t` and row in
223226
// that table (`idx`). Thus, we can hoist out some of the book-keeping.
224227
{%- if not nobag %}
@@ -230,8 +233,6 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row(
230233
{%- endif %}
231234
{%- endif %}
232235

233-
// now, each segment corresponds to exactly one table `t` and row in
234-
// that table (`idx`). Thus, we can hoist out some of the book-keeping.
235236
int64_t hash_size = hash_size_cumsum[t_0];
236237
{%- if not nobag or is_index_select %}
237238
const auto D_start_t0 = D_offsets[t_0];
@@ -245,7 +246,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row(
245246
const auto grad_stride = permute_output_dim_0_1 ? D_offsets[T] : D;
246247
{%- endif %}
247248
{%- endif %}
248-
int64_t idx = linear_index - hash_size; // the id value or emb index
249+
int64_t idx = linear_index - hash_size;
249250

250251
{{ compute_global_weight_decay(is_gwd_kernel) }}
251252

@@ -256,7 +257,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row(
256257
constexpr int32_t kGroupVecWidth = kThreadGroupSize * VEC_WIDTH;
257258
const int32_t num_vecs = (D + kGroupVecWidth - 1) / kGroupVecWidth;
258259

259-
{%- if not nobag and not weighted and vbe and optimizer == "rowwise_adagrad" and not is_gwd_kernel and not ssd %}
260+
{%- if not is_gwd_kernel and not nobag and vbe and not weighted and not ssd and optimizer == "rowwise_adagrad" and mdesc == "split" %}
260261
compute_grad_sum_unweighted_vbe_rowwise_adagrad<
261262
grad_t,
262263
cache_t,
@@ -295,6 +296,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row(
295296
? reinterpret_cast<const uint32_t*>(&sorted_infos[0])[segment_start + threadIdx.x]
296297
: 0;
297298
{%- else %}
299+
298300
compute_grad_sum_{{ kdesc }}<
299301
grad_t,
300302
cache_t,
@@ -331,7 +333,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row(
331333
sl_end,
332334
shfl_sync_mask,
333335
num_vecs
334-
);
336+
);
335337
{%- endif %}
336338

337339
// Copy value to max_vecs to make max_vecs_per_thread known at compile time
@@ -340,7 +342,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row(
340342
kUseVecBlocking ? max_vecs_per_thread : kFixedMaxVecsPerThread;
341343

342344
{%- if not dense and optimizer != "none" %}
343-
{%- if not nobag and not weighted and vbe and optimizer == "rowwise_adagrad" and not is_gwd_kernel and not ssd %}
345+
{%- if not is_gwd_kernel and not nobag and vbe and not weighted and not ssd and optimizer == "rowwise_adagrad" and mdesc == "split" %}
344346
vbe_unweighted_split_rowwise_adagrad_table_update_kernel<
345347
emb_t,
346348
cache_t,
@@ -360,7 +362,9 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row(
360362
stochastic_rounding,
361363
stochastic_rounding_philox_args,
362364
run_id,
363-
segment_start,
365+
use_uniq_cache_locations
366+
? (run_id - table_unique_indices_offsets[t_0])
367+
: segment_start,
364368
D,
365369
t_0,
366370
idx,
@@ -369,7 +373,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row(
369373
max_vecs,
370374
weights_placement,
371375
{{ args.split_kernel_arg_names | join(", ") }}
372-
); // if not dense and optimizer != "none"
376+
); // if not dense and optimizer != "none"
373377

374378
t_0 = info_0 >> info_B_num_bits;
375379
auto weights_placement = static_cast<PlacementType>(weights_placements[t_0]);
@@ -416,7 +420,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row(
416420
enable_optimizer_offloading,
417421
{%- endif %}
418422
{{ args.split_kernel_arg_names | join(", ") }}
419-
);
423+
); // if not dense and optimizer != "none"
420424
{%- endif %}
421425
{%- else %}
422426
// Write deduplicated gradient to grad_dev_weights gradient is sparse
@@ -442,8 +446,8 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row(
442446
weights_offset,
443447
idx,
444448
max_vecs
445-
);
446-
{%- endif %}
449+
); // if not dense and optimizer != "none"
450+
{%- endif %}
447451
}
448452
}
449453

0 commit comments

Comments
 (0)