Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
156 commits
Select commit Hold shift + click to select a range
66d7fd3
combine impls
tarang-jain Apr 10, 2026
07707af
Multi-GPU Batched KMeans
viclafargue Apr 13, 2026
efc270f
Merge branch 'main' into mg-batched-kmeans
viclafargue Apr 13, 2026
0a09e6f
rm inertia_check
tarang-jain Apr 13, 2026
99a5730
change to warning
tarang-jain Apr 13, 2026
a077406
style
tarang-jain Apr 13, 2026
d659875
add init_size param
tarang-jain Apr 13, 2026
ec2e8b7
Merge branch 'main' into combine-batch
tarang-jain Apr 13, 2026
03a6473
docs
tarang-jain Apr 13, 2026
42a8d9d
Merge branch 'combine-batch' of https://github.com/tarang-jain/cuvs i…
tarang-jain Apr 13, 2026
86af2fa
rm direct cuda api calls
tarang-jain Apr 13, 2026
d4e4e2c
std::swap instead of raft::copy
tarang-jain Apr 14, 2026
0819af5
cache batch norms
tarang-jain Apr 14, 2026
e0f079c
centroid norms can also be cached per iteration
tarang-jain Apr 14, 2026
c2f7390
mg n_iter
tarang-jain Apr 14, 2026
b9c3102
pre-commit
tarang-jain Apr 14, 2026
e3956c1
do not break c abi
tarang-jain Apr 14, 2026
986d78a
Merge branch 'main' into combine-batch
tarang-jain Apr 14, 2026
7197b71
cluster_cost on device
viclafargue Apr 14, 2026
84ab315
Updated testing
viclafargue Apr 14, 2026
47d4b94
templating
viclafargue Apr 15, 2026
a8e1d26
Merge branch 'main' into combine-batch
tarang-jain Apr 16, 2026
384d054
fix checkWeight
tarang-jain Apr 21, 2026
455b286
merge upstream:
tarang-jain Apr 21, 2026
5462809
Merge branch 'combine-batch' of https://github.com/tarang-jain/cuvs i…
tarang-jain Apr 21, 2026
6ba759c
fix compilation
tarang-jain Apr 21, 2026
e76eaac
rel_tol
tarang-jain Apr 22, 2026
afbefdf
pass workspace
tarang-jain Apr 22, 2026
e62a63c
Merge branch 'combine-batch' of https://github.com/tarang-jain/cuvs i…
tarang-jain Apr 22, 2026
e4f08bf
style
tarang-jain Apr 22, 2026
6e4a8f0
Merge branch 'main' of https://github.com/rapidsai/cuvs into combine-…
tarang-jain Apr 22, 2026
4a8a85c
do not use batch scratch space; rm update_centroids
tarang-jain Apr 22, 2026
bbf2a9f
move the debug log
tarang-jain Apr 22, 2026
410092c
add new suffixed param struct
tarang-jain Apr 22, 2026
c515c1e
address pr reviews
tarang-jain Apr 22, 2026
e8e63ab
fix docstring
tarang-jain Apr 22, 2026
30c457c
fix wt_sum warning
tarang-jain Apr 22, 2026
ab96623
rm deprecationwarning and instead add FutureWarning:=
tarang-jain Apr 22, 2026
269f23c
unweighted to never materialize batch weights
tarang-jain Apr 22, 2026
80a22ca
add cpp tests
tarang-jain Apr 23, 2026
ac06b05
update cpp tests
tarang-jain Apr 23, 2026
855624a
Merge branch 'main' into mg-batched-kmeans
viclafargue Apr 23, 2026
0a6748d
refactor
viclafargue Apr 23, 2026
7055272
rename to mnmg_fit
viclafargue Apr 23, 2026
0569340
revert batch norms cache
tarang-jain Apr 23, 2026
8cac63a
increase zero cost threshold
tarang-jain Apr 24, 2026
f6df4ae
apply cuda event plus re-add h_norm_cache
tarang-jain Apr 24, 2026
9fc74b1
rm cosine expanded stuff
tarang-jain Apr 24, 2026
dec3dc4
resolve merge conflicts
tarang-jain Apr 28, 2026
0d030a2
change suffix of the params struct
tarang-jain Apr 28, 2026
b1c034e
replace 06 by 08, add todo and note
tarang-jain Apr 28, 2026
a482495
update to v2
tarang-jain Apr 28, 2026
8ecfdc1
avoid stream sync inside weight sum
tarang-jain Apr 29, 2026
1e1525e
Merge branch 'combine-batch' of https://github.com/tarang-jain/cuvs i…
tarang-jain Apr 29, 2026
ec22e07
empty
tarang-jain Apr 29, 2026
d2e410d
empty
tarang-jain Apr 29, 2026
b791c38
Merge branch 'main' into combine-batch
tarang-jain Apr 29, 2026
a05a006
new signatures with new struct
tarang-jain Apr 29, 2026
73293cf
Merge branch 'combine-batch' of https://github.com/tarang-jain/cuvs i…
tarang-jain Apr 29, 2026
880c7b9
Merge branch 'main' of https://github.com/rapidsai/cuvs into combine-…
tarang-jain Apr 30, 2026
e2035ec
revert change to calls in py and rust; add c tests
tarang-jain Apr 30, 2026
e28c200
Merge branch 'main' into combine-batch
tarang-jain May 1, 2026
55bbdad
use to_dlpack
tarang-jain May 5, 2026
9a9b8ee
cache device weights
tarang-jain May 5, 2026
a800b27
rm event
tarang-jain May 5, 2026
3db8582
update names
tarang-jain May 5, 2026
c048352
rename
tarang-jain May 5, 2026
2f968f8
rm docs
tarang-jain May 5, 2026
affe85a
empty
tarang-jain May 5, 2026
c6dea64
fix norm cache
tarang-jain May 5, 2026
7dfab3e
revert changes to minClusterDistanceCompute
tarang-jain May 6, 2026
7a383da
update tests to use mdspan instead of rmm
tarang-jain May 6, 2026
ce6c4b5
Merge branch 'main' into combine-batch
tarang-jain May 6, 2026
5a06a44
Merge branch 'main' into combine-batch
tarang-jain May 6, 2026
419619a
consolidate all unsigned commits
tarang-jain May 7, 2026
2d716ae
rm diff
tarang-jain May 7, 2026
066092b
allow batch sample weights
tarang-jain May 7, 2026
bbdd66d
Merge branch 'main' into mnmg-streaming
tarang-jain May 7, 2026
12d682c
single partition becomes special case
tarang-jain May 7, 2026
9e5e55c
Merge branch 'mnmg-streaming' of github.com:tarang-jain/cuvs into mnm…
tarang-jain May 7, 2026
28cda6a
Merge branch 'combine-batch' into mg-batched-kmeans
viclafargue May 7, 2026
bfb5290
Addressing review
viclafargue May 7, 2026
add9db1
optimize convergence check
viclafargue May 7, 2026
6c08a7b
Merge branch 'main' into mnmg-streaming
tarang-jain May 7, 2026
acbcd5a
Merge branch 'main' into mnmg-streaming
tarang-jain May 7, 2026
af606bc
Adressing review
viclafargue May 8, 2026
41c66b8
Merge branch 'main' into mg-batched-kmeans
viclafargue May 8, 2026
f664c2c
results on all ranks for RAFT + small optimization
viclafargue May 8, 2026
5430f42
Merge branch 'main' into mnmg-streaming
tarang-jain May 8, 2026
b2ab5bd
merge origin
tarang-jain May 8, 2026
bbdf521
Merge branch 'mnmg-streaming' of github.com:tarang-jain/cuvs into mnm…
tarang-jain May 8, 2026
10e6def
Merge branch 'main' into mnmg-streaming
tarang-jain May 8, 2026
2040145
reduce diff
tarang-jain May 8, 2026
1828462
Merge branch 'mnmg-streaming' of github.com:tarang-jain/cuvs into mnm…
tarang-jain May 8, 2026
5c5b8c8
Merge branch 'main' into mnmg-streaming
tarang-jain May 8, 2026
05da5f3
rm prefetch
tarang-jain May 8, 2026
90435c1
Merge branch 'mnmg-streaming' of github.com:tarang-jain/cuvs into mnm…
tarang-jain May 8, 2026
db41338
Merge branch 'main' into mnmg-streaming
tarang-jain May 8, 2026
6c2c03d
reviews
viclafargue May 11, 2026
7f6d664
Global sampling for init
viclafargue May 11, 2026
f8270e2
SNMG -> MNMG
viclafargue May 11, 2026
bbf0302
Merge branch 'main' into mg-batched-kmeans
viclafargue May 11, 2026
a14a6bc
adding asserts
viclafargue May 11, 2026
7b54a42
consume new init
tarang-jain May 11, 2026
d86b8b4
Merge branch 'mnmg-streaming' of github.com:tarang-jain/cuvs into mnm…
tarang-jain May 11, 2026
6e11f67
reduce diff
tarang-jain May 11, 2026
9f5b6e5
Merge branch 'main' into mnmg-streaming
tarang-jain May 13, 2026
aaef638
rm unnecessary functions
tarang-jain May 13, 2026
920a460
Merge branch 'main' into mnmg-streaming
tarang-jain May 13, 2026
548d7db
rm accessor templates for now
tarang-jain May 14, 2026
9f3a486
Merge branch 'mnmg-streaming' of github.com:tarang-jain/cuvs into mnm…
tarang-jain May 14, 2026
c93f248
Merge branch 'main' of https://github.com/rapidsai/cuvs into mnmg-str…
tarang-jain May 14, 2026
51fbf6c
merge upstream
tarang-jain May 20, 2026
d327569
cleanup; re-add device side overload
tarang-jain May 20, 2026
b5e66a3
re-instate removed docs
tarang-jain May 20, 2026
a636188
rm extra fit funcs
tarang-jain May 21, 2026
d3cafed
Merge branch 'release/26.06' into mnmg-streaming
tarang-jain May 26, 2026
1b547f4
Merge branch 'mnmg-streaming' of github.com:tarang-jain/cuvs into mnm…
tarang-jain May 26, 2026
72cfd43
cleanup
tarang-jain May 26, 2026
4d25e95
rm scaled_weights_cache
tarang-jain May 26, 2026
81155e6
rm unnecessary new types
tarang-jain May 26, 2026
85522aa
rm unused helper
tarang-jain May 26, 2026
00336b5
rm unnecessary stream sync
tarang-jain May 27, 2026
6585866
rm unnecessary lambda
tarang-jain May 27, 2026
aa6f28e
cleanup impl
tarang-jain May 27, 2026
178a7e7
rm unnecessary has_data guards
tarang-jain May 27, 2026
713bc7c
rm global_n host scalar
tarang-jain May 27, 2026
c576d8f
fixes
tarang-jain May 27, 2026
a401a0e
Merge branch 'release/26.06' into mnmg-streaming
tarang-jain May 27, 2026
8102596
fuse with in-memory impl
tarang-jain May 27, 2026
00d0adb
Merge branch 'mnmg-streaming' of github.com:tarang-jain/cuvs into mnm…
tarang-jain May 27, 2026
caefd53
style
tarang-jain May 27, 2026
8f6f83d
fix compilation
tarang-jain May 27, 2026
d88a991
Merge branch 'release/26.06' into mnmg-streaming
tarang-jain May 28, 2026
1b57b74
mg tests first commit
tarang-jain May 28, 2026
7bac418
Merge branch 'mnmg-streaming' of github.com:tarang-jain/cuvs into mnm…
tarang-jain May 28, 2026
9851017
merge upstream
tarang-jain May 28, 2026
f572877
update cmakelists
tarang-jain May 29, 2026
edaa7e7
merge upstream
tarang-jain May 29, 2026
588bb6a
rm batched tests
tarang-jain May 29, 2026
ad180ed
Merge branch 'main' into mnmg-streaming
tarang-jain May 29, 2026
72cc34b
rm unnecessary test stream sycns
tarang-jain May 29, 2026
ed50703
reset bs; assertion
tarang-jain May 29, 2026
28f6036
Merge branch 'mnmg-streaming' of github.com:tarang-jain/cuvs into mnm…
tarang-jain May 29, 2026
a811c56
rm has_data flag
tarang-jain May 29, 2026
95f334c
Merge branch 'main' into mnmg-streaming
tarang-jain May 30, 2026
d176314
fix export
tarang-jain May 30, 2026
1db9e02
Merge branch 'mnmg-streaming' of github.com:tarang-jain/cuvs into mnm…
tarang-jain May 30, 2026
089e970
avoid pinned scalar;get_nccl_comms before omp
tarang-jain May 31, 2026
6cc895c
use root from macro
tarang-jain Jun 1, 2026
4abe6f2
avoid copy and rank alloc with initarray
tarang-jain Jun 1, 2026
ebf188a
Merge branch 'main' of https://github.com/rapidsai/cuvs into mnmg-str…
tarang-jain Jun 1, 2026
785e4a3
fix compilation; guardrail MG CMake flag
tarang-jain Jun 1, 2026
9a526c8
get n_features from centroids
tarang-jain Jun 1, 2026
f08e581
add sigs to header
tarang-jain Jun 1, 2026
51efb42
Merge branch 'main' into mnmg-streaming
tarang-jain Jun 5, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1031,8 +1031,6 @@ if(NOT BUILD_CPU_ONLY)
src/cluster/detail/minClusterDistanceCompute.cu
src/cluster/agglomerative.cu
src/cluster/kmeans_cluster_cost.cu
src/cluster/kmeans_fit_mg_float.cu
src/cluster/kmeans_fit_mg_double.cu
src/cluster/kmeans_fit_double.cu
src/cluster/kmeans_fit_float.cu
src/cluster/kmeans_auto_find_k_float.cu
Expand Down Expand Up @@ -1209,6 +1207,10 @@ if(NOT BUILD_CPU_ONLY)
target_link_libraries(cuvs_objs PUBLIC $<BUILD_LOCAL_INTERFACE:NCCL::NCCL>)

