Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
125 commits
Select commit Hold shift + click to select a range
ca07c08
first commit (unclean)
tarang-jain Jan 9, 2026
bc872c8
Merge branch 'main' into minibatch-kmeans
tarang-jain Jan 9, 2026
daf6d6e
Merge branch 'main' into minibatch-kmeans
tarang-jain Jan 10, 2026
f1a19df
style
tarang-jain Jan 10, 2026
181d536
Merge branch 'minibatch-kmeans' of https://github.com/tarang-jain/cuv…
tarang-jain Jan 10, 2026
0fa00b0
copyright
tarang-jain Jan 10, 2026
371543f
Merge branch 'main' into minibatch-kmeans
tarang-jain Jan 23, 2026
c81650c
Merge branch 'main' into minibatch-kmeans
tarang-jain Feb 3, 2026
fcbdda5
python test
tarang-jain Feb 3, 2026
d6ed934
minibatch first commit
tarang-jain Feb 3, 2026
5d4b498
fix docs
tarang-jain Feb 3, 2026
72fe789
replace thrust calls:
tarang-jain Feb 3, 2026
aefae6e
Merge branch 'main' of https://github.com/rapidsai/cuvs into minibatc…
tarang-jain Feb 9, 2026
526ac04
common function in helper
tarang-jain Feb 9, 2026
e9c85b9
Merge branch 'main' into minibatch-kmeans
tarang-jain Feb 10, 2026
1efadde
Merge branch 'main' into minibatch-kmeans
tarang-jain Feb 11, 2026
ee45045
Merge branch 'main' into minibatch-kmeans
tarang-jain Feb 13, 2026
9b6f1ef
fix templates
tarang-jain Feb 13, 2026
ad20d0a
Merge branch 'minibatch-kmeans' of https://github.com/tarang-jain/cuv…
tarang-jain Feb 13, 2026
4b65df5
namespace and init fixes
tarang-jain Feb 13, 2026
5eb2be5
fix docs in main header
tarang-jain Feb 13, 2026
c23985a
several fixes
tarang-jain Feb 15, 2026
c103f87
Merge branch 'main' into minibatch-kmeans
tarang-jain Feb 15, 2026
9d87a5f
rm lower precision
tarang-jain Feb 15, 2026
5bcec91
Merge branch 'minibatch-kmeans' of https://github.com/tarang-jain/cuv…
tarang-jain Feb 15, 2026
a618ed5
rm unnecessary unary-ops
tarang-jain Feb 16, 2026
3b86325
rm unnecessary unary-ops
tarang-jain Feb 16, 2026
f1b4835
minibatch allocations are conditional
tarang-jain Feb 16, 2026
d2d3f4b
cleanup extraneous docs
tarang-jain Feb 16, 2026
639147a
revert changes to get_dataset
tarang-jain Feb 16, 2026
c5b9628
Merge branch 'main' into minibatch-kmeans
tarang-jain Feb 16, 2026
6246289
fix python tests
tarang-jain Feb 16, 2026
22c10ec
Merge branch 'minibatch-kmeans' of https://github.com/tarang-jain/cuv…
tarang-jain Feb 16, 2026
b580760
address sklearn inconsistency
tarang-jain Feb 17, 2026
491c900
fix call to finalize_centroids
tarang-jain Feb 17, 2026
6886fb7
bug fixes and python tests
tarang-jain Feb 17, 2026
8e45bcd
Merge branch 'main' into minibatch-kmeans
tarang-jain Feb 17, 2026
567a01a
Merge branch 'main' of https://github.com/rapidsai/cuvs into minibatc…
tarang-jain Feb 18, 2026
ab366f5
add early stopping criteria from sklearn
tarang-jain Feb 19, 2026
dbfd1a8
fixes
tarang-jain Feb 20, 2026
e76624f
fix test by normalizing data
tarang-jain Feb 24, 2026
8d629f1
Merge branch 'main' into minibatch-kmeans
tarang-jain Feb 24, 2026
3961187
Merge branch 'main' of https://github.com/rapidsai/cuvs into minibatc…
tarang-jain Feb 25, 2026
296942d
Merge branch 'batched-kmeans' of https://github.com/tarang-jain/cuvs …
tarang-jain Feb 25, 2026
aacd543
rejection sampling
tarang-jain Feb 25, 2026
d746af2
style
tarang-jain Feb 25, 2026
6d01aed
update test with inertia check
tarang-jain Feb 25, 2026
2f1189f
fix style
tarang-jain Feb 25, 2026
0862da6
add reassignment; update minibatch params struct
tarang-jain Feb 25, 2026
8b448f6
style
tarang-jain Feb 25, 2026
8225e15
fix merge conflict
tarang-jain Feb 25, 2026
7c3965c
simplify minibatch update step
tarang-jain Feb 25, 2026
c490208
fix oom
tarang-jain Feb 26, 2026
b09611a
update tests
tarang-jain Feb 26, 2026
d404255
Merge branch 'main' into minibatch-kmeans
tarang-jain Feb 27, 2026
350ee82
update n_init use
tarang-jain Mar 2, 2026
13584b8
Merge branch 'minibatch-kmeans' of https://github.com/tarang-jain/cuv…
tarang-jain Mar 2, 2026
46d9754
Merge branch 'main' into minibatch-kmeans
tarang-jain Mar 2, 2026
63a34a3
abstract away commonalities into helpers
tarang-jain Mar 2, 2026
de9206a
Merge branch 'minibatch-kmeans' of https://github.com/tarang-jain/cuv…
tarang-jain Mar 2, 2026
de34c93
fix compilation errors
tarang-jain Mar 2, 2026
29a2358
fix bug, add cpp tests
tarang-jain Mar 3, 2026
4f119ba
style
tarang-jain Mar 3, 2026
bf5726b
make cpp tests more rigorous
tarang-jain Mar 3, 2026
74ec728
style
tarang-jain Mar 3, 2026
568904d
fix learning rate bug
tarang-jain Mar 3, 2026
48a776b
revert
tarang-jain Mar 3, 2026
b2c8a65
add sample weights
tarang-jain Mar 4, 2026
ae2688c
Merge branch 'main' into h-inertia
tarang-jain Mar 4, 2026
4e8f2e4
update impl
tarang-jain Mar 5, 2026
7cc90cb
Merge branch 'h-inertia' of https://github.com/tarang-jain/cuvs into …
tarang-jain Mar 5, 2026
d6f4524
fix min_cluster_dist
tarang-jain Mar 5, 2026
1fa9013
update instantiations
tarang-jain Mar 5, 2026
6752349
Merge branch 'main' into h-inertia
tarang-jain Mar 5, 2026
4ccce83
fix all the docs
tarang-jain Mar 5, 2026
18de062
Merge branch 'h-inertia' of https://github.com/tarang-jain/cuvs into …
tarang-jain Mar 5, 2026
c3ca46b
style
tarang-jain Mar 5, 2026
02e378b
Merge branch 'h-inertia' of https://github.com/tarang-jain/cuvs into …
tarang-jain Mar 5, 2026
7d6bed8
rm compute_inertia
tarang-jain Mar 5, 2026
6d072f6
fix compute_batched_host_inertia
tarang-jain Mar 5, 2026
55667c9
Merge branch 'main' into minibatch-kmeans
tarang-jain Mar 5, 2026
82c7095
Merge branch 'main' into h-inertia
tarang-jain Mar 5, 2026
30f5ac4
fix style
tarang-jain Mar 5, 2026
705339d
Merge branch 'minibatch-kmeans' of https://github.com/tarang-jain/cuv…
tarang-jain Mar 5, 2026
aa9a9e7
rm minibatch
tarang-jain Mar 6, 2026
64b0584
rm extra file
tarang-jain Mar 6, 2026
ec48753
fix header includes
tarang-jain Mar 6, 2026
738eea7
address pr reviews
tarang-jain Mar 6, 2026
d629ca8
fix python tests, style
tarang-jain Mar 6, 2026
c8ac477
fix style
tarang-jain Mar 6, 2026
c1482df
Merge branch 'main' of https://github.com/rapidsai/cuvs into batched-…
tarang-jain Mar 6, 2026
13b4084
rm extra c helpers
tarang-jain Mar 6, 2026
0bb59a9
add eof
tarang-jain Mar 6, 2026
fa77151
fix docs
tarang-jain Mar 6, 2026
068d66f
address pr reviews; update inertia comp
tarang-jain Mar 9, 2026
8e0be37
revert abi change
tarang-jain Mar 9, 2026
c478f87
Merge branch 'main' into batched-kmeans
tarang-jain Mar 9, 2026
0a7f026
rm null dataset norm
tarang-jain Mar 9, 2026
ee3ce56
Merge branch 'h-inertia' of https://github.com/tarang-jain/cuvs into …
tarang-jain Mar 9, 2026
10be5c4
add warning when T and MathT are different
tarang-jain Mar 9, 2026
cf2708b
Merge branch 'main' into h-inertia
tarang-jain Mar 9, 2026
6a2a681
use raft::mul_op
tarang-jain Mar 9, 2026
4936382
Merge branch 'h-inertia' of https://github.com/tarang-jain/cuvs into …
tarang-jain Mar 9, 2026
e7a9b3a
put batch size at the end of the c header struct
tarang-jain Mar 9, 2026
7d8365b
Merge branch 'batched-kmeans' of https://github.com/tarang-jain/cuvs …
tarang-jain Mar 9, 2026
61c881c
Merge branch 'h-inertia' of https://github.com/tarang-jain/cuvs into …
tarang-jain Mar 9, 2026
6439d72
fix c and python compilation
tarang-jain Mar 9, 2026
6a094cb
add docs
tarang-jain Mar 9, 2026
2d14e9a
style
tarang-jain Mar 10, 2026
fc54020
Merge branch 'main' into batched-kmeans
tarang-jain Mar 11, 2026
37ce404
correct treatment for optional
tarang-jain Mar 12, 2026
1b3c341
fill outside loop
tarang-jain Mar 12, 2026
914628c
add warning
tarang-jain Mar 12, 2026
f1d3f8a
python docs
tarang-jain Mar 12, 2026
edad9bb
merge upstream
tarang-jain Mar 12, 2026
cb4e8b3
finish merge
tarang-jain Mar 12, 2026
552f736
fix compilation warning
tarang-jain Mar 12, 2026
9c7bbc8
Merge branch 'release/26.04' into batched-kmeans
tarang-jain Mar 12, 2026
7f6e615
optimizations and cleanups
tarang-jain Mar 13, 2026
185a52f
Merge branch 'batched-kmeans' of https://github.com/tarang-jain/cuvs …
tarang-jain Mar 13, 2026
0719636
fix compilation
tarang-jain Mar 13, 2026
1a7b644
style
tarang-jain Mar 13, 2026
d4c53ee
address python reviews
tarang-jain Mar 13, 2026
8bec6d5
style
tarang-jain Mar 13, 2026
e920ca0
rm batch_size struct
tarang-jain Mar 13, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions c/include/cuvs/cluster/kmeans.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ typedef enum {
Array = 2
} cuvsKMeansInitMethod;


