Skip to content

Commit 39e67f3

Browse files
committed
call functions directly
1 parent e14a119 commit 39e67f3

7 files changed

Lines changed: 55 additions & 39 deletions

File tree

cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_kernel.cu.in

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,27 @@
88

99
namespace cuvs::neighbors::cagra::detail {
1010

11-
using args_t = typename cuvs::neighbors::cagra::detail::dataset_descriptor_base_t<@data_type@, @index_type@, @distance_type@>::args_t;
11+
using args_t = typename dataset_descriptor_base_t<@data_type@, @index_type@, @distance_type@>::args_t;
1212
template __device__ @distance_type@ compute_distance<@team_size@, @dataset_block_dim@, @pq_bits@, @pq_len@, @codebook_type@, @data_type@, @index_type@, @distance_type@, @query_type@>(
1313
const args_t, @index_type@);
1414

1515
template<>
16-
__device__ @distance_type@ (*compute_distance_ptr<@data_type@, @index_type@, @distance_type@>)(const typename cuvs::neighbors::cagra::detail::dataset_descriptor_base_t<@data_type@, @index_type@, @distance_type@>::args_t, @index_type@) =
17-
&cuvs::neighbors::cagra::detail::compute_distance<@team_size@, @dataset_block_dim@, @pq_bits@, @pq_len@, @codebook_type@, @data_type@, @index_type@, @distance_type@, @query_type@>;
16+
__device__ @distance_type@ compute_distance_base<@data_type@, @index_type@, @distance_type@>(
17+
const args_t args, @index_type@ dataset_index, bool valid, uint32_t team_size_bits)
18+
{
19+
auto per_thread = valid
20+
? compute_distance<@team_size@, @dataset_block_dim@, @pq_bits@, @pq_len@, @codebook_type@, @data_type@, @index_type@, @distance_type@, @query_type@>(
21+
args, dataset_index)
22+
: 0;
23+
return device::team_sum(per_thread, team_size_bits);
24+
}
25+
26+
template<>
27+
__device__ @distance_type@ compute_distance_per_thread_base<@data_type@, @index_type@, @distance_type@>(
28+
const args_t args, @index_type@ dataset_index)
29+
{
30+
return compute_distance<@team_size@, @dataset_block_dim@, @pq_bits@, @pq_len@, @codebook_type@, @data_type@, @index_type@, @distance_type@, @query_type@>(
31+
args, dataset_index);
32+
}
1833

1934
} // namespace cuvs::neighbors::cagra::detail

