1717#pragma once
1818
1919#include " ../detail/knn_merge_parts.cuh"
20- #include < raft/core/resource/nccl_clique.hpp>
20+ #include < raft/core/resource/multi_gpu.hpp>
21+ #include < raft/core/resource/nccl_comm.hpp>
2122#include < raft/core/serialize.hpp>
2223#include < raft/linalg/add.cuh>
2324#include < raft/util/cuda_dev_essentials.cuh>
@@ -75,10 +76,10 @@ void deserialize(const raft::resources& clique,
7576 index.mode_ = (cuvs::neighbors::distribution_mode)deserialize_scalar<int >(handle, is);
7677 index.num_ranks_ = deserialize_scalar<int >(handle, is);
7778
78- if (index.num_ranks_ != raft::resource::get_nccl_num_ranks (clique)) {
79+ if (index.num_ranks_ != raft::resource::get_num_ranks (clique)) {
7980 RAFT_FAIL (" Serialized index has %d ranks whereas NCCL clique has %d ranks" ,
8081 index.num_ranks_ ,
81- raft::resource::get_nccl_num_ranks (clique));
82+ raft::resource::get_num_ranks (clique));
8283 }
8384
8485 for (int rank = 0 ; rank < index.num_ranks_ ; rank++) {
@@ -215,8 +216,8 @@ void sharded_search_with_direct_merge(const raft::resources& clique,
215216 const raft::resources& dev_res = raft::resource::set_current_device_to_rank (clique, rank);
216217 auto & ann_if = index.ann_interfaces_ [rank];
217218
218- if (rank == raft::resource::get_nccl_clique_root_rank (clique)) { // root rank
219- uint64_t batch_offset = raft::resource::get_nccl_clique_root_rank (clique) * part_size;
219+ if (rank == raft::resource::get_root_rank (clique)) { // root rank
220+ uint64_t batch_offset = raft::resource::get_root_rank (clique) * part_size;
220221 auto d_neighbors = raft::make_device_matrix_view<IdxT, int64_t , row_major>(
221222 in_neighbors.data_handle () + batch_offset, n_rows_of_current_batch, n_neighbors);
222223 auto d_distances = raft::make_device_matrix_view<float , int64_t , row_major>(
@@ -227,20 +228,20 @@ void sharded_search_with_direct_merge(const raft::resources& clique,
227228 // wait for other ranks
228229 ncclGroupStart ();
229230 for (int from_rank = 0 ; from_rank < index.num_ranks_ ; from_rank++) {
230- if (from_rank == raft::resource::get_nccl_clique_root_rank (clique)) continue ;
231+ if (from_rank == raft::resource::get_root_rank (clique)) continue ;
231232
232233 batch_offset = from_rank * part_size;
233234 ncclRecv (in_neighbors.data_handle () + batch_offset,
234235 part_size * sizeof (IdxT),
235236 ncclUint8,
236237 from_rank,
237- raft::resource::get_nccl_comm (dev_res ),
238+ raft::resource::get_nccl_comm_for_rank (clique, rank ),
238239 raft::resource::get_cuda_stream (dev_res));
239240 ncclRecv (in_distances.data_handle () + batch_offset,
240241 part_size * sizeof (float ),
241242 ncclUint8,
242243 from_rank,
243- raft::resource::get_nccl_comm (dev_res ),
244+ raft::resource::get_nccl_comm_for_rank (clique, rank ),
244245 raft::resource::get_cuda_stream (dev_res));
245246 }
246247 ncclGroupEnd ();
@@ -258,14 +259,14 @@ void sharded_search_with_direct_merge(const raft::resources& clique,
258259 ncclSend (d_neighbors.data_handle (),
259260 part_size * sizeof (IdxT),
260261 ncclUint8,
261- raft::resource::get_nccl_clique_root_rank (clique),
262- raft::resource::get_nccl_comm (dev_res ),
262+ raft::resource::get_root_rank (clique),
263+ raft::resource::get_nccl_comm_for_rank (clique, rank ),
263264 raft::resource::get_cuda_stream (dev_res));
264265 ncclSend (d_distances.data_handle (),
265266 part_size * sizeof (float ),
266267 ncclUint8,
267- raft::resource::get_nccl_clique_root_rank (clique),
268- raft::resource::get_nccl_comm (dev_res ),
268+ raft::resource::get_root_rank (clique),
269+ raft::resource::get_nccl_comm_for_rank (clique, rank ),
269270 raft::resource::get_cuda_stream (dev_res));
270271 ncclGroupEnd ();
271272 resource::sync_stream (dev_res);
@@ -379,13 +380,13 @@ void sharded_search_with_tree_merge(const raft::resources& clique,
379380 part_size * sizeof (IdxT),
380381 ncclUint8,
381382 other_id,
382- raft::resource::get_nccl_comm (dev_res ),
383+ raft::resource::get_nccl_comm_for_rank (clique, rank ),
383384 raft::resource::get_cuda_stream (dev_res));
384385 ncclRecv (tmp_distances.data_handle () + part_size,
385386 part_size * sizeof (float ),
386387 ncclUint8,
387388 other_id,
388- raft::resource::get_nccl_comm (dev_res ),
389+ raft::resource::get_nccl_comm_for_rank (clique, rank ),
389390 raft::resource::get_cuda_stream (dev_res));
390391 received_something = true ;
391392 }
@@ -396,13 +397,13 @@ void sharded_search_with_tree_merge(const raft::resources& clique,
396397 part_size * sizeof (IdxT),
397398 ncclUint8,
398399 other_id,
399- raft::resource::get_nccl_comm (dev_res ),
400+ raft::resource::get_nccl_comm_for_rank (clique, rank ),
400401 raft::resource::get_cuda_stream (dev_res));
401402 ncclSend (tmp_distances.data_handle (),
402403 part_size * sizeof (float ),
403404 ncclUint8,
404405 other_id,
405- raft::resource::get_nccl_comm (dev_res ),
406+ raft::resource::get_nccl_comm_for_rank (clique, rank ),
406407 raft::resource::get_cuda_stream (dev_res));
407408 }
408409 ncclGroupEnd ();
@@ -655,7 +656,7 @@ template <typename AnnIndexType, typename T, typename IdxT>
655656mg_index<AnnIndexType, T, IdxT>::mg_index(const raft::resources& clique, distribution_mode mode)
656657 : mode_(mode), round_robin_counter_(std::make_shared<std::atomic<int64_t >>(0 ))
657658{
658- num_ranks_ = raft::resource::get_nccl_num_ranks (clique);
659+ num_ranks_ = raft::resource::get_num_ranks (clique);
659660}
660661
661662template <typename AnnIndexType, typename T, typename IdxT>
0 commit comments