@@ -153,7 +153,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row(
153153 const float weight_decay_base = 1 - learning_rate * weight_decay;
154154 {%- endif %}
155155
156- {%- if not nobag and vbe and not weighted and not ssd %}
156+ {%- if not is_gwd_kernel and not nobag and vbe and not weighted and not ssd and optimizer == " rowwise_adagrad " and mdesc == " split " %}
157157 const auto run_sum = sorted_linear_indices_run.size (0 ) < sorted_linear_indices_num_runs[0 ]
158158 ? sorted_linear_indices_run.size (0 )
159159 : sorted_linear_indices_num_runs[0 ];
@@ -178,7 +178,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row(
178178 ? smem.getPointer () + threadIdx .y * grad_sum_stride
179179 : nullptr ;
180180
181- {%- if vbe and not weighted and not ssd and not nobag and optimizer == " rowwise_adagrad" and not is_gwd_kernel %}
181+ {%- if not is_gwd_kernel and not nobag and vbe and not weighted and not ssd and optimizer == " rowwise_adagrad" and mdesc == " split " %}
182182 int32_t segment_start = segment_start_pre;
183183 int32_t segment_end = segment_end_pre;
184184 int64_t linear_index = linear_index_pre;
@@ -196,6 +196,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row(
196196
197197 for (uint32_t run_id = start_run_id; run_id < run_sum; run_id += gridDim .x * blockDim .y ) {
198198 {%- else %}
199+
199200 for (uint32_t run_id = start_run_id;
200201 run_id < sorted_linear_indices_run.size (0 ) && run_id < sorted_linear_indices_num_runs[0 ];
201202 run_id += gridDim .x * blockDim .y ) {
@@ -208,17 +209,19 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row(
208209 const int32_t SL = segment_end - segment_start;
209210 {%- endif %}
210211
212+
211213 if (SL >= max_segment_length_per_warp) {
212214 continue ;
213215 }
214216
215- {%- if vbe and not weighted and not ssd and not nobag and optimizer == " rowwise_adagrad" and not is_gwd_kernel %}
217+ {%- if not is_gwd_kernel and not nobag and vbe and not weighted and not ssd and optimizer == " rowwise_adagrad" and mdesc == " split " %}
216218 if (run_id + gridDim .x * blockDim .y < run_sum) {
217219 linear_index_pre = sorted_linear_indices_run[run_id + gridDim .x * blockDim .y ];
218220 segment_start_pre = sorted_linear_indices_cumulative_run_lengths[run_id + gridDim .x * blockDim .y ];
219221 segment_end_pre = sorted_linear_indices_cumulative_run_lengths[run_id + gridDim .x * blockDim .y + 1 ];
220222 }
221223 {%- else %}
224+
222225 // now, each segment corresponds to exactly one table `t` and row in
223226 // that table (`idx`). Thus, we can hoist out some of the book-keeping.
224227 {%- if not nobag %}
@@ -230,8 +233,6 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row(
230233 {%- endif %}
231234 {%- endif %}
232235
233- // now, each segment corresponds to exactly one table `t` and row in
234- // that table (`idx`). Thus, we can hoist out some of the book-keeping.
235236 int64_t hash_size = hash_size_cumsum[t_0];
236237 {%- if not nobag or is_index_select %}
237238 const auto D_start_t0 = D_offsets[t_0];
@@ -245,7 +246,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row(
245246 const auto grad_stride = permute_output_dim_0_1 ? D_offsets[T] : D;
246247 {%- endif %}
247248 {%- endif %}
248- int64_t idx = linear_index - hash_size; // the id value or emb index
249+ int64_t idx = linear_index - hash_size;
249250
250251 {{ compute_global_weight_decay (is_gwd_kernel) }}
251252
@@ -256,7 +257,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row(
256257 constexpr int32_t kGroupVecWidth = kThreadGroupSize * VEC_WIDTH;
257258 const int32_t num_vecs = (D + kGroupVecWidth - 1 ) / kGroupVecWidth ;
258259
259- {%- if not nobag and not weighted and vbe and optimizer == " rowwise_adagrad" and not is_gwd_kernel and not ssd %}
260+ {%- if not is_gwd_kernel and not nobag and vbe and not weighted and not ssd and optimizer == " rowwise_adagrad" and mdesc == " split " %}
260261 compute_grad_sum_unweighted_vbe_rowwise_adagrad<
261262 grad_t ,
262263 cache_t ,
@@ -295,6 +296,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row(
295296 ? reinterpret_cast <const uint32_t *>(&sorted_infos[0 ])[segment_start + threadIdx .x ]
296297 : 0 ;
297298 {%- else %}
299+
298300 compute_grad_sum_{{ kdesc }}<
299301 grad_t ,
300302 cache_t ,
@@ -331,7 +333,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row(
331333 sl_end,
332334 shfl_sync_mask,
333335 num_vecs
334- );
336+ );
335337 {%- endif %}
336338
337339 // Copy value to max_vecs to make max_vecs_per_thread known at compile time
@@ -340,7 +342,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row(
340342 kUseVecBlocking ? max_vecs_per_thread : kFixedMaxVecsPerThread ;
341343
342344 {%- if not dense and optimizer != " none" %}
343- {%- if not nobag and not weighted and vbe and optimizer == " rowwise_adagrad" and not is_gwd_kernel and not ssd %}
345+ {%- if not is_gwd_kernel and not nobag and vbe and not weighted and not ssd and optimizer == " rowwise_adagrad" and mdesc == " split " %}
344346 vbe_unweighted_split_rowwise_adagrad_table_update_kernel<
345347 emb_t ,
346348 cache_t ,
@@ -360,7 +362,9 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row(
360362 stochastic_rounding,
361363 stochastic_rounding_philox_args,
362364 run_id,
363- segment_start,
365+ use_uniq_cache_locations
366+ ? (run_id - table_unique_indices_offsets[t_0])
367+ : segment_start,
364368 D,
365369 t_0,
366370 idx,
@@ -369,7 +373,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row(
369373 max_vecs,
370374 weights_placement,
371375 {{ args.split_kernel_arg_names | join (" , " ) }}
372- ); // if not dense and optimizer != "none"
376+ ); // if not dense and optimizer != "none"
373377
374378 t_0 = info_0 >> info_B_num_bits;
375379 auto weights_placement = static_cast <PlacementType>(weights_placements[t_0]);
@@ -416,7 +420,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row(
416420 enable_optimizer_offloading,
417421 {%- endif %}
418422 {{ args.split_kernel_arg_names | join (" , " ) }}
419- );
423+ ); // if not dense and optimizer != "none"
420424 {%- endif %}
421425 {%- else %}
422426 // Write deduplicated gradient to grad_dev_weights gradient is sparse
@@ -442,8 +446,8 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row(
442446 weights_offset,
443447 idx,
444448 max_vecs
445- );
446- {%- endif %}
449+ ); // if not dense and optimizer != "none"
450+ {%- endif %}
447451 }
448452}
449453
0 commit comments