/**
* @brief Hyper-parameters for the kmeans algorithm
*/
Expand Down Expand Up @@ -90,6 +91,7 @@ struct cuvsKMeansParams {
*/
int batch_centroids;

/** Check inertia during iterations for early convergence. */
bool inertia_check;

/**
Expand All @@ -101,6 +103,12 @@ struct cuvsKMeansParams {
* For hierarchical k-means , defines the number of training iterations
*/
int hierarchical_n_iters;

/**
* Number of samples to process per GPU batch for the batched (host-data) API.
* When set to 0, defaults to n_samples (process all at once).
*/
int64_t batch_size;
};

typedef struct cuvsKMeansParams* cuvsKMeansParams_t;
Expand Down Expand Up @@ -142,18 +150,24 @@ typedef enum { CUVS_KMEANS_TYPE_KMEANS = 0, CUVS_KMEANS_TYPE_KMEANS_BALANCED = 1
* clusters are reinitialized by choosing new centroids with
* k-means++ algorithm.
*
* X may reside on either host (CPU) or device (GPU) memory.
* When X is on the host the data is streamed to the GPU in
* batches controlled by params->batch_size.
*
* @param[in] res opaque C handle
* @param[in] params Parameters for KMeans model.
* @param[in] X Training instances to cluster. The data must
* be in row-major format.
* be in row-major format. May be on host or
* device memory.
* [dim = n_samples x n_features]
* @param[in] sample_weight Optional weights for each observation in X.
* Must be on the same memory space as X.
* [len = n_samples]
* @param[inout] centroids [in] When init is InitMethod::Array, use
* centroids as the initial cluster centers.
* [out] The generated centroids from the
* kmeans algorithm are stored at the address
* pointed by 'centroids'.
* pointed by 'centroids'. Must be on device.
* [dim = n_clusters x n_features]
* @param[out] inertia Sum of squared distances of samples to their
* closest cluster center.
Expand Down Expand Up @@ -212,6 +226,7 @@ cuvsError_t cuvsKMeansClusterCost(cuvsResources_t res,
DLManagedTensor* X,
DLManagedTensor* centroids,
double* cost);

/**
* @}
*/
Expand Down
88 changes: 68 additions & 20 deletions c/src/cluster/kmeans.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,18 @@ namespace {

cuvs::cluster::kmeans::params convert_params(const cuvsKMeansParams& params)
{
auto kmeans_params = cuvs::cluster::kmeans::params();
kmeans_params.metric = static_cast<cuvs::distance::DistanceType>(params.metric);
kmeans_params.init = static_cast<cuvs::cluster::kmeans::params::InitMethod>(params.init);
kmeans_params.n_clusters = params.n_clusters;
kmeans_params.max_iter = params.max_iter;
kmeans_params.tol = params.tol;
auto kmeans_params = cuvs::cluster::kmeans::params();
kmeans_params.metric = static_cast<cuvs::distance::DistanceType>(params.metric);
kmeans_params.init = static_cast<cuvs::cluster::kmeans::params::InitMethod>(params.init);
kmeans_params.n_clusters = params.n_clusters;
kmeans_params.max_iter = params.max_iter;
kmeans_params.tol = params.tol;
kmeans_params.n_init = params.n_init;
kmeans_params.oversampling_factor = params.oversampling_factor;
kmeans_params.batch_samples = params.batch_samples;
kmeans_params.batch_centroids = params.batch_centroids;
kmeans_params.inertia_check = params.inertia_check;
kmeans_params.batch_size = params.batch_size;
return kmeans_params;
}

Expand All @@ -49,8 +51,53 @@ void _fit(cuvsResources_t res,
{
auto X = X_tensor->dl_tensor;
auto res_ptr = reinterpret_cast<raft::resources*>(res);
bool is_host = (X.device.device_type == kDLCPU);

if (cuvs::core::is_dlpack_device_compatible(X)) {
if (is_host) {
auto n_samples = static_cast<IdxT>(X.shape[0]);
auto n_features = static_cast<IdxT>(X.shape[1]);

if (params.hierarchical) {
RAFT_FAIL("hierarchical kmeans is not supported with host data");
}

auto centroids_dl = centroids_tensor->dl_tensor;
if (!cuvs::core::is_dlpack_device_compatible(centroids_dl)) {
RAFT_FAIL("centroids must be on device memory");
}

auto X_view = raft::make_host_matrix_view<T const, IdxT>(
reinterpret_cast<T const*>(X.data), n_samples, n_features);
auto centroids_view =
cuvs::core::from_dlpack<raft::device_matrix_view<T, IdxT, raft::row_major>>(
centroids_tensor);

std::optional<raft::host_vector_view<T const, IdxT>> sample_weight;
if (sample_weight_tensor != NULL) {
auto sw = sample_weight_tensor->dl_tensor;
if (sw.device.device_type != kDLCPU) {
RAFT_FAIL("sample_weight must be on host memory when X is on host");
}
sample_weight = raft::make_host_vector_view<T const, IdxT>(
reinterpret_cast<T const*>(sw.data), n_samples);
}

T inertia_temp;
IdxT n_iter_temp;

auto kmeans_params = convert_params(params);
cuvs::cluster::kmeans::fit(*res_ptr,
kmeans_params,
X_view,
sample_weight,
centroids_view,
raft::make_host_scalar_view<T>(&inertia_temp),
raft::make_host_scalar_view<IdxT>(&n_iter_temp));

*inertia = inertia_temp;
*n_iter = n_iter_temp;

} else {
using const_mdspan_type = raft::device_matrix_view<T const, IdxT, raft::row_major>;
using mdspan_type = raft::device_matrix_view<T, IdxT, raft::row_major>;

Expand Down Expand Up @@ -90,8 +137,6 @@ void _fit(cuvsResources_t res,
*inertia = inertia_temp;
*n_iter = n_iter_temp;
}
} else {
RAFT_FAIL("X dataset must be accessible on device memory");
}
}

Expand Down Expand Up @@ -182,17 +227,20 @@ extern "C" cuvsError_t cuvsKMeansParamsCreate(cuvsKMeansParams_t* params)
return cuvs::core::translate_exceptions([=] {
cuvs::cluster::kmeans::params cpp_params;
cuvs::cluster::kmeans::balanced_params cpp_balanced_params;
*params =
new cuvsKMeansParams{.metric = static_cast<cuvsDistanceType>(cpp_params.metric),
.n_clusters = cpp_params.n_clusters,
.init = static_cast<cuvsKMeansInitMethod>(cpp_params.init),
.max_iter = cpp_params.max_iter,
.tol = cpp_params.tol,
.oversampling_factor = cpp_params.oversampling_factor,
.batch_samples = cpp_params.batch_samples,
.inertia_check = cpp_params.inertia_check,
.hierarchical = false,
.hierarchical_n_iters = static_cast<int>(cpp_balanced_params.n_iters)};
*params = new cuvsKMeansParams{
.metric = static_cast<cuvsDistanceType>(cpp_params.metric),
.n_clusters = cpp_params.n_clusters,
.init = static_cast<cuvsKMeansInitMethod>(cpp_params.init),
.max_iter = cpp_params.max_iter,
.tol = cpp_params.tol,
.n_init = cpp_params.n_init,
.oversampling_factor = cpp_params.oversampling_factor,
.batch_samples = cpp_params.batch_samples,
.batch_centroids = cpp_params.batch_centroids,
.inertia_check = cpp_params.inertia_check,
.hierarchical = false,
.hierarchical_n_iters = static_cast<int>(cpp_balanced_params.n_iters),
.batch_size = cpp_params.batch_size};
});
}

Expand Down
190 changes: 189 additions & 1 deletion cpp/include/cuvs/cluster/kmeans.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,15 +100,31 @@ struct params : base_params {
* useful to optimize/control the memory footprint
* Default tile is [batch_samples x n_clusters] i.e. when batch_centroids is 0
* then don't tile the centroids
*
* NB: These parameters are unrelated to batch_size, which controls how many
* samples to transfer from host to device per batch when processing out-of-core
* data.
*/
int batch_samples = 1 << 15;

/**
* if 0 then batch_centroids = n_clusters
*/
int batch_centroids = 0; //
int batch_centroids = 0;

/**
* If true, check inertia during iterations for early convergence.
*/
bool inertia_check = false;

/**
* Number of samples to process per GPU batch when fitting with host data.
* When set to 0, defaults to n_samples (process all at once).
* Only used by the batched (host-data) code path and ignored by device-data
* overloads.
* Default: 0 (process all data at once).
*/
int64_t batch_size = 0;
};

/**
Expand Down Expand Up @@ -141,6 +157,178 @@ enum class kmeans_type { KMeans = 0, KMeansBalanced = 1 };
* @{
*/

/**
* @brief Find clusters with k-means algorithm using batched processing of host data.
*
* This overload supports out-of-core computation where the dataset resides
* on the host. Data is processed in GPU-sized batches, streaming from host to device.
* The batch size is controlled by params.batch_size.
*
* @code{.cpp}
* #include <raft/core/resources.hpp>
* #include <cuvs/cluster/kmeans.hpp>
* using namespace cuvs::cluster;
* ...
* raft::resources handle;
* cuvs::cluster::kmeans::params params;
* params.n_clusters = 100;
* params.batch_size = 100000;
* int n_features = 15;
* float inertia;
* int n_iter;
*
* // Data on host
* std::vector<float> h_X(n_samples * n_features);
* auto X = raft::make_host_matrix_view<const float, int>(h_X.data(), n_samples, n_features);
*
* // Centroids on device
* auto centroids = raft::make_device_matrix<float, int>(handle, params.n_clusters, n_features);
*
* kmeans::fit(handle,
* params,
* X,
* std::nullopt,
* centroids.view(),
* raft::make_host_scalar_view(&inertia),
* raft::make_host_scalar_view(&n_iter));
* @endcode
*
* @param[in] handle The raft handle.
* @param[in] params Parameters for KMeans model. Batch size is read from
* params.batch_size.
* @param[in] X Training instances on HOST memory. The data must
* be in row-major format.
* [dim = n_samples x n_features]
* @param[in] sample_weight Optional weights for each observation in X (on host).
* [len = n_samples]
* @param[inout] centroids [in] When init is InitMethod::Array, use
* centroids as the initial cluster centers.
* [out] The generated centroids from the
* kmeans algorithm are stored at the address
* pointed by 'centroids'.
* [dim = n_clusters x n_features]
* @param[out] inertia Sum of squared distances of samples to their
* closest cluster center.
* @param[out] n_iter Number of iterations run.
*/
void fit(raft::resources const& handle,
const cuvs::cluster::kmeans::params& params,
raft::host_matrix_view<const float, int> X,
std::optional<raft::host_vector_view<const float, int>> sample_weight,
raft::device_matrix_view<float, int> centroids,
raft::host_scalar_view<float> inertia,
raft::host_scalar_view<int> n_iter);

/**
* @brief Find clusters with k-means algorithm using batched processing of host data.
*/
void fit(raft::resources const& handle,
const cuvs::cluster::kmeans::params& params,
raft::host_matrix_view<const float, int64_t> X,
std::optional<raft::host_vector_view<const float, int64_t>> sample_weight,
raft::device_matrix_view<float, int64_t> centroids,
raft::host_scalar_view<float> inertia,
raft::host_scalar_view<int64_t> n_iter);

/**
* @brief Find clusters with k-means algorithm using batched processing of host data.
*/
void fit(raft::resources const& handle,
const cuvs::cluster::kmeans::params& params,
raft::host_matrix_view<const double, int> X,
std::optional<raft::host_vector_view<const double, int>> sample_weight,
raft::device_matrix_view<double, int> centroids,
raft::host_scalar_view<double> inertia,
raft::host_scalar_view<int> n_iter);

/**
* @brief Find clusters with k-means algorithm using batched processing of host data.
*/
void fit(raft::resources const& handle,
const cuvs::cluster::kmeans::params& params,
raft::host_matrix_view<const double, int64_t> X,
std::optional<raft::host_vector_view<const double, int64_t>> sample_weight,
raft::device_matrix_view<double, int64_t> centroids,
raft::host_scalar_view<double> inertia,
raft::host_scalar_view<int64_t> n_iter);

/**
* @defgroup predict_host K-Means Predict (host data)
* @{
*/

/**
* @brief Predict cluster labels for host data using batched processing.
*
* Streams data from host to GPU in batches, assigns each sample to its nearest
* centroid, and writes labels back to host memory.
* The batch size is controlled by params.batch_size.
*
* @param[in] handle The raft handle.
* @param[in] params Parameters for KMeans model.
* @param[in] X Input samples on HOST memory. [dim = n_samples x n_features]
* @param[in] sample_weight Optional weights for each observation (on host).
* @param[in] centroids Cluster centers on device. [dim = n_clusters x n_features]
* @param[out] labels Predicted cluster labels on HOST memory. [dim = n_samples]
* @param[in] normalize_weight Whether to normalize sample weights.
* @param[out] inertia Sum of squared distances to nearest centroid.
*/
void predict(raft::resources const& handle,
const cuvs::cluster::kmeans::params& params,
raft::host_matrix_view<const float, int64_t> X,
std::optional<raft::host_vector_view<const float, int64_t>> sample_weight,
raft::device_matrix_view<const float, int64_t> centroids,
raft::host_vector_view<int64_t, int64_t> labels,
bool normalize_weight,
raft::host_scalar_view<float> inertia);

/**
* @brief Predict cluster labels for host data using batched processing (double).
*/
void predict(raft::resources const& handle,
const cuvs::cluster::kmeans::params& params,
raft::host_matrix_view<const double, int64_t> X,
std::optional<raft::host_vector_view<const double, int64_t>> sample_weight,
raft::device_matrix_view<const double, int64_t> centroids,
raft::host_vector_view<int64_t, int64_t> labels,
bool normalize_weight,
raft::host_scalar_view<double> inertia);

/**
* @brief Fit k-means and predict cluster labels using batched processing of host data.
*
* Combines batched fit and batched predict into a single call.
*
* @param[in] handle The raft handle.
* @param[in] params Parameters for KMeans model.
* @param[in] X Training instances on HOST memory. [dim = n_samples x n_features]
* @param[in] sample_weight Optional weights for each observation (on host).
* @param[inout] centroids Cluster centers on device. [dim = n_clusters x n_features]
* @param[out] labels Predicted cluster labels on HOST memory. [dim = n_samples]
* @param[out] inertia Sum of squared distances to nearest centroid.
* @param[out] n_iter Number of iterations run.
*/
void fit_predict(raft::resources const& handle,
const cuvs::cluster::kmeans::params& params,
raft::host_matrix_view<const float, int64_t> X,
std::optional<raft::host_vector_view<const float, int64_t>> sample_weight,
raft::device_matrix_view<float, int64_t> centroids,
raft::host_vector_view<int64_t, int64_t> labels,
raft::host_scalar_view<float> inertia,
raft::host_scalar_view<int64_t> n_iter);

/**
* @brief Fit k-means and predict cluster labels using batched processing (double).
*/
void fit_predict(raft::resources const& handle,
const cuvs::cluster::kmeans::params& params,
raft::host_matrix_view<const double, int64_t> X,
std::optional<raft::host_vector_view<const double, int64_t>> sample_weight,
raft::device_matrix_view<double, int64_t> centroids,
raft::host_vector_view<int64_t, int64_t> labels,
raft::host_scalar_view<double> inertia,
raft::host_scalar_view<int64_t> n_iter);

/**
* @brief Find clusters with k-means algorithm.
* Initial centroids are chosen with k-means++ algorithm. Empty
Expand Down
Loading
Loading