target_compile_definitions(cuvs_objs PUBLIC CUVS_BUILD_MG_ALGOS)

target_sources(
cuvs_objs PRIVATE src/cluster/kmeans_fit_mg_float.cu src/cluster/kmeans_fit_mg_double.cu
)
endif()

set(CUVS_CUSOLVER_DEPENDENCY CUDA::cusolver${_ctk_static_suffix})
Expand Down
117 changes: 117 additions & 0 deletions cpp/include/cuvs/cluster/kmeans.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

#include <cuvs/core/export.hpp>
#include <optional>
#include <vector>

namespace CUVS_EXPORT cuvs {
namespace cluster {
Expand Down Expand Up @@ -1607,6 +1608,122 @@ void cluster_cost(
* @}
*/

namespace mg {
/**
* @defgroup kmeans_mg Multi-GPU / out-of-core k-means fit
* @{
*/

/**
* @brief Multi-GPU k-means fit with one or more local data
* partitions per rank.
*
* Each rank supplies its local training data as a vector of partitions. The
* implementation streams every partition through Lloyd iterations using
* `params.streaming_batch_size`.
*
* The active backend is selected by the resources attached to
* `handle`:
* - When `raft::resource::is_multi_gpu(handle)` is true (SNMG clique), the
* call must be issued from inside an OpenMP region with one thread per
* rank in the clique.
* - Otherwise, multi-process NCCL comms must be initialized on the handle
* (`raft::resource::comms_initialized(handle)`); each process supplies its
* own local partitions.
*
* @param[in] handle The raft handle. Must have NCCL comms or
* a SNMG clique initialized.
* @param[in] params K-means parameters. The streaming batch
* size is read from
* `params.streaming_batch_size`.
* @param[in] X_parts Per-partition local data on this rank.
* Each entry is [n_rows_i x n_features].
* @param[in] sample_weight_parts Optional per-partition row weights with
* one vector per data partition.
* @param[inout] centroids Device matrix [n_clusters x n_features].
* On entry, used as the initial centers
* when `params.init == InitMethod::Array`.
* On return, holds the converged
* centroids.
* @param[out] inertia Host scalar receiving the final
* clustering cost.
* @param[out] n_iter Host scalar receiving the iteration
* count at which the run terminated.
*/
void fit(
raft::resources const& handle,
const cuvs::cluster::kmeans::params& params,
const std::vector<raft::device_matrix_view<const float, int>>& X_parts,
const std::optional<std::vector<raft::device_vector_view<const float, int>>>& sample_weight_parts,
raft::device_matrix_view<float, int> centroids,
raft::host_scalar_view<float> inertia,
raft::host_scalar_view<int> n_iter);

/**
* @brief Multi-GPU k-means fit.
*/
void fit(raft::resources const& handle,
const cuvs::cluster::kmeans::params& params,
const std::vector<raft::device_matrix_view<const float, int64_t>>& X_parts,
const std::optional<std::vector<raft::device_vector_view<const float, int64_t>>>&
sample_weight_parts,
raft::device_matrix_view<float, int64_t> centroids,
raft::host_scalar_view<float> inertia,
raft::host_scalar_view<int64_t> n_iter);

/**
* @brief Multi-GPU k-means fit.
*/
void fit(raft::resources const& handle,
const cuvs::cluster::kmeans::params& params,
const std::vector<raft::device_matrix_view<const double, int>>& X_parts,
const std::optional<std::vector<raft::device_vector_view<const double, int>>>&
sample_weight_parts,
raft::device_matrix_view<double, int> centroids,
raft::host_scalar_view<double> inertia,
raft::host_scalar_view<int> n_iter);

/**
* @brief Multi-GPU k-means fit.
*/
void fit(raft::resources const& handle,
const cuvs::cluster::kmeans::params& params,
const std::vector<raft::device_matrix_view<const double, int64_t>>& X_parts,
const std::optional<std::vector<raft::device_vector_view<const double, int64_t>>>&
sample_weight_parts,
raft::device_matrix_view<double, int64_t> centroids,
raft::host_scalar_view<double> inertia,
raft::host_scalar_view<int64_t> n_iter);

/**
* @brief Multi-GPU / out-of-core k-means fit.
*/
void fit(raft::resources const& handle,
const cuvs::cluster::kmeans::params& params,
const std::vector<raft::host_matrix_view<const float, int64_t>>& X_parts,
const std::optional<std::vector<raft::host_vector_view<const float, int64_t>>>&
sample_weight_parts,
raft::device_matrix_view<float, int64_t> centroids,
raft::host_scalar_view<float> inertia,
raft::host_scalar_view<int64_t> n_iter);

/**
* @brief Multi-GPU / out-of-core k-means fit.
*/
void fit(raft::resources const& handle,
const cuvs::cluster::kmeans::params& params,
const std::vector<raft::host_matrix_view<const double, int64_t>>& X_parts,
const std::optional<std::vector<raft::host_vector_view<const double, int64_t>>>&
sample_weight_parts,
raft::device_matrix_view<double, int64_t> centroids,
raft::host_scalar_view<double> inertia,
raft::host_scalar_view<int64_t> n_iter);

/**
* @}
*/
} // namespace mg

namespace helpers {
/**
* @defgroup kmeans_helpers k-means API helpers
Expand Down
Loading
Loading