Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions fbgemm_gpu/include/fbgemm_gpu/utils/rocm/sparse_group_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#pragma once

#include <cstdint>
#include <limits>

#include "fbgemm_gpu/utils/cuda_prelude.cuh"

namespace fbgemm_gpu::rocm {
namespace {
template <typename scalar_t, int kLogicalWarpSize = kWarpSize>
__device__ __forceinline__ void warp_upper_bound(
int* found,
scalar_t* cached_boundary,
const scalar_t* arr,
const scalar_t target,
const int num_entries) {
const auto active_mask = __activemask();
using mask_t = std::remove_const_t<decltype(active_mask)>;

constexpr int kHardwareWarpSize = kWarpSize;
constexpr int kMaskBits = sizeof(mask_t) * 8;

const int hardware_lane = __lane_id();
const int logical_lane = hardware_lane % kLogicalWarpSize;
const int logical_warp_id = hardware_lane / kLogicalWarpSize;

mask_t logical_mask = mask_t(0);
if constexpr (kLogicalWarpSize >= kMaskBits) {
logical_mask = active_mask;
} else {
const mask_t group_bits = (mask_t(1) << kLogicalWarpSize) - 1;
const mask_t group_mask = group_bits << (logical_warp_id * kLogicalWarpSize);
logical_mask = group_mask & active_mask;
}
if (!logical_mask) {
logical_mask = active_mask;
}

int result = -1;
scalar_t cached_result = *cached_boundary;
for (int base = 0; base < num_entries; base += kLogicalWarpSize) {
const int idx = base + logical_lane;
const bool valid = idx < num_entries;
const scalar_t val = valid ? arr[idx] : scalar_t(0);
const mask_t ballot = __ballot_sync(logical_mask, valid && val > target);
const mask_t logical_ballot = ballot & logical_mask;
if (logical_ballot) {
const int first_lane_hw = __ffsll(static_cast<long long>(logical_ballot)) - 1;
const int first_lane = first_lane_hw - logical_warp_id * kLogicalWarpSize;
result = base + first_lane;
cached_result = arr[result];
break;
}
}

*found = result;
*cached_boundary = cached_result;
}
} // namespace
} //namespace fbgemm_gpu::rocm
19 changes: 19 additions & 0 deletions fbgemm_gpu/src/sparse_ops/sparse_group_index.cu
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
*/

#include "common.cuh"
#ifdef USE_ROCM
#include "fbgemm_gpu/utils/rocm/sparse_group_utils.h"
#endif

using Tensor = at::Tensor;

Expand Down Expand Up @@ -63,12 +66,27 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel(
warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP;
}

[[maybe_unused]] int cached_member_id = -1;
[[maybe_unused]] int64_t cached_upper_bound = -1;
for (int64_t warp_id = threadIdx.y * gridDim.x + blockIdx.x;
warp_id < total_num_warps;
warp_id += gridDim.x * blockDim.y) {
int32_t member_id = 0;
int32_t member_warp_id = 0;
if constexpr (USE_VAR_COLS) {
#ifdef USE_ROCM
if (warp_id >= cached_upper_bound) {
rocm::warp_upper_bound<int64_t, EMULATED_WARP_SIZE>(
&member_id,
&cached_upper_bound,
warp_offsets_group + 1,
warp_id,
group_size);
cached_member_id = member_id;
} else {
member_id = cached_member_id;
}
#else
__shared__ int member_ids[kMaxThreads / EMULATED_WARP_SIZE];
if (threadIdx.x == 0) {
binary_search_range(
Expand All @@ -79,6 +97,7 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel(
}
syncwarp();
member_id = member_ids[threadIdx.y];
#endif
num_cols = num_cols_group[member_id];
warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP;
member_warp_id = warp_id - warp_offsets_group[member_id];
Expand Down