cpp/src/neighbors/detail/cagra/jit_lto_kernels/device_common_jit.cuh

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,9 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_random_nodes_jit(
5151
{
5252
constexpr unsigned warp_size = 32;
5353

54-
// Use team_size_bitshift_from_smem since smem_desc is in shared memory
5554
uint32_t team_size_bits = smem_desc->team_size_bitshift_from_smem();
5655
IndexT dataset_size = smem_desc->size;
56+
const auto args_load = smem_desc->args.load();
5757

5858
const auto max_i = raft::round_up_safe<uint32_t>(num_pickup, warp_size >> team_size_bits);
5959

@@ -63,7 +63,6 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_random_nodes_jit(
6363
IndexT best_index_team_local = raft::upper_bound<IndexT>();
6464
DistanceT best_norm2_team_local = raft::upper_bound<DistanceT>();
6565
for (uint32_t j = 0; j < num_distilation; j++) {
66-
// Select a node randomly and compute the distance to it
6766
IndexT seed_index = 0;
6867
if (valid_i) {
6968
uint32_t gid = block_id + (num_blocks * (i + (num_pickup * j)));
@@ -74,13 +73,9 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_random_nodes_jit(
7473
}
7574
}
7675

77-
const auto args_load = smem_desc->args.load();
78-
const auto team_bits = smem_desc->team_size_bitshift_from_smem();
79-
auto per_thread_distances =
80-
valid_i ? (*cuvs::neighbors::cagra::detail::
81-
compute_distance_ptr<DataT, IndexT, DistanceT>)(args_load, seed_index)
82-
: 0;
83-
const auto norm2 = device::team_sum(per_thread_distances, team_bits);
76+
const auto norm2 =
77+
cuvs::neighbors::cagra::detail::compute_distance_base<DataT, IndexT, DistanceT>(
78+
args_load, seed_index, valid_i, team_size_bits);
8479

8580
if (valid_i && (norm2 < best_norm2_team_local)) {
8681
best_norm2_team_local = norm2;
@@ -173,12 +168,12 @@ RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_child_nodes_jit(
173168
const bool valid_i = STATIC_RESULT_POSITION ? (j < num_k) : (j < max_result_position);
174169
const auto child_id = valid_i ? result_child_indices_ptr[j] : invalid_index;
175170

176-
auto per_thread_distances =
171+
const auto per_thread =
177172
(child_id != invalid_index)
178-
? (*cuvs::neighbors::cagra::detail::
179-
compute_distance_ptr<DataT, IndexT, DistanceT>)(args, child_id)
173+
? cuvs::neighbors::cagra::detail::
174+
compute_distance_per_thread_base<DataT, IndexT, DistanceT>(args, child_id)
180175
: (lead_lane ? raft::upper_bound<DistanceT>() : 0);
181-
const DistanceT child_dist = device::team_sum(per_thread_distances, team_size_bits);
176+
const DistanceT child_dist = device::team_sum(per_thread, team_size_bits);
182177
__syncwarp();
183178

184179
// Store the distance

cpp/src/neighbors/detail/cagra/jit_lto_kernels/extern_device_functions.cuh

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ extern __device__ dataset_descriptor_base_t<DataT, IndexT, DistanceT>* setup_wor
2626
uint32_t query_id);
2727

2828
template <typename DataT, typename IndexT, typename DistanceT>
29-
extern __device__ dataset_descriptor_base_t<DataT, IndexT, DistanceT>* (*setup_workspace_ptr)(
29+
extern __device__ dataset_descriptor_base_t<DataT, IndexT, DistanceT>* setup_workspace_base(
3030
dataset_descriptor_base_t<DataT, IndexT, DistanceT>*, void*, const DataT*, uint32_t);
3131

3232
template <uint32_t TeamSize,
@@ -43,7 +43,14 @@ compute_distance(const typename dataset_descriptor_base_t<DataT, IndexT, Distanc
4343
IndexT dataset_index);
4444

4545
template <typename DataT, typename IndexT, typename DistanceT>
46-
extern __device__ DistanceT (*compute_distance_ptr)(
46+
extern __device__ DistanceT compute_distance_base(
47+
const typename dataset_descriptor_base_t<DataT, IndexT, DistanceT>::args_t args,
48+
IndexT dataset_index,
49+
bool valid,
50+
uint32_t team_size_bits);
51+
52+
template <typename DataT, typename IndexT, typename DistanceT>
53+
extern __device__ DistanceT compute_distance_per_thread_base(
4754
const typename dataset_descriptor_base_t<DataT, IndexT, DistanceT>::args_t, IndexT);
4855
} // namespace cuvs::neighbors::cagra::detail
4956

cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_jit.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ __global__ __launch_bounds__(1024, 1) void search_kernel_jit(
9898
uint32_t smem_ws_size_in_bytes = dataset_desc->smem_ws_size_in_bytes();
9999

100100
auto* smem_desc =
101-
(*setup_workspace_ptr<DataT, IndexT, DistanceT>)(dataset_desc, smem, queries_ptr, query_id);
101+
setup_workspace_base<DataT, IndexT, DistanceT>(dataset_desc, smem, queries_ptr, query_id);
102102

103103
auto* __restrict__ result_indices_buffer =
104104
reinterpret_cast<INDEX_T*>(smem + smem_ws_size_in_bytes);

cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_jit.cuh

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,11 @@ RAFT_KERNEL random_pickup_kernel_jit(
4747
extern __shared__ uint8_t smem[];
4848

4949
auto* smem_desc =
50-
(*setup_workspace_ptr<DataT, IndexT, DistanceT>)(dataset_desc, smem, queries_ptr, query_id);
50+
setup_workspace_base<DataT, IndexT, DistanceT>(dataset_desc, smem, queries_ptr, query_id);
5151
__syncthreads();
5252

53-
IndexT dataset_size = smem_desc->size;
53+
IndexT dataset_size = smem_desc->size;
54+
const auto args_load = smem_desc->args.load();
5455

5556
INDEX_T best_index_team_local;
5657
DISTANCE_T best_norm2_team_local = utils::get_max_value<DISTANCE_T>();
@@ -59,15 +60,11 @@ RAFT_KERNEL random_pickup_kernel_jit(
5960
if (seed_ptr && (global_team_index < num_seeds)) {
6061
seed_index = seed_ptr[global_team_index + (num_seeds * query_id)];
6162
} else {
62-
// Chose a seed node randomly
6363
seed_index = device::xorshift64((global_team_index ^ rand_xor_mask) * (i + 1)) % dataset_size;
6464
}
6565

66-
const auto args_load = smem_desc->args.load();
67-
const auto team_bits = smem_desc->team_size_bitshift_from_smem();
68-
auto per_thread_distances =
69-
(*compute_distance_ptr<DataT, IndexT, DistanceT>)(args_load, seed_index);
70-
const auto norm2 = device::team_sum(per_thread_distances, team_bits);
66+
const auto norm2 =
67+
compute_distance_base<DataT, IndexT, DistanceT>(args_load, seed_index, true, team_size_bits);
7168

7269
if (norm2 < best_norm2_team_local) {
7370
best_norm2_team_local = norm2;
@@ -125,7 +122,7 @@ RAFT_KERNEL compute_distance_to_child_nodes_kernel_jit(
125122

126123
extern __shared__ uint8_t smem[];
127124
auto* smem_desc =
128-
(*setup_workspace_ptr<DataT, IndexT, DistanceT>)(dataset_desc, smem, query_ptr, query_id);
125+
setup_workspace_base<DataT, IndexT, DistanceT>(dataset_desc, smem, query_ptr, query_id);
129126

130127
__syncthreads();
131128
if (global_team_id >= search_width * graph_degree) { return; }
@@ -151,12 +148,9 @@ RAFT_KERNEL compute_distance_to_child_nodes_kernel_jit(
151148
const auto compute_distance_flag = hashmap::insert<INDEX_T>(
152149
team_size, visited_hashmap_ptr + (ldb * blockIdx.y), hash_bitlen, child_id);
153150

154-
const auto args = smem_desc->args.load();
155-
auto per_thread_distances =
156-
compute_distance_flag
157-
? (*compute_distance_ptr<DataT, IndexT, DistanceT>)(args, static_cast<INDEX_T>(child_id))
158-
: 0;
159-
DISTANCE_T norm2 = device::team_sum(per_thread_distances, team_size_bits);
151+
const auto args = smem_desc->args.load();
152+
DISTANCE_T norm2 = compute_distance_base<DataT, IndexT, DistanceT>(
153+
args, static_cast<INDEX_T>(child_id), compute_distance_flag, team_size_bits);
160154

161155
if (compute_distance_flag) {
162156
if ((threadIdx.x & (team_size - 1)) == 0) {

cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_jit.cuh

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,9 +121,8 @@ RAFT_DEVICE_INLINE_FUNCTION void search_core(
121121
uint32_t dim = dataset_desc->args.dim;
122122
uint32_t smem_ws_size_in_bytes = dataset_desc->smem_ws_size_in_bytes();
123123

124-
// auto* smem_desc = dataset_desc->setup_workspace(smem, queries_ptr, query_id);
125124
auto* smem_desc =
126-
(*setup_workspace_ptr<DataT, IndexT, DistanceT>)(dataset_desc, smem, queries_ptr, query_id);
125+
setup_workspace_base<DataT, IndexT, DistanceT>(dataset_desc, smem, queries_ptr, query_id);
127126

128127
auto* __restrict__ result_indices_buffer =
129128
reinterpret_cast<IndexT*>(smem + smem_ws_size_in_bytes);

cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_kernel.cu.in

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,14 @@ template __device__ cuvs::neighbors::cagra::detail::dataset_descriptor_base_t<@d
1212
cuvs::neighbors::cagra::detail::dataset_descriptor_base_t<@data_type@, @index_type@, @distance_type@>*, void*, const @data_type@*, uint32_t);
1313

1414
template<>
15-
__device__ cuvs::neighbors::cagra::detail::dataset_descriptor_base_t<@data_type@, @index_type@, @distance_type@>* (*setup_workspace_ptr<@data_type@, @index_type@, @distance_type@>)(
16-
cuvs::neighbors::cagra::detail::dataset_descriptor_base_t<@data_type@, @index_type@, @distance_type@>*, void*, const @data_type@*, uint32_t) =
17-
&cuvs::neighbors::cagra::detail::setup_workspace<@team_size@, @dataset_block_dim@, @pq_bits@, @pq_len@, @codebook_type@, @data_type@, @index_type@, @distance_type@, @query_type@>;
15+
__device__ dataset_descriptor_base_t<@data_type@, @index_type@, @distance_type@>* setup_workspace_base<@data_type@, @index_type@, @distance_type@>(
16+
dataset_descriptor_base_t<@data_type@, @index_type@, @distance_type@>* desc,
17+
void* smem,
18+
const @data_type@* queries,
19+
uint32_t query_id)
20+
{
21+
return setup_workspace<@team_size@, @dataset_block_dim@, @pq_bits@, @pq_len@, @codebook_type@, @data_type@, @index_type@, @distance_type@, @query_type@>(
22+
desc, smem, queries, query_id);
23+
}
1824

1925
} // namespace cuvs::neighbors::cagra::detail

0 commit comments

Comments
 (0)