diff --git a/c/include/cuvs/cluster/kmeans.h b/c/include/cuvs/cluster/kmeans.h index 0bb9591f63..7296c1109f 100644 --- a/c/include/cuvs/cluster/kmeans.h +++ b/c/include/cuvs/cluster/kmeans.h @@ -36,6 +36,7 @@ typedef enum { Array = 2 } cuvsKMeansInitMethod; + /** * @brief Hyper-parameters for the kmeans algorithm */ @@ -90,6 +91,7 @@ struct cuvsKMeansParams { */ int batch_centroids; + /** Check inertia during iterations for early convergence. */ bool inertia_check; /** @@ -101,6 +103,12 @@ struct cuvsKMeansParams { * For hierarchical k-means , defines the number of training iterations */ int hierarchical_n_iters; + + /** + * Number of samples to process per GPU batch for the batched (host-data) API. + * When set to 0, defaults to n_samples (process all at once). + */ + int64_t batch_size; }; typedef struct cuvsKMeansParams* cuvsKMeansParams_t; @@ -142,18 +150,24 @@ typedef enum { CUVS_KMEANS_TYPE_KMEANS = 0, CUVS_KMEANS_TYPE_KMEANS_BALANCED = 1 * clusters are reinitialized by choosing new centroids with * k-means++ algorithm. * + * X may reside on either host (CPU) or device (GPU) memory. + * When X is on the host the data is streamed to the GPU in + * batches controlled by params->batch_size. + * * @param[in] res opaque C handle * @param[in] params Parameters for KMeans model. * @param[in] X Training instances to cluster. The data must - * be in row-major format. + * be in row-major format. May be on host or + * device memory. * [dim = n_samples x n_features] * @param[in] sample_weight Optional weights for each observation in X. + * Must be on the same memory space as X. * [len = n_samples] * @param[inout] centroids [in] When init is InitMethod::Array, use * centroids as the initial cluster centers. * [out] The generated centroids from the * kmeans algorithm are stored at the address - * pointed by 'centroids'. + * pointed by 'centroids'. Must be on device. * [dim = n_clusters x n_features] * @param[out] inertia Sum of squared distances of samples to their * closest cluster center. @@ -212,6 +226,7 @@ cuvsError_t cuvsKMeansClusterCost(cuvsResources_t res, DLManagedTensor* X, DLManagedTensor* centroids, double* cost); + /** * @} */ diff --git a/c/src/cluster/kmeans.cpp b/c/src/cluster/kmeans.cpp index 57b6282c20..af3b49a5a3 100644 --- a/c/src/cluster/kmeans.cpp +++ b/c/src/cluster/kmeans.cpp @@ -17,16 +17,18 @@ namespace { cuvs::cluster::kmeans::params convert_params(const cuvsKMeansParams& params) { - auto kmeans_params = cuvs::cluster::kmeans::params(); - kmeans_params.metric = static_cast(params.metric); - kmeans_params.init = static_cast(params.init); - kmeans_params.n_clusters = params.n_clusters; - kmeans_params.max_iter = params.max_iter; - kmeans_params.tol = params.tol; + auto kmeans_params = cuvs::cluster::kmeans::params(); + kmeans_params.metric = static_cast(params.metric); + kmeans_params.init = static_cast(params.init); + kmeans_params.n_clusters = params.n_clusters; + kmeans_params.max_iter = params.max_iter; + kmeans_params.tol = params.tol; + kmeans_params.n_init = params.n_init; kmeans_params.oversampling_factor = params.oversampling_factor; kmeans_params.batch_samples = params.batch_samples; kmeans_params.batch_centroids = params.batch_centroids; kmeans_params.inertia_check = params.inertia_check; + kmeans_params.batch_size = params.batch_size; return kmeans_params; } @@ -49,8 +51,53 @@ void _fit(cuvsResources_t res, { auto X = X_tensor->dl_tensor; auto res_ptr = reinterpret_cast(res); + bool is_host = (X.device.device_type == kDLCPU); - if (cuvs::core::is_dlpack_device_compatible(X)) { + if (is_host) { + auto n_samples = static_cast(X.shape[0]); + auto n_features = static_cast(X.shape[1]); + + if (params.hierarchical) { + RAFT_FAIL("hierarchical kmeans is not supported with host data"); + } + + auto centroids_dl = centroids_tensor->dl_tensor; + if (!cuvs::core::is_dlpack_device_compatible(centroids_dl)) { + RAFT_FAIL("centroids must be on device memory"); + } + + auto X_view = raft::make_host_matrix_view( + reinterpret_cast(X.data), n_samples, n_features); + auto centroids_view = + cuvs::core::from_dlpack>( + centroids_tensor); + + std::optional> sample_weight; + if (sample_weight_tensor != NULL) { + auto sw = sample_weight_tensor->dl_tensor; + if (sw.device.device_type != kDLCPU) { + RAFT_FAIL("sample_weight must be on host memory when X is on host"); + } + sample_weight = raft::make_host_vector_view( + reinterpret_cast(sw.data), n_samples); + } + + T inertia_temp; + IdxT n_iter_temp; + + auto kmeans_params = convert_params(params); + cuvs::cluster::kmeans::fit(*res_ptr, + kmeans_params, + X_view, + sample_weight, + centroids_view, + raft::make_host_scalar_view(&inertia_temp), + raft::make_host_scalar_view(&n_iter_temp)); + + *inertia = inertia_temp; + *n_iter = n_iter_temp; + + } else { using const_mdspan_type = raft::device_matrix_view; using mdspan_type = raft::device_matrix_view; @@ -90,8 +137,6 @@ void _fit(cuvsResources_t res, *inertia = inertia_temp; *n_iter = n_iter_temp; } - } else { - RAFT_FAIL("X dataset must be accessible on device memory"); } } @@ -182,17 +227,20 @@ extern "C" cuvsError_t cuvsKMeansParamsCreate(cuvsKMeansParams_t* params) return cuvs::core::translate_exceptions([=] { cuvs::cluster::kmeans::params cpp_params; cuvs::cluster::kmeans::balanced_params cpp_balanced_params; - *params = - new cuvsKMeansParams{.metric = static_cast(cpp_params.metric), - .n_clusters = cpp_params.n_clusters, - .init = static_cast(cpp_params.init), - .max_iter = cpp_params.max_iter, - .tol = cpp_params.tol, - .oversampling_factor = cpp_params.oversampling_factor, - .batch_samples = cpp_params.batch_samples, - .inertia_check = cpp_params.inertia_check, - .hierarchical = false, - .hierarchical_n_iters = static_cast(cpp_balanced_params.n_iters)}; + *params = new cuvsKMeansParams{ + .metric = static_cast(cpp_params.metric), + .n_clusters = cpp_params.n_clusters, + .init = static_cast(cpp_params.init), + .max_iter = cpp_params.max_iter, + .tol = cpp_params.tol, + .n_init = cpp_params.n_init, + .oversampling_factor = cpp_params.oversampling_factor, + .batch_samples = cpp_params.batch_samples, + .batch_centroids = cpp_params.batch_centroids, + .inertia_check = cpp_params.inertia_check, + .hierarchical = false, + .hierarchical_n_iters = static_cast(cpp_balanced_params.n_iters), + .batch_size = cpp_params.batch_size}; }); } diff --git a/cpp/include/cuvs/cluster/kmeans.hpp b/cpp/include/cuvs/cluster/kmeans.hpp index a839cecf56..7d3629e760 100644 --- a/cpp/include/cuvs/cluster/kmeans.hpp +++ b/cpp/include/cuvs/cluster/kmeans.hpp @@ -100,15 +100,31 @@ struct params : base_params { * useful to optimize/control the memory footprint * Default tile is [batch_samples x n_clusters] i.e. when batch_centroids is 0 * then don't tile the centroids + * + * NB: These parameters are unrelated to batch_size, which controls how many + * samples to transfer from host to device per batch when processing out-of-core + * data. */ int batch_samples = 1 << 15; /** * if 0 then batch_centroids = n_clusters */ - int batch_centroids = 0; // + int batch_centroids = 0; + /** + * If true, check inertia during iterations for early convergence. + */ bool inertia_check = false; + + /** + * Number of samples to process per GPU batch when fitting with host data. + * When set to 0, defaults to n_samples (process all at once). + * Only used by the batched (host-data) code path and ignored by device-data + * overloads. + * Default: 0 (process all data at once). + */ + int64_t batch_size = 0; }; /** @@ -141,6 +157,178 @@ enum class kmeans_type { KMeans = 0, KMeansBalanced = 1 }; * @{ */ +/** + * @brief Find clusters with k-means algorithm using batched processing of host data. + * + * This overload supports out-of-core computation where the dataset resides + * on the host. Data is processed in GPU-sized batches, streaming from host to device. + * The batch size is controlled by params.batch_size. + * + * @code{.cpp} + * #include + * #include + * using namespace cuvs::cluster; + * ... + * raft::resources handle; + * cuvs::cluster::kmeans::params params; + * params.n_clusters = 100; + * params.batch_size = 100000; + * int n_features = 15; + * float inertia; + * int n_iter; + * + * // Data on host + * std::vector h_X(n_samples * n_features); + * auto X = raft::make_host_matrix_view(h_X.data(), n_samples, n_features); + * + * // Centroids on device + * auto centroids = raft::make_device_matrix(handle, params.n_clusters, n_features); + * + * kmeans::fit(handle, + * params, + * X, + * std::nullopt, + * centroids.view(), + * raft::make_host_scalar_view(&inertia), + * raft::make_host_scalar_view(&n_iter)); + * @endcode + * + * @param[in] handle The raft handle. + * @param[in] params Parameters for KMeans model. Batch size is read from + * params.batch_size. + * @param[in] X Training instances on HOST memory. The data must + * be in row-major format. + * [dim = n_samples x n_features] + * @param[in] sample_weight Optional weights for each observation in X (on host). + * [len = n_samples] + * @param[inout] centroids [in] When init is InitMethod::Array, use + * centroids as the initial cluster centers. + * [out] The generated centroids from the + * kmeans algorithm are stored at the address + * pointed by 'centroids'. + * [dim = n_clusters x n_features] + * @param[out] inertia Sum of squared distances of samples to their + * closest cluster center. + * @param[out] n_iter Number of iterations run. + */ +void fit(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter); + +/** + * @brief Find clusters with k-means algorithm using batched processing of host data. + */ +void fit(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter); + +/** + * @brief Find clusters with k-means algorithm using batched processing of host data. + */ +void fit(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter); + +/** + * @brief Find clusters with k-means algorithm using batched processing of host data. + */ +void fit(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter); + +/** + * @defgroup predict_host K-Means Predict (host data) + * @{ + */ + +/** + * @brief Predict cluster labels for host data using batched processing. + * + * Streams data from host to GPU in batches, assigns each sample to its nearest + * centroid, and writes labels back to host memory. + * The batch size is controlled by params.batch_size. + * + * @param[in] handle The raft handle. + * @param[in] params Parameters for KMeans model. + * @param[in] X Input samples on HOST memory. [dim = n_samples x n_features] + * @param[in] sample_weight Optional weights for each observation (on host). + * @param[in] centroids Cluster centers on device. [dim = n_clusters x n_features] + * @param[out] labels Predicted cluster labels on HOST memory. [dim = n_samples] + * @param[in] normalize_weight Whether to normalize sample weights. + * @param[out] inertia Sum of squared distances to nearest centroid. + */ +void predict(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_vector_view labels, + bool normalize_weight, + raft::host_scalar_view inertia); + +/** + * @brief Predict cluster labels for host data using batched processing (double). + */ +void predict(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_vector_view labels, + bool normalize_weight, + raft::host_scalar_view inertia); + +/** + * @brief Fit k-means and predict cluster labels using batched processing of host data. + * + * Combines batched fit and batched predict into a single call. + * + * @param[in] handle The raft handle. + * @param[in] params Parameters for KMeans model. + * @param[in] X Training instances on HOST memory. [dim = n_samples x n_features] + * @param[in] sample_weight Optional weights for each observation (on host). + * @param[inout] centroids Cluster centers on device. [dim = n_clusters x n_features] + * @param[out] labels Predicted cluster labels on HOST memory. [dim = n_samples] + * @param[out] inertia Sum of squared distances to nearest centroid. + * @param[out] n_iter Number of iterations run. + */ +void fit_predict(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_vector_view labels, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter); + +/** + * @brief Fit k-means and predict cluster labels using batched processing (double). + */ +void fit_predict(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_vector_view labels, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter); + /** * @brief Find clusters with k-means algorithm. * Initial centroids are chosen with k-means++ algorithm. Empty diff --git a/cpp/src/cluster/detail/kmeans.cuh b/cpp/src/cluster/detail/kmeans.cuh index 6e7bff8450..7d876f2e05 100644 --- a/cpp/src/cluster/detail/kmeans.cuh +++ b/cpp/src/cluster/detail/kmeans.cuh @@ -287,61 +287,23 @@ void update_centroids(raft::resources const& handle, rmm::device_uvector& workspace) { auto n_clusters = centroids.extent(0); - auto n_samples = X.extent(0); - - workspace.resize(n_samples, raft::resource::get_cuda_stream(handle)); - - // Calculates weighted sum of all the samples assigned to cluster-i and stores the - // result in new_centroids[i] - raft::linalg::reduce_rows_by_key((DataT*)X.data_handle(), - X.extent(1), - cluster_labels, - sample_weights.data_handle(), - workspace.data(), - X.extent(0), - X.extent(1), - n_clusters, - new_centroids.data_handle(), - raft::resource::get_cuda_stream(handle)); - - // Reduce weights by key to compute weight in each cluster - raft::linalg::reduce_cols_by_key(sample_weights.data_handle(), - cluster_labels, - weight_per_cluster.data_handle(), - (IndexT)1, - (IndexT)sample_weights.extent(0), - (IndexT)n_clusters, - raft::resource::get_cuda_stream(handle)); - - // Computes new_centroids[i] = new_centroids[i]/weight_per_cluster[i] where - // new_centroids[n_clusters x n_features] - 2D array, new_centroids[i] has sum of all the - // samples assigned to cluster-i - // weight_per_cluster[n_clusters] - 1D array, weight_per_cluster[i] contains sum of weights in - // cluster-i. - // Note - when weight_per_cluster[i] is 0, new_centroids[i] is reset to 0 - raft::linalg::matrix_vector_op( - handle, - raft::make_const_mdspan(new_centroids), - raft::make_const_mdspan(weight_per_cluster), - new_centroids, - raft::div_checkzero_op{}); - - // copy centroids[i] to new_centroids[i] when weight_per_cluster[i] is 0 - cub::ArgIndexInputIterator itr_wt(weight_per_cluster.data_handle()); - raft::matrix::gather_if( - const_cast(centroids.data_handle()), - static_cast(centroids.extent(1)), - static_cast(centroids.extent(0)), - itr_wt, - itr_wt, - static_cast(weight_per_cluster.size()), - new_centroids.data_handle(), - [=] __device__(raft::KeyValuePair map) { // predicate - // copy when the sum of weights in the cluster is 0 - return map.value == 0; - }, - raft::key_op{}, - raft::resource::get_cuda_stream(handle)); + + // Compute weighted sums and counts per cluster + cuvs::cluster::kmeans::detail::compute_centroid_adjustments(handle, + X, + sample_weights, + cluster_labels, + static_cast(n_clusters), + new_centroids, + weight_per_cluster, + workspace); + + // Divide sums by counts to get new centroids; preserve old centroids for empty clusters + cuvs::cluster::kmeans::detail::finalize_centroids(handle, + raft::make_const_mdspan(new_centroids), + raft::make_const_mdspan(weight_per_cluster), + centroids, + new_centroids); } // TODO: Resizing is needed to use mdarray instead of rmm::device_uvector @@ -437,18 +399,13 @@ void kmeans_fit_main(raft::resources const& handle, newCentroids.view(), workspace); - // compute the squared norm between the newCentroids and the original - // centroids, destructor releases the resource - auto sqrdNorm = raft::make_device_scalar(handle, DataT(0)); - raft::linalg::mapThenSumReduce(sqrdNorm.data_handle(), - newCentroids.size(), - raft::sqdiff_op{}, - stream, - centroids.data_handle(), - newCentroids.data_handle()); - - DataT sqrdNormError = 0; - raft::copy(handle, raft::make_host_scalar_view(&sqrdNormError), sqrdNorm.view()); + // Compute how much centroids shifted + DataT sqrdNormError = + compute_centroid_shift(handle, + raft::make_device_matrix_view( + centroids.data_handle(), n_clusters, n_features), + raft::make_device_matrix_view( + newCentroids.data_handle(), n_clusters, n_features)); raft::copy(handle, raft::make_device_vector_view(centroidsRawData.data_handle(), newCentroids.size()), @@ -478,7 +435,6 @@ void kmeans_fit_main(raft::resources const& handle, priorClusteringCost = curClusteringCost; } - raft::resource::sync_stream(handle, stream); if (sqrdNormError < params.tol) done = true; if (done) { @@ -487,43 +443,12 @@ void kmeans_fit_main(raft::resources const& handle, } } - auto centroids = raft::make_device_matrix_view( - centroidsRawData.data_handle(), n_clusters, n_features); - - cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute( - handle, - X, - centroids, - minClusterAndDistance.view(), - l2normx_view, - L2NormBuf_OR_DistBuf, - params.metric, - params.batch_samples, - params.batch_centroids, - workspace); - - raft::linalg::map( - handle, - minClusterAndDistance.view(), - [=] __device__(const raft::KeyValuePair kvp, DataT wt) { - raft::KeyValuePair res; - res.value = kvp.value * wt; - res.key = kvp.key; - return res; - }, - raft::make_const_mdspan(minClusterAndDistance.view()), - raft::make_const_mdspan(weight)); - - // calculate cluster cost phi_x(C) - cuvs::cluster::kmeans::detail::computeClusterCost( - handle, - minClusterAndDistance.view(), - workspace, - raft::make_device_scalar_view(clusterCostD.data()), - raft::value_op{}, - raft::add_op{}); - - inertia[0] = clusterCostD.value(stream); + cuvs::cluster::kmeans::cluster_cost(handle, + X, + raft::make_device_matrix_view( + centroidsRawData.data_handle(), n_clusters, n_features), + inertia, + std::make_optional(weight)); RAFT_LOG_DEBUG("KMeans.fit: completed after %d iterations with %f inertia[0] ", n_iter[0] > params.max_iter ? n_iter[0] - 1 : n_iter[0], diff --git a/cpp/src/cluster/detail/kmeans_batched.cuh b/cpp/src/cluster/detail/kmeans_batched.cuh new file mode 100644 index 0000000000..b5899aaf89 --- /dev/null +++ b/cpp/src/cluster/detail/kmeans_batched.cuh @@ -0,0 +1,510 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ +#pragma once + +#include "kmeans.cuh" +#include "kmeans_common.cuh" + +#include "../../neighbors/detail/ann_utils.cuh" +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include + +#include +#include +#include +#include +#include +#include + +namespace cuvs::cluster::kmeans::detail { + +/** + * @brief Initialize centroids from host data + * + * @tparam T Input data type + * @tparam IdxT Index type + */ +template +void init_centroids_from_host_sample(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + IdxT batch_size, + raft::host_matrix_view X, + raft::device_matrix_view centroids, + rmm::device_uvector& workspace) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + auto n_samples = X.extent(0); + auto n_features = X.extent(1); + auto n_clusters = params.n_clusters; + + if (params.init == cuvs::cluster::kmeans::params::InitMethod::KMeansPlusPlus) { + IdxT init_sample_size = 3 * batch_size; + if (init_sample_size < n_clusters) { init_sample_size = 3 * n_clusters; } + init_sample_size = std::min(init_sample_size, n_samples); + + auto init_sample = raft::make_device_matrix(handle, init_sample_size, n_features); + raft::random::RngState random_state(params.rng_state.seed); + raft::matrix::sample_rows(handle, random_state, X, init_sample.view()); + + auto init_sample_view = raft::make_device_matrix_view( + init_sample.data_handle(), init_sample_size, n_features); + + if (params.oversampling_factor == 0) { + cuvs::cluster::kmeans::detail::kmeansPlusPlus( + handle, params, init_sample_view, centroids, workspace); + } else { + cuvs::cluster::kmeans::detail::initScalableKMeansPlusPlus( + handle, params, init_sample_view, centroids, workspace); + } + } else if (params.init == cuvs::cluster::kmeans::params::InitMethod::Random) { + raft::random::RngState random_state(params.rng_state.seed); + raft::matrix::sample_rows(handle, random_state, X, centroids); + } else if (params.init == cuvs::cluster::kmeans::params::InitMethod::Array) { + // already provided + } else { + RAFT_FAIL("Unknown initialization method"); + } +} + +/** + * @brief Accumulate partial centroid sums and counts from a batch + * + * This function adds the partial sums from a batch to the running accumulators. + * It does NOT divide - that happens once at the end of all batches. + */ +template +void accumulate_batch_centroids( + raft::resources const& handle, + raft::device_matrix_view batch_data, + raft::device_vector_view, IdxT> minClusterAndDistance, + raft::device_vector_view sample_weights, + raft::device_matrix_view centroid_sums, + raft::device_vector_view cluster_counts) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + auto n_features = batch_data.extent(1); + auto n_clusters = centroid_sums.extent(0); + + auto workspace = rmm::device_uvector( + batch_data.extent(0), stream, raft::resource::get_workspace_resource(handle)); + + auto batch_sums = raft::make_device_matrix(handle, n_clusters, n_features); + auto batch_counts = raft::make_device_vector(handle, n_clusters); + + raft::matrix::fill(handle, batch_sums.view(), MathT{0}); + raft::matrix::fill(handle, batch_counts.view(), MathT{0}); + + cuvs::cluster::kmeans::detail::KeyValueIndexOp conversion_op; + thrust::transform_iterator, + const raft::KeyValuePair*> + labels_itr(minClusterAndDistance.data_handle(), conversion_op); + + cuvs::cluster::kmeans::detail::compute_centroid_adjustments(handle, + batch_data, + sample_weights, + labels_itr, + static_cast(n_clusters), + batch_sums.view(), + batch_counts.view(), + workspace); + + raft::linalg::add(centroid_sums.data_handle(), + centroid_sums.data_handle(), + batch_sums.data_handle(), + centroid_sums.size(), + stream); + + raft::linalg::add(cluster_counts.data_handle(), + cluster_counts.data_handle(), + batch_counts.data_handle(), + cluster_counts.size(), + stream); +} + +/** + * @brief Main fit function for batched k-means with host data (full-batch / Lloyd's algorithm). + * + * Processes host data in GPU-sized batches per iteration, accumulating partial centroid + * sums and counts, then finalizes centroids at the end of each iteration. + * + * @tparam T Input data type (float, double) + * @tparam IdxT Index type (int, int64_t) + * + * @param[in] handle RAFT resources handle + * @param[in] params K-means parameters + * @param[in] X Input data on HOST [n_samples x n_features] + * @param[in] sample_weight Optional weights per sample (on host) + * @param[inout] centroids Initial/output cluster centers [n_clusters x n_features] + * @param[out] inertia Sum of squared distances to nearest centroid + * @param[out] n_iter Number of iterations run + */ +template +void fit(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + auto n_samples = X.extent(0); + auto n_features = X.extent(1); + auto n_clusters = params.n_clusters; + auto metric = params.metric; + + // Read batch_size from params; default to n_samples if 0 (auto) + IdxT batch_size = static_cast(params.batch_size); + if (batch_size <= 0) { batch_size = static_cast(n_samples); } + + RAFT_EXPECTS(batch_size > 0, "batch_size must be positive"); + + // Warn if user explicitly set batch_size larger than dataset size + if (params.batch_size > 0 && static_cast(params.batch_size) > n_samples) { + RAFT_LOG_WARN( + "batch_size (%zu) is larger than dataset size (%zu). " + "batch_size will be effectively clamped to %zu.", + static_cast(params.batch_size), + static_cast(n_samples), + static_cast(n_samples)); + } + + RAFT_EXPECTS(n_clusters > 0, "n_clusters must be positive"); + RAFT_EXPECTS(static_cast(centroids.extent(0)) == n_clusters, + "centroids.extent(0) must equal n_clusters"); + RAFT_EXPECTS(centroids.extent(1) == n_features, "centroids.extent(1) must equal n_features"); + + RAFT_LOG_DEBUG("KMeans batched fit: n_samples=%zu, n_features=%zu, n_clusters=%d, batch_size=%zu", + static_cast(n_samples), + static_cast(n_features), + n_clusters, + static_cast(batch_size)); + + rmm::device_uvector workspace(0, stream); + + auto n_init = params.n_init; + if (params.init == cuvs::cluster::kmeans::params::InitMethod::Array && n_init != 1) { + RAFT_LOG_DEBUG( + "Explicit initial center position passed: performing only one init in " + "k-means instead of n_init=%d", + n_init); + n_init = 1; + } + + auto best_centroids = n_init > 1 + ? raft::make_device_matrix(handle, n_clusters, n_features) + : raft::make_device_matrix(handle, 0, 0); + T best_inertia = std::numeric_limits::max(); + IdxT best_n_iter = 0; + + std::mt19937 gen(params.rng_state.seed); + + // ----- Allocate reusable work buffers (shared across n_init iterations) ----- + auto batch_data = raft::make_device_matrix(handle, batch_size, n_features); + auto batch_weights = raft::make_device_vector(handle, batch_size); + auto minClusterAndDistance = + raft::make_device_vector, IdxT>(handle, batch_size); + auto L2NormBatch = raft::make_device_vector(handle, batch_size); + rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream); + + auto centroid_sums = raft::make_device_matrix(handle, n_clusters, n_features); + auto cluster_counts = raft::make_device_vector(handle, n_clusters); + auto new_centroids = raft::make_device_matrix(handle, n_clusters, n_features); + + // ---- Main n_init loop ---- + for (int seed_iter = 0; seed_iter < n_init; ++seed_iter) { + cuvs::cluster::kmeans::params iter_params = params; + iter_params.rng_state.seed = gen(); + + RAFT_LOG_DEBUG("KMeans batched fit: n_init iteration %d/%d (seed=%llu)", + seed_iter + 1, + n_init, + (unsigned long long)iter_params.rng_state.seed); + + if (iter_params.init != cuvs::cluster::kmeans::params::InitMethod::Array) { + init_centroids_from_host_sample(handle, iter_params, batch_size, X, centroids, workspace); + } + + if (!sample_weight.has_value()) { raft::matrix::fill(handle, batch_weights.view(), T{1}); } + + // Reset per-iteration state + T prior_cluster_cost = 0; + + for (n_iter[0] = 1; n_iter[0] <= iter_params.max_iter; ++n_iter[0]) { + RAFT_LOG_DEBUG("KMeans batched: Iteration %d", n_iter[0]); + + raft::copy(new_centroids.data_handle(), centroids.data_handle(), centroids.size(), stream); + + raft::matrix::fill(handle, centroid_sums.view(), T{0}); + raft::matrix::fill(handle, cluster_counts.view(), T{0}); + auto clustering_cost = raft::make_device_scalar(handle, T{0}); + + auto centroids_const = raft::make_const_mdspan(centroids); + + using namespace cuvs::spatial::knn::detail::utils; + batch_load_iterator data_batches( + X.data_handle(), n_samples, n_features, batch_size, stream); + + for (const auto& data_batch : data_batches) { + IdxT current_batch_size = static_cast(data_batch.size()); + + auto batch_data_view = raft::make_device_matrix_view( + data_batch.data(), current_batch_size, n_features); + + if (sample_weight.has_value()) { + raft::copy(batch_weights.data_handle(), + sample_weight->data_handle() + data_batch.offset(), + current_batch_size, + stream); + } + + auto batch_weights_view = raft::make_device_vector_view( + batch_weights.data_handle(), current_batch_size); + + if (metric == cuvs::distance::DistanceType::L2Expanded || + metric == cuvs::distance::DistanceType::L2SqrtExpanded) { + raft::linalg::rowNorm( + L2NormBatch.data_handle(), data_batch.data(), n_features, current_batch_size, stream); + } + + auto L2NormBatch_const = raft::make_device_vector_view( + L2NormBatch.data_handle(), current_batch_size); + + cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute( + handle, + batch_data_view, + centroids_const, + minClusterAndDistance.view(), + L2NormBatch_const, + L2NormBuf_OR_DistBuf, + metric, + params.batch_samples, + params.batch_centroids, + workspace); + + auto minClusterAndDistance_const = raft::make_const_mdspan(minClusterAndDistance.view()); + + accumulate_batch_centroids(handle, + batch_data_view, + minClusterAndDistance_const, + batch_weights_view, + centroid_sums.view(), + cluster_counts.view()); + + if (params.inertia_check || n_iter[0] == iter_params.max_iter) { + // Compute cluster cost for this batch and accumulate + cuvs::cluster::kmeans::detail::computeClusterCost(handle, + minClusterAndDistance.view(), + workspace, + clustering_cost.view(), + raft::value_op{}, + raft::add_op{}); + } + } + + auto centroid_sums_const = raft::make_device_matrix_view( + centroid_sums.data_handle(), n_clusters, n_features); + auto cluster_counts_const = + raft::make_device_vector_view(cluster_counts.data_handle(), n_clusters); + + finalize_centroids( + handle, centroid_sums_const, cluster_counts_const, centroids_const, new_centroids.view()); + + // Convergence check + T sqrdNormError = + compute_centroid_shift(handle, + raft::make_device_matrix_view( + centroids.data_handle(), n_clusters, n_features), + raft::make_device_matrix_view( + new_centroids.data_handle(), n_clusters, n_features)); + + raft::copy(centroids.data_handle(), new_centroids.data_handle(), centroids.size(), stream); + + bool done = false; + if (params.inertia_check) { + raft::copy(inertia.data_handle(), clustering_cost.data_handle(), 1, stream); + raft::resource::sync_stream(handle); + if (n_iter[0] > 1) { + T delta = inertia[0] / prior_cluster_cost; + if (delta > 1 - params.tol) done = true; + } + prior_cluster_cost = inertia[0]; + } + + if (sqrdNormError < params.tol) done = true; + + if (done || n_iter[0] == iter_params.max_iter) { + RAFT_LOG_DEBUG("KMeans batched: Converged after %d iterations", n_iter[0]); + // Inertia for the last iteration is always computed + if (!params.inertia_check) { + raft::copy(inertia.data_handle(), clustering_cost.data_handle(), 1, stream); + raft::resource::sync_stream(handle); + } + break; + } + } + + { + RAFT_LOG_DEBUG("KMeans batched: n_init %d/%d completed with inertia=%f", + seed_iter + 1, + n_init, + static_cast(inertia[0])); + + if (n_init > 1 && inertia[0] < best_inertia) { + best_inertia = inertia[0]; + best_n_iter = n_iter[0]; + raft::copy(best_centroids.data_handle(), centroids.data_handle(), centroids.size(), stream); + } + } + } + if (n_init > 1) { + raft::copy(centroids.data_handle(), best_centroids.data_handle(), centroids.size(), stream); + inertia[0] = best_inertia; + n_iter[0] = best_n_iter; + RAFT_LOG_DEBUG("KMeans batched: Best of %d runs: inertia=%f, n_iter=%d", + n_init, + static_cast(best_inertia), + best_n_iter); + } +} + +/** + * @brief Predict cluster labels for host data using batched processing. + * + * @tparam T Input data type (float, double) + * @tparam IdxT Index type (int, int64_t) + */ +template +void predict(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_vector_view labels, + bool normalize_weight, + raft::host_scalar_view inertia) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + auto n_samples = X.extent(0); + auto n_features = X.extent(1); + auto n_clusters = params.n_clusters; + + // Read batch_size from params; default to n_samples if 0 (auto) + IdxT batch_size = static_cast(params.batch_size); + if (batch_size <= 0) { batch_size = static_cast(n_samples); } + + RAFT_EXPECTS(batch_size > 0, "batch_size must be positive"); + + // Warn if user explicitly set batch_size larger than dataset size + if (params.batch_size > 0 && static_cast(params.batch_size) > n_samples) { + RAFT_LOG_WARN( + "batch_size (%zu) is larger than dataset size (%zu). " + "batch_size will be effectively clamped to %zu.", + static_cast(params.batch_size), + static_cast(n_samples), + static_cast(n_samples)); + } + + RAFT_EXPECTS(n_clusters > 0, "n_clusters must be positive"); + RAFT_EXPECTS(centroids.extent(0) == static_cast(n_clusters), + "centroids.extent(0) must equal n_clusters"); + RAFT_EXPECTS(centroids.extent(1) == n_features, "centroids.extent(1) must equal n_features"); + RAFT_EXPECTS(labels.extent(0) == n_samples, "labels.extent(0) must equal n_samples"); + + auto batch_data = raft::make_device_matrix(handle, batch_size, n_features); + auto batch_weights = raft::make_device_vector(handle, batch_size); + auto batch_labels = raft::make_device_vector(handle, batch_size); + + inertia[0] = 0; + + for (IdxT batch_idx = 0; batch_idx < n_samples; batch_idx += batch_size) { + IdxT current_batch_size = std::min(batch_size, n_samples - batch_idx); + + raft::copy(batch_data.data_handle(), + X.data_handle() + batch_idx * n_features, + current_batch_size * n_features, + stream); + + if (sample_weight.has_value()) { + raft::copy(batch_weights.data_handle(), + sample_weight->data_handle() + batch_idx, + current_batch_size, + stream); + } + + auto batch_data_view = raft::make_device_matrix_view( + batch_data.data_handle(), current_batch_size, n_features); + + std::optional> batch_weights_view = std::nullopt; + if (sample_weight.has_value()) { + batch_weights_view = std::make_optional(raft::make_device_vector_view( + batch_weights.data_handle(), current_batch_size)); + } + + auto batch_labels_view = + raft::make_device_vector_view(batch_labels.data_handle(), current_batch_size); + + T batch_inertia = 0; + cuvs::cluster::kmeans::detail::kmeans_predict( + handle, + params, + batch_data_view, + batch_weights_view, + centroids, + batch_labels_view, + normalize_weight, + raft::make_host_scalar_view(&batch_inertia)); + + raft::copy( + labels.data_handle() + batch_idx, batch_labels.data_handle(), current_batch_size, stream); + + inertia[0] += batch_inertia; + } + + raft::resource::sync_stream(handle, stream); +} + +/** + * @brief Fit k-means and predict cluster labels using batched processing. + */ +template +void fit_predict(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_vector_view labels, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) +{ + T fit_inertia = 0; + fit( + handle, params, X, sample_weight, centroids, raft::make_host_scalar_view(&fit_inertia), n_iter); + + auto centroids_const = raft::make_const_mdspan(centroids); + + predict(handle, params, X, sample_weight, centroids_const, labels, false, inertia); +} + +} // namespace cuvs::cluster::kmeans::detail diff --git a/cpp/src/cluster/detail/kmeans_common.cuh b/cpp/src/cluster/detail/kmeans_common.cuh index 4e2a41b26a..a5160be677 100644 --- a/cpp/src/cluster/detail/kmeans_common.cuh +++ b/cpp/src/cluster/detail/kmeans_common.cuh @@ -22,7 +22,10 @@ #include #include #include +#include +#include #include +#include #include #include #include @@ -469,4 +472,130 @@ void countSamplesInCluster(raft::resources const& handle, (IndexT)n_clusters, workspace); } + +/** + * @brief Compute centroid adjustments (weighted sums and counts per cluster) + * + * This helper function computes: + * 1. Weighted sum of samples per cluster using reduce_rows_by_key + * 2. Sum of weights per cluster using reduce_cols_by_key + * + * @tparam DataT Data type for samples and weights + * @tparam IndexT Index type + * @tparam LabelsIterator Iterator type for cluster labels + * + * @param[in] handle RAFT resources handle + * @param[in] X Input samples [n_samples x n_features] + * @param[in] sample_weights Weights for each sample [n_samples] + * @param[in] cluster_labels Cluster assignment for each sample (iterator) + * @param[in] n_clusters Number of clusters + * @param[out] centroid_sums Output weighted sum per cluster [n_clusters x n_features] + * @param[out] weight_per_cluster Output sum of weights per cluster [n_clusters] + * @param[inout] workspace Workspace buffer for intermediate operations + */ +template +void compute_centroid_adjustments( + raft::resources const& handle, + raft::device_matrix_view X, + raft::device_vector_view sample_weights, + LabelsIterator cluster_labels, + IndexT n_clusters, + raft::device_matrix_view centroid_sums, + raft::device_vector_view weight_per_cluster, + rmm::device_uvector& workspace) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + auto n_samples = X.extent(0); + + workspace.resize(n_samples, stream); + + raft::linalg::reduce_rows_by_key(const_cast(X.data_handle()), + X.extent(1), + cluster_labels, + sample_weights.data_handle(), + workspace.data(), + X.extent(0), + X.extent(1), + n_clusters, + centroid_sums.data_handle(), + stream); + + raft::linalg::reduce_cols_by_key(sample_weights.data_handle(), + cluster_labels, + weight_per_cluster.data_handle(), + static_cast(1), + static_cast(n_samples), + n_clusters, + stream); +} +/** + * @brief Finalize centroids by dividing accumulated sums by counts. + * + * For clusters with zero count, the old centroid is preserved. + * + * @tparam DataT Data type + * @tparam IndexT Index type + * + * @param[in] handle RAFT resources handle + * @param[in] centroid_sums Accumulated weighted sums per cluster [n_clusters x n_features] + * @param[in] cluster_counts Sum of weights per cluster [n_clusters] + * @param[in] old_centroids Previous centroids (used for empty clusters) [n_clusters x + * n_features] + * @param[out] new_centroids Output centroids [n_clusters x n_features] + */ +template +void finalize_centroids(raft::resources const& handle, + raft::device_matrix_view centroid_sums, + raft::device_vector_view cluster_counts, + raft::device_matrix_view old_centroids, + raft::device_matrix_view new_centroids) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + + raft::linalg::matrix_vector_op(handle, + raft::make_const_mdspan(centroid_sums), + cluster_counts, + new_centroids, + raft::div_checkzero_op{}); + + // For empty clusters (count == 0), copy old centroid back + cub::ArgIndexInputIterator itr_wt(cluster_counts.data_handle()); + raft::matrix::gather_if( + old_centroids.data_handle(), + static_cast(old_centroids.extent(1)), + static_cast(old_centroids.extent(0)), + itr_wt, + itr_wt, + static_cast(cluster_counts.size()), + new_centroids.data_handle(), + [=] __device__(raft::KeyValuePair map) { return map.value == DataT{0}; }, + raft::key_op{}, + stream); +} + +/** + * @brief Compute the squared norm difference between two centroid sets. + * + * Returns sum((old_centroids - new_centroids)^2). + * Used for convergence checking. + */ +template +DataT compute_centroid_shift(raft::resources const& handle, + raft::device_matrix_view old_centroids, + raft::device_matrix_view new_centroids) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + auto sqrdNorm = raft::make_device_scalar(handle, DataT{0}); + raft::linalg::mapThenSumReduce(sqrdNorm.data_handle(), + old_centroids.size(), + raft::sqdiff_op{}, + stream, + old_centroids.data_handle(), + new_centroids.data_handle()); + DataT result = 0; + raft::copy(&result, sqrdNorm.data_handle(), 1, stream); + raft::resource::sync_stream(handle, stream); + return result; +} + } // namespace cuvs::cluster::kmeans::detail diff --git a/cpp/src/cluster/kmeans_fit_double.cu b/cpp/src/cluster/kmeans_fit_double.cu index 43f457a29a..91830716d5 100644 --- a/cpp/src/cluster/kmeans_fit_double.cu +++ b/cpp/src/cluster/kmeans_fit_double.cu @@ -1,8 +1,9 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ +#include "detail/kmeans_batched.cuh" #include "kmeans.cuh" #include "kmeans_impl.cuh" #include @@ -39,6 +40,30 @@ INSTANTIATE_FIT(double, int64_t) #undef INSTANTIATE_FIT_MAIN #undef INSTANTIATE_FIT +void fit(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) +{ + cuvs::cluster::kmeans::detail::fit( + handle, params, X, sample_weight, centroids, inertia, n_iter); +} + +void fit(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) +{ + cuvs::cluster::kmeans::detail::fit( + handle, params, X, sample_weight, centroids, inertia, n_iter); +} + void fit(raft::resources const& handle, const cuvs::cluster::kmeans::params& params, raft::device_matrix_view X, @@ -62,4 +87,31 @@ void fit(raft::resources const& handle, cuvs::cluster::kmeans::fit( handle, params, X, sample_weight, centroids, inertia, n_iter); } + +void predict(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_vector_view labels, + bool normalize_weight, + raft::host_scalar_view inertia) +{ + cuvs::cluster::kmeans::detail::predict( + handle, params, X, sample_weight, centroids, labels, normalize_weight, inertia); +} + +void fit_predict(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_vector_view labels, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) +{ + cuvs::cluster::kmeans::detail::fit_predict( + handle, params, X, sample_weight, centroids, labels, inertia, n_iter); +} + } // namespace cuvs::cluster::kmeans diff --git a/cpp/src/cluster/kmeans_fit_float.cu b/cpp/src/cluster/kmeans_fit_float.cu index 5624151943..732a34c214 100644 --- a/cpp/src/cluster/kmeans_fit_float.cu +++ b/cpp/src/cluster/kmeans_fit_float.cu @@ -1,8 +1,9 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ +#include "detail/kmeans_batched.cuh" #include "kmeans.cuh" #include "kmeans_impl.cuh" #include @@ -39,6 +40,30 @@ INSTANTIATE_FIT(float, int64_t) #undef INSTANTIATE_FIT_MAIN #undef INSTANTIATE_FIT +void fit(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) +{ + cuvs::cluster::kmeans::detail::fit( + handle, params, X, sample_weight, centroids, inertia, n_iter); +} + +void fit(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) +{ + cuvs::cluster::kmeans::detail::fit( + handle, params, X, sample_weight, centroids, inertia, n_iter); +} + void fit(raft::resources const& handle, const cuvs::cluster::kmeans::params& params, raft::device_matrix_view X, @@ -62,4 +87,31 @@ void fit(raft::resources const& handle, cuvs::cluster::kmeans::fit( handle, params, X, sample_weight, centroids, inertia, n_iter); } + +void predict(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_vector_view labels, + bool normalize_weight, + raft::host_scalar_view inertia) +{ + cuvs::cluster::kmeans::detail::predict( + handle, params, X, sample_weight, centroids, labels, normalize_weight, inertia); +} + +void fit_predict(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_vector_view labels, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) +{ + cuvs::cluster::kmeans::detail::fit_predict( + handle, params, X, sample_weight, centroids, labels, inertia, n_iter); +} + } // namespace cuvs::cluster::kmeans diff --git a/cpp/tests/cluster/kmeans.cu b/cpp/tests/cluster/kmeans.cu index 576e6c1a48..4e79d81a0a 100644 --- a/cpp/tests/cluster/kmeans.cu +++ b/cpp/tests/cluster/kmeans.cu @@ -346,4 +346,377 @@ TEST_P(KmeansTestF, Result) { ASSERT_TRUE(score == 1.0); } INSTANTIATE_TEST_CASE_P(KmeansTests, KmeansTestF, ::testing::ValuesIn(inputsf2)); +// ============================================================================ +// Batched KMeans Tests (fit + predict with host data) +// ============================================================================ + +template +struct KmeansBatchedInputs { + int n_row; + int n_col; + int n_clusters; + T tol; + bool weighted; +}; + +template +class KmeansFitBatchedTest : public ::testing::TestWithParam> { + protected: + KmeansFitBatchedTest() + : d_labels(0, raft::resource::get_cuda_stream(handle)), + d_labels_ref(0, raft::resource::get_cuda_stream(handle)), + d_centroids(0, raft::resource::get_cuda_stream(handle)), + d_centroids_ref(0, raft::resource::get_cuda_stream(handle)) + { + } + + void fitBatchedTest() + { + testparams = ::testing::TestWithParam>::GetParam(); + + int n_samples = testparams.n_row; + int n_features = testparams.n_col; + params.n_clusters = testparams.n_clusters; + params.tol = testparams.tol; + params.n_init = 5; + params.rng_state.seed = 1; + params.oversampling_factor = 0; + + auto stream = raft::resource::get_cuda_stream(handle); + auto X = raft::make_device_matrix(handle, n_samples, n_features); + auto labels = raft::make_device_vector(handle, n_samples); + + raft::random::make_blobs(X.data_handle(), + labels.data_handle(), + n_samples, + n_features, + params.n_clusters, + stream, + true, + nullptr, + nullptr, + T(1.0), + false, + (T)-10.0f, + (T)10.0f, + (uint64_t)1234); + + // Copy X to host for batched API + std::vector h_X(n_samples * n_features); + raft::update_host(h_X.data(), X.data_handle(), n_samples * n_features, stream); + raft::resource::sync_stream(handle, stream); + + d_labels.resize(n_samples, stream); + d_labels_ref.resize(n_samples, stream); + d_centroids.resize(params.n_clusters * n_features, stream); + d_centroids_ref.resize(params.n_clusters * n_features, stream); + raft::copy(d_labels_ref.data(), labels.data_handle(), n_samples, stream); + + auto h_X_view = raft::make_host_matrix_view(h_X.data(), n_samples, n_features); + auto d_centroids_view = + raft::make_device_matrix_view(d_centroids.data(), params.n_clusters, n_features); + + // Run device fit to get reference centroids + std::optional> d_sw = std::nullopt; + rmm::device_uvector d_sample_weight(0, stream); + if (testparams.weighted) { + d_sample_weight.resize(n_samples, stream); + d_sw = std::make_optional( + raft::make_device_vector_view(d_sample_weight.data(), n_samples)); + thrust::fill(thrust::cuda::par.on(stream), + d_sample_weight.data(), + d_sample_weight.data() + n_samples, + T(1)); + } + + auto d_centroids_ref_view = + raft::make_device_matrix_view(d_centroids_ref.data(), params.n_clusters, n_features); + T ref_inertia = 0; + int ref_n_iter = 0; + cuvs::cluster::kmeans::fit(handle, + params, + raft::make_const_mdspan(X.view()), + d_sw, + d_centroids_ref_view, + raft::make_host_scalar_view(&ref_inertia), + raft::make_host_scalar_view(&ref_n_iter)); + + raft::copy(d_centroids.data(), d_centroids_ref.data(), params.n_clusters * n_features, stream); + + cuvs::cluster::kmeans::params batched_params = params; + batched_params.init = cuvs::cluster::kmeans::params::Array; + batched_params.n_init = 1; + batched_params.batch_size = std::min(n_samples, 256); + + std::optional> h_sw = std::nullopt; + std::vector h_sample_weight; + if (testparams.weighted) { + h_sample_weight.resize(n_samples, T(1)); + h_sw = std::make_optional( + raft::make_host_vector_view(h_sample_weight.data(), n_samples)); + } + + T inertia = 0; + int n_iter = 0; + + cuvs::cluster::kmeans::fit(handle, + batched_params, + h_X_view, + h_sw, + d_centroids_view, + raft::make_host_scalar_view(&inertia), + raft::make_host_scalar_view(&n_iter)); + + raft::resource::sync_stream(handle, stream); + + // FullBatch: centroids should match the device fit reference + centroids_match = devArrMatch(d_centroids_ref.data(), + d_centroids.data(), + params.n_clusters, + n_features, + CompareApprox(T(1e-3)), + stream); + + // Also check label quality via ARI + T pred_inertia = 0; + cuvs::cluster::kmeans::predict( + handle, + params, + raft::make_const_mdspan(X.view()), + std::optional>(std::nullopt), + raft::make_device_matrix_view( + d_centroids.data(), params.n_clusters, n_features), + raft::make_device_vector_view(d_labels.data(), n_samples), + true, + raft::make_host_scalar_view(&pred_inertia)); + + raft::resource::sync_stream(handle, stream); + + score = raft::stats::adjusted_rand_index( + d_labels_ref.data(), d_labels.data(), n_samples, raft::resource::get_cuda_stream(handle)); + + if (score < 1.0) { + std::stringstream ss; + ss << "Expected: " << raft::arr2Str(d_labels_ref.data(), 25, "d_labels_ref", stream); + std::cout << (ss.str().c_str()) << '\n'; + ss.str(std::string()); + ss << "Actual: " << raft::arr2Str(d_labels.data(), 25, "d_labels", stream); + std::cout << (ss.str().c_str()) << '\n'; + std::cout << "Score = " << score << '\n'; + } + } + + void SetUp() override { fitBatchedTest(); } + + protected: + raft::resources handle; + KmeansBatchedInputs testparams; + rmm::device_uvector d_labels; + rmm::device_uvector d_labels_ref; + rmm::device_uvector d_centroids; + rmm::device_uvector d_centroids_ref; + double score; + testing::AssertionResult centroids_match = testing::AssertionSuccess(); + cuvs::cluster::kmeans::params params; +}; + +template +class KmeansPredictBatchedTest : public ::testing::TestWithParam> { + protected: + KmeansPredictBatchedTest() + : d_labels(0, raft::resource::get_cuda_stream(handle)), + d_labels_ref(0, raft::resource::get_cuda_stream(handle)), + d_centroids(0, raft::resource::get_cuda_stream(handle)), + d_sample_weight(0, raft::resource::get_cuda_stream(handle)) + { + } + + void predictBatchedTest() + { + testparams = ::testing::TestWithParam>::GetParam(); + + int n_samples = testparams.n_row; + int n_features = testparams.n_col; + params.n_clusters = testparams.n_clusters; + params.tol = testparams.tol; + params.n_init = 5; + params.rng_state.seed = 1; + params.oversampling_factor = 0; + + auto stream = raft::resource::get_cuda_stream(handle); + auto X = raft::make_device_matrix(handle, n_samples, n_features); + auto labels = raft::make_device_vector(handle, n_samples); + + raft::random::make_blobs(X.data_handle(), + labels.data_handle(), + n_samples, + n_features, + params.n_clusters, + stream, + true, + nullptr, + nullptr, + T(1.0), + false, + (T)-10.0f, + (T)10.0f, + (uint64_t)1234); + + d_labels.resize(n_samples, stream); + d_labels_ref.resize(n_samples, stream); + d_centroids.resize(params.n_clusters * n_features, stream); + + // Fit on device to get centroids + auto d_centroids_view = + raft::make_device_matrix_view(d_centroids.data(), params.n_clusters, n_features); + + std::optional> d_sw = std::nullopt; + if (testparams.weighted) { + d_sample_weight.resize(n_samples, stream); + d_sw = std::make_optional( + raft::make_device_vector_view(d_sample_weight.data(), n_samples)); + thrust::fill(thrust::cuda::par.on(stream), + d_sample_weight.data(), + d_sample_weight.data() + n_samples, + T(1)); + } + + T fit_inertia = 0; + int fit_n_iter = 0; + cuvs::cluster::kmeans::fit(handle, + params, + raft::make_const_mdspan(X.view()), + d_sw, + d_centroids_view, + raft::make_host_scalar_view(&fit_inertia), + raft::make_host_scalar_view(&fit_n_iter)); + + T ref_inertia = 0; + cuvs::cluster::kmeans::predict( + handle, + params, + raft::make_const_mdspan(X.view()), + std::optional>(std::nullopt), + raft::make_device_matrix_view( + d_centroids.data(), params.n_clusters, n_features), + raft::make_device_vector_view(d_labels_ref.data(), n_samples), + true, + raft::make_host_scalar_view(&ref_inertia)); + + std::vector h_X(n_samples * n_features); + raft::update_host(h_X.data(), X.data_handle(), n_samples * n_features, stream); + raft::resource::sync_stream(handle, stream); + + auto h_X_view = raft::make_host_matrix_view( + h_X.data(), (int64_t)n_samples, (int64_t)n_features); + auto centroids_const_view = raft::make_device_matrix_view( + d_centroids.data(), (int64_t)params.n_clusters, (int64_t)n_features); + + std::vector h_labels(n_samples); + auto h_labels_view = + raft::make_host_vector_view(h_labels.data(), (int64_t)n_samples); + + T pred_inertia = 0; + params.batch_size = std::min((int64_t)n_samples, (int64_t)256); + + cuvs::cluster::kmeans::predict( + handle, + params, + h_X_view, + std::optional>(std::nullopt), + centroids_const_view, + h_labels_view, + true, + raft::make_host_scalar_view(&pred_inertia)); + + raft::resource::sync_stream(handle, stream); + + std::vector h_labels_int(n_samples); + for (int i = 0; i < n_samples; ++i) { + h_labels_int[i] = static_cast(h_labels[i]); + } + raft::update_device(d_labels.data(), h_labels_int.data(), n_samples, stream); + + // Compare labels directly: batched predict should produce exact same labels + // as device predict given the same centroids + labels_match = + devArrMatch(d_labels_ref.data(), d_labels.data(), n_samples, Compare(), stream); + } + + void SetUp() override { predictBatchedTest(); } + + protected: + raft::resources handle; + KmeansInputs testparams; + rmm::device_uvector d_labels; + rmm::device_uvector d_labels_ref; + rmm::device_uvector d_centroids; + rmm::device_uvector d_sample_weight; + testing::AssertionResult labels_match = testing::AssertionSuccess(); + cuvs::cluster::kmeans::params params; +}; + +// ============================================================================ +// Test inputs for batched tests +// ============================================================================ + +const std::vector> batched_inputsf2 = { + {1000, 32, 5, 0.0001f, true}, + {1000, 32, 5, 0.0001f, false}, + {1000, 100, 20, 0.0001f, true}, + {1000, 100, 20, 0.0001f, false}, + {10000, 32, 10, 0.0001f, true}, + {10000, 32, 10, 0.0001f, false}, +}; + +const std::vector> batched_inputsd2 = { + {1000, 32, 5, 0.0001, true}, + {1000, 32, 5, 0.0001, false}, + {1000, 100, 20, 0.0001, true}, + {1000, 100, 20, 0.0001, false}, + {10000, 32, 10, 0.0001, true}, + {10000, 32, 10, 0.0001, false}, +}; + +// ============================================================================ +// fit (host/batched) tests +// ============================================================================ +typedef KmeansFitBatchedTest KmeansFitBatchedTestF; +typedef KmeansFitBatchedTest KmeansFitBatchedTestD; + +TEST_P(KmeansFitBatchedTestF, Result) +{ + ASSERT_TRUE(centroids_match); + ASSERT_TRUE(score == 1.0); +} + +TEST_P(KmeansFitBatchedTestD, Result) +{ + ASSERT_TRUE(centroids_match); + ASSERT_TRUE(score == 1.0); +} + +INSTANTIATE_TEST_CASE_P(KmeansFitBatchedTests, + KmeansFitBatchedTestF, + ::testing::ValuesIn(batched_inputsf2)); +INSTANTIATE_TEST_CASE_P(KmeansFitBatchedTests, + KmeansFitBatchedTestD, + ::testing::ValuesIn(batched_inputsd2)); + +// ============================================================================ +// predict (host/batched) tests +// ============================================================================ +typedef KmeansPredictBatchedTest KmeansPredictBatchedTestF; +typedef KmeansPredictBatchedTest KmeansPredictBatchedTestD; + +TEST_P(KmeansPredictBatchedTestF, Result) { ASSERT_TRUE(labels_match); } +TEST_P(KmeansPredictBatchedTestD, Result) { ASSERT_TRUE(labels_match); } + +INSTANTIATE_TEST_CASE_P(KmeansPredictBatchedTests, + KmeansPredictBatchedTestF, + ::testing::ValuesIn(inputsf2)); +INSTANTIATE_TEST_CASE_P(KmeansPredictBatchedTests, + KmeansPredictBatchedTestD, + ::testing::ValuesIn(inputsd2)); + } // namespace cuvs diff --git a/python/cuvs/cuvs/cluster/kmeans/kmeans.pxd b/python/cuvs/cuvs/cluster/kmeans/kmeans.pxd index 9f16d46c4d..3b90526417 100644 --- a/python/cuvs/cuvs/cluster/kmeans/kmeans.pxd +++ b/python/cuvs/cuvs/cluster/kmeans/kmeans.pxd @@ -4,7 +4,7 @@ # # cython: language_level=3 -from libc.stdint cimport uintptr_t +from libc.stdint cimport int64_t, uintptr_t from libcpp cimport bool from cuvs.common.c_api cimport cuvsError_t, cuvsResources_t @@ -33,6 +33,7 @@ cdef extern from "cuvs/cluster/kmeans.h" nogil: int batch_samples, int batch_centroids, bool inertia_check, + int64_t batch_size, bool hierarchical, int hierarchical_n_iters diff --git a/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx b/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx index 489d983ac7..349656d6ca 100644 --- a/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx +++ b/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION. +# SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. # SPDX-License-Identifier: Apache-2.0 # # cython: language_level=3 @@ -34,6 +34,7 @@ from libc.stdint cimport ( uint64_t, uintptr_t, ) +from libc.stdlib cimport free, malloc from cuvs.common.exceptions import check_cuvs @@ -70,6 +71,23 @@ cdef class KMeansParams: Number of instance k-means algorithm will be run with different seeds oversampling_factor : double Oversampling factor for use in the k-means|| algorithm + batch_samples : int + Number of samples to process in each batch for tiled 1NN computation. + Useful to optimize/control memory footprint. Default tile is + [batch_samples x n_clusters]. + batch_centroids : int + Number of centroids to process in each batch. If 0, uses n_clusters. + inertia_check : bool + If True, check inertia during iterations for early convergence. + batch_size : int + Number of samples to process per GPU batch when fitting with host + (numpy) data. When set to 0, defaults to n_samples (process all + at once). Only used by the batched (host-data) code path. Reducing + batch_size can help reduce GPU memory pressure but increases + overhead as the number of times centroid adjustments are computed + increases. + + Default: 0 (process all data at once). hierarchical : bool Whether to use hierarchical (balanced) kmeans or not hierarchical_n_iters : int @@ -92,6 +110,10 @@ cdef class KMeansParams: tol=None, n_init=None, oversampling_factor=None, + batch_samples=None, + batch_centroids=None, + inertia_check=None, + batch_size=None, hierarchical=None, hierarchical_n_iters=None): if metric is not None: @@ -109,6 +131,14 @@ cdef class KMeansParams: self.params.n_init = n_init if oversampling_factor is not None: self.params.oversampling_factor = oversampling_factor + if batch_samples is not None: + self.params.batch_samples = batch_samples + if batch_centroids is not None: + self.params.batch_centroids = batch_centroids + if inertia_check is not None: + self.params.inertia_check = inertia_check + if batch_size is not None: + self.params.batch_size = batch_size if hierarchical is not None: self.params.hierarchical = hierarchical if hierarchical_n_iters is not None: @@ -145,6 +175,22 @@ cdef class KMeansParams: def oversampling_factor(self): return self.params.oversampling_factor + @property + def batch_samples(self): + return self.params.batch_samples + + @property + def batch_centroids(self): + return self.params.batch_centroids + + @property + def inertia_check(self): + return self.params.inertia_check + + @property + def batch_size(self): + return self.params.batch_size + @property def hierarchical(self): return self.params.hierarchical @@ -165,16 +211,27 @@ def fit( """ Find clusters with the k-means algorithm + When X is a device array (CUDA array interface), standard on-device + k-means is used. When X is a host array (numpy ndarray or + ``__array_interface__``), data is streamed to the GPU in batches + controlled by ``params.batch_size``. For large host datasets, consider + reducing ``batch_size`` to reduce GPU memory usage. + Parameters ---------- params : KMeansParams - Parameters to use to fit KMeans model - X : Input CUDA array interface compliant matrix shape (m, k) + Parameters to use to fit KMeans model. For host data, + ``params.batch_size`` controls how many samples are sent to the + GPU per batch. + X : array-like + Training instances, shape (m, k). Accepts both device arrays + (cupy / CUDA array interface) and host arrays (numpy). centroids : Optional writable CUDA array interface compliant matrix shape (n_clusters, k) - sample_weights : Optional input CUDA array interface compliant matrix shape - (n_clusters, 1) default: None + sample_weights : Optional weights per observation. Must reside on + the same memory space as X (device or host). + default: None {resources_docstring} Returns @@ -202,10 +259,51 @@ def fit( >>> params = KMeansParams(n_clusters=n_clusters) >>> centroids, inertia, n_iter = fit(params, X) + + Host-data (batched) example: + + >>> import numpy as np + >>> X_host = np.random.random((10_000_000, 128)).astype(np.float32) + >>> params = KMeansParams(n_clusters=1000, batch_size=1_000_000) + >>> centroids, inertia, n_iter = fit(params, X_host) """ + # ---- detect host vs device data for data-preparation ---- + is_host = isinstance(X, np.ndarray) or ( + hasattr(X, '__array_interface__') and + not hasattr(X, '__cuda_array_interface__') + ) + + # Check that sample_weights has the same residency as X + if sample_weights is not None: + is_sample_weight_host = isinstance(sample_weights, np.ndarray) or ( + hasattr(sample_weights, '__array_interface__') and + not hasattr(sample_weights, '__cuda_array_interface__') + ) + if is_host != is_sample_weight_host: + raise ValueError( + "X and sample_weights must have the same memory residency " + "(both host or both device). X is {}, sample_weights is {}.".format( + "host" if is_host else "device", + "host" if is_sample_weight_host else "device" + ) + ) + + if is_host: + if not isinstance(X, np.ndarray): + X = np.asarray(X) + if not X.flags['C_CONTIGUOUS']: + X = np.ascontiguousarray(X) + if sample_weights is not None: + if not isinstance(sample_weights, np.ndarray): + sample_weights = np.asarray(sample_weights) + if not sample_weights.flags['C_CONTIGUOUS']: + sample_weights = np.ascontiguousarray(sample_weights) + x_ai = wrap_array(X) - _check_input_array(x_ai, [np.dtype('float32'), np.dtype('float64')]) + _check_input_array( + x_ai, [np.dtype('float32'), np.dtype('float64')] + ) cdef cydlpack.DLManagedTensor* x_dlpack = cydlpack.dlpack_c(x_ai) cdef cydlpack.DLManagedTensor* sample_weight_dlpack = NULL @@ -216,15 +314,17 @@ def fit( cdef int n_iter = 0 if centroids is None: - centroids = device_ndarray.empty((params.n_clusters, x_ai.shape[1]), - dtype=x_ai.dtype) + centroids = device_ndarray.empty( + (params.n_clusters, x_ai.shape[1]), dtype=x_ai.dtype + ) centroids_ai = wrap_array(centroids) - cdef cydlpack.DLManagedTensor * centroids_dlpack = \ + cdef cydlpack.DLManagedTensor* centroids_dlpack = \ cydlpack.dlpack_c(centroids_ai) if sample_weights is not None: - sample_weight_dlpack = cydlpack.dlpack_c(wrap_array(sample_weights)) + sample_weight_dlpack = \ + cydlpack.dlpack_c(wrap_array(sample_weights)) with cuda_interruptible(): check_cuvs(cuvsKMeansFit( diff --git a/python/cuvs/cuvs/tests/test_kmeans.py b/python/cuvs/cuvs/tests/test_kmeans.py index 6f18137b13..7df9dc86d7 100644 --- a/python/cuvs/cuvs/tests/test_kmeans.py +++ b/python/cuvs/cuvs/tests/test_kmeans.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. # SPDX-License-Identifier: Apache-2.0 # @@ -6,7 +6,12 @@ import pytest from pylibraft.common import device_ndarray -from cuvs.cluster.kmeans import KMeansParams, cluster_cost, fit, predict +from cuvs.cluster.kmeans import ( + KMeansParams, + cluster_cost, + fit, + predict, +) from cuvs.distance import pairwise_distance @@ -69,3 +74,56 @@ def test_cluster_cost(n_rows, n_cols, n_clusters, dtype): # need reduced tolerance for float32 tol = 1e-3 if dtype == np.float32 else 1e-6 assert np.allclose(inertia, sum(cluster_distances), rtol=tol, atol=tol) + + +@pytest.mark.parametrize("n_rows", [1000, 5000]) +@pytest.mark.parametrize("n_cols", [10, 100]) +@pytest.mark.parametrize("n_clusters", [8, 16]) +@pytest.mark.parametrize("batch_size", [0, 100, 500]) +@pytest.mark.parametrize("dtype", [np.float64]) +def test_fit_host_matches_fit_device( + n_rows, n_cols, n_clusters, batch_size, dtype +): + """ + Test that fit() with host (numpy) data produces the same centroids as + fit() with device data, when given the same initial centroids. + """ + rng = np.random.default_rng(99) + X_host = rng.random((n_rows, n_cols)).astype(dtype) + + norms = np.linalg.norm(X_host, ord=1, axis=1, keepdims=True) + norms = np.where(norms == 0, 1.0, norms) + X_host = X_host / norms + + initial_centroids_host = X_host[:n_clusters].copy() + + params_device = KMeansParams( + n_clusters=n_clusters, + init_method="Array", + max_iter=100, + tol=1e-10, + ) + centroids_regular, _, _ = fit( + params_device, + device_ndarray(X_host), + device_ndarray(initial_centroids_host.copy()), + ) + centroids_regular = centroids_regular.copy_to_host() + + params_host = KMeansParams( + n_clusters=n_clusters, + init_method="Array", + max_iter=100, + tol=1e-10, + batch_size=batch_size, + ) + centroids_batched, _, _ = fit( + params_host, + X_host, + centroids=device_ndarray(initial_centroids_host.copy()), + ) + centroids_batched = centroids_batched.copy_to_host() + + assert np.allclose( + centroids_regular, centroids_batched, rtol=1e-3, atol=1e-3 + ), f"max diff: {np.max(np.abs(centroids_regular - centroids_batched))}"