Skip to content

Commit 2c4e34c

Browse files
authored
Change snmg index to use updated multi gpu resource API (rapidsai#869)
These changes are dependent on [this breaking PR](rapidsai/raft#2647) in raft. This PR itself doesn't introduce any new features, but makes changes to the existing code to work with the breaking changes in the PR in raft. Authors: - Jinsol Park (https://github.com/jinsolp) Approvers: - Victor Lafargue (https://github.com/viclafargue) - Corey J. Nolet (https://github.com/cjnolet) URL: rapidsai#869
1 parent 079165d commit 2c4e34c

1 file changed

Lines changed: 18 additions & 17 deletions

File tree

cpp/src/neighbors/mg/snmg.cuh

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
#pragma once
1818

1919
#include "../detail/knn_merge_parts.cuh"
20-
#include <raft/core/resource/nccl_clique.hpp>
20+
#include <raft/core/resource/multi_gpu.hpp>
21+
#include <raft/core/resource/nccl_comm.hpp>
2122
#include <raft/core/serialize.hpp>
2223
#include <raft/linalg/add.cuh>
2324
#include <raft/util/cuda_dev_essentials.cuh>
@@ -75,10 +76,10 @@ void deserialize(const raft::resources& clique,
7576
index.mode_ = (cuvs::neighbors::distribution_mode)deserialize_scalar<int>(handle, is);
7677
index.num_ranks_ = deserialize_scalar<int>(handle, is);
7778

78-
if (index.num_ranks_ != raft::resource::get_nccl_num_ranks(clique)) {
79+
if (index.num_ranks_ != raft::resource::get_num_ranks(clique)) {
7980
RAFT_FAIL("Serialized index has %d ranks whereas NCCL clique has %d ranks",
8081
index.num_ranks_,
81-
raft::resource::get_nccl_num_ranks(clique));
82+
raft::resource::get_num_ranks(clique));
8283
}
8384

8485
for (int rank = 0; rank < index.num_ranks_; rank++) {
@@ -215,8 +216,8 @@ void sharded_search_with_direct_merge(const raft::resources& clique,
215216
const raft::resources& dev_res = raft::resource::set_current_device_to_rank(clique, rank);
216217
auto& ann_if = index.ann_interfaces_[rank];
217218

218-
if (rank == raft::resource::get_nccl_clique_root_rank(clique)) { // root rank
219-
uint64_t batch_offset = raft::resource::get_nccl_clique_root_rank(clique) * part_size;
219+
if (rank == raft::resource::get_root_rank(clique)) { // root rank
220+
uint64_t batch_offset = raft::resource::get_root_rank(clique) * part_size;
220221
auto d_neighbors = raft::make_device_matrix_view<IdxT, int64_t, row_major>(
221222
in_neighbors.data_handle() + batch_offset, n_rows_of_current_batch, n_neighbors);
222223
auto d_distances = raft::make_device_matrix_view<float, int64_t, row_major>(
@@ -227,20 +228,20 @@ void sharded_search_with_direct_merge(const raft::resources& clique,
227228
// wait for other ranks
228229
ncclGroupStart();
229230
for (int from_rank = 0; from_rank < index.num_ranks_; from_rank++) {
230-
if (from_rank == raft::resource::get_nccl_clique_root_rank(clique)) continue;
231+
if (from_rank == raft::resource::get_root_rank(clique)) continue;
231232

232233
batch_offset = from_rank * part_size;
233234
ncclRecv(in_neighbors.data_handle() + batch_offset,
234235
part_size * sizeof(IdxT),
235236
ncclUint8,
236237
from_rank,
237-
raft::resource::get_nccl_comm(dev_res),
238+
raft::resource::get_nccl_comm_for_rank(clique, rank),
238239
raft::resource::get_cuda_stream(dev_res));
239240
ncclRecv(in_distances.data_handle() + batch_offset,
240241
part_size * sizeof(float),
241242
ncclUint8,
242243
from_rank,
243-
raft::resource::get_nccl_comm(dev_res),
244+
raft::resource::get_nccl_comm_for_rank(clique, rank),
244245
raft::resource::get_cuda_stream(dev_res));
245246
}
246247
ncclGroupEnd();
@@ -258,14 +259,14 @@ void sharded_search_with_direct_merge(const raft::resources& clique,
258259
ncclSend(d_neighbors.data_handle(),
259260
part_size * sizeof(IdxT),
260261
ncclUint8,
261-
raft::resource::get_nccl_clique_root_rank(clique),
262-
raft::resource::get_nccl_comm(dev_res),
262+
raft::resource::get_root_rank(clique),
263+
raft::resource::get_nccl_comm_for_rank(clique, rank),
263264
raft::resource::get_cuda_stream(dev_res));
264265
ncclSend(d_distances.data_handle(),
265266
part_size * sizeof(float),
266267
ncclUint8,
267-
raft::resource::get_nccl_clique_root_rank(clique),
268-
raft::resource::get_nccl_comm(dev_res),
268+
raft::resource::get_root_rank(clique),
269+
raft::resource::get_nccl_comm_for_rank(clique, rank),
269270
raft::resource::get_cuda_stream(dev_res));
270271
ncclGroupEnd();
271272
resource::sync_stream(dev_res);
@@ -379,13 +380,13 @@ void sharded_search_with_tree_merge(const raft::resources& clique,
379380
part_size * sizeof(IdxT),
380381
ncclUint8,
381382
other_id,
382-
raft::resource::get_nccl_comm(dev_res),
383+
raft::resource::get_nccl_comm_for_rank(clique, rank),
383384
raft::resource::get_cuda_stream(dev_res));
384385
ncclRecv(tmp_distances.data_handle() + part_size,
385386
part_size * sizeof(float),
386387
ncclUint8,
387388
other_id,
388-
raft::resource::get_nccl_comm(dev_res),
389+
raft::resource::get_nccl_comm_for_rank(clique, rank),
389390
raft::resource::get_cuda_stream(dev_res));
390391
received_something = true;
391392
}
@@ -396,13 +397,13 @@ void sharded_search_with_tree_merge(const raft::resources& clique,
396397
part_size * sizeof(IdxT),
397398
ncclUint8,
398399
other_id,
399-
raft::resource::get_nccl_comm(dev_res),
400+
raft::resource::get_nccl_comm_for_rank(clique, rank),
400401
raft::resource::get_cuda_stream(dev_res));
401402
ncclSend(tmp_distances.data_handle(),
402403
part_size * sizeof(float),
403404
ncclUint8,
404405
other_id,
405-
raft::resource::get_nccl_comm(dev_res),
406+
raft::resource::get_nccl_comm_for_rank(clique, rank),
406407
raft::resource::get_cuda_stream(dev_res));
407408
}
408409
ncclGroupEnd();
@@ -655,7 +656,7 @@ template <typename AnnIndexType, typename T, typename IdxT>
655656
mg_index<AnnIndexType, T, IdxT>::mg_index(const raft::resources& clique, distribution_mode mode)
656657
: mode_(mode), round_robin_counter_(std::make_shared<std::atomic<int64_t>>(0))
657658
{
658-
num_ranks_ = raft::resource::get_nccl_num_ranks(clique);
659+
num_ranks_ = raft::resource::get_num_ranks(clique);
659660
}
660661

661662
template <typename AnnIndexType, typename T, typename IdxT>

0 commit comments

Comments
 (0)