Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,111 @@ DEVICE_INLINE void store_grad_sum(
*/
#}

{%- if not nobag and not weighted and vbe and not ssd %}
template <
typename grad_t,
typename cache_t,
int32_t kFixedMaxVecsPerThread,
int32_t kThreadGroupSize = kWarpSize,
int32_t VEC_WIDTH,
bool kUseVecBlocking
>
DEVICE_INLINE void compute_grad_sum_unweighted_vbe_rowwise_adagrad(
Vec4TAcc<cache_t>* grad_sum,
Vec4TAcc<cache_t>* smem_grad_sum,
const pta::PackedTensorAccessor64<grad_t, 2, at::RestrictPtrTraits>& grad_output,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>& D_offsets,
const int32_t D,
const int32_t T,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>& sorted_infos,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>& B_offsets,
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits>& row_output_offsets,
const int32_t info_B_num_bits,
const uint32_t info_B_mask,
const int32_t segment_start,
const int32_t sl_start,
const int32_t sl_end,
const unsigned int shfl_sync_mask,
const int32_t num_vecs,
const int32_t b_t_pre,
const int32_t boff_pre
) {
// Copy value to vecs to make num_vecs known at compile time when
// kUseVecBlocking == false
const int32_t vecs = kUseVecBlocking ? num_vecs : kFixedMaxVecsPerThread;
for (int32_t vec_start = 0;
vec_start < vecs;
vec_start += kFixedMaxVecsPerThread) {

// Reset grad_sum vectors
#pragma unroll kFixedMaxVecsPerThread
for (int32_t vec = 0; vec < kFixedMaxVecsPerThread; vec++) {
grad_sum[vec].acc.x = 0;
grad_sum[vec].acc.y = 0;
grad_sum[vec].acc.z = 0;
grad_sum[vec].acc.w = 0;
}

for (int32_t sl = sl_start; sl < sl_end; sl += kThreadGroupSize) {
auto sl_j = sl + threadIdx.x;
const auto b_t = (sl==sl_start && vec_start==0) ? b_t_pre : (sl_j < sl_end
? reinterpret_cast<const uint32_t*>(
&sorted_infos[0])[segment_start + sl_j]
: 0);
const auto b = b_t & info_B_mask;
const auto t = b_t >> info_B_num_bits;
const auto boff = (sl == sl_start && vec_start == 0) ? boff_pre: B_offsets[t];
const auto grad_offset = row_output_offsets[boff + b]; // if vbe // if not nobag
const int32_t d = threadIdx.x * VEC_WIDTH;

for (int32_t j = 0; j < kThreadGroupSize && sl + j < sl_end; j += 8) {
const auto grad_offset_j0 = SHFL_SYNC(grad_offset, j);
const auto grad_offset_j1 = SHFL_SYNC(grad_offset, j + 1);
const auto grad_offset_j2 = SHFL_SYNC(grad_offset, j + 2);
const auto grad_offset_j3 = SHFL_SYNC(grad_offset, j + 3);
const auto grad_offset_j4 = SHFL_SYNC(grad_offset, j + 4);
const auto grad_offset_j5 = SHFL_SYNC(grad_offset, j + 5);
const auto grad_offset_j6 = SHFL_SYNC(grad_offset, j + 6);
const auto grad_offset_j7 = SHFL_SYNC(grad_offset, j + 7);
if (threadIdx.x * VEC_WIDTH < D) {
Vec4TAcc<grad_t> grad_out_vec0 = Vec4TAcc<grad_t>(&grad_output[0][grad_offset_j0 + d]);
Vec4TAcc<grad_t> grad_out_vec1 = sl + j + 1 < sl_end ? Vec4TAcc<grad_t>(&grad_output[0][grad_offset_j1 + d]) : Vec4TAcc<grad_t>();
Vec4TAcc<grad_t> grad_out_vec2 = sl + j + 2 < sl_end ? Vec4TAcc<grad_t>(&grad_output[0][grad_offset_j2 + d]) : Vec4TAcc<grad_t>();
Vec4TAcc<grad_t> grad_out_vec3 = sl + j + 3 < sl_end ? Vec4TAcc<grad_t>(&grad_output[0][grad_offset_j3 + d]) : Vec4TAcc<grad_t>();
Vec4TAcc<grad_t> grad_out_vec4 = sl + j + 4 < sl_end ? Vec4TAcc<grad_t>(&grad_output[0][grad_offset_j4 + d]) : Vec4TAcc<grad_t>();
Vec4TAcc<grad_t> grad_out_vec5 = sl + j + 5 < sl_end ? Vec4TAcc<grad_t>(&grad_output[0][grad_offset_j5 + d]) : Vec4TAcc<grad_t>();
Vec4TAcc<grad_t> grad_out_vec6 = sl + j + 6 < sl_end ? Vec4TAcc<grad_t>(&grad_output[0][grad_offset_j6 + d]) : Vec4TAcc<grad_t>();
Vec4TAcc<grad_t> grad_out_vec7 = sl + j + 7 < sl_end ? Vec4TAcc<grad_t>(&grad_output[0][grad_offset_j7 + d]) : Vec4TAcc<grad_t>();
grad_sum[0].add_(grad_out_vec0);
grad_sum[0].add_(grad_out_vec1);
grad_sum[0].add_(grad_out_vec2);
grad_sum[0].add_(grad_out_vec3);
grad_sum[0].add_(grad_out_vec4);
grad_sum[0].add_(grad_out_vec5);
grad_sum[0].add_(grad_out_vec6);
grad_sum[0].add_(grad_out_vec7);

}
}
}

{%- set d_vec = "((vec + vec_start) * kThreadGroupSize + threadIdx.x)" %}

if (smem_grad_sum) {
// Store grad_sum in smem_grad_sum
#pragma unroll kFixedMaxVecsPerThread
for (int32_t vec = 0;
(vec < kFixedMaxVecsPerThread) && {{ d_vec }} * VEC_WIDTH < D;
++vec) {
const int32_t d_vec = {{ d_vec }};
smem_grad_sum[d_vec] = grad_sum[vec];
}
}
}
}

{%- endif %}

template <
typename grad_t,
typename cache_t,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,15 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row(
const float weight_decay_base = 1 - learning_rate * weight_decay;
{%- endif %}

{%- if not is_gwd_kernel and not nobag and vbe and not weighted and not ssd and optimizer == "rowwise_adagrad" and mdesc == "split" %}
const auto run_sum = sorted_linear_indices_run.size(0) < sorted_linear_indices_num_runs[0]
? sorted_linear_indices_run.size(0)
: sorted_linear_indices_num_runs[0];
int64_t linear_index_pre = sorted_linear_indices_run[start_run_id];
int32_t segment_start_pre = sorted_linear_indices_cumulative_run_lengths[start_run_id];
int32_t segment_end_pre = sorted_linear_indices_cumulative_run_lengths[start_run_id + 1];
{%- endif %}

#ifdef FBGEMM_USE_SUBWARP_SHUFFLE
const unsigned int shfl_sync_mask =
((1L << kThreadGroupSize) - 1) <<
Expand All @@ -169,6 +178,25 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row(
? smem.getPointer() + threadIdx.y * grad_sum_stride
: nullptr;

{%- if not is_gwd_kernel and not nobag and vbe and not weighted and not ssd and optimizer == "rowwise_adagrad" and mdesc == "split" %}
int32_t segment_start = segment_start_pre;
int32_t segment_end = segment_end_pre;
int64_t linear_index = linear_index_pre;
int32_t SL = segment_end - segment_start;
auto info_0 = reinterpret_cast<const uint32_t*>(&sorted_infos[0])[segment_start_pre];
auto t_0 = info_0 >> info_B_num_bits;
auto weights_placement = static_cast<PlacementType>(weights_placements[t_0]);

auto b_t_pre = threadIdx.x < SL
? reinterpret_cast<const uint32_t*>(&sorted_infos[0])[segment_start + threadIdx.x]
: 0;

auto t = b_t_pre >> info_B_num_bits;
auto boff_pre = B_offsets[t];

for (uint32_t run_id = start_run_id; run_id < run_sum; run_id += gridDim.x * blockDim.y) {
{%- else %}

for (uint32_t run_id = start_run_id;
run_id < sorted_linear_indices_run.size(0) && run_id < sorted_linear_indices_num_runs[0];
run_id += gridDim.x * blockDim.y) {
Expand All @@ -179,12 +207,21 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row(
const int32_t segment_end =
sorted_linear_indices_cumulative_run_lengths[run_id + 1];
const int32_t SL = segment_end - segment_start;
{%- endif %}


if (SL >= max_segment_length_per_warp) {
continue;
}

{%- if not is_gwd_kernel and not nobag and vbe and not weighted and not ssd and optimizer == "rowwise_adagrad" and mdesc == "split" %}
if (run_id + gridDim.x * blockDim.y < run_sum) {
linear_index_pre = sorted_linear_indices_run[run_id + gridDim.x * blockDim.y];
segment_start_pre = sorted_linear_indices_cumulative_run_lengths[run_id + gridDim.x * blockDim.y];
segment_end_pre = sorted_linear_indices_cumulative_run_lengths[run_id + gridDim.x * blockDim.y + 1];
}
{%- else %}

// now, each segment corresponds to exactly one table `t` and row in
// that table (`idx`). Thus, we can hoist out some of the book-keeping.
{%- if not nobag %}
Expand All @@ -194,6 +231,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row(
const auto info_0 = sorted_infos[segment_start];
int32_t t_0 = info_0 % T;
{%- endif %}
{%- endif %}

int64_t hash_size = hash_size_cumsum[t_0];
{%- if not nobag or is_index_select %}
Expand All @@ -219,6 +257,46 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row(
constexpr int32_t kGroupVecWidth = kThreadGroupSize * VEC_WIDTH;
const int32_t num_vecs = (D + kGroupVecWidth - 1) / kGroupVecWidth;

{%- if not is_gwd_kernel and not nobag and vbe and not weighted and not ssd and optimizer == "rowwise_adagrad" and mdesc == "split" %}
compute_grad_sum_unweighted_vbe_rowwise_adagrad<
grad_t,
cache_t,
kFixedMaxVecsPerThread,
kThreadGroupSize,
VEC_WIDTH,
kUseVecBlocking>(
grad_sum,
smem_grad_sum,
grad_output,
D_offsets,
D,
T,
sorted_infos,
B_offsets,
row_output_offsets,
info_B_num_bits,
info_B_mask,
segment_start,
sl_start,
sl_end,
shfl_sync_mask,
num_vecs,
b_t_pre,
boff_pre
);
if (run_id + gridDim.x * blockDim.y < run_sum) {
info_0 = reinterpret_cast<const uint32_t*>(&sorted_infos[0])[segment_start_pre];
}

segment_start = segment_start_pre;
segment_end = segment_end_pre;
linear_index = linear_index_pre;
SL = segment_end - segment_start;
b_t_pre = threadIdx.x < SL
? reinterpret_cast<const uint32_t*>(&sorted_infos[0])[segment_start + threadIdx.x]
: 0;
{%- else %}

compute_grad_sum_{{ kdesc }}<
grad_t,
cache_t,
Expand Down Expand Up @@ -256,13 +334,52 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row(
shfl_sync_mask,
num_vecs
);
{%- endif %}

// Copy value to max_vecs to make max_vecs_per_thread known at compile time
// when kUseVecBlocking == false
const int32_t max_vecs =
kUseVecBlocking ? max_vecs_per_thread : kFixedMaxVecsPerThread;

{%- if not dense and optimizer != "none" %}
{%- if not is_gwd_kernel and not nobag and vbe and not weighted and not ssd and optimizer == "rowwise_adagrad" and mdesc == "split" %}
vbe_unweighted_split_rowwise_adagrad_table_update_kernel<
emb_t,
cache_t,
kFixedMaxVecsPerThread,
kThreadGroupSize,
VEC_WIDTH,
kUseVecBlocking>(
dev_weights,
uvm_weights,
lxu_cache_weights,
weights_placements,
weights_offsets,
sorted_{{ locs_or_addrs_tensor }},
grad_sum,
smem_grad_sum,
smem_grad_sum, // shared_weight_update_row (reuse smem_grad_sum)
stochastic_rounding,
stochastic_rounding_philox_args,
run_id,
use_uniq_cache_locations
? (run_id - table_unique_indices_offsets[t_0])
: segment_start,
D,
t_0,
idx,
1, // global_weight_decay
shfl_sync_mask,
max_vecs,
weights_placement,
{{ args.split_kernel_arg_names | join(", ") }}
); // if not dense and optimizer != "none"

t_0 = info_0 >> info_B_num_bits;
auto weights_placement = static_cast<PlacementType>(weights_placements[t_0]);
t = b_t_pre >> info_B_num_bits;
boff_pre = B_offsets[t];
{%- else %}
{{ mdesc }}_{{ optimizer }}_table_update_kernel<
emb_t,
cache_t,
Expand Down Expand Up @@ -303,7 +420,8 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row(
enable_optimizer_offloading,
{%- endif %}
{{ args.split_kernel_arg_names | join(", ") }}
);
); // if not dense and optimizer != "none"
{%- endif %}
{%- else %}
// Write deduplicated gradient to grad_dev_weights gradient is sparse
// for split_embedding and dense for dense_embedding
Expand All @@ -328,8 +446,8 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row(
weights_offset,
idx,
max_vecs
);
{%- endif %} // if not dense and optimizer != "none"
); // if not dense and optimizer != "none"
{%- endif %}
}
}

Expand Down Expand Up @@ -853,4 +971,4 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd
#endif
////////////////////////////////////////////////////////////////////////////////
{%- endif %}
// clang-format on
// clang-format on
Loading